Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
41199996
Commit
41199996
authored
Dec 13, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.12.0' into v0.12.0-dev
parents
31021d81
4fd9d6a8
Changes
380
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1846 additions
and
608 deletions
+1846
-608
csrc/cpu/shm.cpp
csrc/cpu/shm.cpp
+1
-1
csrc/cpu/torch_bindings.cpp
csrc/cpu/torch_bindings.cpp
+76
-65
csrc/cpu/utils.cpp
csrc/cpu/utils.cpp
+47
-13
csrc/cpu/utils.hpp
csrc/cpu/utils.hpp
+73
-0
csrc/cub_helpers.h
csrc/cub_helpers.h
+3
-2
csrc/cuda_view.cu
csrc/cuda_view.cu
+3
-8
csrc/cumem_allocator.cpp
csrc/cumem_allocator.cpp
+392
-17
csrc/cumem_allocator_compat.h
csrc/cumem_allocator_compat.h
+109
-0
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
+16
-21
csrc/dispatch_utils.h
csrc/dispatch_utils.h
+50
-0
csrc/fused_qknorm_rope_kernel.cu
csrc/fused_qknorm_rope_kernel.cu
+428
-0
csrc/launch_bounds_utils.h
csrc/launch_bounds_utils.h
+29
-3
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+89
-270
csrc/layernorm_quant_kernels.cu
csrc/layernorm_quant_kernels.cu
+51
-18
csrc/mamba/mamba_ssm/selective_scan.h
csrc/mamba/mamba_ssm/selective_scan.h
+7
-1
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
+113
-21
csrc/moe/dynamic_4bit_int_moe_cpu.cpp
csrc/moe/dynamic_4bit_int_moe_cpu.cpp
+19
-28
csrc/moe/grouped_topk_kernels.cu
csrc/moe/grouped_topk_kernels.cu
+99
-54
csrc/moe/marlin_moe_wna16/.gitignore
csrc/moe/marlin_moe_wna16/.gitignore
+2
-1
csrc/moe/marlin_moe_wna16/generate_kernels.py
csrc/moe/marlin_moe_wna16/generate_kernels.py
+239
-85
No files found.
Too many changes to show.
To preserve performance only
380 of 380+
files are displayed.
Plain diff
Email patch
csrc/cpu/shm.cpp
View file @
41199996
...
...
@@ -192,7 +192,7 @@ class SHMManager {
const
int
group_size
)
:
_rank
(
rank
),
_group_size
(
group_size
),
_thread_num
(
torch
::
get_
num
_threads
()),
_thread_num
(
omp_
get_
max
_threads
()),
_shm_names
({
""
}),
_shared_mem_ptrs
({
nullptr
}),
_shm_ctx
(
nullptr
)
{
...
...
csrc/cpu/torch_bindings.cpp
View file @
41199996
...
...
@@ -27,6 +27,8 @@ int64_t create_onednn_mm_handler(const torch::Tensor& b,
void
onednn_mm
(
torch
::
Tensor
&
c
,
const
torch
::
Tensor
&
a
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
,
int64_t
handler
);
bool
is_onednn_acl_supported
();
void
mla_decode_kvcache
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
kv_cache
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
);
...
...
@@ -72,25 +74,45 @@ at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2,
const
std
::
optional
<
at
::
Tensor
>&
bias
,
at
::
ScalarType
out_dtype
,
bool
is_vnni
);
torch
::
Tensor
get_scheduler_metadata
(
const
int64_t
num_req
,
const
int64_t
num_heads_q
,
const
int64_t
num_heads_kv
,
const
int64_t
head_dim
,
const
torch
::
Tensor
&
seq_lens
,
at
::
ScalarType
dtype
,
const
torch
::
Tensor
&
query_start_loc
,
const
bool
casual
,
const
int64_t
window_size
,
const
std
::
string
&
isa_hint
,
const
bool
enable_kv_split
);
void
cpu_attn_reshape_and_cache
(
const
torch
::
Tensor
&
key
,
const
torch
::
Tensor
&
value
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
const
torch
::
Tensor
&
slot_mapping
,
const
std
::
string
&
isa
);
void
cpu_attention_with_kv_cache
(
const
torch
::
Tensor
&
query
,
const
torch
::
Tensor
&
key_cache
,
const
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
query_start_loc
,
const
torch
::
Tensor
&
seq_lens
,
const
double
scale
,
const
bool
causal
,
const
std
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
int64_t
sliding_window_left
,
const
int64_t
sliding_window_right
,
const
torch
::
Tensor
&
block_table
,
const
double
softcap
,
const
torch
::
Tensor
&
scheduler_metadata
,
const
std
::
optional
<
torch
::
Tensor
>&
s_aux
);
// Note: just for avoiding importing errors
void
placeholder_op
()
{
TORCH_CHECK
(
false
,
"Unimplemented"
);
}
void
cpu_gemm_wna16
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
q_weight
,
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
scales
,
const
std
::
optional
<
torch
::
Tensor
>&
zeros
,
const
std
::
optional
<
torch
::
Tensor
>&
g_idx
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
,
const
int64_t
pack_factor
,
const
std
::
string
&
isa_hint
);
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
ops
)
{
// vLLM custom ops
// Attention ops
// Compute the attention between an input query and the cached keys/values
// using PagedAttention.
ops
.
def
(
"paged_attention_v1("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v1"
,
torch
::
kCPU
,
&
paged_attention_v1
);
ops
.
def
(
"dynamic_4bit_int_moe("
"Tensor x, Tensor topk_ids, Tensor topk_weights,"
...
...
@@ -100,20 +122,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
impl
(
"dynamic_4bit_int_moe"
,
torch
::
kCPU
,
&
dynamic_4bit_int_moe_cpu
);
// PagedAttention V2.
ops
.
def
(
"paged_attention_v2("
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v2"
,
torch
::
kCPU
,
&
paged_attention_v2
);
// Activation ops
// Activation function used in SwiGLU.
...
...
@@ -164,7 +172,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Quantization
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) || \
defined(__powerpc64__)
at
::
Tag
stride_tag
=
at
::
Tag
::
needs_fixed_stride_order
;
// Helper function to release oneDNN handlers
ops
.
def
(
"release_dnnl_matmul_handler(int handler) -> ()"
,
&
release_dnnl_matmul_handler
);
...
...
@@ -181,6 +188,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int handler) -> ()"
);
ops
.
impl
(
"onednn_mm"
,
torch
::
kCPU
,
&
onednn_mm
);
// Check if oneDNN was built with ACL backend
ops
.
def
(
"is_onednn_acl_supported() -> bool"
,
&
is_onednn_acl_supported
);
// Create oneDNN W8A8 handler
ops
.
def
(
"create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType "
...
...
@@ -197,15 +207,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Compute int8 quantized tensor for given scaling factor.
ops
.
def
(
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
"Tensor? azp) -> ()"
,
{
stride_tag
});
"Tensor? azp) -> ()"
);
ops
.
impl
(
"static_scaled_int8_quant"
,
torch
::
kCPU
,
&
static_scaled_int8_quant
);
// Compute int8 quantized tensor and scaling factor
ops
.
def
(
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
"Tensor!? azp) -> ()"
,
{
stride_tag
});
"Tensor!? azp) -> ()"
);
ops
.
impl
(
"dynamic_scaled_int8_quant"
,
torch
::
kCPU
,
&
dynamic_scaled_int8_quant
);
#endif
...
...
@@ -254,37 +262,40 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
impl
(
"int8_scaled_mm_with_quant"
,
torch
::
kCPU
,
&
int8_scaled_mm_with_quant
);
#endif
}
TORCH_LIBRARY_EXPAND
(
CONCAT
(
TORCH_EXTENSION_NAME
,
_cache_ops
),
cache_ops
)
{
// Cache ops
// Swap in (out) the cache blocks from src to dst.
cache_ops
.
def
(
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()"
);
cache_ops
.
impl
(
"swap_blocks"
,
torch
::
kCPU
,
&
swap_blocks
);
// Copy the cache blocks from src to dst.
cache_ops
.
def
(
"copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
"Tensor block_mapping) -> ()"
);
cache_ops
.
impl
(
"copy_blocks"
,
torch
::
kCPU
,
&
copy_blocks
);
// Reshape the key and value tensors and cache them.
cache_ops
.
def
(
"reshape_and_cache(Tensor key, Tensor value,"
" Tensor! key_cache, Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale) -> ()"
);
cache_ops
.
impl
(
"reshape_and_cache"
,
torch
::
kCPU
,
&
reshape_and_cache
);
cache_ops
.
def
(
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
" Tensor! kv_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor scale) -> ()"
);
cache_ops
.
impl
(
"concat_and_cache_mla"
,
torch
::
kCPU
,
&
concat_and_cache_mla
);
// CPU attention kernels
ops
.
def
(
"get_scheduler_metadata(int num_req, int num_heads_q, int num_heads_kv, "
"int head_dim, Tensor seq_lens, ScalarType dtype, Tensor "
"query_start_loc, bool casual, int window_size, str isa_hint, bool "
"enable_kv_split) -> Tensor"
,
&
get_scheduler_metadata
);
ops
.
def
(
"cpu_attn_reshape_and_cache(Tensor key, Tensor value, Tensor(a2!) "
"key_cache, Tensor(a3!) value_cache, Tensor slot_mapping, str "
"isa) -> ()"
,
&
cpu_attn_reshape_and_cache
);
ops
.
def
(
"cpu_attention_with_kv_cache(Tensor query, Tensor key_cache, Tensor "
"value_cache, Tensor(a3!) output, Tensor query_start_loc, Tensor "
"seq_lens, float scale, bool causal, Tensor? alibi_slopes, SymInt "
"sliding_window_left, SymInt sliding_window_right, Tensor block_table, "
"float softcap, Tensor sheduler_metadata, Tensor? s_aux) -> ()"
,
&
cpu_attention_with_kv_cache
);
// placeholders
ops
.
def
(
"static_scaled_fp8_quant() -> ()"
,
placeholder_op
);
ops
.
def
(
"dynamic_scaled_fp8_quant() -> ()"
,
placeholder_op
);
ops
.
def
(
"dynamic_per_token_scaled_fp8_quant() -> ()"
,
placeholder_op
);
// WNA16
#if defined(__AVX512F__)
ops
.
def
(
"cpu_gemm_wna16(Tensor input, Tensor q_weight, Tensor(a2!) output, "
"Tensor scales, Tensor? zeros, Tensor? g_idx, Tensor? bias, SymInt "
"pack_factor, str isa_hint) -> ()"
);
ops
.
impl
(
"cpu_gemm_wna16"
,
torch
::
kCPU
,
&
cpu_gemm_wna16
);
#endif
}
TORCH_LIBRARY_EXPAND
(
CONCAT
(
TORCH_EXTENSION_NAME
,
_utils
),
utils
)
{
...
...
csrc/cpu/utils.cpp
View file @
41199996
...
...
@@ -45,21 +45,55 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
// Memory node binding
if
(
numa_available
()
!=
-
1
)
{
int
mem_node_id
=
numa_node_of_cpu
(
omp_cpu_ids
.
front
());
bitmask
*
mask
=
numa_parse_nodestring
(
std
::
to_string
(
mem_node_id
).
c_str
());
bitmask
*
src_mask
=
numa_get_membind
();
int
pid
=
getpid
();
// move all existing pages to the specified numa node.
*
(
src_mask
->
maskp
)
=
*
(
src_mask
->
maskp
)
^
*
(
mask
->
maskp
);
int
page_num
=
numa_migrate_pages
(
pid
,
src_mask
,
mask
);
if
(
page_num
==
-
1
)
{
TORCH_WARN
(
"numa_migrate_pages failed. errno: "
+
std
::
to_string
(
errno
));
std
::
set
<
int
>
node_ids
;
for
(
const
auto
&
cpu_id
:
omp_cpu_ids
)
{
int
node_id
=
numa_node_of_cpu
(
cpu_id
);
if
(
node_id
!=
-
1
)
{
node_ids
.
insert
(
node_id
);
}
if
(
node_id
!=
mem_node_id
)
{
TORCH_WARN
(
"CPU "
,
cpu_id
,
" is on NUMA node "
,
node_id
,
", but CPU "
,
omp_cpu_ids
.
front
(),
" is on NUMA node "
,
mem_node_id
,
". All CPUs should be on the same NUMA node for optimal "
"performance. Memory will be bound to NUMA node "
,
mem_node_id
,
"."
);
}
}
// Concatenate all node_ids into a single comma-separated string
if
(
!
node_ids
.
empty
())
{
std
::
string
node_ids_str
;
for
(
const
int
node_id
:
node_ids
)
{
if
(
!
node_ids_str
.
empty
())
{
node_ids_str
+=
","
;
}
node_ids_str
+=
std
::
to_string
(
node_id
);
}
// restrict memory allocation node.
numa_set_membind
(
mask
);
numa_set_strict
(
1
);
bitmask
*
mask
=
numa_parse_nodestring
(
node_ids_str
.
c_str
());
bitmask
*
src_mask
=
numa_get_membind
();
int
pid
=
getpid
();
if
(
mask
&&
src_mask
)
{
// move all existing pages to the specified numa node.
*
(
src_mask
->
maskp
)
=
*
(
src_mask
->
maskp
)
^
*
(
mask
->
maskp
);
int
page_num
=
numa_migrate_pages
(
pid
,
src_mask
,
mask
);
if
(
page_num
==
-
1
)
{
TORCH_WARN
(
"numa_migrate_pages failed. errno: "
+
std
::
to_string
(
errno
));
}
// restrict memory allocation node.
numa_set_membind
(
mask
);
numa_set_strict
(
1
);
numa_free_nodemask
(
mask
);
numa_free_nodemask
(
src_mask
);
}
else
{
TORCH_WARN
(
"numa_parse_nodestring or numa_get_membind failed. errno: "
+
std
::
to_string
(
errno
));
}
}
}
// OMP threads binding
...
...
csrc/cpu/utils.hpp
0 → 100644
View file @
41199996
#ifndef UTILS_HPP
#define UTILS_HPP
#include <atomic>
#include <cassert>
#include <cstdint>
#include <unistd.h>
#if defined(__APPLE__)
#include <sys/sysctl.h>
#endif
#include "cpu_types.hpp"
namespace
cpu_utils
{
enum
class
ISA
{
AMX
,
VEC
};
template
<
typename
T
>
struct
VecTypeTrait
{
using
vec_t
=
void
;
};
template
<
>
struct
VecTypeTrait
<
float
>
{
using
vec_t
=
vec_op
::
FP32Vec16
;
};
#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
template
<
>
struct
VecTypeTrait
<
c10
::
BFloat16
>
{
using
vec_t
=
vec_op
::
BF16Vec16
;
};
#endif
template
<
>
struct
VecTypeTrait
<
c10
::
Half
>
{
using
vec_t
=
vec_op
::
FP16Vec16
;
};
struct
Counter
{
std
::
atomic
<
int64_t
>
counter
;
char
_padding
[
56
];
Counter
()
:
counter
(
0
)
{}
void
reset_counter
()
{
counter
.
store
(
0
);
}
int64_t
acquire_counter
()
{
return
counter
++
;
}
};
inline
int64_t
get_l2_size
()
{
static
int64_t
size
=
[]()
{
#if defined(__APPLE__)
// macOS doesn't have _SC_LEVEL2_CACHE_SIZE. Use sysctlbyname.
int64_t
l2_cache_size
=
0
;
size_t
len
=
sizeof
(
l2_cache_size
);
if
(
sysctlbyname
(
"hw.l2cachesize"
,
&
l2_cache_size
,
&
len
,
NULL
,
0
)
==
0
&&
l2_cache_size
>
0
)
{
return
l2_cache_size
>>
1
;
// use 50% of L2 cache
}
// Fallback if sysctlbyname fails
return
128LL
*
1024
>>
1
;
// use 50% of 128KB
#else
long
l2_cache_size
=
sysconf
(
_SC_LEVEL2_CACHE_SIZE
);
assert
(
l2_cache_size
!=
-
1
);
return
l2_cache_size
>>
1
;
// use 50% of L2 cache
#endif
}();
return
size
;
}
}
// namespace cpu_utils
#endif
csrc/cub_helpers.h
View file @
41199996
...
...
@@ -12,6 +12,7 @@ using CubMaxOp = cub::Max;
#endif // CUB_VERSION
#else
#include <hipcub/hipcub.hpp>
using
CubAddOp
=
cub
::
Sum
;
using
CubMaxOp
=
cub
::
Max
;
namespace
cub
=
hipcub
;
using
CubAddOp
=
hipcub
::
Sum
;
using
CubMaxOp
=
hipcub
::
Max
;
#endif // USE_ROCM
csrc/cuda_view.cu
View file @
41199996
...
...
@@ -22,15 +22,10 @@ torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
auto
strides
=
cpu_tensor
.
strides
();
auto
options
=
cpu_tensor
.
options
().
device
(
torch
::
kCUDA
);
// from_blob signature: from_blob(void *data, IntArrayRef sizes, ..., Deleter,
// const TensorOptions &) Provide a no-op deleter. The CPU tensor holds the
// memory, so we don't free it here.
auto
deleter
=
[](
void
*
)
{
// no-op, since the memory is owned by the original CPU tensor
};
// use default no-op deleter, since the memory is owned by the original CPU
// tensor
torch
::
Tensor
cuda_tensor
=
torch
::
from_blob
(
device_ptr
,
sizes
,
strides
,
deleter
,
options
);
torch
::
from_blob
(
device_ptr
,
sizes
,
strides
,
options
);
TORCH_CHECK
(
cuda_tensor
.
device
().
is_cuda
(),
"Resulting tensor is not on CUDA device"
);
...
...
csrc/cumem_allocator.cpp
View file @
41199996
...
...
@@ -3,14 +3,58 @@
// need to be unsigned long long
#include <iostream>
#include "cumem_allocator_compat.h"
#ifndef USE_ROCM
static
const
char
*
PYARGS_PARSE
=
"KKKK"
;
#else
#include <cstdlib>
#include <cerrno>
#include <climits>
// Default chunk size 256MB for ROCm. Can be overridden at runtime by the
// environment variable VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE, specified in megabytes
// (MB). The env value is parsed with strtoull as an integer number of MB
// (decimal or 0x hex). The parsed MB value is converted to bytes. If
// parsing fails, the value is 0, or the multiplication would overflow,
// the default (256MB) is used.
static
const
unsigned
long
long
DEFAULT_MEMCREATE_CHUNK_SIZE
=
(
256ULL
*
1024ULL
*
1024ULL
);
static
unsigned
long
long
get_memcreate_chunk_size
()
{
const
char
*
env
=
getenv
(
"VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE"
);
if
(
!
env
)
return
DEFAULT_MEMCREATE_CHUNK_SIZE
;
char
*
endptr
=
nullptr
;
errno
=
0
;
unsigned
long
long
val_mb
=
strtoull
(
env
,
&
endptr
,
0
);
if
(
endptr
==
env
||
errno
!=
0
)
{
// parsing failed, fallback to default
return
DEFAULT_MEMCREATE_CHUNK_SIZE
;
}
if
(
val_mb
==
0
)
return
DEFAULT_MEMCREATE_CHUNK_SIZE
;
const
unsigned
long
long
MB
=
1024ULL
*
1024ULL
;
// guard against overflow when converting MB -> bytes
if
(
val_mb
>
(
ULLONG_MAX
/
MB
))
{
return
DEFAULT_MEMCREATE_CHUNK_SIZE
;
}
return
val_mb
*
MB
;
}
static
inline
unsigned
long
long
my_min
(
unsigned
long
long
a
,
unsigned
long
long
b
)
{
return
a
<
b
?
a
:
b
;
}
static
const
char
*
PYARGS_PARSE
=
"KKKO"
;
#endif
extern
"C"
{
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <sys/types.h>
#include <cuda_runtime_api.h>
#include <cuda.h>
char
error_msg
[
10240
];
// 10KB buffer to store error messages
CUresult
no_error
=
CUresult
(
0
);
...
...
@@ -49,7 +93,12 @@ void ensure_context(unsigned long long device) {
}
void
create_and_map
(
unsigned
long
long
device
,
ssize_t
size
,
CUdeviceptr
d_mem
,
#ifndef USE_ROCM
CUmemGenericAllocationHandle
*
p_memHandle
)
{
#else
CUmemGenericAllocationHandle
**
p_memHandle
,
unsigned
long
long
*
chunk_sizes
,
size_t
num_chunks
)
{
#endif
ensure_context
(
device
);
// Define memory allocation properties
CUmemAllocationProp
prop
=
{};
...
...
@@ -58,6 +107,7 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
prop
.
location
.
id
=
device
;
prop
.
allocFlags
.
compressionType
=
CU_MEM_ALLOCATION_COMP_NONE
;
#ifndef USE_ROCM
// Allocate memory using cuMemCreate
CUDA_CHECK
(
cuMemCreate
(
p_memHandle
,
size
,
&
prop
,
0
));
if
(
error_code
!=
0
)
{
...
...
@@ -67,6 +117,39 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
if
(
error_code
!=
0
)
{
return
;
}
#else
for
(
auto
i
=
0
;
i
<
num_chunks
;
++
i
)
{
CUDA_CHECK
(
cuMemCreate
(
p_memHandle
[
i
],
chunk_sizes
[
i
],
&
prop
,
0
));
if
(
error_code
!=
0
)
{
// Clean up previously created handles
for
(
auto
j
=
0
;
j
<
i
;
++
j
)
{
cuMemRelease
(
*
(
p_memHandle
[
j
]));
}
return
;
}
}
unsigned
long
long
allocated_size
=
0
;
for
(
auto
i
=
0
;
i
<
num_chunks
;
++
i
)
{
void
*
map_addr
=
(
void
*
)((
uintptr_t
)
d_mem
+
allocated_size
);
CUDA_CHECK
(
cuMemMap
(
map_addr
,
chunk_sizes
[
i
],
0
,
*
(
p_memHandle
[
i
]),
0
));
if
(
error_code
!=
0
)
{
// unmap previously mapped chunks
unsigned
long
long
unmapped_size
=
0
;
for
(
auto
j
=
0
;
j
<
i
;
++
j
)
{
void
*
unmap_addr
=
(
void
*
)((
uintptr_t
)
d_mem
+
unmapped_size
);
cuMemUnmap
(
unmap_addr
,
chunk_sizes
[
j
]);
unmapped_size
+=
chunk_sizes
[
j
];
}
// release all created handles
for
(
auto
j
=
0
;
j
<
num_chunks
;
++
j
)
{
cuMemRelease
(
*
(
p_memHandle
[
j
]));
}
return
;
}
allocated_size
+=
chunk_sizes
[
i
];
}
#endif
CUmemAccessDesc
accessDesc
=
{};
accessDesc
.
location
.
type
=
CU_MEM_LOCATION_TYPE_DEVICE
;
accessDesc
.
location
.
id
=
device
;
...
...
@@ -82,10 +165,16 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
void
unmap_and_release
(
unsigned
long
long
device
,
ssize_t
size
,
CUdeviceptr
d_mem
,
#ifndef USE_ROCM
CUmemGenericAllocationHandle
*
p_memHandle
)
{
#else
CUmemGenericAllocationHandle
**
p_memHandle
,
unsigned
long
long
*
chunk_sizes
,
size_t
num_chunks
)
{
#endif
// std::cout << "unmap_and_release: device=" << device << ", size=" << size <<
// ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
ensure_context
(
device
);
#ifndef USE_ROCM
CUDA_CHECK
(
cuMemUnmap
(
d_mem
,
size
));
if
(
error_code
!=
0
)
{
return
;
...
...
@@ -94,6 +183,30 @@ void unmap_and_release(unsigned long long device, ssize_t size,
if
(
error_code
!=
0
)
{
return
;
}
#else
unsigned
long
long
allocated_size
=
0
;
CUresult
first_error
=
no_error
;
for
(
auto
i
=
0
;
i
<
num_chunks
;
++
i
)
{
void
*
map_addr
=
(
void
*
)((
uintptr_t
)
d_mem
+
allocated_size
);
CUresult
status
=
cuMemUnmap
(
map_addr
,
chunk_sizes
[
i
]);
if
(
status
!=
no_error
&&
first_error
==
no_error
)
{
first_error
=
status
;
}
allocated_size
+=
chunk_sizes
[
i
];
}
for
(
auto
i
=
0
;
i
<
num_chunks
;
++
i
)
{
CUresult
status
=
cuMemRelease
(
*
(
p_memHandle
[
i
]));
if
(
status
!=
no_error
&&
first_error
==
no_error
)
{
first_error
=
status
;
}
}
if
(
first_error
!=
no_error
)
{
CUDA_CHECK
(
first_error
);
}
#endif
}
PyObject
*
create_tuple_from_c_integers
(
unsigned
long
long
a
,
...
...
@@ -120,6 +233,36 @@ PyObject* create_tuple_from_c_integers(unsigned long long a,
return
tuple
;
// Return the created tuple
}
PyObject
*
create_tuple_from_c_mixed
(
unsigned
long
long
a
,
unsigned
long
long
b
,
unsigned
long
long
c
,
CUmemGenericAllocationHandle
**
vec
,
unsigned
long
long
*
chunk_sizes
,
size_t
num_chunks
)
{
PyObject
*
tuple
=
PyTuple_New
(
4
);
if
(
!
tuple
)
{
return
NULL
;
}
// PyObject* list = PyList_New(vec.size());
PyObject
*
list
=
PyList_New
(
num_chunks
);
for
(
auto
i
=
0
;
i
<
num_chunks
;
++
i
)
{
PyObject
*
addr_size_pair
=
PyTuple_New
(
2
);
PyObject
*
addr
=
PyLong_FromUnsignedLongLong
((
unsigned
long
long
)(
vec
[
i
]));
PyObject
*
size
=
PyLong_FromUnsignedLongLong
((
unsigned
long
long
)(
chunk_sizes
[
i
]));
PyTuple_SetItem
(
addr_size_pair
,
0
,
addr
);
PyTuple_SetItem
(
addr_size_pair
,
1
,
size
);
PyList_SetItem
(
list
,
i
,
addr_size_pair
);
}
PyTuple_SetItem
(
tuple
,
0
,
PyLong_FromUnsignedLongLong
(
a
));
PyTuple_SetItem
(
tuple
,
1
,
PyLong_FromUnsignedLongLong
(
b
));
PyTuple_SetItem
(
tuple
,
2
,
PyLong_FromUnsignedLongLong
(
c
));
PyTuple_SetItem
(
tuple
,
3
,
list
);
return
tuple
;
}
// ---------------------------------------------------------------------------
// Our exported C functions that call Python:
...
...
@@ -147,14 +290,55 @@ void* my_malloc(ssize_t size, int device, CUstream stream) {
size_t
alignedSize
=
((
size
+
granularity
-
1
)
/
granularity
)
*
granularity
;
CUdeviceptr
d_mem
;
#ifndef USE_ROCM
CUDA_CHECK
(
cuMemAddressReserve
(
&
d_mem
,
alignedSize
,
0
,
0
,
0
));
if
(
error_code
!=
0
)
{
return
nullptr
;
}
#else
CUDA_CHECK
(
cuMemAddressReserve
(
&
d_mem
,
alignedSize
,
granularity
,
0
,
0
));
if
(
error_code
!=
0
)
{
return
nullptr
;
}
#endif
#ifndef USE_ROCM
// allocate the CUmemGenericAllocationHandle
CUmemGenericAllocationHandle
*
p_memHandle
=
(
CUmemGenericAllocationHandle
*
)
malloc
(
sizeof
(
CUmemGenericAllocationHandle
));
#else
// Make sure chunk size is aligned with hardware granularity. The base
// chunk size can be configured via environment variable
// ``VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE``; otherwise
// DEFAULT_MEMCREATE_CHUNK_SIZE is used.
size_t
base_chunk
=
(
size_t
)
get_memcreate_chunk_size
();
size_t
aligned_chunk_size
=
((
base_chunk
+
granularity
-
1
)
/
granularity
)
*
granularity
;
size_t
num_chunks
=
(
alignedSize
+
aligned_chunk_size
-
1
)
/
aligned_chunk_size
;
CUmemGenericAllocationHandle
**
p_memHandle
=
(
CUmemGenericAllocationHandle
**
)
malloc
(
num_chunks
*
sizeof
(
CUmemGenericAllocationHandle
*
));
unsigned
long
long
*
chunk_sizes
=
(
unsigned
long
long
*
)
malloc
(
num_chunks
*
sizeof
(
unsigned
long
long
));
for
(
auto
i
=
0
;
i
<
num_chunks
;
++
i
)
{
p_memHandle
[
i
]
=
(
CUmemGenericAllocationHandle
*
)
malloc
(
sizeof
(
CUmemGenericAllocationHandle
));
if
(
p_memHandle
[
i
]
==
nullptr
)
{
std
::
cerr
<<
"ERROR: malloc failed for p_memHandle["
<<
i
<<
"].
\n
"
;
for
(
auto
j
=
0
;
j
<
i
;
++
j
)
{
free
(
p_memHandle
[
j
]);
}
free
(
p_memHandle
);
free
(
chunk_sizes
);
return
nullptr
;
}
chunk_sizes
[
i
]
=
(
unsigned
long
long
)
my_min
(
(
unsigned
long
long
)(
alignedSize
-
i
*
aligned_chunk_size
),
(
unsigned
long
long
)
aligned_chunk_size
);
}
#endif
if
(
!
g_python_malloc_callback
)
{
std
::
cerr
<<
"ERROR: g_python_malloc_callback not set.
\n
"
;
...
...
@@ -164,9 +348,15 @@ void* my_malloc(ssize_t size, int device, CUstream stream) {
// Acquire GIL (not in stable ABI officially, but often works)
PyGILState_STATE
gstate
=
PyGILState_Ensure
();
#ifndef USE_ROCM
PyObject
*
arg_tuple
=
create_tuple_from_c_integers
(
(
unsigned
long
long
)
device
,
(
unsigned
long
long
)
alignedSize
,
(
unsigned
long
long
)
d_mem
,
(
unsigned
long
long
)
p_memHandle
);
#else
PyObject
*
arg_tuple
=
create_tuple_from_c_mixed
(
(
unsigned
long
long
)
device
,
(
unsigned
long
long
)
alignedSize
,
(
unsigned
long
long
)
d_mem
,
p_memHandle
,
chunk_sizes
,
num_chunks
);
#endif
// Call g_python_malloc_callback
PyObject
*
py_result
=
...
...
@@ -182,7 +372,27 @@ void* my_malloc(ssize_t size, int device, CUstream stream) {
PyGILState_Release
(
gstate
);
// do the final mapping
#ifndef USE_ROCM
create_and_map
(
device
,
alignedSize
,
d_mem
,
p_memHandle
);
#else
create_and_map
(
device
,
alignedSize
,
d_mem
,
p_memHandle
,
chunk_sizes
,
num_chunks
);
free
(
chunk_sizes
);
#endif
if
(
error_code
!=
0
)
{
// free address and the handle
CUDA_CHECK
(
cuMemAddressFree
(
d_mem
,
alignedSize
));
#ifndef USE_ROCM
free
(
p_memHandle
);
#else
for
(
size_t
i
=
0
;
i
<
num_chunks
;
++
i
)
{
free
(
p_memHandle
[
i
]);
}
free
(
p_memHandle
);
#endif
return
nullptr
;
}
return
(
void
*
)
d_mem
;
}
...
...
@@ -206,36 +416,96 @@ void my_free(void* ptr, ssize_t size, int device, CUstream stream) {
if
(
!
py_result
||
!
PyTuple_Check
(
py_result
)
||
PyTuple_Size
(
py_result
)
!=
4
)
{
PyErr_SetString
(
PyExc_TypeError
,
"Expected a tuple of size 4"
);
Py_XDECREF
(
py_result
);
Py_XDECREF
(
py_ptr
);
return
;
}
unsigned
long
long
recv_device
,
recv_size
;
unsigned
long
long
recv_d_mem
,
recv_p_memHandle
;
unsigned
long
long
recv_d_mem
;
#ifndef USE_ROCM
unsigned
long
long
recv_p_memHandle
;
#else
PyObject
*
recv_p_memHandle
;
#endif
// Unpack the tuple into four C integers
if
(
!
PyArg_ParseTuple
(
py_result
,
"KKKK"
,
&
recv_device
,
&
recv_size
,
if
(
!
PyArg_ParseTuple
(
py_result
,
PYARGS_PARSE
,
&
recv_device
,
&
recv_size
,
&
recv_d_mem
,
&
recv_p_memHandle
))
{
// PyArg_ParseTuple sets an error if it fails
Py_XDECREF
(
py_result
);
Py_XDECREF
(
py_ptr
);
return
;
}
PyGILState_Release
(
gstate
);
// For ROCm, copy the Python list of (addr,size) pairs into C arrays while
// holding the GIL. Then release the GIL and call the unmap/release helper
// using the copied arrays. This avoids calling PyList_* APIs without the
// GIL (which is undefined behavior and can crash when called from other
// threads).
CUdeviceptr
d_mem
=
(
CUdeviceptr
)
recv_d_mem
;
#ifdef USE_ROCM
Py_ssize_t
num_chunks
=
PyList_Size
(
recv_p_memHandle
);
CUmemGenericAllocationHandle
**
p_memHandle
=
(
CUmemGenericAllocationHandle
**
)
malloc
(
num_chunks
*
sizeof
(
CUmemGenericAllocationHandle
*
));
if
(
p_memHandle
==
nullptr
)
{
Py_DECREF
(
py_ptr
);
Py_DECREF
(
py_result
);
PyGILState_Release
(
gstate
);
std
::
cerr
<<
"ERROR: malloc failed for p_memHandle in my_free."
<<
std
::
endl
;
return
;
}
unsigned
long
long
*
chunk_sizes
=
(
unsigned
long
long
*
)
malloc
(
num_chunks
*
sizeof
(
unsigned
long
long
));
if
(
chunk_sizes
==
nullptr
)
{
free
(
p_memHandle
);
Py_DECREF
(
py_ptr
);
Py_DECREF
(
py_result
);
PyGILState_Release
(
gstate
);
std
::
cerr
<<
"ERROR: malloc failed for chunk_sizes in my_free."
<<
std
::
endl
;
return
;
}
for
(
Py_ssize_t
i
=
0
;
i
<
num_chunks
;
++
i
)
{
PyObject
*
item
=
PyList_GetItem
(
recv_p_memHandle
,
i
);
PyObject
*
addr_py
=
PyTuple_GetItem
(
item
,
0
);
PyObject
*
size_py
=
PyTuple_GetItem
(
item
,
1
);
p_memHandle
[
i
]
=
(
CUmemGenericAllocationHandle
*
)
PyLong_AsUnsignedLongLong
(
addr_py
);
chunk_sizes
[
i
]
=
(
unsigned
long
long
)
PyLong_AsUnsignedLongLong
(
size_py
);
}
// recv_size == size
// recv_device == device
// Drop temporary Python refs, then release the GIL before calling into
// non-Python APIs.
Py_DECREF
(
py_ptr
);
Py_DECREF
(
py_result
);
PyGILState_Release
(
gstate
);
// Free memory
unmap_and_release
(
device
,
size
,
d_mem
,
p_memHandle
,
chunk_sizes
,
num_chunks
);
#else
// Non-ROCm path: simple integer handle already extracted; drop temporary
// Python refs while still holding the GIL, then release it.
Py_DECREF
(
py_ptr
);
Py_DECREF
(
py_result
);
PyGILState_Release
(
gstate
);
CUdeviceptr
d_mem
=
(
CUdeviceptr
)
recv_d_mem
;
CUmemGenericAllocationHandle
*
p_memHandle
=
(
CUmemGenericAllocationHandle
*
)
recv_p_memHandle
;
unmap_and_release
(
device
,
size
,
d_mem
,
p_memHandle
);
#endif
// free address and the handle
CUDA_CHECK
(
cuMemAddressFree
(
d_mem
,
size
));
if
(
error_code
!=
0
)
{
return
;
#ifndef USE_ROCM
free
(
p_memHandle
);
#else
for
(
auto
i
=
0
;
i
<
num_chunks
;
++
i
)
{
free
(
p_memHandle
[
i
]);
}
free
(
p_memHandle
);
free
(
chunk_sizes
);
#endif
}
// ---------------------------------------------------------------------------
...
...
@@ -271,19 +541,87 @@ static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) {
}
unsigned
long
long
recv_device
,
recv_size
;
unsigned
long
long
recv_d_mem
,
recv_p_memHandle
;
unsigned
long
long
recv_d_mem
;
#ifndef USE_ROCM
unsigned
long
long
recv_p_memHandle
;
#else
PyObject
*
recv_p_memHandle
;
#endif
// Unpack the tuple into four C integers
if
(
!
PyArg_ParseTuple
(
args
,
"KKKK"
,
&
recv_device
,
&
recv_size
,
&
recv_d_mem
,
&
recv_p_memHandle
))
{
if
(
!
PyArg_ParseTuple
(
args
,
PYARGS_PARSE
,
&
recv_device
,
&
recv_size
,
&
recv_d_mem
,
&
recv_p_memHandle
))
{
// PyArg_ParseTuple sets an error if it fails
return
nullptr
;
}
CUdeviceptr
d_mem_ptr
=
(
CUdeviceptr
)
recv_d_mem
;
#ifndef USE_ROCM
CUmemGenericAllocationHandle
*
p_memHandle
=
(
CUmemGenericAllocationHandle
*
)
recv_p_memHandle
;
unmap_and_release
(
recv_device
,
recv_size
,
d_mem_ptr
,
p_memHandle
);
#else
if
(
!
PyList_Check
(
recv_p_memHandle
))
{
PyErr_SetString
(
PyExc_TypeError
,
"Expected a list for the 4th argument on ROCm"
);
return
nullptr
;
}
Py_ssize_t
num_chunks
=
PyList_Size
(
recv_p_memHandle
);
if
(
num_chunks
<
0
)
{
return
nullptr
;
// PyList_Size sets an exception on error.
}
CUmemGenericAllocationHandle
**
p_memHandle
=
(
CUmemGenericAllocationHandle
**
)
malloc
(
num_chunks
*
sizeof
(
CUmemGenericAllocationHandle
*
));
if
(
p_memHandle
==
nullptr
)
{
PyErr_SetString
(
PyExc_MemoryError
,
"malloc failed for p_memHandle"
);
return
nullptr
;
}
unsigned
long
long
*
chunk_sizes
=
(
unsigned
long
long
*
)
malloc
(
num_chunks
*
sizeof
(
unsigned
long
long
));
if
(
chunk_sizes
==
nullptr
)
{
free
(
p_memHandle
);
PyErr_SetString
(
PyExc_MemoryError
,
"malloc failed for chunk_sizes"
);
return
nullptr
;
}
for
(
Py_ssize_t
i
=
0
;
i
<
num_chunks
;
++
i
)
{
PyObject
*
item
=
PyList_GetItem
(
recv_p_memHandle
,
i
);
if
(
item
==
nullptr
||
!
PyTuple_Check
(
item
)
||
PyTuple_Size
(
item
)
!=
2
)
{
free
(
p_memHandle
);
free
(
chunk_sizes
);
PyErr_SetString
(
PyExc_TypeError
,
"List items must be tuples of size 2 (handle_addr, size)"
);
return
nullptr
;
}
PyObject
*
addr_py
=
PyTuple_GetItem
(
item
,
0
);
PyObject
*
size_py
=
PyTuple_GetItem
(
item
,
1
);
if
(
addr_py
==
nullptr
||
size_py
==
nullptr
)
{
free
(
p_memHandle
);
free
(
chunk_sizes
);
return
nullptr
;
// PyTuple_GetItem sets an exception
}
p_memHandle
[
i
]
=
(
CUmemGenericAllocationHandle
*
)
PyLong_AsUnsignedLongLong
(
addr_py
);
if
(
PyErr_Occurred
())
{
free
(
p_memHandle
);
free
(
chunk_sizes
);
return
nullptr
;
}
chunk_sizes
[
i
]
=
(
unsigned
long
long
)
PyLong_AsUnsignedLongLong
(
size_py
);
if
(
PyErr_Occurred
())
{
free
(
p_memHandle
);
free
(
chunk_sizes
);
return
nullptr
;
}
}
unmap_and_release
(
recv_device
,
recv_size
,
d_mem_ptr
,
p_memHandle
,
chunk_sizes
,
num_chunks
);
free
(
p_memHandle
);
free
(
chunk_sizes
);
#endif
if
(
error_code
!=
0
)
{
error_code
=
no_error
;
...
...
@@ -301,19 +639,56 @@ static PyObject* python_create_and_map(PyObject* self, PyObject* args) {
}
unsigned
long
long
recv_device
,
recv_size
;
unsigned
long
long
recv_d_mem
,
recv_p_memHandle
;
unsigned
long
long
recv_d_mem
;
#ifndef USE_ROCM
unsigned
long
long
recv_p_memHandle
;
#else
PyObject
*
recv_p_memHandle
;
#endif
// Unpack the tuple into four C integers
if
(
!
PyArg_ParseTuple
(
args
,
"KKKK"
,
&
recv_device
,
&
recv_size
,
&
recv_d_mem
,
&
recv_p_memHandle
))
{
if
(
!
PyArg_ParseTuple
(
args
,
PYARGS_PARSE
,
&
recv_device
,
&
recv_size
,
&
recv_d_mem
,
&
recv_p_memHandle
))
{
// PyArg_ParseTuple sets an error if it fails
return
nullptr
;
}
CUdeviceptr
d_mem_ptr
=
(
CUdeviceptr
)
recv_d_mem
;
#ifndef USE_ROCM
CUmemGenericAllocationHandle
*
p_memHandle
=
(
CUmemGenericAllocationHandle
*
)
recv_p_memHandle
;
create_and_map
(
recv_device
,
recv_size
,
d_mem_ptr
,
p_memHandle
);
#else
Py_ssize_t
num_chunks
=
PyList_Size
(
recv_p_memHandle
);
CUmemGenericAllocationHandle
**
p_memHandle
=
(
CUmemGenericAllocationHandle
**
)
malloc
(
num_chunks
*
sizeof
(
CUmemGenericAllocationHandle
*
));
if
(
p_memHandle
==
nullptr
)
{
PyErr_SetString
(
PyExc_MemoryError
,
"malloc failed for p_memHandle"
);
return
nullptr
;
}
unsigned
long
long
*
chunk_sizes
=
(
unsigned
long
long
*
)
malloc
(
num_chunks
*
sizeof
(
unsigned
long
long
));
if
(
chunk_sizes
==
nullptr
)
{
free
(
p_memHandle
);
PyErr_SetString
(
PyExc_MemoryError
,
"malloc failed for chunk_sizes"
);
return
nullptr
;
}
for
(
auto
i
=
0
;
i
<
num_chunks
;
++
i
)
{
PyObject
*
item
=
PyList_GetItem
(
recv_p_memHandle
,
i
);
PyObject
*
addr_py
=
PyTuple_GetItem
(
item
,
0
);
PyObject
*
size_py
=
PyTuple_GetItem
(
item
,
1
);
p_memHandle
[
i
]
=
(
CUmemGenericAllocationHandle
*
)
PyLong_AsUnsignedLongLong
(
addr_py
);
chunk_sizes
[
i
]
=
PyLong_AsUnsignedLongLong
(
size_py
);
}
create_and_map
(
recv_device
,
recv_size
,
d_mem_ptr
,
p_memHandle
,
chunk_sizes
,
num_chunks
);
free
(
p_memHandle
);
free
(
chunk_sizes
);
#endif
if
(
error_code
!=
0
)
{
error_code
=
no_error
;
...
...
csrc/cumem_allocator_compat.h
0 → 100644
View file @
41199996
#pragma once
#ifdef USE_ROCM
////////////////////////////////////////
// For compatibility with CUDA and ROCm
////////////////////////////////////////
#include <hip/hip_runtime_api.h>
extern
"C"
{
#ifndef CUDA_SUCCESS
#define CUDA_SUCCESS hipSuccess
#endif // CUDA_SUCCESS
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html
typedef
unsigned
long
long
CUdevice
;
typedef
hipDeviceptr_t
CUdeviceptr
;
typedef
hipError_t
CUresult
;
typedef
hipCtx_t
CUcontext
;
typedef
hipStream_t
CUstream
;
typedef
hipMemGenericAllocationHandle_t
CUmemGenericAllocationHandle
;
typedef
hipMemAllocationGranularity_flags
CUmemAllocationGranularity_flags
;
typedef
hipMemAllocationProp
CUmemAllocationProp
;
typedef
hipMemAccessDesc
CUmemAccessDesc
;
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_MEM_ALLOC_GRANULARITY_MINIMUM hipMemAllocationGranularityMinimum
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
#define CU_MEM_ALLOCATION_COMP_NONE 0x0
// Error Handling
// https://docs.nvidia.com/cuda/archive/11.4.4/cuda-driver-api/group__CUDA__ERROR.html
CUresult
cuGetErrorString
(
CUresult
hipError
,
const
char
**
pStr
)
{
*
pStr
=
hipGetErrorString
(
hipError
);
return
CUDA_SUCCESS
;
}
// Context Management
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html
CUresult
cuCtxGetCurrent
(
CUcontext
*
ctx
)
{
// This API is deprecated on the AMD platform, only for equivalent cuCtx
// driver API on the NVIDIA platform.
return
hipCtxGetCurrent
(
ctx
);
}
CUresult
cuCtxSetCurrent
(
CUcontext
ctx
)
{
// This API is deprecated on the AMD platform, only for equivalent cuCtx
// driver API on the NVIDIA platform.
return
hipCtxSetCurrent
(
ctx
);
}
// Primary Context Management
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__PRIMARY__CTX.html
CUresult
cuDevicePrimaryCtxRetain
(
CUcontext
*
ctx
,
CUdevice
dev
)
{
return
hipDevicePrimaryCtxRetain
(
ctx
,
dev
);
}
// Virtual Memory Management
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html
CUresult
cuMemAddressFree
(
CUdeviceptr
ptr
,
size_t
size
)
{
return
hipMemAddressFree
(
ptr
,
size
);
}
CUresult
cuMemAddressReserve
(
CUdeviceptr
*
ptr
,
size_t
size
,
size_t
alignment
,
CUdeviceptr
addr
,
unsigned
long
long
flags
)
{
return
hipMemAddressReserve
(
ptr
,
size
,
alignment
,
addr
,
flags
);
}
CUresult
cuMemCreate
(
CUmemGenericAllocationHandle
*
handle
,
size_t
size
,
const
CUmemAllocationProp
*
prop
,
unsigned
long
long
flags
)
{
return
hipMemCreate
(
handle
,
size
,
prop
,
flags
);
}
CUresult
cuMemGetAllocationGranularity
(
size_t
*
granularity
,
const
CUmemAllocationProp
*
prop
,
CUmemAllocationGranularity_flags
option
)
{
return
hipMemGetAllocationGranularity
(
granularity
,
prop
,
option
);
}
CUresult
cuMemMap
(
CUdeviceptr
dptr
,
size_t
size
,
size_t
offset
,
CUmemGenericAllocationHandle
handle
,
unsigned
long
long
flags
)
{
return
hipMemMap
(
dptr
,
size
,
offset
,
handle
,
flags
);
}
CUresult
cuMemRelease
(
CUmemGenericAllocationHandle
handle
)
{
return
hipMemRelease
(
handle
);
}
CUresult
cuMemSetAccess
(
CUdeviceptr
ptr
,
size_t
size
,
const
CUmemAccessDesc
*
desc
,
size_t
count
)
{
return
hipMemSetAccess
(
ptr
,
size
,
desc
,
count
);
}
CUresult
cuMemUnmap
(
CUdeviceptr
ptr
,
size_t
size
)
{
return
hipMemUnmap
(
ptr
,
size
);
}
}
// extern "C"
#else
////////////////////////////////////////
// Import CUDA headers for NVIDIA GPUs
////////////////////////////////////////
#include <cuda_runtime_api.h>
#include <cuda.h>
#endif
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
View file @
41199996
...
...
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
enum
from
typing
import
Union
from
cutlass_library
import
*
...
...
@@ -22,31 +21,31 @@ class MixedInputKernelScheduleType(enum.Enum):
TmaWarpSpecializedCooperative
=
enum_auto
()
VLLMDataTypeNames
:
dict
[
Union
[
VLLMDataType
,
DataType
]
,
str
]
=
{
VLLMDataTypeNames
:
dict
[
VLLMDataType
|
DataType
,
str
]
=
{
**
DataTypeNames
,
# type: ignore
**
{
VLLMDataType
.
u4b8
:
"u4b8"
,
VLLMDataType
.
u8b128
:
"u8b128"
,
}
}
,
}
VLLMDataTypeTag
:
dict
[
Union
[
VLLMDataType
,
DataType
]
,
str
]
=
{
VLLMDataTypeTag
:
dict
[
VLLMDataType
|
DataType
,
str
]
=
{
**
DataTypeTag
,
# type: ignore
**
{
VLLMDataType
.
u4b8
:
"cutlass::vllm_uint4b8_t"
,
VLLMDataType
.
u8b128
:
"cutlass::vllm_uint8b128_t"
,
}
}
,
}
VLLMDataTypeSize
:
dict
[
Union
[
VLLMDataType
,
DataType
]
,
int
]
=
{
VLLMDataTypeSize
:
dict
[
VLLMDataType
|
DataType
,
int
]
=
{
**
DataTypeSize
,
# type: ignore
**
{
VLLMDataType
.
u4b8
:
4
,
VLLMDataType
.
u8b128
:
8
,
}
}
,
}
VLLMDataTypeVLLMScalarTypeTag
:
dict
[
Union
[
VLLMDataType
,
DataType
]
,
str
]
=
{
VLLMDataTypeVLLMScalarTypeTag
:
dict
[
VLLMDataType
|
DataType
,
str
]
=
{
VLLMDataType
.
u4b8
:
"vllm::kU4B8"
,
VLLMDataType
.
u8b128
:
"vllm::kU8B128"
,
DataType
.
u4
:
"vllm::kU4"
,
...
...
@@ -57,7 +56,7 @@ VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = {
DataType
.
bf16
:
"vllm::kBfloat16"
,
}
VLLMDataTypeTorchDataTypeTag
:
dict
[
Union
[
VLLMDataType
,
DataType
]
,
str
]
=
{
VLLMDataTypeTorchDataTypeTag
:
dict
[
VLLMDataType
|
DataType
,
str
]
=
{
DataType
.
u8
:
"at::ScalarType::Byte"
,
DataType
.
s8
:
"at::ScalarType::Char"
,
DataType
.
e4m3
:
"at::ScalarType::Float8_e4m3fn"
,
...
...
@@ -67,15 +66,11 @@ VLLMDataTypeTorchDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
DataType
.
f32
:
"at::ScalarType::Float"
,
}
VLLMKernelScheduleTag
:
dict
[
Union
[
MixedInputKernelScheduleType
,
KernelScheduleType
],
str
]
=
{
**
KernelScheduleTag
,
# type: ignore
**
{
MixedInputKernelScheduleType
.
TmaWarpSpecialized
:
"cutlass::gemm::KernelTmaWarpSpecialized"
,
MixedInputKernelScheduleType
.
TmaWarpSpecializedPingpong
:
"cutlass::gemm::KernelTmaWarpSpecializedPingpong"
,
MixedInputKernelScheduleType
.
TmaWarpSpecializedCooperative
:
"cutlass::gemm::KernelTmaWarpSpecializedCooperative"
,
}
}
VLLMKernelScheduleTag
:
dict
[
MixedInputKernelScheduleType
|
KernelScheduleType
,
str
]
=
{
**
KernelScheduleTag
,
# type: ignore
**
{
MixedInputKernelScheduleType
.
TmaWarpSpecialized
:
"cutlass::gemm::KernelTmaWarpSpecialized"
,
# noqa: E501
MixedInputKernelScheduleType
.
TmaWarpSpecializedPingpong
:
"cutlass::gemm::KernelTmaWarpSpecializedPingpong"
,
# noqa: E501
MixedInputKernelScheduleType
.
TmaWarpSpecializedCooperative
:
"cutlass::gemm::KernelTmaWarpSpecializedCooperative"
,
# noqa: E501
},
}
csrc/dispatch_utils.h
View file @
41199996
...
...
@@ -88,3 +88,53 @@
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_VEC_SIZE(VEC_SIZE, ...) \
switch (VEC_SIZE) { \
case 16: { \
constexpr int vec_size = 16; \
__VA_ARGS__(); \
break; \
} \
case 8: { \
constexpr int vec_size = 8; \
__VA_ARGS__(); \
break; \
} \
case 4: { \
constexpr int vec_size = 4; \
__VA_ARGS__(); \
break; \
} \
case 2: { \
constexpr int vec_size = 2; \
__VA_ARGS__(); \
break; \
} \
default: { \
constexpr int vec_size = 1; \
__VA_ARGS__(); \
break; \
} \
}
#define VLLM_DISPATCH_RANK234(NUM_DIMS, ...) \
switch (NUM_DIMS) { \
case 2: { \
constexpr int tensor_rank = 2; \
__VA_ARGS__(); \
break; \
} \
case 3: { \
constexpr int tensor_rank = 3; \
__VA_ARGS__(); \
break; \
} \
case 4: { \
constexpr int tensor_rank = 4; \
__VA_ARGS__(); \
break; \
} \
default: \
TORCH_CHECK(false, "Expects rank 2, 3 or 4 tensors but got ", NUM_DIMS); \
}
csrc/fused_qknorm_rope_kernel.cu
0 → 100644
View file @
41199996
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cmath>
#include <cuda_runtime.h>
#include <type_traits>
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "type_convert.cuh"
#define CHECK_TYPE(x, st) \
TORCH_CHECK(x.scalar_type() == st, #x " dtype is ", x.scalar_type(), \
", while ", st, " is expected")
#define CHECK_TH_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_TH_CUDA(x); \
CHECK_CONTIGUOUS(x)
#ifdef USE_ROCM
#define FINAL_MASK 0xffffffffffffffffULL
#if defined(HIP_VERSION) && HIP_VERSION < 70000000
// On ROCm versions before 7.0, __syncwarp isn't defined. The below
// implementation is copy/pasted from the implementation in ROCm 7.0
__device__
inline
void
__syncwarp
()
{
__builtin_amdgcn_fence
(
__ATOMIC_RELEASE
,
"wavefront"
);
__builtin_amdgcn_wave_barrier
();
__builtin_amdgcn_fence
(
__ATOMIC_ACQUIRE
,
"wavefront"
);
}
#endif
#else
#define FINAL_MASK 0xffffffff
#endif
namespace
tensorrt_llm
::
common
{
template
<
typename
T
,
int
num
>
struct
packed_as
;
// Specialization for packed_as used in this kernel.
template
<
>
struct
packed_as
<
uint
,
1
>
{
using
type
=
uint
;
};
template
<
>
struct
packed_as
<
uint
,
2
>
{
using
type
=
uint2
;
};
template
<
>
struct
packed_as
<
uint
,
4
>
{
using
type
=
uint4
;
};
template
<
typename
T
>
__inline__
__device__
T
warpReduceSum
(
T
val
)
{
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
val
+=
__shfl_xor_sync
(
FINAL_MASK
,
val
,
mask
,
32
);
return
val
;
}
template
<
typename
T
>
inline
__device__
__host__
T
divUp
(
T
m
,
T
n
)
{
return
(
m
+
n
-
1
)
/
n
;
}
}
// namespace tensorrt_llm::common
namespace
tensorrt_llm
::
kernels
{
// NOTE(zhuhaoran): This kernel is adapted from TensorRT-LLM implementation,
// with added support for passing the cos_sin_cache as an input.
// https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu
// Perform per-head QK Norm and RoPE in a single kernel.
// scalar_t_in: data type of QKV and RMSNorm weights
// scalar_t_cache: data type of cos/sin cache
// head_dim: the dimension of each head
// interleave: interleave=!is_neox.
template
<
typename
scalar_t_in
,
typename
scalar_t_cache
,
int
head_dim
,
bool
interleave
>
__global__
void
fusedQKNormRopeKernel
(
void
*
qkv_void
,
// Combined QKV tensor
int
const
num_heads_q
,
// Number of query heads
int
const
num_heads_k
,
// Number of key heads
int
const
num_heads_v
,
// Number of value heads
float
const
eps
,
// Epsilon for RMS normalization
void
const
*
q_weight_void
,
// RMSNorm weights for query
void
const
*
k_weight_void
,
// RMSNorm weights for key
void
const
*
cos_sin_cache_void
,
// Pre-computed cos/sin cache
int64_t
const
*
position_ids
,
// Position IDs for RoPE
int
const
num_tokens
// Number of tokens
)
{
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
if
constexpr
((
std
::
is_same_v
<
scalar_t_in
,
c10
::
BFloat16
>
)
||
std
::
is_same_v
<
scalar_t_cache
,
c10
::
BFloat16
>
)
{
return
;
}
else
{
#endif
using
Converter
=
vllm
::
_typeConvert
<
scalar_t_in
>
;
static_assert
(
Converter
::
exists
,
"Input QKV data type is not supported for this CUDA "
"architecture or toolkit version."
);
using
T_in
=
typename
Converter
::
hip_type
;
using
T2_in
=
typename
Converter
::
packed_hip_type
;
using
CacheConverter
=
vllm
::
_typeConvert
<
scalar_t_cache
>
;
static_assert
(
CacheConverter
::
exists
,
"Cache data type is not supported for this CUDA architecture "
"or toolkit version."
);
using
T_cache
=
typename
CacheConverter
::
hip_type
;
T_in
*
qkv
=
reinterpret_cast
<
T_in
*>
(
qkv_void
);
T_in
const
*
q_weight
=
reinterpret_cast
<
T_in
const
*>
(
q_weight_void
);
T_in
const
*
k_weight
=
reinterpret_cast
<
T_in
const
*>
(
k_weight_void
);
T_cache
const
*
cos_sin_cache
=
reinterpret_cast
<
T_cache
const
*>
(
cos_sin_cache_void
);
int
const
warpsPerBlock
=
blockDim
.
x
/
32
;
int
const
warpId
=
threadIdx
.
x
/
32
;
int
const
laneId
=
threadIdx
.
x
%
32
;
// Calculate global warp index to determine which head/token this warp
// processes
int
const
globalWarpIdx
=
blockIdx
.
x
*
warpsPerBlock
+
warpId
;
// Total number of attention heads (Q and K)
int
const
total_qk_heads
=
num_heads_q
+
num_heads_k
;
// Determine which token and head type (Q or K) this warp processes
int
const
tokenIdx
=
globalWarpIdx
/
total_qk_heads
;
int
const
localHeadIdx
=
globalWarpIdx
%
total_qk_heads
;
// Skip if this warp is assigned beyond the number of tokens
if
(
tokenIdx
>=
num_tokens
)
return
;
bool
const
isQ
=
localHeadIdx
<
num_heads_q
;
int
const
headIdx
=
isQ
?
localHeadIdx
:
localHeadIdx
-
num_heads_q
;
int
const
num_heads
=
num_heads_q
+
num_heads_k
+
num_heads_v
;
static_assert
(
head_dim
%
(
32
*
2
)
==
0
,
"head_dim must be divisible by 64 (each warp processes one "
"head, and each thread gets even number of "
"elements)"
);
constexpr
int
numElemsPerThread
=
head_dim
/
32
;
float
elements
[
numElemsPerThread
];
constexpr
int
elemSizeBytes
=
numElemsPerThread
*
sizeof
(
__nv_bfloat16
);
static_assert
(
elemSizeBytes
%
4
==
0
,
"numSizeBytes must be a multiple of 4"
);
constexpr
int
vecSize
=
elemSizeBytes
/
4
;
// Use packed_as<uint, vecSize> to perform loading/saving.
using
vec_T
=
typename
tensorrt_llm
::
common
::
packed_as
<
uint
,
vecSize
>::
type
;
int
offsetWarp
;
// Offset for the warp
if
(
isQ
)
{
// Q segment: token offset + head offset within Q segment
offsetWarp
=
tokenIdx
*
num_heads
*
head_dim
+
headIdx
*
head_dim
;
}
else
{
// K segment: token offset + entire Q segment + head offset within K
// segment
offsetWarp
=
tokenIdx
*
num_heads
*
head_dim
+
num_heads_q
*
head_dim
+
headIdx
*
head_dim
;
}
int
offsetThread
=
offsetWarp
+
laneId
*
numElemsPerThread
;
// Sum of squares for RMSNorm
float
sumOfSquares
=
0.0
f
;
// Load.
{
vec_T
vec
=
*
reinterpret_cast
<
vec_T
const
*>
(
&
qkv
[
offsetThread
]);
constexpr
int
num_packed_elems
=
elemSizeBytes
/
sizeof
(
T2_in
);
#pragma unroll
for
(
int
i
=
0
;
i
<
num_packed_elems
;
i
++
)
{
// Interpret the generic vector chunk as the specific packed type
T2_in
packed_val
=
*
(
reinterpret_cast
<
T2_in
*>
(
&
vec
)
+
i
);
// Convert to float2 for computation
float2
vals
=
Converter
::
convert
(
packed_val
);
sumOfSquares
+=
vals
.
x
*
vals
.
x
;
sumOfSquares
+=
vals
.
y
*
vals
.
y
;
elements
[
2
*
i
]
=
vals
.
x
;
elements
[
2
*
i
+
1
]
=
vals
.
y
;
}
}
// Reduce sum across warp using the utility function
sumOfSquares
=
tensorrt_llm
::
common
::
warpReduceSum
(
sumOfSquares
);
// Compute RMS normalization factor
float
rms_rcp
=
rsqrtf
(
sumOfSquares
/
static_cast
<
float
>
(
head_dim
)
+
eps
);
// Normalize elements
#pragma unroll
for
(
int
i
=
0
;
i
<
numElemsPerThread
;
i
++
)
{
int
dim
=
laneId
*
numElemsPerThread
+
i
;
float
weight
=
isQ
?
Converter
::
convert
(
q_weight
[
dim
])
:
Converter
::
convert
(
k_weight
[
dim
]);
elements
[
i
]
*=
rms_rcp
*
weight
;
}
// Apply RoPE to normalized elements
float
elements2
[
numElemsPerThread
];
// Additional buffer required for RoPE.
int64_t
pos_id
=
position_ids
[
tokenIdx
];
// Calculate cache pointer for this position - similar to
// pos_encoding_kernels.cu
T_cache
const
*
cache_ptr
=
cos_sin_cache
+
pos_id
*
head_dim
;
int
const
embed_dim
=
head_dim
/
2
;
T_cache
const
*
cos_ptr
=
cache_ptr
;
T_cache
const
*
sin_ptr
=
cache_ptr
+
embed_dim
;
if
constexpr
(
interleave
)
{
// Perform interleaving. Use pre-computed cos/sin values.
#pragma unroll
for
(
int
i
=
0
;
i
<
numElemsPerThread
/
2
;
++
i
)
{
int
const
idx0
=
2
*
i
;
int
const
idx1
=
2
*
i
+
1
;
float
const
val0
=
elements
[
idx0
];
float
const
val1
=
elements
[
idx1
];
int
const
dim_idx
=
laneId
*
numElemsPerThread
+
idx0
;
int
const
half_dim
=
dim_idx
/
2
;
float
const
cos_val
=
CacheConverter
::
convert
(
VLLM_LDG
(
cos_ptr
+
half_dim
));
float
const
sin_val
=
CacheConverter
::
convert
(
VLLM_LDG
(
sin_ptr
+
half_dim
));
elements
[
idx0
]
=
val0
*
cos_val
-
val1
*
sin_val
;
elements
[
idx1
]
=
val0
*
sin_val
+
val1
*
cos_val
;
}
}
else
{
// Before data exchange with in warp, we need to sync.
__syncwarp
();
// Get the data from the other half of the warp. Use pre-computed cos/sin
// values.
#pragma unroll
for
(
int
i
=
0
;
i
<
numElemsPerThread
;
i
++
)
{
elements2
[
i
]
=
__shfl_xor_sync
(
FINAL_MASK
,
elements
[
i
],
16
);
if
(
laneId
<
16
)
{
elements2
[
i
]
=
-
elements2
[
i
];
}
int
dim_idx
=
laneId
*
numElemsPerThread
+
i
;
dim_idx
=
(
dim_idx
*
2
)
%
head_dim
;
int
half_dim
=
dim_idx
/
2
;
// Use pre-computed cos/sin from cache
float
cos_val
=
CacheConverter
::
convert
(
VLLM_LDG
(
cos_ptr
+
half_dim
));
float
sin_val
=
CacheConverter
::
convert
(
VLLM_LDG
(
sin_ptr
+
half_dim
));
elements
[
i
]
=
elements
[
i
]
*
cos_val
+
elements2
[
i
]
*
sin_val
;
}
// __shfl_xor_sync does not provide memfence. Need to sync again.
__syncwarp
();
}
// Store.
{
vec_T
vec
;
constexpr
int
num_packed_elems
=
elemSizeBytes
/
sizeof
(
T2_in
);
#pragma unroll
for
(
int
i
=
0
;
i
<
num_packed_elems
;
i
++
)
{
// Convert from float2 back to the specific packed type
T2_in
packed_val
=
Converter
::
convert
(
make_float2
(
elements
[
2
*
i
],
elements
[
2
*
i
+
1
]));
// Place it into the generic vector
*
(
reinterpret_cast
<
T2_in
*>
(
&
vec
)
+
i
)
=
packed_val
;
}
*
reinterpret_cast
<
vec_T
*>
(
&
qkv
[
offsetThread
])
=
vec
;
}
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
}
#endif
}
// Borrowed from
// https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
if (interleave) { \
const bool INTERLEAVE = true; \
__VA_ARGS__ \
} else { \
const bool INTERLEAVE = false; \
__VA_ARGS__ \
}
template
<
typename
scalar_t_in
,
typename
scalar_t_cache
>
void
launchFusedQKNormRope
(
void
*
qkv
,
int
const
num_tokens
,
int
const
num_heads_q
,
int
const
num_heads_k
,
int
const
num_heads_v
,
int
const
head_dim
,
float
const
eps
,
void
const
*
q_weight
,
void
const
*
k_weight
,
void
const
*
cos_sin_cache
,
bool
const
interleave
,
int64_t
const
*
position_ids
,
cudaStream_t
stream
)
{
constexpr
int
blockSize
=
256
;
int
const
warpsPerBlock
=
blockSize
/
32
;
int
const
totalQKHeads
=
num_heads_q
+
num_heads_k
;
int
const
totalWarps
=
num_tokens
*
totalQKHeads
;
int
const
gridSize
=
common
::
divUp
(
totalWarps
,
warpsPerBlock
);
dim3
gridDim
(
gridSize
);
dim3
blockDim
(
blockSize
);
switch
(
head_dim
)
{
case
64
:
DISPATCH_INTERLEAVE
(
interleave
,
INTERLEAVE
,
{
fusedQKNormRopeKernel
<
scalar_t_in
,
scalar_t_cache
,
64
,
INTERLEAVE
>
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
qkv
,
num_heads_q
,
num_heads_k
,
num_heads_v
,
eps
,
q_weight
,
k_weight
,
cos_sin_cache
,
position_ids
,
num_tokens
);
});
break
;
case
128
:
DISPATCH_INTERLEAVE
(
interleave
,
INTERLEAVE
,
{
fusedQKNormRopeKernel
<
scalar_t_in
,
scalar_t_cache
,
128
,
INTERLEAVE
>
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
qkv
,
num_heads_q
,
num_heads_k
,
num_heads_v
,
eps
,
q_weight
,
k_weight
,
cos_sin_cache
,
position_ids
,
num_tokens
);
});
break
;
case
256
:
DISPATCH_INTERLEAVE
(
interleave
,
INTERLEAVE
,
{
fusedQKNormRopeKernel
<
scalar_t_in
,
scalar_t_cache
,
256
,
INTERLEAVE
>
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
qkv
,
num_heads_q
,
num_heads_k
,
num_heads_v
,
eps
,
q_weight
,
k_weight
,
cos_sin_cache
,
position_ids
,
num_tokens
);
});
break
;
default:
TORCH_CHECK
(
false
,
"Unsupported head dimension for fusedQKNormRope: "
,
head_dim
);
}
}
}
// namespace tensorrt_llm::kernels
void
fused_qk_norm_rope
(
torch
::
Tensor
&
qkv
,
// Combined QKV tensor [num_tokens,
// (num_heads_q+num_heads_k+num_heads_v)*head_dim]
int64_t
num_heads_q
,
// Number of query heads
int64_t
num_heads_k
,
// Number of key heads
int64_t
num_heads_v
,
// Number of value heads
int64_t
head_dim
,
// Dimension per head
double
eps
,
// Epsilon for RMS normalization
torch
::
Tensor
&
q_weight
,
// RMSNorm weights for query [head_dim]
torch
::
Tensor
&
k_weight
,
// RMSNorm weights for key [head_dim]
torch
::
Tensor
&
cos_sin_cache
,
// Cos/sin cache [max_position, head_dim]
bool
is_neox
,
// Whether RoPE is applied in Neox style
torch
::
Tensor
&
position_ids
// Position IDs for RoPE [num_tokens]
)
{
// Input validation
CHECK_INPUT
(
qkv
);
CHECK_INPUT
(
position_ids
);
CHECK_INPUT
(
q_weight
);
CHECK_INPUT
(
k_weight
);
CHECK_INPUT
(
cos_sin_cache
);
CHECK_TYPE
(
position_ids
,
torch
::
kInt64
);
TORCH_CHECK
(
qkv
.
dim
()
==
2
,
"QKV tensor must be 2D: [num_tokens, "
"(num_heads_q+num_heads_k+num_heads_v)*head_dim]"
);
TORCH_CHECK
(
position_ids
.
dim
()
==
1
,
"Position IDs must be 1D: [num_tokens]"
);
TORCH_CHECK
(
q_weight
.
dim
()
==
1
,
"Query weights must be 1D: [head_dim]"
);
TORCH_CHECK
(
k_weight
.
dim
()
==
1
,
"Key weights must be 1D: [head_dim]"
);
TORCH_CHECK
(
cos_sin_cache
.
dim
()
==
2
,
"Cos/sin cache must be 2D: [max_position, head_dim]"
);
TORCH_CHECK
(
q_weight
.
size
(
0
)
==
head_dim
,
"Query weights size must match head dimension"
);
TORCH_CHECK
(
k_weight
.
size
(
0
)
==
head_dim
,
"Key weights size must match head dimension"
);
TORCH_CHECK
(
cos_sin_cache
.
size
(
1
)
==
head_dim
,
"Cos/sin cache dimension must match head_dim"
);
TORCH_CHECK
(
qkv
.
scalar_type
()
==
q_weight
.
scalar_type
()
&&
qkv
.
scalar_type
()
==
k_weight
.
scalar_type
(),
"qkv, q_weight and k_weight must have the same dtype"
);
int64_t
num_tokens
=
qkv
.
size
(
0
);
TORCH_CHECK
(
position_ids
.
size
(
0
)
==
num_tokens
,
"Number of tokens in position_ids must match QKV"
);
int64_t
total_heads
=
num_heads_q
+
num_heads_k
+
num_heads_v
;
TORCH_CHECK
(
qkv
.
size
(
1
)
==
total_heads
*
head_dim
,
"QKV tensor size must match total number of heads and head dimension"
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
qkv
.
get_device
());
VLLM_DISPATCH_HALF_TYPES
(
qkv
.
scalar_type
(),
"fused_qk_norm_rope_kernel"
,
[
&
]
{
using
qkv_scalar_t
=
scalar_t
;
VLLM_DISPATCH_FLOATING_TYPES
(
cos_sin_cache
.
scalar_type
(),
"fused_qk_norm_rope_kernel"
,
[
&
]
{
using
cache_scalar_t
=
scalar_t
;
tensorrt_llm
::
kernels
::
launchFusedQKNormRope
<
qkv_scalar_t
,
cache_scalar_t
>
(
qkv
.
data_ptr
(),
static_cast
<
int
>
(
num_tokens
),
static_cast
<
int
>
(
num_heads_q
),
static_cast
<
int
>
(
num_heads_k
),
static_cast
<
int
>
(
num_heads_v
),
static_cast
<
int
>
(
head_dim
),
static_cast
<
float
>
(
eps
),
q_weight
.
data_ptr
(),
k_weight
.
data_ptr
(),
cos_sin_cache
.
data_ptr
(),
!
is_neox
,
reinterpret_cast
<
int64_t
const
*>
(
position_ids
.
data_ptr
()),
stream
);
});
});
}
\ No newline at end of file
csrc/launch_bounds_utils.h
View file @
41199996
...
...
@@ -8,11 +8,37 @@
#define VLLM_LAUNCH_BLOCKS_CAP 4
#endif
// compile-time estimate of max threads per SM for launch bounds.
// Compile-time estimate of max threads per SM for launch bounds.
// Families: 1024, 1536, 2048 threads/SM.
#ifndef VLLM_MAX_THREADS_PER_SM
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 300
#define VLLM_MAX_THREADS_PER_SM 1536
#ifdef __CUDA_ARCH__
/* 1024 thr/SM: Turing (sm_75) */
#if (__CUDA_ARCH__ == 750)
#define VLLM_MAX_THREADS_PER_SM 1024
/* 1536 thr/SM: Ampere GA10x (sm_86/87), Ada (sm_89),
GB20x consumer (sm_120/121), Thor (sm_101 or sm_110) */
#elif (__CUDA_ARCH__ == 860) || (__CUDA_ARCH__ == 870) || \
(__CUDA_ARCH__ == 890) || (__CUDA_ARCH__ == 1010) || \
(__CUDA_ARCH__ == 1100) || (__CUDA_ARCH__ == 1200) || \
(__CUDA_ARCH__ == 1210)
#define VLLM_MAX_THREADS_PER_SM 1536
/* 2048 thr/SM: Volta (sm_70/72), Ampere GA100 (sm_80),
Hopper (sm_90), Blackwell (sm_100/103) */
#elif (__CUDA_ARCH__ == 700) || (__CUDA_ARCH__ == 720) || \
(__CUDA_ARCH__ == 800) || (__CUDA_ARCH__ == 900) || \
(__CUDA_ARCH__ == 1000) || (__CUDA_ARCH__ == 1030)
#define VLLM_MAX_THREADS_PER_SM 2048
/* Fallback: use 2048 for unknown future CCs */
#else
#define VLLM_MAX_THREADS_PER_SM 2048
#endif
#else
/* Host pass (no __CUDA_ARCH__): neutral default */
#define VLLM_MAX_THREADS_PER_SM 2048
#endif
#endif
...
...
csrc/layernorm_kernels.cu
View file @
41199996
#include "type_convert.cuh"
#include "dispatch_utils.h"
#include "cub_helpers.h"
#include "core/batch_invariant.hpp"
#include "quantization/vectorization_utils.cuh"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
...
...
@@ -8,20 +10,52 @@
namespace
vllm
{
// TODO(woosuk): Further optimize this kernel.
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
int
VEC_SIZE
,
int
NUM_DIMS
>
__global__
void
rms_norm_kernel
(
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
int64_t
input_stride
,
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
int64_t
input_stride_d2
,
// input.stride(-2)
const
int64_t
input_stride_d3
,
// input.stride(-3)
const
int64_t
input_stride_d4
,
// input.stride(-4)
const
int64_t
input_shape_d2
,
// input.size(-2)
const
int64_t
input_shape_d3
,
// input.size(-3)
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
const
scalar_t
*
input_row
;
if
constexpr
(
NUM_DIMS
==
2
)
{
// 2D for layernorm normal case [batch_size, hidden]
input_row
=
input
+
blockIdx
.
x
*
input_stride_d2
;
}
else
if
constexpr
(
NUM_DIMS
==
3
)
{
// 3D for q/k norm [batch_size, num_heads, head_size]
int
batch_idx
=
blockIdx
.
x
/
input_shape_d2
;
int
head_idx
=
blockIdx
.
x
%
input_shape_d2
;
input_row
=
input
+
batch_idx
*
input_stride_d3
+
head_idx
*
input_stride_d2
;
}
else
if
constexpr
(
NUM_DIMS
==
4
)
{
// 4D for transformers model_impl qk norm [batch, seq, head, head_dim]
int
batch_idx
=
blockIdx
.
x
/
(
input_shape_d3
*
input_shape_d2
);
int
remaining
=
blockIdx
.
x
%
(
input_shape_d3
*
input_shape_d2
);
int
seq_idx
=
remaining
/
input_shape_d2
;
int
head_idx
=
remaining
%
input_shape_d2
;
input_row
=
input
+
batch_idx
*
input_stride_d4
+
seq_idx
*
input_stride_d3
+
head_idx
*
input_stride_d2
;
}
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
input_stride
+
idx
];
auto
vec_op
=
[
&
variance
](
const
vec_n_t
<
scalar_t
,
VEC_SIZE
>&
vec
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
float
x
=
static_cast
<
float
>
(
vec
.
val
[
i
]);
variance
+=
x
*
x
;
}
};
auto
scalar_op
=
[
&
variance
](
const
scalar_t
&
val
)
{
float
x
=
static_cast
<
float
>
(
val
);
variance
+=
x
*
x
;
}
};
vllm
::
vectorize_read_with_alignment
<
VEC_SIZE
>
(
input_row
,
hidden_size
,
threadIdx
.
x
,
blockDim
.
x
,
vec_op
,
scalar_op
);
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
...
...
@@ -32,10 +66,20 @@ __global__ void rms_norm_kernel(
}
__syncthreads
();
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
input
[
blockIdx
.
x
*
input_stride
+
idx
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
scalar_t
*
out_row
=
out
+
blockIdx
.
x
*
hidden_size
;
auto
*
v_in
=
reinterpret_cast
<
const
vec_n_t
<
scalar_t
,
VEC_SIZE
>*>
(
input_row
);
auto
*
v_w
=
reinterpret_cast
<
const
vec_n_t
<
scalar_t
,
VEC_SIZE
>*>
(
weight
);
auto
*
v_out
=
reinterpret_cast
<
vec_n_t
<
scalar_t
,
VEC_SIZE
>*>
(
out_row
);
for
(
int
i
=
threadIdx
.
x
;
i
<
hidden_size
/
VEC_SIZE
;
i
+=
blockDim
.
x
)
{
vec_n_t
<
scalar_t
,
VEC_SIZE
>
dst
;
vec_n_t
<
scalar_t
,
VEC_SIZE
>
src1
=
v_in
[
i
];
vec_n_t
<
scalar_t
,
VEC_SIZE
>
src2
=
v_w
[
i
];
#pragma unroll
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
j
++
)
{
float
x
=
static_cast
<
float
>
(
src1
.
val
[
j
]);
dst
.
val
[
j
]
=
((
scalar_t
)(
x
*
s_variance
))
*
src2
.
val
[
j
];
}
v_out
[
i
]
=
dst
;
}
}
...
...
@@ -135,211 +179,6 @@ fused_add_rms_norm_kernel(
}
}
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck.
_f16VecPN struct extends _f16Vec to add operations specifically required for
polynomial normalization (poly norm).
The original _f16Vec does not include the sum-of-powers computation or
in-place polynomial normalization logic. */
template
<
typename
scalar_t
,
int
width
>
struct
alignas
(
16
)
_f16VecPN
:
_f16Vec
<
scalar_t
,
width
>
{
using
Base
=
_f16Vec
<
scalar_t
,
width
>
;
using
Converter
=
typename
Base
::
Converter
;
using
T1
=
typename
Base
::
T1
;
using
T2
=
typename
Base
::
T2
;
using
Base
::
data
;
__device__
auto
sum_pows
()
const
{
float
s2
=
0.0
f
,
s4
=
0.0
f
,
s6
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
float2
z
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
float
x2
=
z
.
x
*
z
.
x
;
float
x4
=
x2
*
x2
;
float
x6
=
x4
*
x2
;
float
y2
=
z
.
y
*
z
.
y
;
float
y4
=
y2
*
y2
;
float
y6
=
y4
*
y2
;
s2
+=
x2
+
y2
;
s4
+=
x4
+
y4
;
s6
+=
x6
+
y6
;
}
return
std
::
make_tuple
(
s2
,
s4
,
s6
);
}
__device__
void
poly_norm_inplace
(
const
float
w2_inv_std
,
const
float
w1_inv_std2
,
const
float
w0_inv_std3
,
const
float
bias
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
float2
z
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
float
x2
=
z
.
x
*
z
.
x
;
float
x3
=
x2
*
z
.
x
;
z
.
x
=
w2_inv_std
*
z
.
x
+
w1_inv_std2
*
x2
+
w0_inv_std3
*
x3
+
bias
;
float
y2
=
z
.
y
*
z
.
y
;
float
y3
=
y2
*
z
.
y
;
z
.
y
=
w2_inv_std
*
z
.
y
+
w1_inv_std2
*
y2
+
w0_inv_std3
*
y3
+
bias
;
auto
out
=
Converter
::
convert
(
z
);
data
[
i
]
=
out
.
x
;
data
[
i
+
1
]
=
out
.
y
;
}
}
};
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
poly_norm_kernel
(
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [3]
const
scalar_t
*
__restrict__
bias
,
// [1]
const
float
epsilon
,
const
int
hidden_size
)
{
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert
(
std
::
is_pod_v
<
_f16VecPN
<
scalar_t
,
width
>>
);
static_assert
(
sizeof
(
_f16VecPN
<
scalar_t
,
width
>
)
==
sizeof
(
scalar_t
)
*
width
);
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto
*
__restrict__
input_v
=
reinterpret_cast
<
const
_f16VecPN
<
scalar_t
,
width
>*>
(
input
);
const
int
vec_hidden_size
=
hidden_size
/
width
;
float
variance
=
0.0
f
;
float
variance2
=
0.0
f
;
float
variance3
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
_f16VecPN
<
scalar_t
,
width
>
temp
=
input_v
[
id
];
auto
[
x2
,
x4
,
x6
]
=
temp
.
sum_pows
();
variance
+=
x2
;
variance2
+=
x4
;
variance3
+=
x6
;
}
float3
thread_variances
=
make_float3
(
variance
,
variance2
,
variance3
);
struct
SumOp
{
__device__
float3
operator
()(
const
float3
&
a
,
const
float3
&
b
)
const
{
return
make_float3
(
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
,
a
.
z
+
b
.
z
);
}
};
using
BlockReduce
=
cub
::
BlockReduce
<
float3
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
float3
block_variances
=
BlockReduce
(
reduceStore
).
Reduce
(
thread_variances
,
SumOp
{},
blockDim
.
x
);
variance
=
block_variances
.
x
;
variance2
=
block_variances
.
y
;
variance3
=
block_variances
.
z
;
__shared__
float
s_w2_inv_std
;
__shared__
float
s_w1_inv_std2
;
__shared__
float
s_w0_inv_std3
;
__shared__
float
s_bias
;
if
(
threadIdx
.
x
==
0
)
{
float
w0
=
(
float
)
weight
[
0
];
float
w1
=
(
float
)
weight
[
1
];
float
w2
=
(
float
)
weight
[
2
];
s_bias
=
(
float
)
bias
[
0
];
s_w2_inv_std
=
w2
*
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
s_w1_inv_std2
=
w1
*
rsqrtf
(
variance2
/
hidden_size
+
epsilon
);
s_w0_inv_std3
=
w0
*
rsqrtf
(
variance3
/
hidden_size
+
epsilon
);
}
__syncthreads
();
auto
*
__restrict__
out_v
=
reinterpret_cast
<
_f16VecPN
<
scalar_t
,
width
>*>
(
out
);
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
_f16VecPN
<
scalar_t
,
width
>
temp
=
input_v
[
id
];
temp
.
poly_norm_inplace
(
s_w2_inv_std
,
s_w1_inv_std2
,
s_w0_inv_std3
,
s_bias
);
out_v
[
id
]
=
temp
;
}
}
/* Generic poly_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
poly_norm_kernel
(
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [3]
const
scalar_t
*
__restrict__
bias
,
// [1]
const
float
epsilon
,
const
int
hidden_size
)
{
float
variance
=
0.0
f
;
float
variance2
=
0.0
f
;
float
variance3
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
x2
=
x
*
x
;
float
x4
=
x2
*
x2
;
float
x6
=
x4
*
x2
;
variance
+=
x2
;
variance2
+=
x4
;
variance3
+=
x6
;
}
float3
thread_variances
=
make_float3
(
variance
,
variance2
,
variance3
);
struct
SumOp
{
__device__
float3
operator
()(
const
float3
&
a
,
const
float3
&
b
)
const
{
return
make_float3
(
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
,
a
.
z
+
b
.
z
);
}
};
using
BlockReduce
=
cub
::
BlockReduce
<
float3
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
float3
block_variances
=
BlockReduce
(
reduceStore
).
Reduce
(
thread_variances
,
SumOp
{},
blockDim
.
x
);
variance
=
block_variances
.
x
;
variance2
=
block_variances
.
y
;
variance3
=
block_variances
.
z
;
__shared__
float
s_w2_inv_std
;
__shared__
float
s_w1_inv_std2
;
__shared__
float
s_w0_inv_std3
;
__shared__
float
s_bias
;
if
(
threadIdx
.
x
==
0
)
{
float
w0
=
(
float
)
weight
[
0
];
float
w1
=
(
float
)
weight
[
1
];
float
w2
=
(
float
)
weight
[
2
];
s_bias
=
(
float
)
bias
[
0
];
s_w2_inv_std
=
w2
*
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
s_w1_inv_std2
=
w1
*
rsqrtf
(
variance2
/
hidden_size
+
epsilon
);
s_w0_inv_std3
=
w0
*
rsqrtf
(
variance3
/
hidden_size
+
epsilon
);
}
__syncthreads
();
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
x2
=
x
*
x
;
float
x3
=
x2
*
x
;
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
(
scalar_t
)(
x
*
s_w2_inv_std
+
x2
*
s_w1_inv_std2
+
x3
*
s_w0_inv_std3
+
s_bias
);
}
}
}
// namespace vllm
void
rms_norm
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
...
...
@@ -347,21 +186,43 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
double
epsilon
)
{
TORCH_CHECK
(
out
.
is_contiguous
());
if
(
input
.
stride
(
-
1
)
!=
1
)
{
input
=
input
.
contiguous
();
}
TORCH_CHECK
(
input
.
stride
(
-
1
)
==
1
);
TORCH_CHECK
(
weight
.
is_contiguous
());
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
int64_t
input_stride
=
input
.
stride
(
-
2
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
int
num_dims
=
input
.
dim
();
int64_t
input_stride_d2
=
input
.
stride
(
-
2
);
int64_t
input_stride_d3
=
(
num_dims
>=
3
)
?
input
.
stride
(
-
3
)
:
0
;
int64_t
input_stride_d4
=
(
num_dims
>=
4
)
?
input
.
stride
(
-
4
)
:
0
;
int64_t
input_shape_d2
=
(
num_dims
>=
3
)
?
input
.
size
(
-
2
)
:
0
;
int64_t
input_shape_d3
=
(
num_dims
>=
4
)
?
input
.
size
(
-
3
)
:
0
;
// For large num_tokens, use smaller blocks to increase SM concurrency.
const
int
max_block_size
=
(
num_tokens
<
256
)
?
1024
:
256
;
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel"
,
[
&
]
{
vllm
::
rms_norm_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
input_stride
,
weight
.
data_ptr
<
scalar_t
>
(),
epsilon
,
num_tokens
,
hidden_size
);
VLLM_DISPATCH_RANK234
(
num_dims
,
[
&
]
{
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel"
,
[
&
]
{
const
int
calculated_vec_size
=
std
::
gcd
(
16
/
sizeof
(
scalar_t
),
hidden_size
);
const
int
block_size
=
std
::
min
(
hidden_size
/
calculated_vec_size
,
max_block_size
);
dim3
block
(
block_size
);
VLLM_DISPATCH_VEC_SIZE
(
calculated_vec_size
,
[
&
]
{
vllm
::
rms_norm_kernel
<
scalar_t
,
vec_size
,
tensor_rank
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
input_stride_d2
,
input_stride_d3
,
input_stride_d4
,
input_shape_d2
,
input_shape_d3
,
weight
.
data_ptr
<
scalar_t
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
});
});
}
...
...
@@ -379,6 +240,8 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
torch
::
Tensor
&
residual
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
double
epsilon
)
{
TORCH_CHECK
(
weight
.
scalar_type
()
==
input
.
scalar_type
());
TORCH_CHECK
(
input
.
scalar_type
()
==
residual
.
scalar_type
());
TORCH_CHECK
(
residual
.
is_contiguous
());
TORCH_CHECK
(
weight
.
is_contiguous
());
int
hidden_size
=
input
.
size
(
-
1
);
...
...
@@ -413,55 +276,11 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
wt_ptr
%
req_alignment_bytes
==
0
;
bool
offsets_are_multiple_of_vector_width
=
hidden_size
%
vector_width
==
0
&&
input_stride
%
vector_width
==
0
;
if
(
ptrs_are_aligned
&&
offsets_are_multiple_of_vector_width
)
{
bool
batch_invariant_launch
=
vllm
::
vllm_is_batch_invariant
();
if
(
ptrs_are_aligned
&&
offsets_are_multiple_of_vector_width
&&
!
batch_invariant_launch
)
{
LAUNCH_FUSED_ADD_RMS_NORM
(
8
);
}
else
{
LAUNCH_FUSED_ADD_RMS_NORM
(
0
);
}
}
#define LAUNCH_FUSED_POLY_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \
vllm::poly_norm_kernel<scalar_t, width><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), epsilon, \
hidden_size); \
});
void
poly_norm
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [3]
torch
::
Tensor
&
bias
,
// [1]
double
epsilon
)
{
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
data_ptr
()
!=
input
.
data_ptr
());
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
grid
(
num_tokens
);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const
int
max_block_size
=
(
num_tokens
<
256
)
?
1024
:
256
;
dim3
block
(
std
::
min
(
hidden_size
,
max_block_size
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
out_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
out
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
out_ptr
%
16
==
0
;
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
)
{
LAUNCH_FUSED_POLY_NORM
(
8
);
}
else
{
LAUNCH_FUSED_POLY_NORM
(
0
);
}
}
csrc/layernorm_quant_kernels.cu
View file @
41199996
...
...
@@ -7,10 +7,12 @@
#include "type_convert.cuh"
#ifndef USE_ROCM
#include "quantization/fp8/common.cuh"
#include "quantization/
w8a8/
fp8/common.cuh"
#endif
#include "dispatch_utils.h"
#include "cub_helpers.h"
#include "core/batch_invariant.hpp"
#include "quantization/vectorization_utils.cuh"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
...
...
@@ -18,7 +20,7 @@
namespace
vllm
{
// TODO(woosuk): Further optimize this kernel.
template
<
typename
scalar_t
,
typename
fp8_type
>
template
<
typename
scalar_t
,
typename
fp8_type
,
int
VEC_SIZE
>
__global__
void
rms_norm_static_fp8_quant_kernel
(
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
...
...
@@ -29,10 +31,21 @@ __global__ void rms_norm_static_fp8_quant_kernel(
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
input_stride
+
idx
];
const
scalar_t
*
input_row
=
input
+
blockIdx
.
x
*
input_stride
;
auto
vec_op
=
[
&
variance
](
const
vec_n_t
<
scalar_t
,
VEC_SIZE
>&
vec
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
float
x
=
static_cast
<
float
>
(
vec
.
val
[
i
]);
variance
+=
x
*
x
;
}
};
auto
scalar_op
=
[
&
variance
](
const
scalar_t
&
val
)
{
float
x
=
static_cast
<
float
>
(
val
);
variance
+=
x
*
x
;
}
};
vllm
::
vectorize_read_with_alignment
<
VEC_SIZE
>
(
input_row
,
hidden_size
,
threadIdx
.
x
,
blockDim
.
x
,
vec_op
,
scalar_op
);
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
...
...
@@ -46,11 +59,18 @@ __global__ void rms_norm_static_fp8_quant_kernel(
// invert scale to avoid division
float
const
scale_inv
=
1.0
f
/
*
scale
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
input
[
blockIdx
.
x
*
input_stride
+
idx
];
float
const
out_norm
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
auto
*
v_in
=
reinterpret_cast
<
const
vec_n_t
<
scalar_t
,
VEC_SIZE
>*>
(
input_row
);
auto
*
v_w
=
reinterpret_cast
<
const
vec_n_t
<
scalar_t
,
VEC_SIZE
>*>
(
weight
);
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
/
VEC_SIZE
;
idx
+=
blockDim
.
x
)
{
vec_n_t
<
scalar_t
,
VEC_SIZE
>
src1
=
v_in
[
idx
];
vec_n_t
<
scalar_t
,
VEC_SIZE
>
src2
=
v_w
[
idx
];
#pragma unroll
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
j
++
)
{
float
x
=
static_cast
<
float
>
(
src1
.
val
[
j
]);
float
const
out_norm
=
((
scalar_t
)(
x
*
s_variance
))
*
src2
.
val
[
j
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
*
VEC_SIZE
+
j
]
=
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
}
}
}
...
...
@@ -176,20 +196,29 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
int
input_stride
=
input
.
stride
(
-
2
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
// For large num_tokens, use smaller blocks to increase SM concurrency.
const
int
max_block_size
=
(
num_tokens
<
256
)
?
1024
:
256
;
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel_scalar_type"
,
[
&
]
{
VLLM_DISPATCH_FP8_TYPES
(
out
.
scalar_type
(),
"rms_norm_kernel_fp8_type"
,
[
&
]
{
vllm
::
rms_norm_static_fp8_quant_kernel
<
scalar_t
,
fp8_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
fp8_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
input_stride
,
weight
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
epsilon
,
num_tokens
,
hidden_size
);
const
int
calculated_vec_size
=
std
::
gcd
(
16
/
sizeof
(
scalar_t
),
hidden_size
);
const
int
block_size
=
std
::
min
(
hidden_size
/
calculated_vec_size
,
max_block_size
);
dim3
block
(
block_size
);
VLLM_DISPATCH_VEC_SIZE
(
calculated_vec_size
,
[
&
]
{
vllm
::
rms_norm_static_fp8_quant_kernel
<
scalar_t
,
fp8_t
,
vec_size
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
fp8_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
input_stride
,
weight
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
});
});
}
...
...
@@ -217,6 +246,8 @@ void fused_add_rms_norm_static_fp8_quant(
double
epsilon
)
{
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
residual
.
is_contiguous
());
TORCH_CHECK
(
residual
.
scalar_type
()
==
input
.
scalar_type
());
TORCH_CHECK
(
weight
.
scalar_type
()
==
input
.
scalar_type
());
int
hidden_size
=
input
.
size
(
-
1
);
int
input_stride
=
input
.
stride
(
-
2
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
...
...
@@ -242,7 +273,9 @@ void fused_add_rms_norm_static_fp8_quant(
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
res_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
&&
input_stride
%
8
==
0
)
{
bool
batch_invariant_launch
=
vllm
::
vllm_is_batch_invariant
();
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
&&
input_stride
%
8
==
0
&&
!
batch_invariant_launch
)
{
LAUNCH_FUSED_ADD_RMS_NORM
(
8
);
}
else
{
LAUNCH_FUSED_ADD_RMS_NORM
(
0
);
...
...
csrc/mamba/mamba_ssm/selective_scan.h
View file @
41199996
...
...
@@ -24,6 +24,8 @@ struct SSMParamsBase {
int64_t
pad_slot_id
;
bool
delta_softplus
;
bool
cache_enabled
;
int
block_size
;
index_t
A_d_stride
;
index_t
A_dstate_stride
;
...
...
@@ -46,8 +48,9 @@ struct SSMParamsBase {
index_t
out_z_batch_stride
;
index_t
out_z_d_stride
;
index_t
ssm_states_batch_stride
;
index_t
ssm_states_dim_stride
;
index_t
ssm_states_dim_stride
;
index_t
ssm_states_dstate_stride
;
index_t
cache_indices_stride
;
// Common data pointers.
void
*
__restrict__
A_ptr
;
...
...
@@ -66,6 +69,9 @@ struct SSMParamsBase {
void
*
__restrict__
cache_indices_ptr
;
void
*
__restrict__
has_initial_state_ptr
;
void
*
__restrict__
block_idx_first_scheduled_token_ptr
;
// (batch,) - first block to write
void
*
__restrict__
block_idx_last_scheduled_token_ptr
;
// (batch,) - last block to write
void
*
__restrict__
initial_state_idx_ptr
;
// (batch,) - index of the initial state to use
};
...
...
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
View file @
41199996
...
...
@@ -119,7 +119,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
const
int
*
cache_indices
=
params
.
cache_indices_ptr
==
nullptr
?
nullptr
:
reinterpret_cast
<
int
*>
(
params
.
cache_indices_ptr
);
const
int
cache_index
=
cache_indices
==
nullptr
?
batch_id
:
cache_indices
[
batch_id
];
const
int
cache_index
=
cache_indices
==
nullptr
?
batch_id
:
cache_indices
[
batch_id
];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if
(
cache_index
==
params
.
pad_slot_id
){
return
;
...
...
@@ -133,9 +133,18 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
input_t
*
Bvar
=
reinterpret_cast
<
input_t
*>
(
params
.
B_ptr
)
+
sequence_start_index
*
params
.
B_batch_stride
+
group_id
*
params
.
B_group_stride
;
weight_t
*
C
=
reinterpret_cast
<
weight_t
*>
(
params
.
C_ptr
)
+
dim_id
*
kNRows
*
params
.
C_d_stride
;
input_t
*
Cvar
=
reinterpret_cast
<
input_t
*>
(
params
.
C_ptr
)
+
sequence_start_index
*
params
.
C_batch_stride
+
group_id
*
params
.
C_group_stride
;
typename
Ktraits
::
state_t
*
ssm_states
=
reinterpret_cast
<
typename
Ktraits
::
state_t
*>
(
params
.
ssm_states_ptr
)
+
cache_index
*
params
.
ssm_states_batch_stride
+
dim_id
*
kNRows
*
params
.
ssm_states_dim_stride
;
typename
Ktraits
::
state_t
*
ssm_states
;
if
(
params
.
cache_enabled
)
{
// APC mode: ssm_states points to the base, we'll use absolute cache slots later
ssm_states
=
reinterpret_cast
<
typename
Ktraits
::
state_t
*>
(
params
.
ssm_states_ptr
)
+
dim_id
*
kNRows
*
params
.
ssm_states_dim_stride
;
}
else
{
// Non-APC mode: offset by cache_index as before
ssm_states
=
reinterpret_cast
<
typename
Ktraits
::
state_t
*>
(
params
.
ssm_states_ptr
)
+
cache_index
*
params
.
ssm_states_batch_stride
+
dim_id
*
kNRows
*
params
.
ssm_states_dim_stride
;
}
float
D_val
[
kNRows
]
=
{
0
};
if
(
params
.
D_ptr
!=
nullptr
)
{
...
...
@@ -159,7 +168,22 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
// }
constexpr
int
kChunkSize
=
kNThreads
*
kNItems
;
const
int
n_chunks
=
(
seqlen
+
2048
-
1
)
/
2048
;
// Use block_size for chunking when APC is enabled, otherwise use 2048 for backwards compatibility
const
int
iteration_chunk_size
=
params
.
cache_enabled
?
params
.
block_size
:
2048
;
const
int
n_chunks
=
(
seqlen
+
iteration_chunk_size
-
1
)
/
iteration_chunk_size
;
const
int
*
batch_cache_indices
=
cache_indices
!=
nullptr
?
cache_indices
+
batch_id
*
params
.
cache_indices_stride
:
nullptr
;
const
int
*
block_idx_first_scheduled
=
params
.
block_idx_first_scheduled_token_ptr
!=
nullptr
?
reinterpret_cast
<
const
int
*>
(
params
.
block_idx_first_scheduled_token_ptr
)
:
nullptr
;
const
int
*
block_idx_last_scheduled
=
params
.
block_idx_last_scheduled_token_ptr
!=
nullptr
?
reinterpret_cast
<
const
int
*>
(
params
.
block_idx_last_scheduled_token_ptr
)
:
nullptr
;
const
int
*
initial_state_idx
=
params
.
initial_state_idx_ptr
!=
nullptr
?
reinterpret_cast
<
const
int
*>
(
params
.
initial_state_idx_ptr
)
:
nullptr
;
const
size_t
load_cache_slot
=
params
.
cache_enabled
&&
batch_cache_indices
!=
nullptr
?
batch_cache_indices
[
initial_state_idx
[
batch_id
]]
:
cache_index
;
for
(
int
chunk
=
0
;
chunk
<
n_chunks
;
++
chunk
)
{
input_t
u_vals
[
kNRows
][
kNItems
],
delta_vals_load
[
kNRows
][
kNItems
];
...
...
@@ -219,7 +243,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if
constexpr
(
kIsVariableC
)
{
auto
&
smem_load_weight_C
=
!
kIsVariableB
?
smem_load_weight
:
smem_load_weight1
;
load_weight
<
Ktraits
>
(
Cvar
+
state_idx
*
params
.
C_dstate_stride
,
C_vals
,
smem_load_weight_C
,
(
seqlen
-
chunk
*
kChunkSize
)
*
(
1
));
smem_load_weight_C
,
(
seqlen
-
chunk
*
kChunkSize
)
*
(
1
));
if
constexpr
(
!
kIsVariableB
)
{
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
...
...
@@ -242,7 +266,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
for
(
int
i
=
0
;
i
<
kNItems
;
++
i
)
{
thread_data
[
i
]
=
make_float2
(
exp2f
(
delta_vals
[
r
][
i
]
*
A_val
[
r
]),
!
kIsVariableB
?
delta_u_vals
[
r
][
i
]
:
B_vals
[
i
]
*
delta_u_vals
[
r
][
i
]);
if
(
seqlen
%
(
kNItems
*
kNThreads
)
!=
0
)
{
// So that the last state is correct
if
(
threadIdx
.
x
*
kNItems
+
i
>=
seqlen
-
chunk
*
kChunkSize
)
{
thread_data
[
i
]
=
make_float2
(
1.
f
,
0.
f
);
...
...
@@ -250,8 +273,24 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
}
}
// Initialize running total
scan_t
running_prefix
=
chunk
>
0
?
smem_running_prefix
[
state_idx
+
r
*
MAX_DSTATE
]
:
make_float2
(
1.0
,
has_initial_state
?
float
(
ssm_states
[
state_idx
*
params
.
ssm_states_dstate_stride
])
:
0.0
);
scan_t
running_prefix
;
if
(
chunk
>
0
)
{
running_prefix
=
smem_running_prefix
[
state_idx
+
r
*
MAX_DSTATE
];
}
else
{
// Load initial state
if
(
params
.
cache_enabled
&&
has_initial_state
&&
batch_cache_indices
!=
nullptr
)
{
size_t
state_offset
=
load_cache_slot
*
params
.
ssm_states_batch_stride
+
r
*
params
.
ssm_states_dim_stride
+
state_idx
*
params
.
ssm_states_dstate_stride
;
running_prefix
=
make_float2
(
1.0
,
float
(
ssm_states
[
state_offset
]));
}
else
if
(
has_initial_state
)
{
// Non-APC mode: load from current batch position
running_prefix
=
make_float2
(
1.0
,
float
(
ssm_states
[
state_idx
*
params
.
ssm_states_dstate_stride
]));
}
else
{
// No initial state
running_prefix
=
make_float2
(
1.0
,
0.0
);
}
}
SSMScanPrefixCallbackOp
<
weight_t
>
prefix_op
(
running_prefix
);
typename
Ktraits
::
BlockScanT
(
smem_scan
).
InclusiveScan
(
...
...
@@ -260,8 +299,25 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
// There's a syncthreads in the scan op, so we don't need to sync here.
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
if
(
threadIdx
.
x
==
0
)
{
smem_running_prefix
[
state_idx
]
=
prefix_op
.
running_prefix
;
if
(
chunk
==
n_chunks
-
1
)
{
smem_running_prefix
[
state_idx
+
r
*
MAX_DSTATE
]
=
prefix_op
.
running_prefix
;
// Store state at the end of each chunk when cache is enabled
if
(
params
.
cache_enabled
&&
batch_cache_indices
!=
nullptr
)
{
size_t
cache_slot
;
if
(
chunk
==
n_chunks
-
1
)
{
cache_slot
=
batch_cache_indices
[
block_idx_last_scheduled
[
batch_id
]];
}
else
{
cache_slot
=
batch_cache_indices
[
block_idx_first_scheduled
[
batch_id
]
+
chunk
];
}
size_t
state_offset
=
cache_slot
*
params
.
ssm_states_batch_stride
+
r
*
params
.
ssm_states_dim_stride
+
state_idx
*
params
.
ssm_states_dstate_stride
;
ssm_states
[
state_offset
]
=
typename
Ktraits
::
state_t
(
prefix_op
.
running_prefix
.
y
);
}
else
if
(
!
params
.
cache_enabled
&&
chunk
==
n_chunks
-
1
)
{
// Non-APC mode: store only final state at current batch position
ssm_states
[
state_idx
*
params
.
ssm_states_dstate_stride
]
=
typename
Ktraits
::
state_t
(
prefix_op
.
running_prefix
.
y
);
}
}
...
...
@@ -274,7 +330,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
}
}
}
input_t
*
out
=
reinterpret_cast
<
input_t
*>
(
params
.
out_ptr
)
+
sequence_start_index
*
params
.
out_batch_stride
+
dim_id
*
kNRows
*
params
.
out_d_stride
+
chunk
*
kChunkSize
;
__syncthreads
();
...
...
@@ -346,7 +401,9 @@ template<typename input_t, typename weight_t, typename state_t>
void
selective_scan_fwd_cuda
(
SSMParamsBase
&
params
,
cudaStream_t
stream
)
{
#ifndef USE_ROCM
if
(
params
.
seqlen
<=
128
)
{
if
(
params
.
cache_enabled
&&
params
.
block_size
==
1024
)
{
selective_scan_fwd_launch
<
64
,
16
,
input_t
,
weight_t
,
state_t
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen
<=
128
)
{
selective_scan_fwd_launch
<
32
,
4
,
input_t
,
weight_t
,
state_t
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen
<=
256
)
{
selective_scan_fwd_launch
<
32
,
8
,
input_t
,
weight_t
,
state_t
>
(
params
,
stream
);
...
...
@@ -358,7 +415,9 @@ void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) {
selective_scan_fwd_launch
<
128
,
16
,
input_t
,
weight_t
,
state_t
>
(
params
,
stream
);
}
#else
if
(
params
.
seqlen
<=
256
)
{
if
(
params
.
cache_enabled
&&
params
.
block_size
==
1024
)
{
selective_scan_fwd_launch
<
64
,
16
,
input_t
,
weight_t
,
state_t
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen
<=
256
)
{
selective_scan_fwd_launch
<
64
,
4
,
input_t
,
weight_t
,
state_t
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen
<=
512
)
{
selective_scan_fwd_launch
<
64
,
8
,
input_t
,
weight_t
,
state_t
>
(
params
,
stream
);
...
...
@@ -437,13 +496,17 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
const
std
::
optional
<
at
::
Tensor
>&
D
,
const
std
::
optional
<
at
::
Tensor
>&
delta_bias
,
const
torch
::
Tensor
ssm_states
,
bool
has_z
,
bool
has_z
,
bool
delta_softplus
,
const
std
::
optional
<
at
::
Tensor
>&
query_start_loc
,
const
std
::
optional
<
at
::
Tensor
>&
cache_indices
,
const
std
::
optional
<
at
::
Tensor
>&
has_initial_state
,
bool
varlen
,
int64_t
pad_slot_id
)
{
int64_t
pad_slot_id
,
int64_t
block_size
,
const
std
::
optional
<
torch
::
Tensor
>
&
block_idx_first_scheduled_token
,
const
std
::
optional
<
torch
::
Tensor
>
&
block_idx_last_scheduled_token
,
const
std
::
optional
<
torch
::
Tensor
>
&
initial_state_idx
)
{
// Reset the parameters
memset
(
&
params
,
0
,
sizeof
(
params
));
...
...
@@ -477,6 +540,14 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
params
.
cache_indices_ptr
=
cache_indices
.
has_value
()
?
cache_indices
.
value
().
data_ptr
()
:
nullptr
;
params
.
has_initial_state_ptr
=
has_initial_state
.
has_value
()
?
has_initial_state
.
value
().
data_ptr
()
:
nullptr
;
// Set cache parameters - cache is enabled if we have direct cache writing params
params
.
cache_enabled
=
block_idx_first_scheduled_token
.
has_value
();
params
.
block_size
=
static_cast
<
int
>
(
block_size
);
// Set direct cache writing pointers
params
.
block_idx_first_scheduled_token_ptr
=
block_idx_first_scheduled_token
.
has_value
()
?
block_idx_first_scheduled_token
.
value
().
data_ptr
()
:
nullptr
;
params
.
block_idx_last_scheduled_token_ptr
=
block_idx_last_scheduled_token
.
has_value
()
?
block_idx_last_scheduled_token
.
value
().
data_ptr
()
:
nullptr
;
params
.
initial_state_idx_ptr
=
initial_state_idx
.
has_value
()
?
initial_state_idx
.
value
().
data_ptr
()
:
nullptr
;
// All stride are in elements, not bytes.
params
.
A_d_stride
=
A
.
stride
(
0
);
...
...
@@ -504,9 +575,11 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
params
.
out_d_stride
=
out
.
stride
(
0
);
params
.
ssm_states_batch_stride
=
ssm_states
.
stride
(
0
);
params
.
ssm_states_dim_stride
=
ssm_states
.
stride
(
1
);
params
.
ssm_states_dim_stride
=
ssm_states
.
stride
(
1
);
params
.
ssm_states_dstate_stride
=
ssm_states
.
stride
(
2
);
params
.
cache_indices_stride
=
cache_indices
.
has_value
()
?
cache_indices
.
value
().
stride
(
0
)
:
0
;
}
else
{
if
(
!
is_variable_B
)
{
...
...
@@ -537,8 +610,10 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
params
.
out_d_stride
=
out
.
stride
(
1
);
params
.
ssm_states_batch_stride
=
ssm_states
.
stride
(
0
);
params
.
ssm_states_dim_stride
=
ssm_states
.
stride
(
1
);
params
.
ssm_states_dim_stride
=
ssm_states
.
stride
(
1
);
params
.
ssm_states_dstate_stride
=
ssm_states
.
stride
(
2
);
params
.
cache_indices_stride
=
cache_indices
.
has_value
()
?
cache_indices
.
value
().
stride
(
0
)
:
0
;
}
}
...
...
@@ -554,7 +629,11 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const
torch
::
Tensor
&
ssm_states
,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t
pad_slot_id
)
{
int64_t
pad_slot_id
,
int64_t
block_size
,
const
std
::
optional
<
torch
::
Tensor
>
&
block_idx_first_scheduled_token
,
const
std
::
optional
<
torch
::
Tensor
>
&
block_idx_last_scheduled_token
,
const
std
::
optional
<
torch
::
Tensor
>
&
initial_state_idx
)
{
auto
input_type
=
u
.
scalar_type
();
auto
weight_type
=
A
.
scalar_type
();
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
...
...
@@ -646,7 +725,16 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
auto
cache_indices_
=
cache_indices
.
value
();
TORCH_CHECK
(
cache_indices_
.
scalar_type
()
==
at
::
ScalarType
::
Int
);
TORCH_CHECK
(
cache_indices_
.
is_cuda
());
CHECK_SHAPE
(
cache_indices_
,
batch_size
);
// cache_indices can be either 1D (batch_size,) for non-APC mode
// or 2D (batch_size, max_positions) for APC mode
const
bool
is_apc_mode
=
block_idx_first_scheduled_token
.
has_value
();
if
(
is_apc_mode
)
{
TORCH_CHECK
(
cache_indices_
.
dim
()
==
2
,
"cache_indices must be 2D for APC mode"
);
TORCH_CHECK
(
cache_indices_
.
size
(
0
)
==
batch_size
,
"cache_indices first dimension must match batch_size"
);
}
else
{
CHECK_SHAPE
(
cache_indices_
,
batch_size
);
}
}
...
...
@@ -686,7 +774,11 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
cache_indices
,
has_initial_state
,
varlen
,
pad_slot_id
pad_slot_id
,
block_size
,
block_idx_first_scheduled_token
,
block_idx_last_scheduled_token
,
initial_state_idx
);
...
...
csrc/moe/dynamic_4bit_int_moe_cpu.cpp
View file @
41199996
...
...
@@ -87,30 +87,23 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
const
int64_t
g_eff_13
=
(
group_size
!=
-
1
)
?
group_size
:
H
;
const
int64_t
g_eff_2
=
(
group_size
!=
-
1
)
?
group_size
:
I
;
// Per-expert outputs filled in parallel
std
::
vector
<
torch
::
Tensor
>
y_list
(
E
);
y_list
.
resize
(
E
);
at
::
parallel_for
(
0
,
E
,
1
,
[
&
](
int64_t
e_begin
,
int64_t
e_end
)
{
for
(
int64_t
e
=
e_begin
;
e
<
e_end
;
++
e
)
{
const
int64_t
te
=
counts
[
e
];
if
(
te
==
0
)
{
y_list
[
e
]
=
at
::
empty
({
0
,
H
},
x_c
.
options
());
auto
X_all
=
x_c
.
index_select
(
/*dim=*/
0
,
expert_tokens
);
if
(
apply_router_weight_on_input
)
{
X_all
=
X_all
.
mul
(
expert_gates
.
unsqueeze
(
1
));
}
auto
Y_all
=
at
::
empty
({
offsets
[
E
],
H
},
x_c
.
options
());
at
::
parallel_for
(
0
,
offsets
[
E
],
0
,
[
&
](
int64_t
idx_begin
,
int64_t
idx_end
)
{
c10
::
InferenceMode
guard
;
for
(
int64_t
e
=
0
;
e
<
E
;
++
e
)
{
int64_t
start
=
std
::
max
(
offsets
[
e
],
idx_begin
);
int64_t
end
=
std
::
min
(
offsets
[
e
+
1
],
idx_end
);
int64_t
te
=
end
-
start
;
if
(
te
<=
0
)
{
continue
;
}
const
int64_t
start
=
offsets
[
e
];
auto
sel_tokens
=
expert_tokens
.
narrow
(
/*dim=*/
0
,
/*start=*/
start
,
/*length=*/
te
);
auto
gates_e
=
expert_gates
.
narrow
(
/*dim=*/
0
,
/*start=*/
start
,
/*length=*/
te
);
auto
x_e
=
x_c
.
index_select
(
/*dim=*/
0
,
sel_tokens
);
if
(
apply_router_weight_on_input
)
{
x_e
=
x_e
.
mul
(
gates_e
.
unsqueeze
(
1
));
}
auto
x_e
=
X_all
.
narrow
(
/*dim=*/
0
,
/*start=*/
start
,
/*length=*/
te
);
auto
w13_e
=
w13_packed
.
select
(
/*dim=*/
0
,
e
);
auto
w2_e
=
w2_packed
.
select
(
/*dim=*/
0
,
e
);
...
...
@@ -137,17 +130,15 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
// W2
auto
y
=
mm
(
act
,
w2_e
,
g_eff_2
,
/*in_features=*/
I
,
/*out_features=*/
H
);
if
(
!
apply_router_weight_on_input
)
{
y
=
y
.
mul
(
gates_e
.
unsqueeze
(
1
));
}
// Store per-expert result
y_list
[
e
]
=
y
;
Y_all
.
narrow
(
/*dim=*/
0
,
/*start=*/
start
,
/*length=*/
te
).
copy_
(
y
)
;
}
});
// Concatenate all expert outputs to match expert_tokens order
auto
Y_all
=
at
::
cat
(
y_list
,
/*dim=*/
0
);
if
(
!
apply_router_weight_on_input
)
{
Y_all
=
Y_all
.
mul
(
expert_gates
.
unsqueeze
(
1
));
}
auto
out
=
at
::
zeros
({
T
,
H
},
x
.
options
());
out
=
at
::
index_add
(
out
,
/*dim=*/
0
,
/*index=*/
expert_tokens
,
/*source=*/
Y_all
);
...
...
csrc/moe/grouped_topk_kernels.cu
View file @
41199996
...
...
@@ -427,11 +427,29 @@ __device__ inline bool is_finite(const T val) {
#endif
}
// Scoring function enums
enum
ScoringFunc
{
SCORING_NONE
=
0
,
// no activation function
SCORING_SIGMOID
=
1
// apply sigmoid
};
// Efficient sigmoid approximation from TensorRT-LLM
__device__
inline
float
sigmoid_accurate
(
float
x
)
{
return
0.5
f
*
tanhf
(
0.5
f
*
x
)
+
0.5
f
;
}
template
<
typename
T
>
__device__
void
topk_with_k2
(
T
*
output
,
T
const
*
input
,
__device__
inline
T
apply_sigmoid
(
T
val
)
{
float
f
=
cuda_cast
<
float
,
T
>
(
val
);
return
cuda_cast
<
T
,
float
>
(
sigmoid_accurate
(
f
));
}
template
<
typename
T
>
__device__
void
topk_with_k2
(
T
*
output
,
T
const
*
input
,
T
const
*
bias
,
cg
::
thread_block_tile
<
32
>
const
&
tile
,
int32_t
const
lane_id
,
int
const
num_experts_per_group
)
{
int
const
num_experts_per_group
,
int
const
scoring_func
)
{
// Get the top2 per thread
T
largest
=
neg_inf
<
T
>
();
T
second_largest
=
neg_inf
<
T
>
();
...
...
@@ -439,6 +457,12 @@ __device__ void topk_with_k2(T* output, T const* input,
if
(
num_experts_per_group
>
WARP_SIZE
)
{
for
(
int
i
=
lane_id
;
i
<
num_experts_per_group
;
i
+=
WARP_SIZE
)
{
T
value
=
input
[
i
];
// Apply scoring function if needed
if
(
scoring_func
==
SCORING_SIGMOID
)
{
value
=
apply_sigmoid
(
value
);
}
value
=
value
+
bias
[
i
];
if
(
value
>
largest
)
{
second_largest
=
largest
;
largest
=
value
;
...
...
@@ -448,7 +472,13 @@ __device__ void topk_with_k2(T* output, T const* input,
}
}
else
{
for
(
int
i
=
lane_id
;
i
<
num_experts_per_group
;
i
+=
WARP_SIZE
)
{
largest
=
input
[
i
];
T
value
=
input
[
i
];
// Apply scoring function if needed
if
(
scoring_func
==
SCORING_SIGMOID
)
{
value
=
apply_sigmoid
(
value
);
}
value
=
value
+
bias
[
i
];
largest
=
value
;
}
}
...
...
@@ -472,17 +502,21 @@ __device__ void topk_with_k2(T* output, T const* input,
}
template
<
typename
T
>
__global__
void
topk_with_k2_kernel
(
T
*
output
,
T
*
input
,
__global__
void
topk_with_k2_kernel
(
T
*
output
,
T
*
input
,
T
const
*
bias
,
int64_t
const
num_tokens
,
int64_t
const
num_cases
,
int64_t
const
n_group
,
int64_t
const
num_experts_per_group
)
{
int64_t
const
num_experts_per_group
,
int
const
scoring_func
)
{
int32_t
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
int32_t
lane_id
=
threadIdx
.
x
%
WARP_SIZE
;
int32_t
case_id
=
blockIdx
.
x
*
NUM_WARPS_PER_BLOCK
+
warp_id
;
if
(
case_id
<
num_cases
)
{
input
+=
case_id
*
num_experts_per_group
;
// bias is per expert group, offset to current group
int32_t
group_id
=
case_id
%
n_group
;
T
const
*
group_bias
=
bias
+
group_id
*
num_experts_per_group
;
output
+=
case_id
;
cg
::
thread_block
block
=
cg
::
this_thread_block
();
...
...
@@ -491,7 +525,8 @@ __global__ void topk_with_k2_kernel(T* output, T* input,
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm
volatile
(
"griddepcontrol.wait;"
);
#endif
topk_with_k2
(
output
,
input
,
tile
,
lane_id
,
num_experts_per_group
);
topk_with_k2
(
output
,
input
,
group_bias
,
tile
,
lane_id
,
num_experts_per_group
,
scoring_func
);
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm
volatile
(
"griddepcontrol.launch_dependents;"
);
...
...
@@ -500,16 +535,15 @@ __global__ void topk_with_k2_kernel(T* output, T* input,
template
<
typename
T
,
typename
IdxT
>
__global__
void
group_idx_and_topk_idx_kernel
(
T
*
scores
,
T
const
*
group_scores
,
T
*
topk_values
,
IdxT
*
topk_indices
,
T
*
scores_with_
bias
,
int64_t
const
num_tokens
,
int64_t
const
n_group
,
T
*
scores
,
T
const
*
group_scores
,
float
*
topk_values
,
IdxT
*
topk_indices
,
T
const
*
bias
,
int64_t
const
num_tokens
,
int64_t
const
n_group
,
int64_t
const
topk_group
,
int64_t
const
topk
,
int64_t
const
num_experts
,
int64_t
const
num_experts_per_group
,
bool
renormalize
,
double
routed_scaling_factor
)
{
double
routed_scaling_factor
,
int
scoring_func
)
{
int32_t
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
int32_t
lane_id
=
threadIdx
.
x
%
WARP_SIZE
;
int32_t
case_id
=
blockIdx
.
x
*
NUM_WARPS_PER_BLOCK
+
warp_id
;
// one per token
scores_with_bias
+=
case_id
*
num_experts
;
scores
+=
case_id
*
num_experts
;
group_scores
+=
case_id
*
n_group
;
topk_values
+=
case_id
*
topk
;
...
...
@@ -577,10 +611,16 @@ __global__ void group_idx_and_topk_idx_kernel(
int32_t
offset
=
i_group
*
num_experts_per_group
;
for
(
int32_t
i
=
lane_id
;
i
<
align_num_experts_per_group
;
i
+=
WARP_SIZE
)
{
T
candidates
=
(
i
<
num_experts_per_group
)
&&
is_finite
(
scores_with_bias
[
offset
+
i
])
?
scores_with_bias
[
offset
+
i
]
:
neg_inf
<
T
>
();
T
candidates
=
neg_inf
<
T
>
();
if
(
i
<
num_experts_per_group
)
{
// Apply scoring function (if any) and add bias
T
input
=
scores
[
offset
+
i
];
if
(
is_finite
(
input
))
{
T
score
=
(
scoring_func
==
SCORING_SIGMOID
)
?
apply_sigmoid
(
input
)
:
input
;
candidates
=
score
+
bias
[
offset
+
i
];
}
}
queue
.
add
(
candidates
,
offset
+
i
);
}
if
(
group_scores
[
i_group
]
==
topk_group_value
)
{
...
...
@@ -602,11 +642,12 @@ __global__ void group_idx_and_topk_idx_kernel(
for
(
int
i
=
lane_id
;
i
<
warp_topk
::
round_up_to_multiple_of
<
WARP_SIZE
>
(
topk
);
i
+=
WARP_SIZE
)
{
T
value
=
i
<
topk
?
scores
[
s_topk_idx
[
i
]]
:
cuda_cast
<
T
,
float
>
(
0.0
f
);
// Load the valid value of expert
T
value
=
cuda_cast
<
T
,
float
>
(
0.0
f
);
if
(
i
<
topk
)
{
// Load the score value (without bias) for normalization
T
input
=
scores
[
s_topk_idx
[
i
]];
value
=
(
scoring_func
==
SCORING_SIGMOID
)
?
apply_sigmoid
(
input
)
:
input
;
s_topk_value
[
i
]
=
value
;
}
topk_sum
+=
...
...
@@ -627,12 +668,12 @@ __global__ void group_idx_and_topk_idx_kernel(
value
=
cuda_cast
<
float
,
T
>
(
s_topk_value
[
i
])
*
routed_scaling_factor
;
}
topk_indices
[
i
]
=
s_topk_idx
[
i
];
topk_values
[
i
]
=
cuda_cast
<
T
,
float
>
(
value
)
;
topk_values
[
i
]
=
value
;
}
}
else
{
for
(
int
i
=
lane_id
;
i
<
topk
;
i
+=
WARP_SIZE
)
{
topk_indices
[
i
]
=
i
;
topk_values
[
i
]
=
cuda_cast
<
T
,
float
>
(
1.0
f
/
topk
)
;
topk_values
[
i
]
=
1.0
f
/
topk
;
}
}
// Note: when if_proceed_next_topk==false, choose the first 8 experts as the
...
...
@@ -644,12 +685,12 @@ __global__ void group_idx_and_topk_idx_kernel(
}
template
<
typename
T
,
typename
IdxT
>
void
invokeNoAuxTc
(
T
*
scores
,
T
*
group_scores
,
T
*
topk_values
,
IdxT
*
topk_indices
,
T
*
scores_with_bia
s
,
int64_t
const
num_
token
s
,
int64_t
const
n
um_experts
,
int64_t
const
n
_group
,
int64_t
const
topk
_group
,
int64_t
const
topk
,
bool
const
renormalize
,
double
const
routed_scaling_factor
,
bool
enable_pdl
=
false
,
void
invokeNoAuxTc
(
T
*
scores
,
T
*
group_scores
,
float
*
topk_values
,
IdxT
*
topk_indices
,
T
const
*
bias
,
int64_t
const
num_token
s
,
int64_t
const
num_
expert
s
,
int64_t
const
n
_group
,
int64_t
const
topk
_group
,
int64_t
const
topk
,
bool
const
renormalize
,
double
const
routed_scaling_factor
,
int
const
scoring_func
,
bool
enable_pdl
=
false
,
cudaStream_t
const
stream
=
0
)
{
int64_t
num_cases
=
num_tokens
*
n_group
;
int64_t
topk_with_k2_num_blocks
=
(
num_cases
-
1
)
/
NUM_WARPS_PER_BLOCK
+
1
;
...
...
@@ -664,8 +705,9 @@ void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values,
attrs
[
0
].
val
.
programmaticStreamSerializationAllowed
=
enable_pdl
;
config
.
numAttrs
=
1
;
config
.
attrs
=
attrs
;
cudaLaunchKernelEx
(
&
config
,
kernel_instance1
,
group_scores
,
scores_with_bias
,
num_tokens
,
num_cases
,
n_group
,
num_experts
/
n_group
);
cudaLaunchKernelEx
(
&
config
,
kernel_instance1
,
group_scores
,
scores
,
bias
,
num_tokens
,
num_cases
,
n_group
,
num_experts
/
n_group
,
scoring_func
);
int64_t
topk_with_k_group_num_blocks
=
(
num_tokens
-
1
)
/
NUM_WARPS_PER_BLOCK
+
1
;
...
...
@@ -682,19 +724,18 @@ void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values,
config
.
numAttrs
=
1
;
config
.
attrs
=
attrs
;
cudaLaunchKernelEx
(
&
config
,
kernel_instance2
,
scores
,
group_scores
,
topk_values
,
topk_indices
,
scores_with_
bias
,
num_tokens
,
n_group
,
topk_group
,
topk
,
num_experts
,
num_experts
/
n_group
,
renormalize
,
routed_scaling_factor
);
topk_values
,
topk_indices
,
bias
,
num_tokens
,
n_group
,
topk_group
,
topk
,
num_experts
,
num_experts
/
n_group
,
renormalize
,
routed_scaling_factor
,
scoring_func
);
}
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
template void invokeNoAuxTc<T, IdxT>( \
T * scores, T * group_scores, T * topk_values, IdxT * topk_indices, \
T * scores_with_bias, int64_t const num_tokens, \
int64_t const num_experts, int64_t const n_group, \
int64_t const topk_group, int64_t const topk, bool const renormalize, \
double const routed_scaling_factor, bool enable_pdl, \
cudaStream_t const stream);
T * scores, T * group_scores, float* topk_values, IdxT* topk_indices, \
T const* bias, int64_t const num_tokens, int64_t const num_experts, \
int64_t const n_group, int64_t const topk_group, int64_t const topk, \
bool const renormalize, double const routed_scaling_factor, \
int const scoring_func, bool enable_pdl, cudaStream_t const stream);
INSTANTIATE_NOAUX_TC
(
float
,
int32_t
);
INSTANTIATE_NOAUX_TC
(
half
,
int32_t
);
...
...
@@ -703,28 +744,32 @@ INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t);
}
// namespace vllm
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
grouped_topk
(
torch
::
Tensor
const
&
scores
,
torch
::
Tensor
const
&
scores_with_bias
,
int64_t
n_group
,
int64_t
topk_group
,
int64_t
topk
,
bool
renormalize
,
double
routed_scaling_factor
)
{
auto
data_type
=
scores
_with_bias
.
scalar_type
();
auto
input_size
=
scores
_with_bias
.
sizes
();
torch
::
Tensor
const
&
scores
,
int64_t
n_group
,
int64_t
topk_group
,
int64_t
topk
,
bool
renormalize
,
double
routed_scaling_factor
,
torch
::
Tensor
const
&
bias
,
int64_t
scoring_func
=
0
)
{
auto
data_type
=
scores
.
scalar_type
();
auto
input_size
=
scores
.
sizes
();
int64_t
num_tokens
=
input_size
[
0
];
int64_t
num_experts
=
input_size
[
1
];
TORCH_CHECK
(
input_size
.
size
()
==
2
,
"scores
_with_bias
must be a 2D Tensor"
);
TORCH_CHECK
(
input_size
.
size
()
==
2
,
"scores must be a 2D Tensor"
);
TORCH_CHECK
(
num_experts
%
n_group
==
0
,
"num_experts should be divisible by n_group"
);
TORCH_CHECK
(
n_group
<=
32
,
"n_group should be smaller than or equal to 32 for now"
);
TORCH_CHECK
(
topk
<=
32
,
"topk should be smaller than or equal to 32 for now"
);
TORCH_CHECK
(
scoring_func
==
vllm
::
moe
::
SCORING_NONE
||
scoring_func
==
vllm
::
moe
::
SCORING_SIGMOID
,
"scoring_func must be SCORING_NONE (0) or SCORING_SIGMOID (1)"
);
torch
::
Tensor
group_scores
=
torch
::
empty
(
{
num_tokens
,
n_group
},
torch
::
dtype
(
data_type
).
device
(
torch
::
kCUDA
));
// Always output float32 for topk_values (eliminates Python-side conversion)
torch
::
Tensor
topk_values
=
torch
::
empty
(
{
num_tokens
,
topk
},
torch
::
dtype
(
data_type
).
device
(
torch
::
kCUDA
));
{
num_tokens
,
topk
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
torch
::
Tensor
topk_indices
=
torch
::
empty
(
{
num_tokens
,
topk
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
(
scores
_with_bias
.
get_device
());
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
(
scores
.
get_device
());
switch
(
data_type
)
{
case
torch
::
kFloat16
:
...
...
@@ -732,11 +777,11 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
vllm
::
moe
::
invokeNoAuxTc
<
half
,
int32_t
>
(
reinterpret_cast
<
half
*>
(
scores
.
mutable_data_ptr
()),
reinterpret_cast
<
half
*>
(
group_scores
.
mutable_data_ptr
()),
reinterpret_cast
<
half
*>
(
topk_values
.
mutable_data_ptr
()),
reinterpret_cast
<
float
*>
(
topk_values
.
mutable_data_ptr
()),
reinterpret_cast
<
int32_t
*>
(
topk_indices
.
mutable_data_ptr
()),
reinterpret_cast
<
half
*>
(
scores_with_
bias
.
data_ptr
()),
num_tokens
,
reinterpret_cast
<
half
const
*>
(
bias
.
data_ptr
()),
num_tokens
,
num_experts
,
n_group
,
topk_group
,
topk
,
renormalize
,
routed_scaling_factor
,
false
,
stream
);
routed_scaling_factor
,
static_cast
<
int
>
(
scoring_func
),
false
,
stream
);
break
;
case
torch
::
kFloat32
:
// Handle Float32
...
...
@@ -745,20 +790,20 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
reinterpret_cast
<
float
*>
(
group_scores
.
mutable_data_ptr
()),
reinterpret_cast
<
float
*>
(
topk_values
.
mutable_data_ptr
()),
reinterpret_cast
<
int32_t
*>
(
topk_indices
.
mutable_data_ptr
()),
reinterpret_cast
<
float
*>
(
scores_with_
bias
.
data_ptr
()),
num_tokens
,
reinterpret_cast
<
float
const
*>
(
bias
.
data_ptr
()),
num_tokens
,
num_experts
,
n_group
,
topk_group
,
topk
,
renormalize
,
routed_scaling_factor
,
false
,
stream
);
routed_scaling_factor
,
static_cast
<
int
>
(
scoring_func
),
false
,
stream
);
break
;
case
torch
::
kBFloat16
:
// Handle BFloat16
vllm
::
moe
::
invokeNoAuxTc
<
__nv_bfloat16
,
int32_t
>
(
reinterpret_cast
<
__nv_bfloat16
*>
(
scores
.
mutable_data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
*>
(
group_scores
.
mutable_data_ptr
()),
reinterpret_cast
<
__nv_b
float
16
*>
(
topk_values
.
mutable_data_ptr
()),
reinterpret_cast
<
float
*>
(
topk_values
.
mutable_data_ptr
()),
reinterpret_cast
<
int32_t
*>
(
topk_indices
.
mutable_data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
*>
(
scores_with_
bias
.
data_ptr
()),
num_tokens
,
num_experts
,
n_group
,
topk_group
,
topk
,
renormalize
,
routed_scaling_factor
,
false
,
stream
);
reinterpret_cast
<
__nv_bfloat16
const
*>
(
bias
.
data_ptr
()),
num_tokens
,
num_experts
,
n_group
,
topk_group
,
topk
,
renormalize
,
routed_scaling_factor
,
static_cast
<
int
>
(
scoring_func
),
false
,
stream
);
break
;
default:
// Handle other data types
...
...
csrc/moe/marlin_moe_wna16/.gitignore
View file @
41199996
kernel_*.cu
\ No newline at end of file
sm*_kernel_*.cu
kernel_selector.h
csrc/moe/marlin_moe_wna16/generate_kernels.py
View file @
41199996
...
...
@@ -4,128 +4,282 @@ import glob
import
itertools
import
os
import
subprocess
import
sys
import
jinja2
FILE_HEAD
=
"""
// auto generated by generate.py
ARCHS
=
[]
SUPPORT_FP8
=
False
for
arch
in
sys
.
argv
[
1
].
split
(
","
):
arch
=
arch
[:
arch
.
index
(
"."
)
+
2
].
replace
(
"."
,
""
)
arch
=
int
(
arch
)
# only SM89 and SM120 fully support
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM90 and SM100 can use this PTX, but it’s simulated
# with FP16 MMA, so it cannot achieve any acceleration.
if
arch
in
[
89
,
120
]:
SUPPORT_FP8
=
True
FILE_HEAD_COMMENT
=
"""
// auto generated by generate_kernels.py
// clang-format off
"""
.
lstrip
()
FILE_HEAD
=
(
FILE_HEAD_COMMENT
+
"""
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
"""
.
strip
()
TEMPLATE
=
(
"template __global__ void Marlin<"
"{{scalar_t}}, "
"{{w_type_id}}, "
"{{s_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );"
)
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES
=
[
"vllm::kU4"
,
"vllm::kU4B8"
,
"vllm::kU8B128"
,
"vllm::kFE4M3fn"
,
"vllm::kFE2M1f"
]
"""
)
TEMPLATE
=
(
"template __global__ void Marlin<"
"{{a_type_id}}, "
"{{b_type_id}}, "
"{{c_type_id}}, "
"{{s_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
"{{thread_k_blocks}}, "
"{{m_block_size_8}}, "
"{{stages}}, "
"{{group_blocks}}, "
"{{is_zp_float}}>"
"( MARLIN_KERNEL_PARAMS );"
)
THREAD_CONFIGS
=
[(
128
,
128
,
256
),
(
64
,
256
,
256
),
(
64
,
128
,
128
)]
THREAD_M_BLOCKS
=
[
0.5
,
1
,
2
,
3
,
4
]
# group_blocks:
# = 0 : act order case
# = -1 : channelwise quantization
# > 0 : group_size=16*group_blocks
GROUP_BLOCKS
=
[
0
,
-
1
,
1
,
2
,
4
,
8
]
DTYPES
=
[
"fp16"
,
"bf16"
]
QUANT_CONFIGS
=
[
# AWQ-INT4
{
"b_type"
:
"kU4"
,
"thread_configs"
:
THREAD_CONFIGS
,
"thread_m_blocks"
:
THREAD_M_BLOCKS
,
"group_blocks"
:
[
-
1
,
2
,
4
,
8
],
},
# GPTQ-INT4
{
"b_type"
:
"kU4B8"
,
"thread_configs"
:
THREAD_CONFIGS
,
"thread_m_blocks"
:
THREAD_M_BLOCKS
,
"group_blocks"
:
[
-
1
,
0
,
2
,
4
,
8
],
},
# AWQ-INT8
{
"b_type"
:
"kU8B128"
,
"thread_configs"
:
THREAD_CONFIGS
,
"thread_m_blocks"
:
THREAD_M_BLOCKS
,
"group_blocks"
:
[
-
1
,
0
,
2
,
4
,
8
],
},
# FP8
{
"b_type"
:
"kFE4M3fn"
,
"thread_configs"
:
THREAD_CONFIGS
,
"thread_m_blocks"
:
THREAD_M_BLOCKS
,
"group_blocks"
:
[
-
1
,
8
],
},
# NVFP4
{
"b_type"
:
"kFE2M1f"
,
"s_type"
:
"kFE4M3fn"
,
"thread_configs"
:
THREAD_CONFIGS
,
"thread_m_blocks"
:
THREAD_M_BLOCKS
,
"group_blocks"
:
[
1
],
},
# MXFP4
{
"a_type"
:
[
"kBFloat16"
],
"b_type"
:
"kFE2M1f"
,
"s_type"
:
"kFE8M0fnu"
,
"thread_configs"
:
THREAD_CONFIGS
,
"thread_m_blocks"
:
THREAD_M_BLOCKS
,
"group_blocks"
:
[
2
],
},
# AWQ-INT4 with INT8 activation
{
"a_type"
:
[
"kS8"
],
"b_type"
:
"kU4"
,
"thread_configs"
:
THREAD_CONFIGS
,
"thread_m_blocks"
:
[
1
,
2
,
3
,
4
],
"group_blocks"
:
[
-
1
,
2
,
4
,
8
],
},
# GPTQ-INT4 with INT8 activation
{
"a_type"
:
[
"kS8"
],
"b_type"
:
"kU4B8"
,
"thread_configs"
:
THREAD_CONFIGS
,
"thread_m_blocks"
:
[
1
,
2
,
3
,
4
],
"group_blocks"
:
[
-
1
,
2
,
4
,
8
],
},
# GPTQ-INT4 with FP8 activation
{
"a_type"
:
[
"kFE4M3fn"
],
"b_type"
:
"kU4B8"
,
"thread_configs"
:
THREAD_CONFIGS
,
"thread_m_blocks"
:
[
1
,
2
,
3
,
4
],
"group_blocks"
:
[
-
1
,
2
,
4
,
8
],
},
# AWQ-INT4 with FP8 activation
{
"a_type"
:
[
"kFE4M3fn"
],
"b_type"
:
"kU4"
,
"thread_configs"
:
THREAD_CONFIGS
,
"thread_m_blocks"
:
[
1
,
2
,
3
,
4
],
"group_blocks"
:
[
-
1
,
2
,
4
,
8
],
},
# MXFP4 with FP8 activation
{
"a_type"
:
[
"kFE4M3fn"
],
"b_type"
:
"kFE2M1f"
,
"c_type"
:
[
"kBFloat16"
],
"s_type"
:
"kFE8M0fnu"
,
"thread_configs"
:
THREAD_CONFIGS
,
"thread_m_blocks"
:
[
1
,
2
,
3
,
4
],
"group_blocks"
:
[
2
],
},
]
def
remove_old_kernels
():
for
filename
in
glob
.
glob
(
os
.
path
.
dirname
(
__file__
)
+
"/kernel_*.cu"
):
for
filename
in
glob
.
glob
(
os
.
path
.
dirname
(
__file__
)
+
"/
*
kernel_*.cu"
):
subprocess
.
call
([
"rm"
,
"-f"
,
filename
])
filename
=
os
.
path
.
dirname
(
__file__
)
+
"/kernel_selector.h"
subprocess
.
call
([
"rm"
,
"-f"
,
filename
])
def
generate_new_kernels
():
for
scalar_type
,
dtype
in
itertools
.
product
(
SCALAR_TYPES
,
DTYPES
):
all_template_str_list
=
[]
result_dict
=
{}
for
group_blocks
,
m_blocks
,
thread_configs
in
itertools
.
product
(
GROUP_BLOCKS
,
THREAD_M_BLOCKS
,
THREAD_CONFIGS
):
for
quant_config
in
QUANT_CONFIGS
:
c_types
=
quant_config
.
get
(
"c_type"
,
[
"kFloat16"
,
"kBFloat16"
])
a_types
=
quant_config
.
get
(
"a_type"
,
[
"kFloat16"
,
"kBFloat16"
])
b_type
=
quant_config
[
"b_type"
]
all_group_blocks
=
quant_config
[
"group_blocks"
]
all_m_blocks
=
quant_config
[
"thread_m_blocks"
]
all_thread_configs
=
quant_config
[
"thread_configs"
]
# act order case only support gptq-int4 and gptq-int8
if
group_blocks
==
0
and
scalar_type
not
in
[
"vllm::kU4B8"
,
"vllm::kU8B128"
]:
continue
if
thread_configs
[
2
]
==
256
:
# for small batch (m_blocks == 1), we only need (128, 128, 256)
# for large batch (m_blocks > 1), we only need (64, 256, 256)
if
m_blocks
<=
1
and
thread_configs
[
0
]
!=
128
:
continue
if
m_blocks
>
1
and
thread_configs
[
0
]
!=
64
:
continue
# we only support channelwise quantization and group_size == 128
# for fp8
if
scalar_type
==
"vllm::kFE4M3fn"
and
group_blocks
not
in
[
-
1
,
8
]:
for
a_type
,
c_type
in
itertools
.
product
(
a_types
,
c_types
):
if
not
SUPPORT_FP8
and
a_type
==
"kFE4M3fn"
:
continue
# nvfp4 only supports group_size == 16
# mxfp4 only supports group_size == 32
if
scalar_type
==
"vllm::kFE2M1f"
and
group_blocks
not
in
[
1
,
2
]:
continue
# other quantization methods don't support group_size = 16
if
scalar_type
!=
"vllm::kFE2M1f"
and
group_blocks
==
1
:
if
"16"
in
a_type
and
"16"
in
c_type
and
a_type
!=
c_type
:
continue
s_type
=
quant_config
.
get
(
"s_type"
,
c_type
)
if
(
a_type
,
b_type
,
c_type
)
not
in
result_dict
:
result_dict
[(
a_type
,
b_type
,
c_type
)]
=
[]
for
group_blocks
,
m_blocks
,
thread_configs
in
itertools
.
product
(
all_group_blocks
,
all_m_blocks
,
all_thread_configs
):
thread_k
,
thread_n
,
threads
=
thread_configs
if
threads
==
256
:
# for small batch (m_blocks == 1),
# we only need (128, 128, 256)
# for large batch (m_blocks > 1),
# we only need (64, 256, 256)
if
m_blocks
<=
1
and
(
thread_k
,
thread_n
)
!=
(
128
,
128
):
continue
if
m_blocks
>
1
and
(
thread_k
,
thread_n
)
!=
(
64
,
256
):
continue
k_blocks
=
thread_configs
[
0
]
//
16
n_blocks
=
thread_configs
[
1
]
//
16
threads
=
thread_configs
[
2
]
config
=
{
"threads"
:
threads
,
"s_type"
:
s_type
,
"thread_m_blocks"
:
max
(
m_blocks
,
1
),
"thread_k_blocks"
:
thread_k
//
16
,
"thread_n_blocks"
:
thread_n
//
16
,
"m_block_size_8"
:
"true"
if
m_blocks
==
0.5
else
"false"
,
"stages"
:
"pipe_stages"
,
"group_blocks"
:
group_blocks
,
"is_zp_float"
:
"false"
,
}
c_dtype
=
"half"
if
dtype
==
"fp16"
else
"nv_bfloat16"
result_dict
[(
a_type
,
b_type
,
c_type
)].
append
(
config
)
if
scalar_type
==
"vllm::kFE2M1f"
and
group_blocks
==
1
:
s_type
=
"vllm::kFE4M3fn"
elif
scalar_type
==
"vllm::kFE2M1f"
and
group_blocks
==
2
:
s_type
=
"vllm::kFE8M0fnu"
if
dtype
==
"fp16"
:
# we cannot safely dequantize e8m0 to fp16, so skip this
continue
elif
dtype
==
"fp16"
:
s_type
=
"vllm::kFloat16"
elif
dtype
==
"bf16"
:
s_type
=
"vllm::kBFloat16"
kernel_selector_str
=
FILE_HEAD_COMMENT
for
(
a_type
,
b_type
,
c_type
),
config_list
in
result_dict
.
items
():
all_template_str_list
=
[]
for
config
in
config_list
:
s_type
=
config
[
"s_type"
]
template_str
=
jinja2
.
Template
(
TEMPLATE
).
render
(
scalar_t
=
c_dtype
,
w_type_id
=
scalar_type
+
".id()"
,
s_type_id
=
s_type
+
".id()"
,
threads
=
threads
,
thread_m_blocks
=
max
(
m_blocks
,
1
),
thread_n_blocks
=
n_blocks
,
thread_k_blocks
=
k_blocks
,
m_block_size_8
=
m_blocks
==
0.5
,
stages
=
"pipe_stages"
,
group_blocks
=
group_blocks
,
is_zp_float
=
False
,
a_type_id
=
f
"vllm::
{
a_type
}
.id()"
,
b_type_id
=
f
"vllm::
{
b_type
}
.id()"
,
c_type_id
=
f
"vllm::
{
c_type
}
.id()"
,
s_type_id
=
f
"vllm::
{
s_type
}
.id()"
,
**
config
,
)
all_template_str_list
.
append
(
template_str
)
conditions
=
[
f
"a_type == vllm::
{
a_type
}
"
,
f
"b_type == vllm::
{
b_type
}
"
,
f
"c_type == vllm::
{
c_type
}
"
,
f
"s_type == vllm::
{
s_type
}
"
,
f
"threads ==
{
config
[
'threads'
]
}
"
,
f
"thread_m_blocks ==
{
config
[
'thread_m_blocks'
]
}
"
,
f
"thread_n_blocks ==
{
config
[
'thread_n_blocks'
]
}
"
,
f
"thread_k_blocks ==
{
config
[
'thread_k_blocks'
]
}
"
,
f
"m_block_size_8 ==
{
config
[
'm_block_size_8'
]
}
"
,
f
"group_blocks ==
{
config
[
'group_blocks'
]
}
"
,
f
"is_zp_float ==
{
config
[
'is_zp_float'
]
}
"
,
]
conditions
=
" && "
.
join
(
conditions
)
if
kernel_selector_str
==
FILE_HEAD_COMMENT
:
kernel_selector_str
+=
f
"if (
{
conditions
}
)
\n
kernel = "
else
:
kernel_selector_str
+=
f
"else if (
{
conditions
}
)
\n
kernel = "
kernel_template2
=
(
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
)
kernel_selector_str
+=
(
jinja2
.
Template
(
kernel_template2
).
render
(
a_type_id
=
f
"vllm::
{
a_type
}
.id()"
,
b_type_id
=
f
"vllm::
{
b_type
}
.id()"
,
c_type_id
=
f
"vllm::
{
c_type
}
.id()"
,
s_type_id
=
f
"vllm::
{
s_type
}
.id()"
,
**
config
,
)
+
"
\n
"
)
file_content
=
FILE_HEAD
+
"
\n\n
"
file_content
+=
"
\n\n
"
.
join
(
all_template_str_list
)
+
"
\n\n
}
\n
"
filename
=
f
"kernel_
{
dtype
}
_
{
scalar_type
[
6
:].
lower
()
}
.cu"
if
a_type
==
"kFE4M3fn"
:
filename
=
f
"sm89_kernel_
{
a_type
[
1
:]
}
_
{
b_type
[
1
:]
}
_
{
c_type
[
1
:]
}
.cu"
else
:
filename
=
f
"sm80_kernel_
{
a_type
[
1
:]
}
_
{
b_type
[
1
:]
}
_
{
c_type
[
1
:]
}
.cu"
filename
=
filename
.
lower
()
with
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
filename
),
"w"
)
as
f
:
f
.
write
(
file_content
)
if
not
SUPPORT_FP8
and
kernel_selector_str
!=
FILE_HEAD_COMMENT
:
kernel_selector_str
+=
(
"else if (a_type == vllm::kFE4M3fn)
\n
"
" TORCH_CHECK(false, "
'"marlin kernel with fp8 activation is not built.");'
)
with
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"kernel_selector.h"
),
"w"
)
as
f
:
f
.
write
(
kernel_selector_str
)
if
__name__
==
"__main__"
:
remove_old_kernels
()
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment