Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
006693ed
Commit
006693ed
authored
Dec 01, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.11.2' into v0.11.2-ori
parents
4b51e6f1
275de341
Changes
544
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2316 additions
and
794 deletions
+2316
-794
csrc/cumem_allocator.cpp
csrc/cumem_allocator.cpp
+392
-17
csrc/cumem_allocator_compat.h
csrc/cumem_allocator_compat.h
+109
-0
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
+16
-21
csrc/dispatch_utils.h
csrc/dispatch_utils.h
+29
-0
csrc/fused_qknorm_rope_kernel.cu
csrc/fused_qknorm_rope_kernel.cu
+428
-0
csrc/launch_bounds_utils.h
csrc/launch_bounds_utils.h
+29
-3
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+60
-269
csrc/layernorm_quant_kernels.cu
csrc/layernorm_quant_kernels.cu
+51
-18
csrc/mamba/mamba_ssm/selective_scan.h
csrc/mamba/mamba_ssm/selective_scan.h
+7
-1
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
+113
-21
csrc/moe/dynamic_4bit_int_moe_cpu.cpp
csrc/moe/dynamic_4bit_int_moe_cpu.cpp
+12
-21
csrc/moe/grouped_topk_kernels.cu
csrc/moe/grouped_topk_kernels.cu
+99
-54
csrc/moe/marlin_moe_wna16/generate_kernels.py
csrc/moe/marlin_moe_wna16/generate_kernels.py
+24
-18
csrc/moe/moe_align_sum_kernels.cu
csrc/moe/moe_align_sum_kernels.cu
+92
-0
csrc/moe/moe_lora_align_sum_kernels.cu
csrc/moe/moe_lora_align_sum_kernels.cu
+174
-0
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+19
-4
csrc/moe/topk_softmax_kernels.cu
csrc/moe/topk_softmax_kernels.cu
+216
-87
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+31
-3
csrc/ops.h
csrc/ops.h
+33
-20
csrc/quantization/activation_kernels.cu
csrc/quantization/activation_kernels.cu
+382
-237
No files found.
Too many changes to show.
To preserve performance only
544 of 544+
files are displayed.
Plain diff
Email patch
csrc/cumem_allocator.cpp
View file @
006693ed
...
...
@@ -3,14 +3,58 @@
// need to be unsigned long long
#include <iostream>
#include "cumem_allocator_compat.h"
#ifndef USE_ROCM
static
const
char
*
PYARGS_PARSE
=
"KKKK"
;
#else
#include <cstdlib>
#include <cerrno>
#include <climits>
// Default chunk size 256MB for ROCm. Can be overridden at runtime by the
// environment variable VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE, specified in megabytes
// (MB). The env value is parsed with strtoull as an integer number of MB
// (decimal or 0x hex). The parsed MB value is converted to bytes. If
// parsing fails, the value is 0, or the multiplication would overflow,
// the default (256MB) is used.
static
const
unsigned
long
long
DEFAULT_MEMCREATE_CHUNK_SIZE
=
(
256ULL
*
1024ULL
*
1024ULL
);
static
unsigned
long
long
get_memcreate_chunk_size
()
{
const
char
*
env
=
getenv
(
"VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE"
);
if
(
!
env
)
return
DEFAULT_MEMCREATE_CHUNK_SIZE
;
char
*
endptr
=
nullptr
;
errno
=
0
;
unsigned
long
long
val_mb
=
strtoull
(
env
,
&
endptr
,
0
);
if
(
endptr
==
env
||
errno
!=
0
)
{
// parsing failed, fallback to default
return
DEFAULT_MEMCREATE_CHUNK_SIZE
;
}
if
(
val_mb
==
0
)
return
DEFAULT_MEMCREATE_CHUNK_SIZE
;
const
unsigned
long
long
MB
=
1024ULL
*
1024ULL
;
// guard against overflow when converting MB -> bytes
if
(
val_mb
>
(
ULLONG_MAX
/
MB
))
{
return
DEFAULT_MEMCREATE_CHUNK_SIZE
;
}
return
val_mb
*
MB
;
}
static
inline
unsigned
long
long
my_min
(
unsigned
long
long
a
,
unsigned
long
long
b
)
{
return
a
<
b
?
a
:
b
;
}
static
const
char
*
PYARGS_PARSE
=
"KKKO"
;
#endif
extern
"C"
{
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <sys/types.h>
#include <cuda_runtime_api.h>
#include <cuda.h>
char
error_msg
[
10240
];
// 10KB buffer to store error messages
CUresult
no_error
=
CUresult
(
0
);
...
...
@@ -49,7 +93,12 @@ void ensure_context(unsigned long long device) {
}
void
create_and_map
(
unsigned
long
long
device
,
ssize_t
size
,
CUdeviceptr
d_mem
,
#ifndef USE_ROCM
CUmemGenericAllocationHandle
*
p_memHandle
)
{
#else
CUmemGenericAllocationHandle
**
p_memHandle
,
unsigned
long
long
*
chunk_sizes
,
size_t
num_chunks
)
{
#endif
ensure_context
(
device
);
// Define memory allocation properties
CUmemAllocationProp
prop
=
{};
...
...
@@ -58,6 +107,7 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
prop
.
location
.
id
=
device
;
prop
.
allocFlags
.
compressionType
=
CU_MEM_ALLOCATION_COMP_NONE
;
#ifndef USE_ROCM
// Allocate memory using cuMemCreate
CUDA_CHECK
(
cuMemCreate
(
p_memHandle
,
size
,
&
prop
,
0
));
if
(
error_code
!=
0
)
{
...
...
@@ -67,6 +117,39 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
if
(
error_code
!=
0
)
{
return
;
}
#else
for
(
auto
i
=
0
;
i
<
num_chunks
;
++
i
)
{
CUDA_CHECK
(
cuMemCreate
(
p_memHandle
[
i
],
chunk_sizes
[
i
],
&
prop
,
0
));
if
(
error_code
!=
0
)
{
// Clean up previously created handles
for
(
auto
j
=
0
;
j
<
i
;
++
j
)
{
cuMemRelease
(
*
(
p_memHandle
[
j
]));
}
return
;
}
}
unsigned
long
long
allocated_size
=
0
;
for
(
auto
i
=
0
;
i
<
num_chunks
;
++
i
)
{
void
*
map_addr
=
(
void
*
)((
uintptr_t
)
d_mem
+
allocated_size
);
CUDA_CHECK
(
cuMemMap
(
map_addr
,
chunk_sizes
[
i
],
0
,
*
(
p_memHandle
[
i
]),
0
));
if
(
error_code
!=
0
)
{
// unmap previously mapped chunks
unsigned
long
long
unmapped_size
=
0
;
for
(
auto
j
=
0
;
j
<
i
;
++
j
)
{
void
*
unmap_addr
=
(
void
*
)((
uintptr_t
)
d_mem
+
unmapped_size
);
cuMemUnmap
(
unmap_addr
,
chunk_sizes
[
j
]);
unmapped_size
+=
chunk_sizes
[
j
];
}
// release all created handles
for
(
auto
j
=
0
;
j
<
num_chunks
;
++
j
)
{
cuMemRelease
(
*
(
p_memHandle
[
j
]));
}
return
;
}
allocated_size
+=
chunk_sizes
[
i
];
}
#endif
CUmemAccessDesc
accessDesc
=
{};
accessDesc
.
location
.
type
=
CU_MEM_LOCATION_TYPE_DEVICE
;
accessDesc
.
location
.
id
=
device
;
...
...
@@ -82,10 +165,16 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
void
unmap_and_release
(
unsigned
long
long
device
,
ssize_t
size
,
CUdeviceptr
d_mem
,
#ifndef USE_ROCM
CUmemGenericAllocationHandle
*
p_memHandle
)
{
#else
CUmemGenericAllocationHandle
**
p_memHandle
,
unsigned
long
long
*
chunk_sizes
,
size_t
num_chunks
)
{
#endif
// std::cout << "unmap_and_release: device=" << device << ", size=" << size <<
// ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
ensure_context
(
device
);
#ifndef USE_ROCM
CUDA_CHECK
(
cuMemUnmap
(
d_mem
,
size
));
if
(
error_code
!=
0
)
{
return
;
...
...
@@ -94,6 +183,30 @@ void unmap_and_release(unsigned long long device, ssize_t size,
if
(
error_code
!=
0
)
{
return
;
}
#else
unsigned
long
long
allocated_size
=
0
;
CUresult
first_error
=
no_error
;
for
(
auto
i
=
0
;
i
<
num_chunks
;
++
i
)
{
void
*
map_addr
=
(
void
*
)((
uintptr_t
)
d_mem
+
allocated_size
);
CUresult
status
=
cuMemUnmap
(
map_addr
,
chunk_sizes
[
i
]);
if
(
status
!=
no_error
&&
first_error
==
no_error
)
{
first_error
=
status
;
}
allocated_size
+=
chunk_sizes
[
i
];
}
for
(
auto
i
=
0
;
i
<
num_chunks
;
++
i
)
{
CUresult
status
=
cuMemRelease
(
*
(
p_memHandle
[
i
]));
if
(
status
!=
no_error
&&
first_error
==
no_error
)
{
first_error
=
status
;
}
}
if
(
first_error
!=
no_error
)
{
CUDA_CHECK
(
first_error
);
}
#endif
}
PyObject
*
create_tuple_from_c_integers
(
unsigned
long
long
a
,
...
...
@@ -120,6 +233,36 @@ PyObject* create_tuple_from_c_integers(unsigned long long a,
return
tuple
;
// Return the created tuple
}
PyObject
*
create_tuple_from_c_mixed
(
unsigned
long
long
a
,
unsigned
long
long
b
,
unsigned
long
long
c
,
CUmemGenericAllocationHandle
**
vec
,
unsigned
long
long
*
chunk_sizes
,
size_t
num_chunks
)
{
PyObject
*
tuple
=
PyTuple_New
(
4
);
if
(
!
tuple
)
{
return
NULL
;
}
// PyObject* list = PyList_New(vec.size());
PyObject
*
list
=
PyList_New
(
num_chunks
);
for
(
auto
i
=
0
;
i
<
num_chunks
;
++
i
)
{
PyObject
*
addr_size_pair
=
PyTuple_New
(
2
);
PyObject
*
addr
=
PyLong_FromUnsignedLongLong
((
unsigned
long
long
)(
vec
[
i
]));
PyObject
*
size
=
PyLong_FromUnsignedLongLong
((
unsigned
long
long
)(
chunk_sizes
[
i
]));
PyTuple_SetItem
(
addr_size_pair
,
0
,
addr
);
PyTuple_SetItem
(
addr_size_pair
,
1
,
size
);
PyList_SetItem
(
list
,
i
,
addr_size_pair
);
}
PyTuple_SetItem
(
tuple
,
0
,
PyLong_FromUnsignedLongLong
(
a
));
PyTuple_SetItem
(
tuple
,
1
,
PyLong_FromUnsignedLongLong
(
b
));
PyTuple_SetItem
(
tuple
,
2
,
PyLong_FromUnsignedLongLong
(
c
));
PyTuple_SetItem
(
tuple
,
3
,
list
);
return
tuple
;
}
// ---------------------------------------------------------------------------
// Our exported C functions that call Python:
...
...
@@ -147,14 +290,55 @@ void* my_malloc(ssize_t size, int device, CUstream stream) {
size_t
alignedSize
=
((
size
+
granularity
-
1
)
/
granularity
)
*
granularity
;
CUdeviceptr
d_mem
;
#ifndef USE_ROCM
CUDA_CHECK
(
cuMemAddressReserve
(
&
d_mem
,
alignedSize
,
0
,
0
,
0
));
if
(
error_code
!=
0
)
{
return
nullptr
;
}
#else
CUDA_CHECK
(
cuMemAddressReserve
(
&
d_mem
,
alignedSize
,
granularity
,
0
,
0
));
if
(
error_code
!=
0
)
{
return
nullptr
;
}
#endif
#ifndef USE_ROCM
// allocate the CUmemGenericAllocationHandle
CUmemGenericAllocationHandle
*
p_memHandle
=
(
CUmemGenericAllocationHandle
*
)
malloc
(
sizeof
(
CUmemGenericAllocationHandle
));
#else
// Make sure chunk size is aligned with hardware granularity. The base
// chunk size can be configured via environment variable
// ``VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE``; otherwise
// DEFAULT_MEMCREATE_CHUNK_SIZE is used.
size_t
base_chunk
=
(
size_t
)
get_memcreate_chunk_size
();
size_t
aligned_chunk_size
=
((
base_chunk
+
granularity
-
1
)
/
granularity
)
*
granularity
;
size_t
num_chunks
=
(
alignedSize
+
aligned_chunk_size
-
1
)
/
aligned_chunk_size
;
CUmemGenericAllocationHandle
**
p_memHandle
=
(
CUmemGenericAllocationHandle
**
)
malloc
(
num_chunks
*
sizeof
(
CUmemGenericAllocationHandle
*
));
unsigned
long
long
*
chunk_sizes
=
(
unsigned
long
long
*
)
malloc
(
num_chunks
*
sizeof
(
unsigned
long
long
));
for
(
auto
i
=
0
;
i
<
num_chunks
;
++
i
)
{
p_memHandle
[
i
]
=
(
CUmemGenericAllocationHandle
*
)
malloc
(
sizeof
(
CUmemGenericAllocationHandle
));
if
(
p_memHandle
[
i
]
==
nullptr
)
{
std
::
cerr
<<
"ERROR: malloc failed for p_memHandle["
<<
i
<<
"].
\n
"
;
for
(
auto
j
=
0
;
j
<
i
;
++
j
)
{
free
(
p_memHandle
[
j
]);
}
free
(
p_memHandle
);
free
(
chunk_sizes
);
return
nullptr
;
}
chunk_sizes
[
i
]
=
(
unsigned
long
long
)
my_min
(
(
unsigned
long
long
)(
alignedSize
-
i
*
aligned_chunk_size
),
(
unsigned
long
long
)
aligned_chunk_size
);
}
#endif
if
(
!
g_python_malloc_callback
)
{
std
::
cerr
<<
"ERROR: g_python_malloc_callback not set.
\n
"
;
...
...
@@ -164,9 +348,15 @@ void* my_malloc(ssize_t size, int device, CUstream stream) {
// Acquire GIL (not in stable ABI officially, but often works)
PyGILState_STATE
gstate
=
PyGILState_Ensure
();
#ifndef USE_ROCM
PyObject
*
arg_tuple
=
create_tuple_from_c_integers
(
(
unsigned
long
long
)
device
,
(
unsigned
long
long
)
alignedSize
,
(
unsigned
long
long
)
d_mem
,
(
unsigned
long
long
)
p_memHandle
);
#else
PyObject
*
arg_tuple
=
create_tuple_from_c_mixed
(
(
unsigned
long
long
)
device
,
(
unsigned
long
long
)
alignedSize
,
(
unsigned
long
long
)
d_mem
,
p_memHandle
,
chunk_sizes
,
num_chunks
);
#endif
// Call g_python_malloc_callback
PyObject
*
py_result
=
...
...
@@ -182,7 +372,27 @@ void* my_malloc(ssize_t size, int device, CUstream stream) {
PyGILState_Release
(
gstate
);
// do the final mapping
#ifndef USE_ROCM
create_and_map
(
device
,
alignedSize
,
d_mem
,
p_memHandle
);
#else
create_and_map
(
device
,
alignedSize
,
d_mem
,
p_memHandle
,
chunk_sizes
,
num_chunks
);
free
(
chunk_sizes
);
#endif
if
(
error_code
!=
0
)
{
// free address and the handle
CUDA_CHECK
(
cuMemAddressFree
(
d_mem
,
alignedSize
));
#ifndef USE_ROCM
free
(
p_memHandle
);
#else
for
(
size_t
i
=
0
;
i
<
num_chunks
;
++
i
)
{
free
(
p_memHandle
[
i
]);
}
free
(
p_memHandle
);
#endif
return
nullptr
;
}
return
(
void
*
)
d_mem
;
}
...
...
@@ -206,36 +416,96 @@ void my_free(void* ptr, ssize_t size, int device, CUstream stream) {
if
(
!
py_result
||
!
PyTuple_Check
(
py_result
)
||
PyTuple_Size
(
py_result
)
!=
4
)
{
PyErr_SetString
(
PyExc_TypeError
,
"Expected a tuple of size 4"
);
Py_XDECREF
(
py_result
);
Py_XDECREF
(
py_ptr
);
return
;
}
unsigned
long
long
recv_device
,
recv_size
;
unsigned
long
long
recv_d_mem
,
recv_p_memHandle
;
unsigned
long
long
recv_d_mem
;
#ifndef USE_ROCM
unsigned
long
long
recv_p_memHandle
;
#else
PyObject
*
recv_p_memHandle
;
#endif
// Unpack the tuple into four C integers
if
(
!
PyArg_ParseTuple
(
py_result
,
"KKKK"
,
&
recv_device
,
&
recv_size
,
if
(
!
PyArg_ParseTuple
(
py_result
,
PYARGS_PARSE
,
&
recv_device
,
&
recv_size
,
&
recv_d_mem
,
&
recv_p_memHandle
))
{
// PyArg_ParseTuple sets an error if it fails
Py_XDECREF
(
py_result
);
Py_XDECREF
(
py_ptr
);
return
;
}
PyGILState_Release
(
gstate
);
// For ROCm, copy the Python list of (addr,size) pairs into C arrays while
// holding the GIL. Then release the GIL and call the unmap/release helper
// using the copied arrays. This avoids calling PyList_* APIs without the
// GIL (which is undefined behavior and can crash when called from other
// threads).
CUdeviceptr
d_mem
=
(
CUdeviceptr
)
recv_d_mem
;
#ifdef USE_ROCM
Py_ssize_t
num_chunks
=
PyList_Size
(
recv_p_memHandle
);
CUmemGenericAllocationHandle
**
p_memHandle
=
(
CUmemGenericAllocationHandle
**
)
malloc
(
num_chunks
*
sizeof
(
CUmemGenericAllocationHandle
*
));
if
(
p_memHandle
==
nullptr
)
{
Py_DECREF
(
py_ptr
);
Py_DECREF
(
py_result
);
PyGILState_Release
(
gstate
);
std
::
cerr
<<
"ERROR: malloc failed for p_memHandle in my_free."
<<
std
::
endl
;
return
;
}
unsigned
long
long
*
chunk_sizes
=
(
unsigned
long
long
*
)
malloc
(
num_chunks
*
sizeof
(
unsigned
long
long
));
if
(
chunk_sizes
==
nullptr
)
{
free
(
p_memHandle
);
Py_DECREF
(
py_ptr
);
Py_DECREF
(
py_result
);
PyGILState_Release
(
gstate
);
std
::
cerr
<<
"ERROR: malloc failed for chunk_sizes in my_free."
<<
std
::
endl
;
return
;
}
for
(
Py_ssize_t
i
=
0
;
i
<
num_chunks
;
++
i
)
{
PyObject
*
item
=
PyList_GetItem
(
recv_p_memHandle
,
i
);
PyObject
*
addr_py
=
PyTuple_GetItem
(
item
,
0
);
PyObject
*
size_py
=
PyTuple_GetItem
(
item
,
1
);
p_memHandle
[
i
]
=
(
CUmemGenericAllocationHandle
*
)
PyLong_AsUnsignedLongLong
(
addr_py
);
chunk_sizes
[
i
]
=
(
unsigned
long
long
)
PyLong_AsUnsignedLongLong
(
size_py
);
}
// recv_size == size
// recv_device == device
// Drop temporary Python refs, then release the GIL before calling into
// non-Python APIs.
Py_DECREF
(
py_ptr
);
Py_DECREF
(
py_result
);
PyGILState_Release
(
gstate
);
// Free memory
unmap_and_release
(
device
,
size
,
d_mem
,
p_memHandle
,
chunk_sizes
,
num_chunks
);
#else
// Non-ROCm path: simple integer handle already extracted; drop temporary
// Python refs while still holding the GIL, then release it.
Py_DECREF
(
py_ptr
);
Py_DECREF
(
py_result
);
PyGILState_Release
(
gstate
);
CUdeviceptr
d_mem
=
(
CUdeviceptr
)
recv_d_mem
;
CUmemGenericAllocationHandle
*
p_memHandle
=
(
CUmemGenericAllocationHandle
*
)
recv_p_memHandle
;
unmap_and_release
(
device
,
size
,
d_mem
,
p_memHandle
);
#endif
// free address and the handle
CUDA_CHECK
(
cuMemAddressFree
(
d_mem
,
size
));
if
(
error_code
!=
0
)
{
return
;
#ifndef USE_ROCM
free
(
p_memHandle
);
#else
for
(
auto
i
=
0
;
i
<
num_chunks
;
++
i
)
{
free
(
p_memHandle
[
i
]);
}
free
(
p_memHandle
);
free
(
chunk_sizes
);
#endif
}
// ---------------------------------------------------------------------------
...
...
@@ -271,19 +541,87 @@ static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) {
}
unsigned
long
long
recv_device
,
recv_size
;
unsigned
long
long
recv_d_mem
,
recv_p_memHandle
;
unsigned
long
long
recv_d_mem
;
#ifndef USE_ROCM
unsigned
long
long
recv_p_memHandle
;
#else
PyObject
*
recv_p_memHandle
;
#endif
// Unpack the tuple into four C integers
if
(
!
PyArg_ParseTuple
(
args
,
"KKKK"
,
&
recv_device
,
&
recv_size
,
&
recv_d_mem
,
&
recv_p_memHandle
))
{
if
(
!
PyArg_ParseTuple
(
args
,
PYARGS_PARSE
,
&
recv_device
,
&
recv_size
,
&
recv_d_mem
,
&
recv_p_memHandle
))
{
// PyArg_ParseTuple sets an error if it fails
return
nullptr
;
}
CUdeviceptr
d_mem_ptr
=
(
CUdeviceptr
)
recv_d_mem
;
#ifndef USE_ROCM
CUmemGenericAllocationHandle
*
p_memHandle
=
(
CUmemGenericAllocationHandle
*
)
recv_p_memHandle
;
unmap_and_release
(
recv_device
,
recv_size
,
d_mem_ptr
,
p_memHandle
);
#else
if
(
!
PyList_Check
(
recv_p_memHandle
))
{
PyErr_SetString
(
PyExc_TypeError
,
"Expected a list for the 4th argument on ROCm"
);
return
nullptr
;
}
Py_ssize_t
num_chunks
=
PyList_Size
(
recv_p_memHandle
);
if
(
num_chunks
<
0
)
{
return
nullptr
;
// PyList_Size sets an exception on error.
}
CUmemGenericAllocationHandle
**
p_memHandle
=
(
CUmemGenericAllocationHandle
**
)
malloc
(
num_chunks
*
sizeof
(
CUmemGenericAllocationHandle
*
));
if
(
p_memHandle
==
nullptr
)
{
PyErr_SetString
(
PyExc_MemoryError
,
"malloc failed for p_memHandle"
);
return
nullptr
;
}
unsigned
long
long
*
chunk_sizes
=
(
unsigned
long
long
*
)
malloc
(
num_chunks
*
sizeof
(
unsigned
long
long
));
if
(
chunk_sizes
==
nullptr
)
{
free
(
p_memHandle
);
PyErr_SetString
(
PyExc_MemoryError
,
"malloc failed for chunk_sizes"
);
return
nullptr
;
}
for
(
Py_ssize_t
i
=
0
;
i
<
num_chunks
;
++
i
)
{
PyObject
*
item
=
PyList_GetItem
(
recv_p_memHandle
,
i
);
if
(
item
==
nullptr
||
!
PyTuple_Check
(
item
)
||
PyTuple_Size
(
item
)
!=
2
)
{
free
(
p_memHandle
);
free
(
chunk_sizes
);
PyErr_SetString
(
PyExc_TypeError
,
"List items must be tuples of size 2 (handle_addr, size)"
);
return
nullptr
;
}
PyObject
*
addr_py
=
PyTuple_GetItem
(
item
,
0
);
PyObject
*
size_py
=
PyTuple_GetItem
(
item
,
1
);
if
(
addr_py
==
nullptr
||
size_py
==
nullptr
)
{
free
(
p_memHandle
);
free
(
chunk_sizes
);
return
nullptr
;
// PyTuple_GetItem sets an exception
}
p_memHandle
[
i
]
=
(
CUmemGenericAllocationHandle
*
)
PyLong_AsUnsignedLongLong
(
addr_py
);
if
(
PyErr_Occurred
())
{
free
(
p_memHandle
);
free
(
chunk_sizes
);
return
nullptr
;
}
chunk_sizes
[
i
]
=
(
unsigned
long
long
)
PyLong_AsUnsignedLongLong
(
size_py
);
if
(
PyErr_Occurred
())
{
free
(
p_memHandle
);
free
(
chunk_sizes
);
return
nullptr
;
}
}
unmap_and_release
(
recv_device
,
recv_size
,
d_mem_ptr
,
p_memHandle
,
chunk_sizes
,
num_chunks
);
free
(
p_memHandle
);
free
(
chunk_sizes
);
#endif
if
(
error_code
!=
0
)
{
error_code
=
no_error
;
...
...
@@ -301,19 +639,56 @@ static PyObject* python_create_and_map(PyObject* self, PyObject* args) {
}
unsigned
long
long
recv_device
,
recv_size
;
unsigned
long
long
recv_d_mem
,
recv_p_memHandle
;
unsigned
long
long
recv_d_mem
;
#ifndef USE_ROCM
unsigned
long
long
recv_p_memHandle
;
#else
PyObject
*
recv_p_memHandle
;
#endif
// Unpack the tuple into four C integers
if
(
!
PyArg_ParseTuple
(
args
,
"KKKK"
,
&
recv_device
,
&
recv_size
,
&
recv_d_mem
,
&
recv_p_memHandle
))
{
if
(
!
PyArg_ParseTuple
(
args
,
PYARGS_PARSE
,
&
recv_device
,
&
recv_size
,
&
recv_d_mem
,
&
recv_p_memHandle
))
{
// PyArg_ParseTuple sets an error if it fails
return
nullptr
;
}
CUdeviceptr
d_mem_ptr
=
(
CUdeviceptr
)
recv_d_mem
;
#ifndef USE_ROCM
CUmemGenericAllocationHandle
*
p_memHandle
=
(
CUmemGenericAllocationHandle
*
)
recv_p_memHandle
;
create_and_map
(
recv_device
,
recv_size
,
d_mem_ptr
,
p_memHandle
);
#else
Py_ssize_t
num_chunks
=
PyList_Size
(
recv_p_memHandle
);
CUmemGenericAllocationHandle
**
p_memHandle
=
(
CUmemGenericAllocationHandle
**
)
malloc
(
num_chunks
*
sizeof
(
CUmemGenericAllocationHandle
*
));
if
(
p_memHandle
==
nullptr
)
{
PyErr_SetString
(
PyExc_MemoryError
,
"malloc failed for p_memHandle"
);
return
nullptr
;
}
unsigned
long
long
*
chunk_sizes
=
(
unsigned
long
long
*
)
malloc
(
num_chunks
*
sizeof
(
unsigned
long
long
));
if
(
chunk_sizes
==
nullptr
)
{
free
(
p_memHandle
);
PyErr_SetString
(
PyExc_MemoryError
,
"malloc failed for chunk_sizes"
);
return
nullptr
;
}
for
(
auto
i
=
0
;
i
<
num_chunks
;
++
i
)
{
PyObject
*
item
=
PyList_GetItem
(
recv_p_memHandle
,
i
);
PyObject
*
addr_py
=
PyTuple_GetItem
(
item
,
0
);
PyObject
*
size_py
=
PyTuple_GetItem
(
item
,
1
);
p_memHandle
[
i
]
=
(
CUmemGenericAllocationHandle
*
)
PyLong_AsUnsignedLongLong
(
addr_py
);
chunk_sizes
[
i
]
=
PyLong_AsUnsignedLongLong
(
size_py
);
}
create_and_map
(
recv_device
,
recv_size
,
d_mem_ptr
,
p_memHandle
,
chunk_sizes
,
num_chunks
);
free
(
p_memHandle
);
free
(
chunk_sizes
);
#endif
if
(
error_code
!=
0
)
{
error_code
=
no_error
;
...
...
csrc/cumem_allocator_compat.h
0 → 100644
View file @
006693ed
#pragma once
#ifdef USE_ROCM
////////////////////////////////////////
// For compatibility with CUDA and ROCm
////////////////////////////////////////
#include <hip/hip_runtime_api.h>
extern
"C"
{
#ifndef CUDA_SUCCESS
#define CUDA_SUCCESS hipSuccess
#endif // CUDA_SUCCESS
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html
typedef
unsigned
long
long
CUdevice
;
typedef
hipDeviceptr_t
CUdeviceptr
;
typedef
hipError_t
CUresult
;
typedef
hipCtx_t
CUcontext
;
typedef
hipStream_t
CUstream
;
typedef
hipMemGenericAllocationHandle_t
CUmemGenericAllocationHandle
;
typedef
hipMemAllocationGranularity_flags
CUmemAllocationGranularity_flags
;
typedef
hipMemAllocationProp
CUmemAllocationProp
;
typedef
hipMemAccessDesc
CUmemAccessDesc
;
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_MEM_ALLOC_GRANULARITY_MINIMUM hipMemAllocationGranularityMinimum
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
#define CU_MEM_ALLOCATION_COMP_NONE 0x0
// Error Handling
// https://docs.nvidia.com/cuda/archive/11.4.4/cuda-driver-api/group__CUDA__ERROR.html
CUresult
cuGetErrorString
(
CUresult
hipError
,
const
char
**
pStr
)
{
*
pStr
=
hipGetErrorString
(
hipError
);
return
CUDA_SUCCESS
;
}
// Context Management
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html
CUresult
cuCtxGetCurrent
(
CUcontext
*
ctx
)
{
// This API is deprecated on the AMD platform, only for equivalent cuCtx
// driver API on the NVIDIA platform.
return
hipCtxGetCurrent
(
ctx
);
}
CUresult
cuCtxSetCurrent
(
CUcontext
ctx
)
{
// This API is deprecated on the AMD platform, only for equivalent cuCtx
// driver API on the NVIDIA platform.
return
hipCtxSetCurrent
(
ctx
);
}
// Primary Context Management
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__PRIMARY__CTX.html
CUresult
cuDevicePrimaryCtxRetain
(
CUcontext
*
ctx
,
CUdevice
dev
)
{
return
hipDevicePrimaryCtxRetain
(
ctx
,
dev
);
}
// Virtual Memory Management
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html
CUresult
cuMemAddressFree
(
CUdeviceptr
ptr
,
size_t
size
)
{
return
hipMemAddressFree
(
ptr
,
size
);
}
CUresult
cuMemAddressReserve
(
CUdeviceptr
*
ptr
,
size_t
size
,
size_t
alignment
,
CUdeviceptr
addr
,
unsigned
long
long
flags
)
{
return
hipMemAddressReserve
(
ptr
,
size
,
alignment
,
addr
,
flags
);
}
CUresult
cuMemCreate
(
CUmemGenericAllocationHandle
*
handle
,
size_t
size
,
const
CUmemAllocationProp
*
prop
,
unsigned
long
long
flags
)
{
return
hipMemCreate
(
handle
,
size
,
prop
,
flags
);
}
CUresult
cuMemGetAllocationGranularity
(
size_t
*
granularity
,
const
CUmemAllocationProp
*
prop
,
CUmemAllocationGranularity_flags
option
)
{
return
hipMemGetAllocationGranularity
(
granularity
,
prop
,
option
);
}
CUresult
cuMemMap
(
CUdeviceptr
dptr
,
size_t
size
,
size_t
offset
,
CUmemGenericAllocationHandle
handle
,
unsigned
long
long
flags
)
{
return
hipMemMap
(
dptr
,
size
,
offset
,
handle
,
flags
);
}
CUresult
cuMemRelease
(
CUmemGenericAllocationHandle
handle
)
{
return
hipMemRelease
(
handle
);
}
CUresult
cuMemSetAccess
(
CUdeviceptr
ptr
,
size_t
size
,
const
CUmemAccessDesc
*
desc
,
size_t
count
)
{
return
hipMemSetAccess
(
ptr
,
size
,
desc
,
count
);
}
CUresult
cuMemUnmap
(
CUdeviceptr
ptr
,
size_t
size
)
{
return
hipMemUnmap
(
ptr
,
size
);
}
}
// extern "C"
#else
////////////////////////////////////////
// Import CUDA headers for NVIDIA GPUs
////////////////////////////////////////
#include <cuda_runtime_api.h>
#include <cuda.h>
#endif
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
View file @
006693ed
...
...
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
enum
from
typing
import
Union
from
cutlass_library
import
*
...
...
@@ -22,31 +21,31 @@ class MixedInputKernelScheduleType(enum.Enum):
TmaWarpSpecializedCooperative
=
enum_auto
()
VLLMDataTypeNames
:
dict
[
Union
[
VLLMDataType
,
DataType
]
,
str
]
=
{
VLLMDataTypeNames
:
dict
[
VLLMDataType
|
DataType
,
str
]
=
{
**
DataTypeNames
,
# type: ignore
**
{
VLLMDataType
.
u4b8
:
"u4b8"
,
VLLMDataType
.
u8b128
:
"u8b128"
,
}
}
,
}
VLLMDataTypeTag
:
dict
[
Union
[
VLLMDataType
,
DataType
]
,
str
]
=
{
VLLMDataTypeTag
:
dict
[
VLLMDataType
|
DataType
,
str
]
=
{
**
DataTypeTag
,
# type: ignore
**
{
VLLMDataType
.
u4b8
:
"cutlass::vllm_uint4b8_t"
,
VLLMDataType
.
u8b128
:
"cutlass::vllm_uint8b128_t"
,
}
}
,
}
VLLMDataTypeSize
:
dict
[
Union
[
VLLMDataType
,
DataType
]
,
int
]
=
{
VLLMDataTypeSize
:
dict
[
VLLMDataType
|
DataType
,
int
]
=
{
**
DataTypeSize
,
# type: ignore
**
{
VLLMDataType
.
u4b8
:
4
,
VLLMDataType
.
u8b128
:
8
,
}
}
,
}
VLLMDataTypeVLLMScalarTypeTag
:
dict
[
Union
[
VLLMDataType
,
DataType
]
,
str
]
=
{
VLLMDataTypeVLLMScalarTypeTag
:
dict
[
VLLMDataType
|
DataType
,
str
]
=
{
VLLMDataType
.
u4b8
:
"vllm::kU4B8"
,
VLLMDataType
.
u8b128
:
"vllm::kU8B128"
,
DataType
.
u4
:
"vllm::kU4"
,
...
...
@@ -57,7 +56,7 @@ VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = {
DataType
.
bf16
:
"vllm::kBfloat16"
,
}
VLLMDataTypeTorchDataTypeTag
:
dict
[
Union
[
VLLMDataType
,
DataType
]
,
str
]
=
{
VLLMDataTypeTorchDataTypeTag
:
dict
[
VLLMDataType
|
DataType
,
str
]
=
{
DataType
.
u8
:
"at::ScalarType::Byte"
,
DataType
.
s8
:
"at::ScalarType::Char"
,
DataType
.
e4m3
:
"at::ScalarType::Float8_e4m3fn"
,
...
...
@@ -67,15 +66,11 @@ VLLMDataTypeTorchDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
DataType
.
f32
:
"at::ScalarType::Float"
,
}
VLLMKernelScheduleTag
:
dict
[
Union
[
MixedInputKernelScheduleType
,
KernelScheduleType
],
str
]
=
{
**
KernelScheduleTag
,
# type: ignore
**
{
MixedInputKernelScheduleType
.
TmaWarpSpecialized
:
"cutlass::gemm::KernelTmaWarpSpecialized"
,
MixedInputKernelScheduleType
.
TmaWarpSpecializedPingpong
:
"cutlass::gemm::KernelTmaWarpSpecializedPingpong"
,
MixedInputKernelScheduleType
.
TmaWarpSpecializedCooperative
:
"cutlass::gemm::KernelTmaWarpSpecializedCooperative"
,
}
}
VLLMKernelScheduleTag
:
dict
[
MixedInputKernelScheduleType
|
KernelScheduleType
,
str
]
=
{
**
KernelScheduleTag
,
# type: ignore
**
{
MixedInputKernelScheduleType
.
TmaWarpSpecialized
:
"cutlass::gemm::KernelTmaWarpSpecialized"
,
# noqa: E501
MixedInputKernelScheduleType
.
TmaWarpSpecializedPingpong
:
"cutlass::gemm::KernelTmaWarpSpecializedPingpong"
,
# noqa: E501
MixedInputKernelScheduleType
.
TmaWarpSpecializedCooperative
:
"cutlass::gemm::KernelTmaWarpSpecializedCooperative"
,
# noqa: E501
},
}
csrc/dispatch_utils.h
View file @
006693ed
...
...
@@ -88,3 +88,32 @@
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_VEC_SIZE(VEC_SIZE, ...) \
switch (VEC_SIZE) { \
case 16: { \
constexpr int vec_size = 16; \
__VA_ARGS__(); \
break; \
} \
case 8: { \
constexpr int vec_size = 8; \
__VA_ARGS__(); \
break; \
} \
case 4: { \
constexpr int vec_size = 4; \
__VA_ARGS__(); \
break; \
} \
case 2: { \
constexpr int vec_size = 2; \
__VA_ARGS__(); \
break; \
} \
default: { \
constexpr int vec_size = 1; \
__VA_ARGS__(); \
break; \
} \
}
csrc/fused_qknorm_rope_kernel.cu
0 → 100644
View file @
006693ed
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cmath>
#include <cuda_runtime.h>
#include <type_traits>
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "type_convert.cuh"
#define CHECK_TYPE(x, st) \
TORCH_CHECK(x.scalar_type() == st, #x " dtype is ", x.scalar_type(), \
", while ", st, " is expected")
#define CHECK_TH_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_TH_CUDA(x); \
CHECK_CONTIGUOUS(x)
#ifdef USE_ROCM
#define FINAL_MASK 0xffffffffffffffffULL
#if defined(HIP_VERSION) && HIP_VERSION < 70000000
// On ROCm versions before 7.0, __syncwarp isn't defined. The below
// implementation is copy/pasted from the implementation in ROCm 7.0
__device__
inline
void
__syncwarp
()
{
__builtin_amdgcn_fence
(
__ATOMIC_RELEASE
,
"wavefront"
);
__builtin_amdgcn_wave_barrier
();
__builtin_amdgcn_fence
(
__ATOMIC_ACQUIRE
,
"wavefront"
);
}
#endif
#else
#define FINAL_MASK 0xffffffff
#endif
namespace
tensorrt_llm
::
common
{
template
<
typename
T
,
int
num
>
struct
packed_as
;
// Specialization for packed_as used in this kernel.
template
<
>
struct
packed_as
<
uint
,
1
>
{
using
type
=
uint
;
};
template
<
>
struct
packed_as
<
uint
,
2
>
{
using
type
=
uint2
;
};
template
<
>
struct
packed_as
<
uint
,
4
>
{
using
type
=
uint4
;
};
template
<
typename
T
>
__inline__
__device__
T
warpReduceSum
(
T
val
)
{
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
val
+=
__shfl_xor_sync
(
FINAL_MASK
,
val
,
mask
,
32
);
return
val
;
}
template
<
typename
T
>
inline
__device__
__host__
T
divUp
(
T
m
,
T
n
)
{
return
(
m
+
n
-
1
)
/
n
;
}
}
// namespace tensorrt_llm::common
namespace
tensorrt_llm
::
kernels
{
// NOTE(zhuhaoran): This kernel is adapted from TensorRT-LLM implementation,
// with added support for passing the cos_sin_cache as an input.
// https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu
// Perform per-head QK Norm and RoPE in a single kernel.
// scalar_t_in: data type of QKV and RMSNorm weights
// scalar_t_cache: data type of cos/sin cache
// head_dim: the dimension of each head
// interleave: interleave=!is_neox.
template
<
typename
scalar_t_in
,
typename
scalar_t_cache
,
int
head_dim
,
bool
interleave
>
__global__
void
fusedQKNormRopeKernel
(
void
*
qkv_void
,
// Combined QKV tensor
int
const
num_heads_q
,
// Number of query heads
int
const
num_heads_k
,
// Number of key heads
int
const
num_heads_v
,
// Number of value heads
float
const
eps
,
// Epsilon for RMS normalization
void
const
*
q_weight_void
,
// RMSNorm weights for query
void
const
*
k_weight_void
,
// RMSNorm weights for key
void
const
*
cos_sin_cache_void
,
// Pre-computed cos/sin cache
int64_t
const
*
position_ids
,
// Position IDs for RoPE
int
const
num_tokens
// Number of tokens
)
{
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
if
constexpr
((
std
::
is_same_v
<
scalar_t_in
,
c10
::
BFloat16
>
)
||
std
::
is_same_v
<
scalar_t_cache
,
c10
::
BFloat16
>
)
{
return
;
}
else
{
#endif
using
Converter
=
vllm
::
_typeConvert
<
scalar_t_in
>
;
static_assert
(
Converter
::
exists
,
"Input QKV data type is not supported for this CUDA "
"architecture or toolkit version."
);
using
T_in
=
typename
Converter
::
hip_type
;
using
T2_in
=
typename
Converter
::
packed_hip_type
;
using
CacheConverter
=
vllm
::
_typeConvert
<
scalar_t_cache
>
;
static_assert
(
CacheConverter
::
exists
,
"Cache data type is not supported for this CUDA architecture "
"or toolkit version."
);
using
T_cache
=
typename
CacheConverter
::
hip_type
;
T_in
*
qkv
=
reinterpret_cast
<
T_in
*>
(
qkv_void
);
T_in
const
*
q_weight
=
reinterpret_cast
<
T_in
const
*>
(
q_weight_void
);
T_in
const
*
k_weight
=
reinterpret_cast
<
T_in
const
*>
(
k_weight_void
);
T_cache
const
*
cos_sin_cache
=
reinterpret_cast
<
T_cache
const
*>
(
cos_sin_cache_void
);
int
const
warpsPerBlock
=
blockDim
.
x
/
32
;
int
const
warpId
=
threadIdx
.
x
/
32
;
int
const
laneId
=
threadIdx
.
x
%
32
;
// Calculate global warp index to determine which head/token this warp
// processes
int
const
globalWarpIdx
=
blockIdx
.
x
*
warpsPerBlock
+
warpId
;
// Total number of attention heads (Q and K)
int
const
total_qk_heads
=
num_heads_q
+
num_heads_k
;
// Determine which token and head type (Q or K) this warp processes
int
const
tokenIdx
=
globalWarpIdx
/
total_qk_heads
;
int
const
localHeadIdx
=
globalWarpIdx
%
total_qk_heads
;
// Skip if this warp is assigned beyond the number of tokens
if
(
tokenIdx
>=
num_tokens
)
return
;
bool
const
isQ
=
localHeadIdx
<
num_heads_q
;
int
const
headIdx
=
isQ
?
localHeadIdx
:
localHeadIdx
-
num_heads_q
;
int
const
num_heads
=
num_heads_q
+
num_heads_k
+
num_heads_v
;
static_assert
(
head_dim
%
(
32
*
2
)
==
0
,
"head_dim must be divisible by 64 (each warp processes one "
"head, and each thread gets even number of "
"elements)"
);
constexpr
int
numElemsPerThread
=
head_dim
/
32
;
float
elements
[
numElemsPerThread
];
constexpr
int
elemSizeBytes
=
numElemsPerThread
*
sizeof
(
__nv_bfloat16
);
static_assert
(
elemSizeBytes
%
4
==
0
,
"numSizeBytes must be a multiple of 4"
);
constexpr
int
vecSize
=
elemSizeBytes
/
4
;
// Use packed_as<uint, vecSize> to perform loading/saving.
using
vec_T
=
typename
tensorrt_llm
::
common
::
packed_as
<
uint
,
vecSize
>::
type
;
int
offsetWarp
;
// Offset for the warp
if
(
isQ
)
{
// Q segment: token offset + head offset within Q segment
offsetWarp
=
tokenIdx
*
num_heads
*
head_dim
+
headIdx
*
head_dim
;
}
else
{
// K segment: token offset + entire Q segment + head offset within K
// segment
offsetWarp
=
tokenIdx
*
num_heads
*
head_dim
+
num_heads_q
*
head_dim
+
headIdx
*
head_dim
;
}
int
offsetThread
=
offsetWarp
+
laneId
*
numElemsPerThread
;
// Sum of squares for RMSNorm
float
sumOfSquares
=
0.0
f
;
// Load.
{
vec_T
vec
=
*
reinterpret_cast
<
vec_T
const
*>
(
&
qkv
[
offsetThread
]);
constexpr
int
num_packed_elems
=
elemSizeBytes
/
sizeof
(
T2_in
);
#pragma unroll
for
(
int
i
=
0
;
i
<
num_packed_elems
;
i
++
)
{
// Interpret the generic vector chunk as the specific packed type
T2_in
packed_val
=
*
(
reinterpret_cast
<
T2_in
*>
(
&
vec
)
+
i
);
// Convert to float2 for computation
float2
vals
=
Converter
::
convert
(
packed_val
);
sumOfSquares
+=
vals
.
x
*
vals
.
x
;
sumOfSquares
+=
vals
.
y
*
vals
.
y
;
elements
[
2
*
i
]
=
vals
.
x
;
elements
[
2
*
i
+
1
]
=
vals
.
y
;
}
}
// Reduce sum across warp using the utility function
sumOfSquares
=
tensorrt_llm
::
common
::
warpReduceSum
(
sumOfSquares
);
// Compute RMS normalization factor
float
rms_rcp
=
rsqrtf
(
sumOfSquares
/
static_cast
<
float
>
(
head_dim
)
+
eps
);
// Normalize elements
#pragma unroll
for
(
int
i
=
0
;
i
<
numElemsPerThread
;
i
++
)
{
int
dim
=
laneId
*
numElemsPerThread
+
i
;
float
weight
=
isQ
?
Converter
::
convert
(
q_weight
[
dim
])
:
Converter
::
convert
(
k_weight
[
dim
]);
elements
[
i
]
*=
rms_rcp
*
weight
;
}
// Apply RoPE to normalized elements
float
elements2
[
numElemsPerThread
];
// Additional buffer required for RoPE.
int64_t
pos_id
=
position_ids
[
tokenIdx
];
// Calculate cache pointer for this position - similar to
// pos_encoding_kernels.cu
T_cache
const
*
cache_ptr
=
cos_sin_cache
+
pos_id
*
head_dim
;
int
const
embed_dim
=
head_dim
/
2
;
T_cache
const
*
cos_ptr
=
cache_ptr
;
T_cache
const
*
sin_ptr
=
cache_ptr
+
embed_dim
;
if
constexpr
(
interleave
)
{
// Perform interleaving. Use pre-computed cos/sin values.
#pragma unroll
for
(
int
i
=
0
;
i
<
numElemsPerThread
/
2
;
++
i
)
{
int
const
idx0
=
2
*
i
;
int
const
idx1
=
2
*
i
+
1
;
float
const
val0
=
elements
[
idx0
];
float
const
val1
=
elements
[
idx1
];
int
const
dim_idx
=
laneId
*
numElemsPerThread
+
idx0
;
int
const
half_dim
=
dim_idx
/
2
;
float
const
cos_val
=
CacheConverter
::
convert
(
VLLM_LDG
(
cos_ptr
+
half_dim
));
float
const
sin_val
=
CacheConverter
::
convert
(
VLLM_LDG
(
sin_ptr
+
half_dim
));
elements
[
idx0
]
=
val0
*
cos_val
-
val1
*
sin_val
;
elements
[
idx1
]
=
val0
*
sin_val
+
val1
*
cos_val
;
}
}
else
{
// Before data exchange with in warp, we need to sync.
__syncwarp
();
// Get the data from the other half of the warp. Use pre-computed cos/sin
// values.
#pragma unroll
for
(
int
i
=
0
;
i
<
numElemsPerThread
;
i
++
)
{
elements2
[
i
]
=
__shfl_xor_sync
(
FINAL_MASK
,
elements
[
i
],
16
);
if
(
laneId
<
16
)
{
elements2
[
i
]
=
-
elements2
[
i
];
}
int
dim_idx
=
laneId
*
numElemsPerThread
+
i
;
dim_idx
=
(
dim_idx
*
2
)
%
head_dim
;
int
half_dim
=
dim_idx
/
2
;
// Use pre-computed cos/sin from cache
float
cos_val
=
CacheConverter
::
convert
(
VLLM_LDG
(
cos_ptr
+
half_dim
));
float
sin_val
=
CacheConverter
::
convert
(
VLLM_LDG
(
sin_ptr
+
half_dim
));
elements
[
i
]
=
elements
[
i
]
*
cos_val
+
elements2
[
i
]
*
sin_val
;
}
// __shfl_xor_sync does not provide memfence. Need to sync again.
__syncwarp
();
}
// Store.
{
vec_T
vec
;
constexpr
int
num_packed_elems
=
elemSizeBytes
/
sizeof
(
T2_in
);
#pragma unroll
for
(
int
i
=
0
;
i
<
num_packed_elems
;
i
++
)
{
// Convert from float2 back to the specific packed type
T2_in
packed_val
=
Converter
::
convert
(
make_float2
(
elements
[
2
*
i
],
elements
[
2
*
i
+
1
]));
// Place it into the generic vector
*
(
reinterpret_cast
<
T2_in
*>
(
&
vec
)
+
i
)
=
packed_val
;
}
*
reinterpret_cast
<
vec_T
*>
(
&
qkv
[
offsetThread
])
=
vec
;
}
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
}
#endif
}
// Borrowed from
// https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
if (interleave) { \
const bool INTERLEAVE = true; \
__VA_ARGS__ \
} else { \
const bool INTERLEAVE = false; \
__VA_ARGS__ \
}
template
<
typename
scalar_t_in
,
typename
scalar_t_cache
>
void
launchFusedQKNormRope
(
void
*
qkv
,
int
const
num_tokens
,
int
const
num_heads_q
,
int
const
num_heads_k
,
int
const
num_heads_v
,
int
const
head_dim
,
float
const
eps
,
void
const
*
q_weight
,
void
const
*
k_weight
,
void
const
*
cos_sin_cache
,
bool
const
interleave
,
int64_t
const
*
position_ids
,
cudaStream_t
stream
)
{
constexpr
int
blockSize
=
256
;
int
const
warpsPerBlock
=
blockSize
/
32
;
int
const
totalQKHeads
=
num_heads_q
+
num_heads_k
;
int
const
totalWarps
=
num_tokens
*
totalQKHeads
;
int
const
gridSize
=
common
::
divUp
(
totalWarps
,
warpsPerBlock
);
dim3
gridDim
(
gridSize
);
dim3
blockDim
(
blockSize
);
switch
(
head_dim
)
{
case
64
:
DISPATCH_INTERLEAVE
(
interleave
,
INTERLEAVE
,
{
fusedQKNormRopeKernel
<
scalar_t_in
,
scalar_t_cache
,
64
,
INTERLEAVE
>
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
qkv
,
num_heads_q
,
num_heads_k
,
num_heads_v
,
eps
,
q_weight
,
k_weight
,
cos_sin_cache
,
position_ids
,
num_tokens
);
});
break
;
case
128
:
DISPATCH_INTERLEAVE
(
interleave
,
INTERLEAVE
,
{
fusedQKNormRopeKernel
<
scalar_t_in
,
scalar_t_cache
,
128
,
INTERLEAVE
>
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
qkv
,
num_heads_q
,
num_heads_k
,
num_heads_v
,
eps
,
q_weight
,
k_weight
,
cos_sin_cache
,
position_ids
,
num_tokens
);
});
break
;
case
256
:
DISPATCH_INTERLEAVE
(
interleave
,
INTERLEAVE
,
{
fusedQKNormRopeKernel
<
scalar_t_in
,
scalar_t_cache
,
256
,
INTERLEAVE
>
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
qkv
,
num_heads_q
,
num_heads_k
,
num_heads_v
,
eps
,
q_weight
,
k_weight
,
cos_sin_cache
,
position_ids
,
num_tokens
);
});
break
;
default:
TORCH_CHECK
(
false
,
"Unsupported head dimension for fusedQKNormRope: "
,
head_dim
);
}
}
}
// namespace tensorrt_llm::kernels
void
fused_qk_norm_rope
(
torch
::
Tensor
&
qkv
,
// Combined QKV tensor [num_tokens,
// (num_heads_q+num_heads_k+num_heads_v)*head_dim]
int64_t
num_heads_q
,
// Number of query heads
int64_t
num_heads_k
,
// Number of key heads
int64_t
num_heads_v
,
// Number of value heads
int64_t
head_dim
,
// Dimension per head
double
eps
,
// Epsilon for RMS normalization
torch
::
Tensor
&
q_weight
,
// RMSNorm weights for query [head_dim]
torch
::
Tensor
&
k_weight
,
// RMSNorm weights for key [head_dim]
torch
::
Tensor
&
cos_sin_cache
,
// Cos/sin cache [max_position, head_dim]
bool
is_neox
,
// Whether RoPE is applied in Neox style
torch
::
Tensor
&
position_ids
// Position IDs for RoPE [num_tokens]
)
{
// Input validation
CHECK_INPUT
(
qkv
);
CHECK_INPUT
(
position_ids
);
CHECK_INPUT
(
q_weight
);
CHECK_INPUT
(
k_weight
);
CHECK_INPUT
(
cos_sin_cache
);
CHECK_TYPE
(
position_ids
,
torch
::
kInt64
);
TORCH_CHECK
(
qkv
.
dim
()
==
2
,
"QKV tensor must be 2D: [num_tokens, "
"(num_heads_q+num_heads_k+num_heads_v)*head_dim]"
);
TORCH_CHECK
(
position_ids
.
dim
()
==
1
,
"Position IDs must be 1D: [num_tokens]"
);
TORCH_CHECK
(
q_weight
.
dim
()
==
1
,
"Query weights must be 1D: [head_dim]"
);
TORCH_CHECK
(
k_weight
.
dim
()
==
1
,
"Key weights must be 1D: [head_dim]"
);
TORCH_CHECK
(
cos_sin_cache
.
dim
()
==
2
,
"Cos/sin cache must be 2D: [max_position, head_dim]"
);
TORCH_CHECK
(
q_weight
.
size
(
0
)
==
head_dim
,
"Query weights size must match head dimension"
);
TORCH_CHECK
(
k_weight
.
size
(
0
)
==
head_dim
,
"Key weights size must match head dimension"
);
TORCH_CHECK
(
cos_sin_cache
.
size
(
1
)
==
head_dim
,
"Cos/sin cache dimension must match head_dim"
);
TORCH_CHECK
(
qkv
.
scalar_type
()
==
q_weight
.
scalar_type
()
&&
qkv
.
scalar_type
()
==
k_weight
.
scalar_type
(),
"qkv, q_weight and k_weight must have the same dtype"
);
int64_t
num_tokens
=
qkv
.
size
(
0
);
TORCH_CHECK
(
position_ids
.
size
(
0
)
==
num_tokens
,
"Number of tokens in position_ids must match QKV"
);
int64_t
total_heads
=
num_heads_q
+
num_heads_k
+
num_heads_v
;
TORCH_CHECK
(
qkv
.
size
(
1
)
==
total_heads
*
head_dim
,
"QKV tensor size must match total number of heads and head dimension"
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
qkv
.
get_device
());
VLLM_DISPATCH_HALF_TYPES
(
qkv
.
scalar_type
(),
"fused_qk_norm_rope_kernel"
,
[
&
]
{
using
qkv_scalar_t
=
scalar_t
;
VLLM_DISPATCH_FLOATING_TYPES
(
cos_sin_cache
.
scalar_type
(),
"fused_qk_norm_rope_kernel"
,
[
&
]
{
using
cache_scalar_t
=
scalar_t
;
tensorrt_llm
::
kernels
::
launchFusedQKNormRope
<
qkv_scalar_t
,
cache_scalar_t
>
(
qkv
.
data_ptr
(),
static_cast
<
int
>
(
num_tokens
),
static_cast
<
int
>
(
num_heads_q
),
static_cast
<
int
>
(
num_heads_k
),
static_cast
<
int
>
(
num_heads_v
),
static_cast
<
int
>
(
head_dim
),
static_cast
<
float
>
(
eps
),
q_weight
.
data_ptr
(),
k_weight
.
data_ptr
(),
cos_sin_cache
.
data_ptr
(),
!
is_neox
,
reinterpret_cast
<
int64_t
const
*>
(
position_ids
.
data_ptr
()),
stream
);
});
});
}
\ No newline at end of file
csrc/launch_bounds_utils.h
View file @
006693ed
...
...
@@ -8,11 +8,37 @@
#define VLLM_LAUNCH_BLOCKS_CAP 4
#endif
// compile-time estimate of max threads per SM for launch bounds.
// Compile-time estimate of max threads per SM for launch bounds.
// Families: 1024, 1536, 2048 threads/SM.
#ifndef VLLM_MAX_THREADS_PER_SM
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 300
#define VLLM_MAX_THREADS_PER_SM 1536
#ifdef __CUDA_ARCH__
/* 1024 thr/SM: Turing (sm_75) */
#if (__CUDA_ARCH__ == 750)
#define VLLM_MAX_THREADS_PER_SM 1024
/* 1536 thr/SM: Ampere GA10x (sm_86/87), Ada (sm_89),
GB20x consumer (sm_120/121), Thor (sm_101 or sm_110) */
#elif (__CUDA_ARCH__ == 860) || (__CUDA_ARCH__ == 870) || \
(__CUDA_ARCH__ == 890) || (__CUDA_ARCH__ == 1010) || \
(__CUDA_ARCH__ == 1100) || (__CUDA_ARCH__ == 1200) || \
(__CUDA_ARCH__ == 1210)
#define VLLM_MAX_THREADS_PER_SM 1536
/* 2048 thr/SM: Volta (sm_70/72), Ampere GA100 (sm_80),
Hopper (sm_90), Blackwell (sm_100/103) */
#elif (__CUDA_ARCH__ == 700) || (__CUDA_ARCH__ == 720) || \
(__CUDA_ARCH__ == 800) || (__CUDA_ARCH__ == 900) || \
(__CUDA_ARCH__ == 1000) || (__CUDA_ARCH__ == 1030)
#define VLLM_MAX_THREADS_PER_SM 2048
/* Fallback: use 2048 for unknown future CCs */
#else
#define VLLM_MAX_THREADS_PER_SM 2048
#endif
#else
/* Host pass (no __CUDA_ARCH__): neutral default */
#define VLLM_MAX_THREADS_PER_SM 2048
#endif
#endif
...
...
csrc/layernorm_kernels.cu
View file @
006693ed
#include "type_convert.cuh"
#include "dispatch_utils.h"
#include "cub_helpers.h"
#include "core/batch_invariant.hpp"
#include "quantization/vectorization_utils.cuh"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
...
...
@@ -8,7 +10,7 @@
namespace
vllm
{
// TODO(woosuk): Further optimize this kernel.
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
int
VEC_SIZE
>
__global__
void
rms_norm_kernel
(
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
...
...
@@ -17,11 +19,21 @@ __global__ void rms_norm_kernel(
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
const
scalar_t
*
input_row
=
input
+
blockIdx
.
x
*
input_stride
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
input_stride
+
idx
];
auto
vec_op
=
[
&
variance
](
const
vec_n_t
<
scalar_t
,
VEC_SIZE
>&
vec
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
float
x
=
static_cast
<
float
>
(
vec
.
val
[
i
]);
variance
+=
x
*
x
;
}
};
auto
scalar_op
=
[
&
variance
](
const
scalar_t
&
val
)
{
float
x
=
static_cast
<
float
>
(
val
);
variance
+=
x
*
x
;
}
};
vllm
::
vectorize_read_with_alignment
<
VEC_SIZE
>
(
input_row
,
hidden_size
,
threadIdx
.
x
,
blockDim
.
x
,
vec_op
,
scalar_op
);
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
...
...
@@ -32,10 +44,20 @@ __global__ void rms_norm_kernel(
}
__syncthreads
();
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
input
[
blockIdx
.
x
*
input_stride
+
idx
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
scalar_t
*
out_row
=
out
+
blockIdx
.
x
*
hidden_size
;
auto
*
v_in
=
reinterpret_cast
<
const
vec_n_t
<
scalar_t
,
VEC_SIZE
>*>
(
input_row
);
auto
*
v_w
=
reinterpret_cast
<
const
vec_n_t
<
scalar_t
,
VEC_SIZE
>*>
(
weight
);
auto
*
v_out
=
reinterpret_cast
<
vec_n_t
<
scalar_t
,
VEC_SIZE
>*>
(
out_row
);
for
(
int
i
=
threadIdx
.
x
;
i
<
hidden_size
/
VEC_SIZE
;
i
+=
blockDim
.
x
)
{
vec_n_t
<
scalar_t
,
VEC_SIZE
>
dst
;
vec_n_t
<
scalar_t
,
VEC_SIZE
>
src1
=
v_in
[
i
];
vec_n_t
<
scalar_t
,
VEC_SIZE
>
src2
=
v_w
[
i
];
#pragma unroll
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
j
++
)
{
float
x
=
static_cast
<
float
>
(
src1
.
val
[
j
]);
dst
.
val
[
j
]
=
((
scalar_t
)(
x
*
s_variance
))
*
src2
.
val
[
j
];
}
v_out
[
i
]
=
dst
;
}
}
...
...
@@ -135,211 +157,6 @@ fused_add_rms_norm_kernel(
}
}
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck.
_f16VecPN struct extends _f16Vec to add operations specifically required for
polynomial normalization (poly norm).
The original _f16Vec does not include the sum-of-powers computation or
in-place polynomial normalization logic. */
template
<
typename
scalar_t
,
int
width
>
struct
alignas
(
16
)
_f16VecPN
:
_f16Vec
<
scalar_t
,
width
>
{
using
Base
=
_f16Vec
<
scalar_t
,
width
>
;
using
Converter
=
typename
Base
::
Converter
;
using
T1
=
typename
Base
::
T1
;
using
T2
=
typename
Base
::
T2
;
using
Base
::
data
;
__device__
auto
sum_pows
()
const
{
float
s2
=
0.0
f
,
s4
=
0.0
f
,
s6
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
float2
z
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
float
x2
=
z
.
x
*
z
.
x
;
float
x4
=
x2
*
x2
;
float
x6
=
x4
*
x2
;
float
y2
=
z
.
y
*
z
.
y
;
float
y4
=
y2
*
y2
;
float
y6
=
y4
*
y2
;
s2
+=
x2
+
y2
;
s4
+=
x4
+
y4
;
s6
+=
x6
+
y6
;
}
return
std
::
make_tuple
(
s2
,
s4
,
s6
);
}
__device__
void
poly_norm_inplace
(
const
float
w2_inv_std
,
const
float
w1_inv_std2
,
const
float
w0_inv_std3
,
const
float
bias
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
float2
z
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
float
x2
=
z
.
x
*
z
.
x
;
float
x3
=
x2
*
z
.
x
;
z
.
x
=
w2_inv_std
*
z
.
x
+
w1_inv_std2
*
x2
+
w0_inv_std3
*
x3
+
bias
;
float
y2
=
z
.
y
*
z
.
y
;
float
y3
=
y2
*
z
.
y
;
z
.
y
=
w2_inv_std
*
z
.
y
+
w1_inv_std2
*
y2
+
w0_inv_std3
*
y3
+
bias
;
auto
out
=
Converter
::
convert
(
z
);
data
[
i
]
=
out
.
x
;
data
[
i
+
1
]
=
out
.
y
;
}
}
};
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
poly_norm_kernel
(
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [3]
const
scalar_t
*
__restrict__
bias
,
// [1]
const
float
epsilon
,
const
int
hidden_size
)
{
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert
(
std
::
is_pod_v
<
_f16VecPN
<
scalar_t
,
width
>>
);
static_assert
(
sizeof
(
_f16VecPN
<
scalar_t
,
width
>
)
==
sizeof
(
scalar_t
)
*
width
);
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto
*
__restrict__
input_v
=
reinterpret_cast
<
const
_f16VecPN
<
scalar_t
,
width
>*>
(
input
);
const
int
vec_hidden_size
=
hidden_size
/
width
;
float
variance
=
0.0
f
;
float
variance2
=
0.0
f
;
float
variance3
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
_f16VecPN
<
scalar_t
,
width
>
temp
=
input_v
[
id
];
auto
[
x2
,
x4
,
x6
]
=
temp
.
sum_pows
();
variance
+=
x2
;
variance2
+=
x4
;
variance3
+=
x6
;
}
float3
thread_variances
=
make_float3
(
variance
,
variance2
,
variance3
);
struct
SumOp
{
__device__
float3
operator
()(
const
float3
&
a
,
const
float3
&
b
)
const
{
return
make_float3
(
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
,
a
.
z
+
b
.
z
);
}
};
using
BlockReduce
=
cub
::
BlockReduce
<
float3
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
float3
block_variances
=
BlockReduce
(
reduceStore
).
Reduce
(
thread_variances
,
SumOp
{},
blockDim
.
x
);
variance
=
block_variances
.
x
;
variance2
=
block_variances
.
y
;
variance3
=
block_variances
.
z
;
__shared__
float
s_w2_inv_std
;
__shared__
float
s_w1_inv_std2
;
__shared__
float
s_w0_inv_std3
;
__shared__
float
s_bias
;
if
(
threadIdx
.
x
==
0
)
{
float
w0
=
(
float
)
weight
[
0
];
float
w1
=
(
float
)
weight
[
1
];
float
w2
=
(
float
)
weight
[
2
];
s_bias
=
(
float
)
bias
[
0
];
s_w2_inv_std
=
w2
*
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
s_w1_inv_std2
=
w1
*
rsqrtf
(
variance2
/
hidden_size
+
epsilon
);
s_w0_inv_std3
=
w0
*
rsqrtf
(
variance3
/
hidden_size
+
epsilon
);
}
__syncthreads
();
auto
*
__restrict__
out_v
=
reinterpret_cast
<
_f16VecPN
<
scalar_t
,
width
>*>
(
out
);
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
_f16VecPN
<
scalar_t
,
width
>
temp
=
input_v
[
id
];
temp
.
poly_norm_inplace
(
s_w2_inv_std
,
s_w1_inv_std2
,
s_w0_inv_std3
,
s_bias
);
out_v
[
id
]
=
temp
;
}
}
/* Generic poly_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
poly_norm_kernel
(
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [3]
const
scalar_t
*
__restrict__
bias
,
// [1]
const
float
epsilon
,
const
int
hidden_size
)
{
float
variance
=
0.0
f
;
float
variance2
=
0.0
f
;
float
variance3
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
x2
=
x
*
x
;
float
x4
=
x2
*
x2
;
float
x6
=
x4
*
x2
;
variance
+=
x2
;
variance2
+=
x4
;
variance3
+=
x6
;
}
float3
thread_variances
=
make_float3
(
variance
,
variance2
,
variance3
);
struct
SumOp
{
__device__
float3
operator
()(
const
float3
&
a
,
const
float3
&
b
)
const
{
return
make_float3
(
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
,
a
.
z
+
b
.
z
);
}
};
using
BlockReduce
=
cub
::
BlockReduce
<
float3
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
float3
block_variances
=
BlockReduce
(
reduceStore
).
Reduce
(
thread_variances
,
SumOp
{},
blockDim
.
x
);
variance
=
block_variances
.
x
;
variance2
=
block_variances
.
y
;
variance3
=
block_variances
.
z
;
__shared__
float
s_w2_inv_std
;
__shared__
float
s_w1_inv_std2
;
__shared__
float
s_w0_inv_std3
;
__shared__
float
s_bias
;
if
(
threadIdx
.
x
==
0
)
{
float
w0
=
(
float
)
weight
[
0
];
float
w1
=
(
float
)
weight
[
1
];
float
w2
=
(
float
)
weight
[
2
];
s_bias
=
(
float
)
bias
[
0
];
s_w2_inv_std
=
w2
*
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
s_w1_inv_std2
=
w1
*
rsqrtf
(
variance2
/
hidden_size
+
epsilon
);
s_w0_inv_std3
=
w0
*
rsqrtf
(
variance3
/
hidden_size
+
epsilon
);
}
__syncthreads
();
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
x2
=
x
*
x
;
float
x3
=
x2
*
x
;
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
(
scalar_t
)(
x
*
s_w2_inv_std
+
x2
*
s_w1_inv_std2
+
x3
*
s_w0_inv_std3
+
s_bias
);
}
}
}
// namespace vllm
void
rms_norm
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
...
...
@@ -351,18 +168,34 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
TORCH_CHECK
(
weight
.
is_contiguous
());
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
int64_t
input_stride
=
input
.
stride
(
-
2
);
// We cannot just use `input.stride(-2)` if the tensor is not row-major.
// Instead, we use a 2d view to get the second-innermost stride.
// That way the dimensions (except the last one) can be arbitrarily permuted.
torch
::
Tensor
input_view
=
input
.
view
({
-
1
,
hidden_size
});
int
num_tokens
=
input_view
.
numel
()
/
hidden_size
;
int64_t
input_stride
=
input_view
.
stride
(
-
2
);
// For large num_tokens, use smaller blocks to increase SM concurrency.
const
int
max_block_size
=
(
num_tokens
<
256
)
?
1024
:
256
;
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input_view
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel"
,
[
&
]
{
vllm
::
rms_norm_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
input_stride
,
weight
.
data_ptr
<
scalar_t
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
VLLM_DISPATCH_FLOATING_TYPES
(
input_view
.
scalar_type
(),
"rms_norm_kernel"
,
[
&
]
{
const
int
calculated_vec_size
=
std
::
gcd
(
16
/
sizeof
(
scalar_t
),
hidden_size
);
const
int
block_size
=
std
::
min
(
hidden_size
/
calculated_vec_size
,
max_block_size
);
dim3
block
(
block_size
);
VLLM_DISPATCH_VEC_SIZE
(
calculated_vec_size
,
[
&
]
{
vllm
::
rms_norm_kernel
<
scalar_t
,
vec_size
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
input_view
.
data_ptr
<
scalar_t
>
(),
input_stride
,
weight
.
data_ptr
<
scalar_t
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
});
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
...
...
@@ -379,6 +212,8 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
torch
::
Tensor
&
residual
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
double
epsilon
)
{
TORCH_CHECK
(
weight
.
scalar_type
()
==
input
.
scalar_type
());
TORCH_CHECK
(
input
.
scalar_type
()
==
residual
.
scalar_type
());
TORCH_CHECK
(
residual
.
is_contiguous
());
TORCH_CHECK
(
weight
.
is_contiguous
());
int
hidden_size
=
input
.
size
(
-
1
);
...
...
@@ -413,55 +248,11 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
wt_ptr
%
req_alignment_bytes
==
0
;
bool
offsets_are_multiple_of_vector_width
=
hidden_size
%
vector_width
==
0
&&
input_stride
%
vector_width
==
0
;
if
(
ptrs_are_aligned
&&
offsets_are_multiple_of_vector_width
)
{
bool
batch_invariant_launch
=
vllm
::
vllm_is_batch_invariant
();
if
(
ptrs_are_aligned
&&
offsets_are_multiple_of_vector_width
&&
!
batch_invariant_launch
)
{
LAUNCH_FUSED_ADD_RMS_NORM
(
8
);
}
else
{
LAUNCH_FUSED_ADD_RMS_NORM
(
0
);
}
}
#define LAUNCH_FUSED_POLY_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \
vllm::poly_norm_kernel<scalar_t, width><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), epsilon, \
hidden_size); \
});
void
poly_norm
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [3]
torch
::
Tensor
&
bias
,
// [1]
double
epsilon
)
{
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
data_ptr
()
!=
input
.
data_ptr
());
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
grid
(
num_tokens
);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const
int
max_block_size
=
(
num_tokens
<
256
)
?
1024
:
256
;
dim3
block
(
std
::
min
(
hidden_size
,
max_block_size
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
out_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
out
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
out_ptr
%
16
==
0
;
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
)
{
LAUNCH_FUSED_POLY_NORM
(
8
);
}
else
{
LAUNCH_FUSED_POLY_NORM
(
0
);
}
}
csrc/layernorm_quant_kernels.cu
View file @
006693ed
...
...
@@ -6,9 +6,11 @@
*/
#include "type_convert.cuh"
#include "quantization/fp8/common.cuh"
#include "quantization/
w8a8/
fp8/common.cuh"
#include "dispatch_utils.h"
#include "cub_helpers.h"
#include "core/batch_invariant.hpp"
#include "quantization/vectorization_utils.cuh"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
...
...
@@ -16,7 +18,7 @@
namespace
vllm
{
// TODO(woosuk): Further optimize this kernel.
template
<
typename
scalar_t
,
typename
fp8_type
>
template
<
typename
scalar_t
,
typename
fp8_type
,
int
VEC_SIZE
>
__global__
void
rms_norm_static_fp8_quant_kernel
(
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
...
...
@@ -27,10 +29,21 @@ __global__ void rms_norm_static_fp8_quant_kernel(
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
input_stride
+
idx
];
const
scalar_t
*
input_row
=
input
+
blockIdx
.
x
*
input_stride
;
auto
vec_op
=
[
&
variance
](
const
vec_n_t
<
scalar_t
,
VEC_SIZE
>&
vec
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
float
x
=
static_cast
<
float
>
(
vec
.
val
[
i
]);
variance
+=
x
*
x
;
}
};
auto
scalar_op
=
[
&
variance
](
const
scalar_t
&
val
)
{
float
x
=
static_cast
<
float
>
(
val
);
variance
+=
x
*
x
;
}
};
vllm
::
vectorize_read_with_alignment
<
VEC_SIZE
>
(
input_row
,
hidden_size
,
threadIdx
.
x
,
blockDim
.
x
,
vec_op
,
scalar_op
);
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
...
...
@@ -44,11 +57,18 @@ __global__ void rms_norm_static_fp8_quant_kernel(
// invert scale to avoid division
float
const
scale_inv
=
1.0
f
/
*
scale
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
input
[
blockIdx
.
x
*
input_stride
+
idx
];
float
const
out_norm
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
auto
*
v_in
=
reinterpret_cast
<
const
vec_n_t
<
scalar_t
,
VEC_SIZE
>*>
(
input_row
);
auto
*
v_w
=
reinterpret_cast
<
const
vec_n_t
<
scalar_t
,
VEC_SIZE
>*>
(
weight
);
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
/
VEC_SIZE
;
idx
+=
blockDim
.
x
)
{
vec_n_t
<
scalar_t
,
VEC_SIZE
>
src1
=
v_in
[
idx
];
vec_n_t
<
scalar_t
,
VEC_SIZE
>
src2
=
v_w
[
idx
];
#pragma unroll
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
j
++
)
{
float
x
=
static_cast
<
float
>
(
src1
.
val
[
j
]);
float
const
out_norm
=
((
scalar_t
)(
x
*
s_variance
))
*
src2
.
val
[
j
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
*
VEC_SIZE
+
j
]
=
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
}
}
}
...
...
@@ -174,20 +194,29 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
int
input_stride
=
input
.
stride
(
-
2
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
// For large num_tokens, use smaller blocks to increase SM concurrency.
const
int
max_block_size
=
(
num_tokens
<
256
)
?
1024
:
256
;
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel_scalar_type"
,
[
&
]
{
VLLM_DISPATCH_FP8_TYPES
(
out
.
scalar_type
(),
"rms_norm_kernel_fp8_type"
,
[
&
]
{
vllm
::
rms_norm_static_fp8_quant_kernel
<
scalar_t
,
fp8_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
fp8_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
input_stride
,
weight
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
epsilon
,
num_tokens
,
hidden_size
);
const
int
calculated_vec_size
=
std
::
gcd
(
16
/
sizeof
(
scalar_t
),
hidden_size
);
const
int
block_size
=
std
::
min
(
hidden_size
/
calculated_vec_size
,
max_block_size
);
dim3
block
(
block_size
);
VLLM_DISPATCH_VEC_SIZE
(
calculated_vec_size
,
[
&
]
{
vllm
::
rms_norm_static_fp8_quant_kernel
<
scalar_t
,
fp8_t
,
vec_size
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
fp8_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
input_stride
,
weight
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
});
});
}
...
...
@@ -215,6 +244,8 @@ void fused_add_rms_norm_static_fp8_quant(
double
epsilon
)
{
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
residual
.
is_contiguous
());
TORCH_CHECK
(
residual
.
scalar_type
()
==
input
.
scalar_type
());
TORCH_CHECK
(
weight
.
scalar_type
()
==
input
.
scalar_type
());
int
hidden_size
=
input
.
size
(
-
1
);
int
input_stride
=
input
.
stride
(
-
2
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
...
...
@@ -240,7 +271,9 @@ void fused_add_rms_norm_static_fp8_quant(
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
res_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
&&
input_stride
%
8
==
0
)
{
bool
batch_invariant_launch
=
vllm
::
vllm_is_batch_invariant
();
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
&&
input_stride
%
8
==
0
&&
!
batch_invariant_launch
)
{
LAUNCH_FUSED_ADD_RMS_NORM
(
8
);
}
else
{
LAUNCH_FUSED_ADD_RMS_NORM
(
0
);
...
...
csrc/mamba/mamba_ssm/selective_scan.h
View file @
006693ed
...
...
@@ -24,6 +24,8 @@ struct SSMParamsBase {
int64_t
pad_slot_id
;
bool
delta_softplus
;
bool
cache_enabled
;
int
block_size
;
index_t
A_d_stride
;
index_t
A_dstate_stride
;
...
...
@@ -46,8 +48,9 @@ struct SSMParamsBase {
index_t
out_z_batch_stride
;
index_t
out_z_d_stride
;
index_t
ssm_states_batch_stride
;
index_t
ssm_states_dim_stride
;
index_t
ssm_states_dim_stride
;
index_t
ssm_states_dstate_stride
;
index_t
cache_indices_stride
;
// Common data pointers.
void
*
__restrict__
A_ptr
;
...
...
@@ -66,6 +69,9 @@ struct SSMParamsBase {
void
*
__restrict__
cache_indices_ptr
;
void
*
__restrict__
has_initial_state_ptr
;
void
*
__restrict__
block_idx_first_scheduled_token_ptr
;
// (batch,) - first block to write
void
*
__restrict__
block_idx_last_scheduled_token_ptr
;
// (batch,) - last block to write
void
*
__restrict__
initial_state_idx_ptr
;
// (batch,) - index of the initial state to use
};
...
...
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
View file @
006693ed
...
...
@@ -119,7 +119,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
const
int
*
cache_indices
=
params
.
cache_indices_ptr
==
nullptr
?
nullptr
:
reinterpret_cast
<
int
*>
(
params
.
cache_indices_ptr
);
const
int
cache_index
=
cache_indices
==
nullptr
?
batch_id
:
cache_indices
[
batch_id
];
const
int
cache_index
=
cache_indices
==
nullptr
?
batch_id
:
cache_indices
[
batch_id
];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if
(
cache_index
==
params
.
pad_slot_id
){
return
;
...
...
@@ -133,9 +133,18 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
input_t
*
Bvar
=
reinterpret_cast
<
input_t
*>
(
params
.
B_ptr
)
+
sequence_start_index
*
params
.
B_batch_stride
+
group_id
*
params
.
B_group_stride
;
weight_t
*
C
=
reinterpret_cast
<
weight_t
*>
(
params
.
C_ptr
)
+
dim_id
*
kNRows
*
params
.
C_d_stride
;
input_t
*
Cvar
=
reinterpret_cast
<
input_t
*>
(
params
.
C_ptr
)
+
sequence_start_index
*
params
.
C_batch_stride
+
group_id
*
params
.
C_group_stride
;
typename
Ktraits
::
state_t
*
ssm_states
=
reinterpret_cast
<
typename
Ktraits
::
state_t
*>
(
params
.
ssm_states_ptr
)
+
cache_index
*
params
.
ssm_states_batch_stride
+
dim_id
*
kNRows
*
params
.
ssm_states_dim_stride
;
typename
Ktraits
::
state_t
*
ssm_states
;
if
(
params
.
cache_enabled
)
{
// APC mode: ssm_states points to the base, we'll use absolute cache slots later
ssm_states
=
reinterpret_cast
<
typename
Ktraits
::
state_t
*>
(
params
.
ssm_states_ptr
)
+
dim_id
*
kNRows
*
params
.
ssm_states_dim_stride
;
}
else
{
// Non-APC mode: offset by cache_index as before
ssm_states
=
reinterpret_cast
<
typename
Ktraits
::
state_t
*>
(
params
.
ssm_states_ptr
)
+
cache_index
*
params
.
ssm_states_batch_stride
+
dim_id
*
kNRows
*
params
.
ssm_states_dim_stride
;
}
float
D_val
[
kNRows
]
=
{
0
};
if
(
params
.
D_ptr
!=
nullptr
)
{
...
...
@@ -159,7 +168,22 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
// }
constexpr
int
kChunkSize
=
kNThreads
*
kNItems
;
const
int
n_chunks
=
(
seqlen
+
2048
-
1
)
/
2048
;
// Use block_size for chunking when APC is enabled, otherwise use 2048 for backwards compatibility
const
int
iteration_chunk_size
=
params
.
cache_enabled
?
params
.
block_size
:
2048
;
const
int
n_chunks
=
(
seqlen
+
iteration_chunk_size
-
1
)
/
iteration_chunk_size
;
const
int
*
batch_cache_indices
=
cache_indices
!=
nullptr
?
cache_indices
+
batch_id
*
params
.
cache_indices_stride
:
nullptr
;
const
int
*
block_idx_first_scheduled
=
params
.
block_idx_first_scheduled_token_ptr
!=
nullptr
?
reinterpret_cast
<
const
int
*>
(
params
.
block_idx_first_scheduled_token_ptr
)
:
nullptr
;
const
int
*
block_idx_last_scheduled
=
params
.
block_idx_last_scheduled_token_ptr
!=
nullptr
?
reinterpret_cast
<
const
int
*>
(
params
.
block_idx_last_scheduled_token_ptr
)
:
nullptr
;
const
int
*
initial_state_idx
=
params
.
initial_state_idx_ptr
!=
nullptr
?
reinterpret_cast
<
const
int
*>
(
params
.
initial_state_idx_ptr
)
:
nullptr
;
const
size_t
load_cache_slot
=
params
.
cache_enabled
&&
batch_cache_indices
!=
nullptr
?
batch_cache_indices
[
initial_state_idx
[
batch_id
]]
:
cache_index
;
for
(
int
chunk
=
0
;
chunk
<
n_chunks
;
++
chunk
)
{
input_t
u_vals
[
kNRows
][
kNItems
],
delta_vals_load
[
kNRows
][
kNItems
];
...
...
@@ -219,7 +243,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if
constexpr
(
kIsVariableC
)
{
auto
&
smem_load_weight_C
=
!
kIsVariableB
?
smem_load_weight
:
smem_load_weight1
;
load_weight
<
Ktraits
>
(
Cvar
+
state_idx
*
params
.
C_dstate_stride
,
C_vals
,
smem_load_weight_C
,
(
seqlen
-
chunk
*
kChunkSize
)
*
(
1
));
smem_load_weight_C
,
(
seqlen
-
chunk
*
kChunkSize
)
*
(
1
));
if
constexpr
(
!
kIsVariableB
)
{
#pragma unroll
for
(
int
r
=
0
;
r
<
kNRows
;
++
r
)
{
...
...
@@ -242,7 +266,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
for
(
int
i
=
0
;
i
<
kNItems
;
++
i
)
{
thread_data
[
i
]
=
make_float2
(
exp2f
(
delta_vals
[
r
][
i
]
*
A_val
[
r
]),
!
kIsVariableB
?
delta_u_vals
[
r
][
i
]
:
B_vals
[
i
]
*
delta_u_vals
[
r
][
i
]);
if
(
seqlen
%
(
kNItems
*
kNThreads
)
!=
0
)
{
// So that the last state is correct
if
(
threadIdx
.
x
*
kNItems
+
i
>=
seqlen
-
chunk
*
kChunkSize
)
{
thread_data
[
i
]
=
make_float2
(
1.
f
,
0.
f
);
...
...
@@ -250,8 +273,24 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
}
}
// Initialize running total
scan_t
running_prefix
=
chunk
>
0
?
smem_running_prefix
[
state_idx
+
r
*
MAX_DSTATE
]
:
make_float2
(
1.0
,
has_initial_state
?
float
(
ssm_states
[
state_idx
*
params
.
ssm_states_dstate_stride
])
:
0.0
);
scan_t
running_prefix
;
if
(
chunk
>
0
)
{
running_prefix
=
smem_running_prefix
[
state_idx
+
r
*
MAX_DSTATE
];
}
else
{
// Load initial state
if
(
params
.
cache_enabled
&&
has_initial_state
&&
batch_cache_indices
!=
nullptr
)
{
size_t
state_offset
=
load_cache_slot
*
params
.
ssm_states_batch_stride
+
r
*
params
.
ssm_states_dim_stride
+
state_idx
*
params
.
ssm_states_dstate_stride
;
running_prefix
=
make_float2
(
1.0
,
float
(
ssm_states
[
state_offset
]));
}
else
if
(
has_initial_state
)
{
// Non-APC mode: load from current batch position
running_prefix
=
make_float2
(
1.0
,
float
(
ssm_states
[
state_idx
*
params
.
ssm_states_dstate_stride
]));
}
else
{
// No initial state
running_prefix
=
make_float2
(
1.0
,
0.0
);
}
}
SSMScanPrefixCallbackOp
<
weight_t
>
prefix_op
(
running_prefix
);
typename
Ktraits
::
BlockScanT
(
smem_scan
).
InclusiveScan
(
...
...
@@ -260,8 +299,25 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
// There's a syncthreads in the scan op, so we don't need to sync here.
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
if
(
threadIdx
.
x
==
0
)
{
smem_running_prefix
[
state_idx
]
=
prefix_op
.
running_prefix
;
if
(
chunk
==
n_chunks
-
1
)
{
smem_running_prefix
[
state_idx
+
r
*
MAX_DSTATE
]
=
prefix_op
.
running_prefix
;
// Store state at the end of each chunk when cache is enabled
if
(
params
.
cache_enabled
&&
batch_cache_indices
!=
nullptr
)
{
size_t
cache_slot
;
if
(
chunk
==
n_chunks
-
1
)
{
cache_slot
=
batch_cache_indices
[
block_idx_last_scheduled
[
batch_id
]];
}
else
{
cache_slot
=
batch_cache_indices
[
block_idx_first_scheduled
[
batch_id
]
+
chunk
];
}
size_t
state_offset
=
cache_slot
*
params
.
ssm_states_batch_stride
+
r
*
params
.
ssm_states_dim_stride
+
state_idx
*
params
.
ssm_states_dstate_stride
;
ssm_states
[
state_offset
]
=
typename
Ktraits
::
state_t
(
prefix_op
.
running_prefix
.
y
);
}
else
if
(
!
params
.
cache_enabled
&&
chunk
==
n_chunks
-
1
)
{
// Non-APC mode: store only final state at current batch position
ssm_states
[
state_idx
*
params
.
ssm_states_dstate_stride
]
=
typename
Ktraits
::
state_t
(
prefix_op
.
running_prefix
.
y
);
}
}
...
...
@@ -274,7 +330,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
}
}
}
input_t
*
out
=
reinterpret_cast
<
input_t
*>
(
params
.
out_ptr
)
+
sequence_start_index
*
params
.
out_batch_stride
+
dim_id
*
kNRows
*
params
.
out_d_stride
+
chunk
*
kChunkSize
;
__syncthreads
();
...
...
@@ -346,7 +401,9 @@ template<typename input_t, typename weight_t, typename state_t>
void
selective_scan_fwd_cuda
(
SSMParamsBase
&
params
,
cudaStream_t
stream
)
{
#ifndef USE_ROCM
if
(
params
.
seqlen
<=
128
)
{
if
(
params
.
cache_enabled
&&
params
.
block_size
==
1024
)
{
selective_scan_fwd_launch
<
64
,
16
,
input_t
,
weight_t
,
state_t
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen
<=
128
)
{
selective_scan_fwd_launch
<
32
,
4
,
input_t
,
weight_t
,
state_t
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen
<=
256
)
{
selective_scan_fwd_launch
<
32
,
8
,
input_t
,
weight_t
,
state_t
>
(
params
,
stream
);
...
...
@@ -358,7 +415,9 @@ void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) {
selective_scan_fwd_launch
<
128
,
16
,
input_t
,
weight_t
,
state_t
>
(
params
,
stream
);
}
#else
if
(
params
.
seqlen
<=
256
)
{
if
(
params
.
cache_enabled
&&
params
.
block_size
==
1024
)
{
selective_scan_fwd_launch
<
64
,
16
,
input_t
,
weight_t
,
state_t
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen
<=
256
)
{
selective_scan_fwd_launch
<
64
,
4
,
input_t
,
weight_t
,
state_t
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen
<=
512
)
{
selective_scan_fwd_launch
<
64
,
8
,
input_t
,
weight_t
,
state_t
>
(
params
,
stream
);
...
...
@@ -437,13 +496,17 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
const
std
::
optional
<
at
::
Tensor
>&
D
,
const
std
::
optional
<
at
::
Tensor
>&
delta_bias
,
const
torch
::
Tensor
ssm_states
,
bool
has_z
,
bool
has_z
,
bool
delta_softplus
,
const
std
::
optional
<
at
::
Tensor
>&
query_start_loc
,
const
std
::
optional
<
at
::
Tensor
>&
cache_indices
,
const
std
::
optional
<
at
::
Tensor
>&
has_initial_state
,
bool
varlen
,
int64_t
pad_slot_id
)
{
int64_t
pad_slot_id
,
int64_t
block_size
,
const
std
::
optional
<
torch
::
Tensor
>
&
block_idx_first_scheduled_token
,
const
std
::
optional
<
torch
::
Tensor
>
&
block_idx_last_scheduled_token
,
const
std
::
optional
<
torch
::
Tensor
>
&
initial_state_idx
)
{
// Reset the parameters
memset
(
&
params
,
0
,
sizeof
(
params
));
...
...
@@ -477,6 +540,14 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
params
.
cache_indices_ptr
=
cache_indices
.
has_value
()
?
cache_indices
.
value
().
data_ptr
()
:
nullptr
;
params
.
has_initial_state_ptr
=
has_initial_state
.
has_value
()
?
has_initial_state
.
value
().
data_ptr
()
:
nullptr
;
// Set cache parameters - cache is enabled if we have direct cache writing params
params
.
cache_enabled
=
block_idx_first_scheduled_token
.
has_value
();
params
.
block_size
=
static_cast
<
int
>
(
block_size
);
// Set direct cache writing pointers
params
.
block_idx_first_scheduled_token_ptr
=
block_idx_first_scheduled_token
.
has_value
()
?
block_idx_first_scheduled_token
.
value
().
data_ptr
()
:
nullptr
;
params
.
block_idx_last_scheduled_token_ptr
=
block_idx_last_scheduled_token
.
has_value
()
?
block_idx_last_scheduled_token
.
value
().
data_ptr
()
:
nullptr
;
params
.
initial_state_idx_ptr
=
initial_state_idx
.
has_value
()
?
initial_state_idx
.
value
().
data_ptr
()
:
nullptr
;
// All stride are in elements, not bytes.
params
.
A_d_stride
=
A
.
stride
(
0
);
...
...
@@ -504,9 +575,11 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
params
.
out_d_stride
=
out
.
stride
(
0
);
params
.
ssm_states_batch_stride
=
ssm_states
.
stride
(
0
);
params
.
ssm_states_dim_stride
=
ssm_states
.
stride
(
1
);
params
.
ssm_states_dim_stride
=
ssm_states
.
stride
(
1
);
params
.
ssm_states_dstate_stride
=
ssm_states
.
stride
(
2
);
params
.
cache_indices_stride
=
cache_indices
.
has_value
()
?
cache_indices
.
value
().
stride
(
0
)
:
0
;
}
else
{
if
(
!
is_variable_B
)
{
...
...
@@ -537,8 +610,10 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
params
.
out_d_stride
=
out
.
stride
(
1
);
params
.
ssm_states_batch_stride
=
ssm_states
.
stride
(
0
);
params
.
ssm_states_dim_stride
=
ssm_states
.
stride
(
1
);
params
.
ssm_states_dim_stride
=
ssm_states
.
stride
(
1
);
params
.
ssm_states_dstate_stride
=
ssm_states
.
stride
(
2
);
params
.
cache_indices_stride
=
cache_indices
.
has_value
()
?
cache_indices
.
value
().
stride
(
0
)
:
0
;
}
}
...
...
@@ -554,7 +629,11 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const
torch
::
Tensor
&
ssm_states
,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t
pad_slot_id
)
{
int64_t
pad_slot_id
,
int64_t
block_size
,
const
std
::
optional
<
torch
::
Tensor
>
&
block_idx_first_scheduled_token
,
const
std
::
optional
<
torch
::
Tensor
>
&
block_idx_last_scheduled_token
,
const
std
::
optional
<
torch
::
Tensor
>
&
initial_state_idx
)
{
auto
input_type
=
u
.
scalar_type
();
auto
weight_type
=
A
.
scalar_type
();
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
...
...
@@ -646,7 +725,16 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
auto
cache_indices_
=
cache_indices
.
value
();
TORCH_CHECK
(
cache_indices_
.
scalar_type
()
==
at
::
ScalarType
::
Int
);
TORCH_CHECK
(
cache_indices_
.
is_cuda
());
CHECK_SHAPE
(
cache_indices_
,
batch_size
);
// cache_indices can be either 1D (batch_size,) for non-APC mode
// or 2D (batch_size, max_positions) for APC mode
const
bool
is_apc_mode
=
block_idx_first_scheduled_token
.
has_value
();
if
(
is_apc_mode
)
{
TORCH_CHECK
(
cache_indices_
.
dim
()
==
2
,
"cache_indices must be 2D for APC mode"
);
TORCH_CHECK
(
cache_indices_
.
size
(
0
)
==
batch_size
,
"cache_indices first dimension must match batch_size"
);
}
else
{
CHECK_SHAPE
(
cache_indices_
,
batch_size
);
}
}
...
...
@@ -686,7 +774,11 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
cache_indices
,
has_initial_state
,
varlen
,
pad_slot_id
pad_slot_id
,
block_size
,
block_idx_first_scheduled_token
,
block_idx_last_scheduled_token
,
initial_state_idx
);
...
...
csrc/moe/dynamic_4bit_int_moe_cpu.cpp
View file @
006693ed
...
...
@@ -87,30 +87,23 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
const
int64_t
g_eff_13
=
(
group_size
!=
-
1
)
?
group_size
:
H
;
const
int64_t
g_eff_2
=
(
group_size
!=
-
1
)
?
group_size
:
I
;
// Per-expert outputs filled in parallel
std
::
vector
<
torch
::
Tensor
>
y_list
(
E
);
y_list
.
resize
(
E
);
auto
X_all
=
x_c
.
index_select
(
/*dim=*/
0
,
expert_tokens
);
if
(
apply_router_weight_on_input
)
{
X_all
=
X_all
.
mul
(
expert_gates
.
unsqueeze
(
1
));
}
auto
Y_all
=
at
::
empty
({
offsets
[
E
],
H
},
x_c
.
options
());
at
::
parallel_for
(
0
,
E
,
1
,
[
&
](
int64_t
e_begin
,
int64_t
e_end
)
{
c10
::
InferenceMode
guard
;
for
(
int64_t
e
=
e_begin
;
e
<
e_end
;
++
e
)
{
const
int64_t
te
=
counts
[
e
];
if
(
te
==
0
)
{
y_list
[
e
]
=
at
::
empty
({
0
,
H
},
x_c
.
options
());
continue
;
}
const
int64_t
start
=
offsets
[
e
];
auto
sel_tokens
=
expert_tokens
.
narrow
(
/*dim=*/
0
,
/*start=*/
start
,
/*length=*/
te
);
auto
gates_e
=
expert_gates
.
narrow
(
/*dim=*/
0
,
/*start=*/
start
,
/*length=*/
te
);
auto
x_e
=
x_c
.
index_select
(
/*dim=*/
0
,
sel_tokens
);
if
(
apply_router_weight_on_input
)
{
x_e
=
x_e
.
mul
(
gates_e
.
unsqueeze
(
1
));
}
auto
x_e
=
X_all
.
narrow
(
/*dim=*/
0
,
/*start=*/
start
,
/*length=*/
te
);
auto
w13_e
=
w13_packed
.
select
(
/*dim=*/
0
,
e
);
auto
w2_e
=
w2_packed
.
select
(
/*dim=*/
0
,
e
);
...
...
@@ -137,17 +130,15 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
// W2
auto
y
=
mm
(
act
,
w2_e
,
g_eff_2
,
/*in_features=*/
I
,
/*out_features=*/
H
);
if
(
!
apply_router_weight_on_input
)
{
y
=
y
.
mul
(
gates_e
.
unsqueeze
(
1
));
}
// Store per-expert result
y_list
[
e
]
=
y
;
Y_all
.
narrow
(
/*dim=*/
0
,
/*start=*/
start
,
/*length=*/
te
).
copy_
(
y
)
;
}
});
// Concatenate all expert outputs to match expert_tokens order
auto
Y_all
=
at
::
cat
(
y_list
,
/*dim=*/
0
);
if
(
!
apply_router_weight_on_input
)
{
Y_all
=
Y_all
.
mul
(
expert_gates
.
unsqueeze
(
1
));
}
auto
out
=
at
::
zeros
({
T
,
H
},
x
.
options
());
out
=
at
::
index_add
(
out
,
/*dim=*/
0
,
/*index=*/
expert_tokens
,
/*source=*/
Y_all
);
...
...
csrc/moe/grouped_topk_kernels.cu
View file @
006693ed
...
...
@@ -427,11 +427,29 @@ __device__ inline bool is_finite(const T val) {
#endif
}
// Scoring function enums
enum
ScoringFunc
{
SCORING_NONE
=
0
,
// no activation function
SCORING_SIGMOID
=
1
// apply sigmoid
};
// Efficient sigmoid approximation from TensorRT-LLM
__device__
inline
float
sigmoid_accurate
(
float
x
)
{
return
0.5
f
*
tanhf
(
0.5
f
*
x
)
+
0.5
f
;
}
template
<
typename
T
>
__device__
void
topk_with_k2
(
T
*
output
,
T
const
*
input
,
__device__
inline
T
apply_sigmoid
(
T
val
)
{
float
f
=
cuda_cast
<
float
,
T
>
(
val
);
return
cuda_cast
<
T
,
float
>
(
sigmoid_accurate
(
f
));
}
template
<
typename
T
>
__device__
void
topk_with_k2
(
T
*
output
,
T
const
*
input
,
T
const
*
bias
,
cg
::
thread_block_tile
<
32
>
const
&
tile
,
int32_t
const
lane_id
,
int
const
num_experts_per_group
)
{
int
const
num_experts_per_group
,
int
const
scoring_func
)
{
// Get the top2 per thread
T
largest
=
neg_inf
<
T
>
();
T
second_largest
=
neg_inf
<
T
>
();
...
...
@@ -439,6 +457,12 @@ __device__ void topk_with_k2(T* output, T const* input,
if
(
num_experts_per_group
>
WARP_SIZE
)
{
for
(
int
i
=
lane_id
;
i
<
num_experts_per_group
;
i
+=
WARP_SIZE
)
{
T
value
=
input
[
i
];
// Apply scoring function if needed
if
(
scoring_func
==
SCORING_SIGMOID
)
{
value
=
apply_sigmoid
(
value
);
}
value
=
value
+
bias
[
i
];
if
(
value
>
largest
)
{
second_largest
=
largest
;
largest
=
value
;
...
...
@@ -448,7 +472,13 @@ __device__ void topk_with_k2(T* output, T const* input,
}
}
else
{
for
(
int
i
=
lane_id
;
i
<
num_experts_per_group
;
i
+=
WARP_SIZE
)
{
largest
=
input
[
i
];
T
value
=
input
[
i
];
// Apply scoring function if needed
if
(
scoring_func
==
SCORING_SIGMOID
)
{
value
=
apply_sigmoid
(
value
);
}
value
=
value
+
bias
[
i
];
largest
=
value
;
}
}
...
...
@@ -472,17 +502,21 @@ __device__ void topk_with_k2(T* output, T const* input,
}
template
<
typename
T
>
__global__
void
topk_with_k2_kernel
(
T
*
output
,
T
*
input
,
__global__
void
topk_with_k2_kernel
(
T
*
output
,
T
*
input
,
T
const
*
bias
,
int64_t
const
num_tokens
,
int64_t
const
num_cases
,
int64_t
const
n_group
,
int64_t
const
num_experts_per_group
)
{
int64_t
const
num_experts_per_group
,
int
const
scoring_func
)
{
int32_t
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
int32_t
lane_id
=
threadIdx
.
x
%
WARP_SIZE
;
int32_t
case_id
=
blockIdx
.
x
*
NUM_WARPS_PER_BLOCK
+
warp_id
;
if
(
case_id
<
num_cases
)
{
input
+=
case_id
*
num_experts_per_group
;
// bias is per expert group, offset to current group
int32_t
group_id
=
case_id
%
n_group
;
T
const
*
group_bias
=
bias
+
group_id
*
num_experts_per_group
;
output
+=
case_id
;
cg
::
thread_block
block
=
cg
::
this_thread_block
();
...
...
@@ -491,7 +525,8 @@ __global__ void topk_with_k2_kernel(T* output, T* input,
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm
volatile
(
"griddepcontrol.wait;"
);
#endif
topk_with_k2
(
output
,
input
,
tile
,
lane_id
,
num_experts_per_group
);
topk_with_k2
(
output
,
input
,
group_bias
,
tile
,
lane_id
,
num_experts_per_group
,
scoring_func
);
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm
volatile
(
"griddepcontrol.launch_dependents;"
);
...
...
@@ -500,16 +535,15 @@ __global__ void topk_with_k2_kernel(T* output, T* input,
template
<
typename
T
,
typename
IdxT
>
__global__
void
group_idx_and_topk_idx_kernel
(
T
*
scores
,
T
const
*
group_scores
,
T
*
topk_values
,
IdxT
*
topk_indices
,
T
*
scores_with_
bias
,
int64_t
const
num_tokens
,
int64_t
const
n_group
,
T
*
scores
,
T
const
*
group_scores
,
float
*
topk_values
,
IdxT
*
topk_indices
,
T
const
*
bias
,
int64_t
const
num_tokens
,
int64_t
const
n_group
,
int64_t
const
topk_group
,
int64_t
const
topk
,
int64_t
const
num_experts
,
int64_t
const
num_experts_per_group
,
bool
renormalize
,
double
routed_scaling_factor
)
{
double
routed_scaling_factor
,
int
scoring_func
)
{
int32_t
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
int32_t
lane_id
=
threadIdx
.
x
%
WARP_SIZE
;
int32_t
case_id
=
blockIdx
.
x
*
NUM_WARPS_PER_BLOCK
+
warp_id
;
// one per token
scores_with_bias
+=
case_id
*
num_experts
;
scores
+=
case_id
*
num_experts
;
group_scores
+=
case_id
*
n_group
;
topk_values
+=
case_id
*
topk
;
...
...
@@ -577,10 +611,16 @@ __global__ void group_idx_and_topk_idx_kernel(
int32_t
offset
=
i_group
*
num_experts_per_group
;
for
(
int32_t
i
=
lane_id
;
i
<
align_num_experts_per_group
;
i
+=
WARP_SIZE
)
{
T
candidates
=
(
i
<
num_experts_per_group
)
&&
is_finite
(
scores_with_bias
[
offset
+
i
])
?
scores_with_bias
[
offset
+
i
]
:
neg_inf
<
T
>
();
T
candidates
=
neg_inf
<
T
>
();
if
(
i
<
num_experts_per_group
)
{
// Apply scoring function (if any) and add bias
T
input
=
scores
[
offset
+
i
];
if
(
is_finite
(
input
))
{
T
score
=
(
scoring_func
==
SCORING_SIGMOID
)
?
apply_sigmoid
(
input
)
:
input
;
candidates
=
score
+
bias
[
offset
+
i
];
}
}
queue
.
add
(
candidates
,
offset
+
i
);
}
if
(
group_scores
[
i_group
]
==
topk_group_value
)
{
...
...
@@ -602,11 +642,12 @@ __global__ void group_idx_and_topk_idx_kernel(
for
(
int
i
=
lane_id
;
i
<
warp_topk
::
round_up_to_multiple_of
<
WARP_SIZE
>
(
topk
);
i
+=
WARP_SIZE
)
{
T
value
=
i
<
topk
?
scores
[
s_topk_idx
[
i
]]
:
cuda_cast
<
T
,
float
>
(
0.0
f
);
// Load the valid value of expert
T
value
=
cuda_cast
<
T
,
float
>
(
0.0
f
);
if
(
i
<
topk
)
{
// Load the score value (without bias) for normalization
T
input
=
scores
[
s_topk_idx
[
i
]];
value
=
(
scoring_func
==
SCORING_SIGMOID
)
?
apply_sigmoid
(
input
)
:
input
;
s_topk_value
[
i
]
=
value
;
}
topk_sum
+=
...
...
@@ -627,12 +668,12 @@ __global__ void group_idx_and_topk_idx_kernel(
value
=
cuda_cast
<
float
,
T
>
(
s_topk_value
[
i
])
*
routed_scaling_factor
;
}
topk_indices
[
i
]
=
s_topk_idx
[
i
];
topk_values
[
i
]
=
cuda_cast
<
T
,
float
>
(
value
)
;
topk_values
[
i
]
=
value
;
}
}
else
{
for
(
int
i
=
lane_id
;
i
<
topk
;
i
+=
WARP_SIZE
)
{
topk_indices
[
i
]
=
i
;
topk_values
[
i
]
=
cuda_cast
<
T
,
float
>
(
1.0
f
/
topk
)
;
topk_values
[
i
]
=
1.0
f
/
topk
;
}
}
// Note: when if_proceed_next_topk==false, choose the first 8 experts as the
...
...
@@ -644,12 +685,12 @@ __global__ void group_idx_and_topk_idx_kernel(
}
template
<
typename
T
,
typename
IdxT
>
void
invokeNoAuxTc
(
T
*
scores
,
T
*
group_scores
,
T
*
topk_values
,
IdxT
*
topk_indices
,
T
*
scores_with_bia
s
,
int64_t
const
num_
token
s
,
int64_t
const
n
um_experts
,
int64_t
const
n
_group
,
int64_t
const
topk
_group
,
int64_t
const
topk
,
bool
const
renormalize
,
double
const
routed_scaling_factor
,
bool
enable_pdl
=
false
,
void
invokeNoAuxTc
(
T
*
scores
,
T
*
group_scores
,
float
*
topk_values
,
IdxT
*
topk_indices
,
T
const
*
bias
,
int64_t
const
num_token
s
,
int64_t
const
num_
expert
s
,
int64_t
const
n
_group
,
int64_t
const
topk
_group
,
int64_t
const
topk
,
bool
const
renormalize
,
double
const
routed_scaling_factor
,
int
const
scoring_func
,
bool
enable_pdl
=
false
,
cudaStream_t
const
stream
=
0
)
{
int64_t
num_cases
=
num_tokens
*
n_group
;
int64_t
topk_with_k2_num_blocks
=
(
num_cases
-
1
)
/
NUM_WARPS_PER_BLOCK
+
1
;
...
...
@@ -664,8 +705,9 @@ void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values,
attrs
[
0
].
val
.
programmaticStreamSerializationAllowed
=
enable_pdl
;
config
.
numAttrs
=
1
;
config
.
attrs
=
attrs
;
cudaLaunchKernelEx
(
&
config
,
kernel_instance1
,
group_scores
,
scores_with_bias
,
num_tokens
,
num_cases
,
n_group
,
num_experts
/
n_group
);
cudaLaunchKernelEx
(
&
config
,
kernel_instance1
,
group_scores
,
scores
,
bias
,
num_tokens
,
num_cases
,
n_group
,
num_experts
/
n_group
,
scoring_func
);
int64_t
topk_with_k_group_num_blocks
=
(
num_tokens
-
1
)
/
NUM_WARPS_PER_BLOCK
+
1
;
...
...
@@ -682,19 +724,18 @@ void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values,
config
.
numAttrs
=
1
;
config
.
attrs
=
attrs
;
cudaLaunchKernelEx
(
&
config
,
kernel_instance2
,
scores
,
group_scores
,
topk_values
,
topk_indices
,
scores_with_
bias
,
num_tokens
,
n_group
,
topk_group
,
topk
,
num_experts
,
num_experts
/
n_group
,
renormalize
,
routed_scaling_factor
);
topk_values
,
topk_indices
,
bias
,
num_tokens
,
n_group
,
topk_group
,
topk
,
num_experts
,
num_experts
/
n_group
,
renormalize
,
routed_scaling_factor
,
scoring_func
);
}
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
template void invokeNoAuxTc<T, IdxT>( \
T * scores, T * group_scores, T * topk_values, IdxT * topk_indices, \
T * scores_with_bias, int64_t const num_tokens, \
int64_t const num_experts, int64_t const n_group, \
int64_t const topk_group, int64_t const topk, bool const renormalize, \
double const routed_scaling_factor, bool enable_pdl, \
cudaStream_t const stream);
T * scores, T * group_scores, float* topk_values, IdxT* topk_indices, \
T const* bias, int64_t const num_tokens, int64_t const num_experts, \
int64_t const n_group, int64_t const topk_group, int64_t const topk, \
bool const renormalize, double const routed_scaling_factor, \
int const scoring_func, bool enable_pdl, cudaStream_t const stream);
INSTANTIATE_NOAUX_TC
(
float
,
int32_t
);
INSTANTIATE_NOAUX_TC
(
half
,
int32_t
);
...
...
@@ -703,28 +744,32 @@ INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t);
}
// namespace vllm
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
grouped_topk
(
torch
::
Tensor
const
&
scores
,
torch
::
Tensor
const
&
scores_with_bias
,
int64_t
n_group
,
int64_t
topk_group
,
int64_t
topk
,
bool
renormalize
,
double
routed_scaling_factor
)
{
auto
data_type
=
scores
_with_bias
.
scalar_type
();
auto
input_size
=
scores
_with_bias
.
sizes
();
torch
::
Tensor
const
&
scores
,
int64_t
n_group
,
int64_t
topk_group
,
int64_t
topk
,
bool
renormalize
,
double
routed_scaling_factor
,
torch
::
Tensor
const
&
bias
,
int64_t
scoring_func
=
0
)
{
auto
data_type
=
scores
.
scalar_type
();
auto
input_size
=
scores
.
sizes
();
int64_t
num_tokens
=
input_size
[
0
];
int64_t
num_experts
=
input_size
[
1
];
TORCH_CHECK
(
input_size
.
size
()
==
2
,
"scores
_with_bias
must be a 2D Tensor"
);
TORCH_CHECK
(
input_size
.
size
()
==
2
,
"scores must be a 2D Tensor"
);
TORCH_CHECK
(
num_experts
%
n_group
==
0
,
"num_experts should be divisible by n_group"
);
TORCH_CHECK
(
n_group
<=
32
,
"n_group should be smaller than or equal to 32 for now"
);
TORCH_CHECK
(
topk
<=
32
,
"topk should be smaller than or equal to 32 for now"
);
TORCH_CHECK
(
scoring_func
==
vllm
::
moe
::
SCORING_NONE
||
scoring_func
==
vllm
::
moe
::
SCORING_SIGMOID
,
"scoring_func must be SCORING_NONE (0) or SCORING_SIGMOID (1)"
);
torch
::
Tensor
group_scores
=
torch
::
empty
(
{
num_tokens
,
n_group
},
torch
::
dtype
(
data_type
).
device
(
torch
::
kCUDA
));
// Always output float32 for topk_values (eliminates Python-side conversion)
torch
::
Tensor
topk_values
=
torch
::
empty
(
{
num_tokens
,
topk
},
torch
::
dtype
(
data_type
).
device
(
torch
::
kCUDA
));
{
num_tokens
,
topk
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
torch
::
Tensor
topk_indices
=
torch
::
empty
(
{
num_tokens
,
topk
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
(
scores
_with_bias
.
get_device
());
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
(
scores
.
get_device
());
switch
(
data_type
)
{
case
torch
::
kFloat16
:
...
...
@@ -732,11 +777,11 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
vllm
::
moe
::
invokeNoAuxTc
<
half
,
int32_t
>
(
reinterpret_cast
<
half
*>
(
scores
.
mutable_data_ptr
()),
reinterpret_cast
<
half
*>
(
group_scores
.
mutable_data_ptr
()),
reinterpret_cast
<
half
*>
(
topk_values
.
mutable_data_ptr
()),
reinterpret_cast
<
float
*>
(
topk_values
.
mutable_data_ptr
()),
reinterpret_cast
<
int32_t
*>
(
topk_indices
.
mutable_data_ptr
()),
reinterpret_cast
<
half
*>
(
scores_with_
bias
.
data_ptr
()),
num_tokens
,
reinterpret_cast
<
half
const
*>
(
bias
.
data_ptr
()),
num_tokens
,
num_experts
,
n_group
,
topk_group
,
topk
,
renormalize
,
routed_scaling_factor
,
false
,
stream
);
routed_scaling_factor
,
static_cast
<
int
>
(
scoring_func
),
false
,
stream
);
break
;
case
torch
::
kFloat32
:
// Handle Float32
...
...
@@ -745,20 +790,20 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
reinterpret_cast
<
float
*>
(
group_scores
.
mutable_data_ptr
()),
reinterpret_cast
<
float
*>
(
topk_values
.
mutable_data_ptr
()),
reinterpret_cast
<
int32_t
*>
(
topk_indices
.
mutable_data_ptr
()),
reinterpret_cast
<
float
*>
(
scores_with_
bias
.
data_ptr
()),
num_tokens
,
reinterpret_cast
<
float
const
*>
(
bias
.
data_ptr
()),
num_tokens
,
num_experts
,
n_group
,
topk_group
,
topk
,
renormalize
,
routed_scaling_factor
,
false
,
stream
);
routed_scaling_factor
,
static_cast
<
int
>
(
scoring_func
),
false
,
stream
);
break
;
case
torch
::
kBFloat16
:
// Handle BFloat16
vllm
::
moe
::
invokeNoAuxTc
<
__nv_bfloat16
,
int32_t
>
(
reinterpret_cast
<
__nv_bfloat16
*>
(
scores
.
mutable_data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
*>
(
group_scores
.
mutable_data_ptr
()),
reinterpret_cast
<
__nv_b
float
16
*>
(
topk_values
.
mutable_data_ptr
()),
reinterpret_cast
<
float
*>
(
topk_values
.
mutable_data_ptr
()),
reinterpret_cast
<
int32_t
*>
(
topk_indices
.
mutable_data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
*>
(
scores_with_
bias
.
data_ptr
()),
num_tokens
,
num_experts
,
n_group
,
topk_group
,
topk
,
renormalize
,
routed_scaling_factor
,
false
,
stream
);
reinterpret_cast
<
__nv_bfloat16
const
*>
(
bias
.
data_ptr
()),
num_tokens
,
num_experts
,
n_group
,
topk_group
,
topk
,
renormalize
,
routed_scaling_factor
,
static_cast
<
int
>
(
scoring_func
),
false
,
stream
);
break
;
default:
// Handle other data types
...
...
csrc/moe/marlin_moe_wna16/generate_kernels.py
View file @
006693ed
...
...
@@ -17,25 +17,30 @@ FILE_HEAD = """
namespace MARLIN_NAMESPACE_NAME {
"""
.
strip
()
TEMPLATE
=
(
"template __global__ void Marlin<"
"{{scalar_t}}, "
"{{w_type_id}}, "
"{{s_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );"
)
TEMPLATE
=
(
"template __global__ void Marlin<"
"{{scalar_t}}, "
"{{w_type_id}}, "
"{{s_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );"
)
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES
=
[
"vllm::kU4"
,
"vllm::kU4B8"
,
"vllm::kU8B128"
,
"vllm::kFE4M3fn"
,
"vllm::kFE2M1f"
"vllm::kU4"
,
"vllm::kU4B8"
,
"vllm::kU8B128"
,
"vllm::kFE4M3fn"
,
"vllm::kFE2M1f"
,
]
THREAD_CONFIGS
=
[(
128
,
128
,
256
),
(
64
,
256
,
256
),
(
64
,
128
,
128
)]
...
...
@@ -58,11 +63,12 @@ def generate_new_kernels():
all_template_str_list
=
[]
for
group_blocks
,
m_blocks
,
thread_configs
in
itertools
.
product
(
GROUP_BLOCKS
,
THREAD_M_BLOCKS
,
THREAD_CONFIGS
):
GROUP_BLOCKS
,
THREAD_M_BLOCKS
,
THREAD_CONFIGS
):
# act order case only support gptq-int4 and gptq-int8
if
group_blocks
==
0
and
scalar_type
not
in
[
"vllm::kU4B8"
,
"vllm::kU8B128"
"vllm::kU4B8"
,
"vllm::kU8B128"
,
]:
continue
if
thread_configs
[
2
]
==
256
:
...
...
csrc/moe/moe_align_sum_kernels.cu
View file @
006693ed
...
...
@@ -8,12 +8,77 @@
#include "../cuda_compat.h"
#include "../dispatch_utils.h"
#include "core/math.hpp"
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
namespace
vllm
{
namespace
moe
{
namespace
batched_moe_align_block_size
{
// Note num_threads needs to be 1024 for BlockScan Reduction in the kernel.
static
constexpr
int32_t
num_threads
=
1024
;
static
constexpr
int32_t
num_blocks
=
1
;
__global__
void
batched_moe_align_block_size_kernel
(
int32_t
const
num_batches
,
int32_t
const
max_tokens_per_batch
,
int32_t
const
block_size
,
int32_t
const
*
__restrict__
batch_num_tokens
,
int32_t
*
__restrict__
sorted_ids
,
int32_t
*
__restrict__
block_ids
,
int32_t
*
__restrict__
num_tokens_post_pad
)
{
// TODO(varun): This is a naive implementation. Could be optimized.
size_t
const
batch_id
=
threadIdx
.
x
;
size_t
const
stride
=
blockDim
.
x
*
gridDim
.
x
;
int32_t
const
num_blocks_per_batch
=
CEILDIV
(
max_tokens_per_batch
,
block_size
);
int32_t
const
sorted_ids_size
=
num_blocks_per_batch
*
num_batches
*
block_size
;
int32_t
const
block_ids_size
=
sorted_ids_size
/
block_size
;
int32_t
const
SENTINEL
=
num_batches
*
max_tokens_per_batch
;
// To denote invalid entries.
// Intialize sorted_ids
for
(
size_t
i
=
threadIdx
.
x
;
i
<
sorted_ids_size
;
i
+=
stride
)
{
sorted_ids
[
i
]
=
SENTINEL
;
}
// Intialize expert_ids with -1
for
(
size_t
i
=
threadIdx
.
x
;
i
<
block_ids_size
;
i
+=
stride
)
{
block_ids
[
i
]
=
-
1
;
}
int32_t
b_num_tokens
=
0
;
if
(
batch_id
<
num_batches
)
{
b_num_tokens
=
batch_num_tokens
[
batch_id
];
}
int32_t
const
ceil_b_num_tokens
=
CEILDIV
(
b_num_tokens
,
block_size
)
*
block_size
;
// Compute prefix sum over token counts per expert
using
BlockScan
=
cub
::
BlockScan
<
int32_t
,
1024
>
;
__shared__
typename
BlockScan
::
TempStorage
temp_storage
;
int
cumsum_val
;
BlockScan
(
temp_storage
).
ExclusiveSum
(
ceil_b_num_tokens
,
cumsum_val
);
__syncthreads
();
bool
const
is_last_batch
=
batch_id
==
(
num_batches
-
1
);
if
(
is_last_batch
)
{
*
num_tokens_post_pad
=
cumsum_val
+
ceil_b_num_tokens
;
}
if
(
batch_id
<
num_batches
)
{
int32_t
const
batch_offset
=
batch_id
*
max_tokens_per_batch
;
for
(
size_t
i
=
0
;
i
<
b_num_tokens
;
++
i
)
{
sorted_ids
[
cumsum_val
+
i
]
=
batch_offset
+
i
;
}
int32_t
const
block_start
=
cumsum_val
/
block_size
;
int32_t
const
num_blocks
=
ceil_b_num_tokens
/
block_size
;
for
(
size_t
i
=
0
;
i
<
num_blocks
;
++
i
)
{
block_ids
[
block_start
+
i
]
=
batch_id
;
}
}
}
}
// namespace batched_moe_align_block_size
template
<
typename
scalar_t
>
__global__
void
moe_align_block_size_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
...
...
@@ -280,6 +345,33 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
});
}
void
batched_moe_align_block_size
(
int64_t
max_tokens_per_batch
,
int64_t
block_size
,
torch
::
Tensor
const
&
batch_num_tokens
,
torch
::
Tensor
sorted_ids
,
torch
::
Tensor
batch_ids
,
torch
::
Tensor
num_tokens_post_pad
)
{
namespace
batched_kernel
=
vllm
::
moe
::
batched_moe_align_block_size
;
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
int32_t
const
B
=
batch_num_tokens
.
size
(
0
);
int32_t
const
num_blocks_per_batch
=
round_to_next_multiple_of
(
max_tokens_per_batch
,
block_size
)
/
block_size
;
int32_t
const
num_blocks
=
num_blocks_per_batch
*
B
;
int64_t
const
sorted_ids_size
=
num_blocks
*
block_size
;
TORCH_CHECK
(
sorted_ids
.
size
(
0
)
==
sorted_ids_size
);
TORCH_CHECK
(
batch_ids
.
size
(
0
)
==
sorted_ids_size
/
block_size
);
TORCH_CHECK
(
num_tokens_post_pad
.
size
(
0
)
==
1
);
TORCH_CHECK
(
B
<=
batched_kernel
::
num_threads
);
batched_kernel
::
batched_moe_align_block_size_kernel
<<<
batched_kernel
::
num_blocks
,
batched_kernel
::
num_threads
,
0
,
stream
>>>
(
B
,
max_tokens_per_batch
,
block_size
,
batch_num_tokens
.
data_ptr
<
int32_t
>
(),
sorted_ids
.
data_ptr
<
int32_t
>
(),
batch_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
());
}
void
moe_sum
(
torch
::
Tensor
&
input
,
// [num_tokens, topk, hidden_size]
torch
::
Tensor
&
output
)
// [num_tokens, hidden_size]
{
...
...
csrc/moe/moe_lora_align_sum_kernels.cu
0 → 100644
View file @
006693ed
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/ATen.h>
#include <ATen/cuda/Atomic.cuh>
#include "../cuda_compat.h"
#include "../dispatch_utils.h"
#include "core/math.hpp"
namespace
{
__device__
__forceinline__
int32_t
index
(
int32_t
total_col
,
int32_t
row
,
int32_t
col
)
{
return
row
*
total_col
+
col
;
}
}
// namespace
// TODO: Refactor common parts with moe_align_sum_kernels
template
<
typename
scalar_t
,
typename
token_cnts_t
>
__global__
void
moe_lora_align_sum_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
token_lora_mapping
,
int64_t
block_size
,
int
num_experts
,
int
max_loras
,
size_t
numel
,
int
max_num_tokens_padded
,
int
max_num_m_blocks
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
int
topk_num
,
int32_t
*
total_tokens_post_pad
,
int32_t
*
adapter_enabled
,
int32_t
*
lora_ids
)
{
const
size_t
tokens_per_thread
=
div_ceil
(
numel
,
blockDim
.
x
);
const
size_t
start_idx
=
threadIdx
.
x
*
tokens_per_thread
;
int
lora_idx
=
blockIdx
.
x
;
int
lora_id
=
lora_ids
[
lora_idx
];
if
(
lora_id
==
-
1
||
adapter_enabled
[
lora_id
]
==
0
)
{
return
;
}
extern
__shared__
int32_t
shared_mem
[];
int32_t
*
cumsum
=
shared_mem
;
token_cnts_t
*
tokens_cnts
=
(
token_cnts_t
*
)(
shared_mem
+
num_experts
+
1
);
// Initialize sorted_token_ids with numel
for
(
size_t
it
=
threadIdx
.
x
;
it
<
max_num_tokens_padded
;
it
+=
blockDim
.
x
)
{
sorted_token_ids
[
lora_id
*
max_num_tokens_padded
+
it
]
=
numel
;
}
// Initialize expert_ids with -1
for
(
size_t
it
=
threadIdx
.
x
;
it
<
max_num_m_blocks
;
it
+=
blockDim
.
x
)
{
expert_ids
[
lora_id
*
max_num_m_blocks
+
it
]
=
-
1
;
}
// Initialize total_tokens_post_pad with 0
if
(
threadIdx
.
x
==
0
)
{
total_tokens_post_pad
[
lora_id
]
=
0
;
}
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
i
)]
=
0
;
}
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
int
mask
=
token_lora_mapping
[
i
/
topk_num
]
==
lora_id
;
int
idx
=
index
(
num_experts
,
threadIdx
.
x
+
1
,
topk_ids
[
i
]);
tokens_cnts
[
idx
]
+=
mask
;
}
__syncthreads
();
// For each expert we accumulate the token counts from the different threads.
if
(
threadIdx
.
x
<
num_experts
)
{
tokens_cnts
[
index
(
num_experts
,
0
,
threadIdx
.
x
)]
=
0
;
for
(
int
i
=
1
;
i
<=
blockDim
.
x
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
i
,
threadIdx
.
x
)]
+=
tokens_cnts
[
index
(
num_experts
,
i
-
1
,
threadIdx
.
x
)];
}
}
__syncthreads
();
// We accumulate the token counts of all experts in thread 0.
if
(
threadIdx
.
x
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
div_ceil
(
tokens_cnts
[
index
(
num_experts
,
blockDim
.
x
,
i
-
1
)],
block_size
)
*
block_size
;
}
total_tokens_post_pad
[
lora_id
]
=
static_cast
<
int32_t
>
(
cumsum
[
num_experts
]);
}
__syncthreads
();
/**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
if
(
threadIdx
.
x
<
num_experts
)
{
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
expert_ids
[
index
(
max_num_m_blocks
,
lora_id
,
i
/
block_size
)]
=
threadIdx
.
x
;
}
}
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
int32_t
expert_id
=
topk_ids
[
i
];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t
rank_post_pad
=
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)]
+
cumsum
[
expert_id
];
int
mask
=
(
int
)
token_lora_mapping
[
i
/
topk_num
]
==
lora_id
;
atomicAdd
(
&
sorted_token_ids
[
index
(
max_num_tokens_padded
,
lora_id
,
rank_post_pad
)],
(
i
-
numel
)
*
mask
);
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)]
+=
mask
;
}
}
void
moe_lora_align_block_size
(
torch
::
Tensor
topk_ids
,
torch
::
Tensor
token_lora_mapping
,
int64_t
num_experts
,
int64_t
block_size
,
int64_t
max_loras
,
int64_t
max_num_tokens_padded
,
int64_t
max_num_m_blocks
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
expert_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
adapter_enabled
,
torch
::
Tensor
lora_ids
)
{
const
int
topk_num
=
topk_ids
.
size
(
1
);
TORCH_CHECK
(
block_size
>
0
,
"block_size should be greater than 0. "
);
int
device_max_shared_mem
;
auto
dev
=
topk_ids
.
get_device
();
cudaDeviceGetAttribute
(
&
device_max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
int32_t
num_thread
=
max
((
int32_t
)
num_experts
,
128
);
// WARP_SIZE,
TORCH_CHECK
(
num_thread
<=
1024
,
"num_thread must be less than 1024, "
"and fallback is not implemented yet."
);
const
int32_t
shared_mem
=
(
num_thread
+
1
)
*
num_experts
*
sizeof
(
int32_t
)
+
(
num_experts
+
1
)
*
sizeof
(
int32_t
);
if
(
shared_mem
>
device_max_shared_mem
)
{
TORCH_CHECK
(
false
,
"Shared memory usage exceeds device limit, and global memory "
"fallback is not implemented yet."
);
}
VLLM_DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_lora_align_sum_kernel"
,
[
&
]
{
dim3
blockDim
(
num_thread
);
auto
kernel
=
moe_lora_align_sum_kernel
<
scalar_t
,
int32_t
>
;
AT_CUDA_CHECK
(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize
(
(
void
*
)
kernel
,
shared_mem
));
kernel
<<<
max_loras
,
blockDim
,
shared_mem
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
token_lora_mapping
.
data_ptr
<
int32_t
>
(),
block_size
,
num_experts
,
max_loras
,
topk_ids
.
numel
(),
max_num_tokens_padded
,
max_num_m_blocks
,
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
expert_ids
.
data_ptr
<
int32_t
>
(),
topk_num
,
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
adapter_enabled
.
data_ptr
<
int32_t
>
(),
lora_ids
.
data_ptr
<
int32_t
>
());
});
}
\ No newline at end of file
csrc/moe/moe_ops.h
View file @
006693ed
...
...
@@ -4,7 +4,7 @@
void
topk_softmax
(
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_indices
,
torch
::
Tensor
&
token_expert_indices
,
torch
::
Tensor
&
gating_output
);
torch
::
Tensor
&
gating_output
,
bool
renormalize
);
void
moe_sum
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
output
);
...
...
@@ -12,6 +12,21 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
);
void
batched_moe_align_block_size
(
int64_t
max_tokens_per_batch
,
int64_t
block_size
,
torch
::
Tensor
const
&
expert_num_tokens
,
torch
::
Tensor
sorted_ids
,
torch
::
Tensor
expert_ids
,
torch
::
Tensor
num_tokens_post_pad
);
void
moe_lora_align_block_size
(
torch
::
Tensor
topk_ids
,
torch
::
Tensor
token_lora_mapping
,
int64_t
num_experts
,
int64_t
block_size
,
int64_t
max_loras
,
int64_t
max_num_tokens_padded
,
int64_t
max_num_m_blocks
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
expert_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
adapter_enabled
,
torch
::
Tensor
lora_ids
);
#ifndef USE_ROCM
torch
::
Tensor
moe_wna16_gemm
(
torch
::
Tensor
input
,
torch
::
Tensor
output
,
torch
::
Tensor
b_qweight
,
torch
::
Tensor
b_scales
,
...
...
@@ -24,9 +39,9 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
int64_t
BLOCK_SIZE_K
,
int64_t
bit
);
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
grouped_topk
(
torch
::
Tensor
const
&
scores
,
torch
::
Tensor
const
&
scores_with_bias
,
int64_t
n_group
,
int64_t
topk_group
,
int64_t
topk
,
bool
renormalize
,
double
routed_scaling_factor
);
torch
::
Tensor
const
&
scores
,
int64_t
n_group
,
int64_t
topk_group
,
int64_t
topk
,
bool
renormalize
,
double
routed_scaling_factor
,
torch
::
Tensor
const
&
bias
,
int64_t
scoring_func
);
#endif
bool
moe_permute_unpermute_supported
();
...
...
csrc/moe/topk_softmax_kernels.cu
View file @
006693ed
...
...
@@ -16,12 +16,23 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <type_traits>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "../cuda_compat.h"
#include "../cub_helpers.h"
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
typedef
__hip_bfloat16
__nv_bfloat16
;
typedef
__hip_bfloat162
__nv_bfloat162
;
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
...
...
@@ -36,16 +47,27 @@ template <
/// Alignment requirement in bytes
int
Alignment
=
sizeof
(
T
)
*
N
>
class
alignas
(
Alignment
)
AlignedArray
{
float
data
[
N
];
struct
alignas
(
Alignment
)
AlignedArray
{
T
data
[
N
];
};
template
<
typename
T
>
__device__
__forceinline__
float
toFloat
(
T
value
)
{
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
return
value
;
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
__nv_bfloat16
>
)
{
return
__bfloat162float
(
value
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
__half
>
)
{
return
__half2float
(
value
);
}
}
// ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing the output
// in the softmax kernel when we extend this module to support expert-choice routing.
template
<
int
TPB
>
template
<
int
TPB
,
typename
InputType
>
__launch_bounds__
(
TPB
)
__global__
void
moeSoftmax
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
const
int
num_cols
)
void
moeSoftmax
(
const
InputType
*
input
,
const
bool
*
finished
,
float
*
output
,
const
int
num_cols
)
{
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
TPB
>
;
__shared__
typename
BlockReduce
::
TempStorage
tmpStorage
;
...
...
@@ -66,7 +88,8 @@ __launch_bounds__(TPB) __global__
for
(
int
ii
=
threadIdx
.
x
;
ii
<
num_cols
;
ii
+=
TPB
)
{
const
int
idx
=
thread_row_offset
+
ii
;
threadData
=
max
(
static_cast
<
float
>
(
input
[
idx
]),
threadData
);
const
float
val
=
toFloat
(
input
[
idx
]);
threadData
=
max
(
val
,
threadData
);
}
const
float
maxElem
=
BlockReduce
(
tmpStorage
).
Reduce
(
threadData
,
CubMaxOp
());
...
...
@@ -81,7 +104,8 @@ __launch_bounds__(TPB) __global__
for
(
int
ii
=
threadIdx
.
x
;
ii
<
num_cols
;
ii
+=
TPB
)
{
const
int
idx
=
thread_row_offset
+
ii
;
threadData
+=
exp
((
static_cast
<
float
>
(
input
[
idx
])
-
float_max
));
const
float
val
=
toFloat
(
input
[
idx
]);
threadData
+=
expf
(
val
-
float_max
);
}
const
auto
Z
=
BlockReduce
(
tmpStorage
).
Reduce
(
threadData
,
CubAddOp
());
...
...
@@ -95,8 +119,9 @@ __launch_bounds__(TPB) __global__
for
(
int
ii
=
threadIdx
.
x
;
ii
<
num_cols
;
ii
+=
TPB
)
{
const
int
idx
=
thread_row_offset
+
ii
;
const
float
val
=
exp
((
static_cast
<
float
>
(
input
[
idx
])
-
float_max
))
*
normalizing_factor
;
output
[
idx
]
=
val
;
const
float
val
=
toFloat
(
input
[
idx
]);
const
float
softmax_val
=
expf
(
val
-
float_max
)
*
normalizing_factor
;
output
[
idx
]
=
softmax_val
;
}
}
...
...
@@ -110,7 +135,8 @@ __launch_bounds__(TPB) __global__ void moeTopK(
const
int
num_experts
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
)
const
int
end_expert
,
const
bool
renormalize
)
{
using
cub_kvp
=
cub
::
KeyValuePair
<
int
,
float
>
;
...
...
@@ -125,6 +151,7 @@ __launch_bounds__(TPB) __global__ void moeTopK(
const
bool
row_is_active
=
finished
?
!
finished
[
block_row
]
:
true
;
const
int
thread_read_offset
=
blockIdx
.
x
*
num_experts
;
float
selected_sum
=
0.
f
;
for
(
int
k_idx
=
0
;
k_idx
<
k
;
++
k_idx
)
{
thread_kvp
.
key
=
0
;
...
...
@@ -163,9 +190,23 @@ __launch_bounds__(TPB) __global__ void moeTopK(
indices
[
idx
]
=
should_process_row
?
(
expert
-
start_expert
)
:
num_experts
;
assert
(
indices
[
idx
]
>=
0
);
source_rows
[
idx
]
=
k_idx
*
num_rows
+
block_row
;
if
(
renormalize
)
{
selected_sum
+=
result_kvp
.
value
;
}
}
__syncthreads
();
}
// Renormalize the k weights for this row to sum to 1, if requested.
if
(
renormalize
)
{
if
(
threadIdx
.
x
==
0
)
{
const
float
denom
=
selected_sum
>
0.
f
?
selected_sum
:
1.
f
;
for
(
int
k_idx
=
0
;
k_idx
<
k
;
++
k_idx
)
{
const
int
idx
=
k
*
block_row
+
k_idx
;
output
[
idx
]
=
output
[
idx
]
/
denom
;
}
}
}
}
// ====================== TopK softmax things ===============================
...
...
@@ -184,21 +225,30 @@ __launch_bounds__(TPB) __global__ void moeTopK(
2) This implementation assumes k is small, but will work for any k.
*/
template
<
int
VPT
,
int
NUM_EXPERTS
,
int
WARPS_PER_CTA
,
int
BYTES_PER_LDG
,
int
WARP_SIZE_PARAM
,
typename
IndType
>
template
<
int
VPT
,
int
NUM_EXPERTS
,
int
WARPS_PER_CTA
,
int
BYTES_PER_LDG
,
int
WARP_SIZE_PARAM
,
typename
IndType
,
typename
InputType
=
float
>
__launch_bounds__
(
WARPS_PER_CTA
*
WARP_SIZE_PARAM
)
__global__
void
topkGatingSoftmax
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
const
int
num_rows
,
IndType
*
indices
,
int
*
source_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
)
void
topkGatingSoftmax
(
const
InputType
*
input
,
const
bool
*
finished
,
float
*
output
,
const
int
num_rows
,
IndType
*
indices
,
int
*
source_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
,
const
bool
renormalize
)
{
static_assert
(
std
::
is_same_v
<
InputType
,
float
>
||
std
::
is_same_v
<
InputType
,
__nv_bfloat16
>
||
std
::
is_same_v
<
InputType
,
__half
>
,
"InputType must be float, __nv_bfloat16, or __half"
);
// We begin by enforcing compile time assertions and setting up compile time constants.
static_assert
(
BYTES_PER_LDG
==
(
BYTES_PER_LDG
&
-
BYTES_PER_LDG
),
"BYTES_PER_LDG must be power of 2"
);
static_assert
(
BYTES_PER_LDG
<=
16
,
"BYTES_PER_LDG must be leq 16"
);
// Number of bytes each thread pulls in per load
static
constexpr
int
ELTS_PER_LDG
=
BYTES_PER_LDG
/
sizeof
(
float
);
static
constexpr
int
ELTS_PER_LDG
=
BYTES_PER_LDG
/
sizeof
(
InputType
);
static
constexpr
int
ELTS_PER_ROW
=
NUM_EXPERTS
;
static
constexpr
int
THREADS_PER_ROW
=
ELTS_PER_ROW
/
VPT
;
static
constexpr
int
LDG_PER_THREAD
=
VPT
/
ELTS_PER_LDG
;
if
constexpr
(
std
::
is_same_v
<
InputType
,
__nv_bfloat16
>
||
std
::
is_same_v
<
InputType
,
__half
>
)
{
static_assert
(
ELTS_PER_LDG
==
1
||
ELTS_PER_LDG
%
2
==
0
,
"ELTS_PER_LDG must be 1 or even for 16-bit conversion"
);
}
// Restrictions based on previous section.
static_assert
(
VPT
%
ELTS_PER_LDG
==
0
,
"The elements per thread must be a multiple of the elements per ldg"
);
static_assert
(
WARP_SIZE_PARAM
%
THREADS_PER_ROW
==
0
,
"The threads per row must cleanly divide the threads per warp"
);
...
...
@@ -236,27 +286,71 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
// row it will read.
const
float
*
thread_row_ptr
=
input
+
thread_row
*
ELTS_PER_ROW
;
const
InputType
*
thread_row_ptr
=
input
+
thread_row
*
ELTS_PER_ROW
;
// Now, we compute the group each thread belong to in order to determine the first column to start loads.
const
int
thread_group_idx
=
threadIdx
.
x
%
THREADS_PER_ROW
;
const
int
first_elt_read_by_thread
=
thread_group_idx
*
ELTS_PER_LDG
;
const
float
*
thread_read_ptr
=
thread_row_ptr
+
first_elt_read_by_thread
;
// Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
// this can support all powers of 2 up to 16.
// NOTE(woosuk): The original implementation uses CUTLASS aligned array here.
// We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
using
AccessType
=
AlignedArray
<
float
,
ELTS_PER_LDG
>
;
const
InputType
*
thread_read_ptr
=
thread_row_ptr
+
first_elt_read_by_thread
;
// Finally, we pull in the data from global mem
float
row_chunk
[
VPT
];
AccessType
*
row_chunk_vec_ptr
=
reinterpret_cast
<
AccessType
*>
(
&
row_chunk
);
const
AccessType
*
vec_thread_read_ptr
=
reinterpret_cast
<
const
AccessType
*>
(
thread_read_ptr
);
// NOTE(zhuhaoran): dispatch different input types loading, BF16/FP16 convert to float
if
constexpr
(
std
::
is_same_v
<
InputType
,
float
>
)
{
using
VecType
=
AlignedArray
<
float
,
ELTS_PER_LDG
>
;
VecType
*
row_chunk_vec_ptr
=
reinterpret_cast
<
VecType
*>
(
&
row_chunk
);
const
VecType
*
vec_thread_read_ptr
=
reinterpret_cast
<
const
VecType
*>
(
thread_read_ptr
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDG_PER_THREAD
;
++
ii
)
{
row_chunk_vec_ptr
[
ii
]
=
vec_thread_read_ptr
[
ii
*
THREADS_PER_ROW
];
for
(
int
ii
=
0
;
ii
<
LDG_PER_THREAD
;
++
ii
)
{
row_chunk_vec_ptr
[
ii
]
=
vec_thread_read_ptr
[
ii
*
THREADS_PER_ROW
];
}
}
else
if
constexpr
(
std
::
is_same_v
<
InputType
,
__nv_bfloat16
>
)
{
if
constexpr
(
ELTS_PER_LDG
>=
2
)
{
using
VecType
=
AlignedArray
<
__nv_bfloat16
,
ELTS_PER_LDG
>
;
float2
*
row_chunk_f2
=
reinterpret_cast
<
float2
*>
(
row_chunk
);
const
VecType
*
vec_thread_read_ptr
=
reinterpret_cast
<
const
VecType
*>
(
thread_read_ptr
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDG_PER_THREAD
;
++
ii
)
{
VecType
vec
=
vec_thread_read_ptr
[
ii
*
THREADS_PER_ROW
];
int
base_idx_f2
=
ii
*
ELTS_PER_LDG
/
2
;
#pragma unroll
for
(
int
jj
=
0
;
jj
<
ELTS_PER_LDG
/
2
;
++
jj
)
{
row_chunk_f2
[
base_idx_f2
+
jj
]
=
__bfloat1622float2
(
*
reinterpret_cast
<
const
__nv_bfloat162
*>
(
vec
.
data
+
jj
*
2
)
);
}
}
}
else
{
// ELTS_PER_LDG == 1
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDG_PER_THREAD
;
++
ii
)
{
const
__nv_bfloat16
*
scalar_ptr
=
thread_read_ptr
+
ii
*
THREADS_PER_ROW
;
row_chunk
[
ii
]
=
__bfloat162float
(
*
scalar_ptr
);
}
}
}
else
if
constexpr
(
std
::
is_same_v
<
InputType
,
__half
>
)
{
if
constexpr
(
ELTS_PER_LDG
>=
2
)
{
using
VecType
=
AlignedArray
<
__half
,
ELTS_PER_LDG
>
;
float2
*
row_chunk_f2
=
reinterpret_cast
<
float2
*>
(
row_chunk
);
const
VecType
*
vec_thread_read_ptr
=
reinterpret_cast
<
const
VecType
*>
(
thread_read_ptr
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDG_PER_THREAD
;
++
ii
)
{
VecType
vec
=
vec_thread_read_ptr
[
ii
*
THREADS_PER_ROW
];
int
base_idx_f2
=
ii
*
ELTS_PER_LDG
/
2
;
#pragma unroll
for
(
int
jj
=
0
;
jj
<
ELTS_PER_LDG
/
2
;
++
jj
)
{
row_chunk_f2
[
base_idx_f2
+
jj
]
=
__half22float2
(
*
reinterpret_cast
<
const
__half2
*>
(
vec
.
data
+
jj
*
2
)
);
}
}
}
else
{
// ELTS_PER_LDG == 1
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDG_PER_THREAD
;
++
ii
)
{
const
__half
*
scalar_ptr
=
thread_read_ptr
+
ii
*
THREADS_PER_ROW
;
row_chunk
[
ii
]
=
__half2float
(
*
scalar_ptr
);
}
}
}
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
...
...
@@ -310,6 +404,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
int
start_col
=
first_elt_read_by_thread
;
static
constexpr
int
COLS_PER_GROUP_LDG
=
ELTS_PER_LDG
*
THREADS_PER_ROW
;
float
selected_sum
=
0.
f
;
for
(
int
k_idx
=
0
;
k_idx
<
k
;
++
k_idx
)
{
// First, each thread does the local argmax
...
...
@@ -363,6 +458,9 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
output
[
idx
]
=
max_val
;
indices
[
idx
]
=
should_process_row
?
(
expert
-
start_expert
)
:
NUM_EXPERTS
;
source_rows
[
idx
]
=
k_idx
*
num_rows
+
thread_row
;
if
(
renormalize
)
{
selected_sum
+=
max_val
;
}
}
// Finally, we clear the value in the thread with the current max if there is another iteration to run.
...
...
@@ -380,15 +478,28 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
}
}
}
// Renormalize the k weights for this row to sum to 1, if requested.
if
(
renormalize
)
{
if
(
thread_group_idx
==
0
)
{
const
float
denom
=
selected_sum
>
0.
f
?
selected_sum
:
1.
f
;
for
(
int
k_idx
=
0
;
k_idx
<
k
;
++
k_idx
)
{
const
int
idx
=
k
*
thread_row
+
k_idx
;
output
[
idx
]
=
output
[
idx
]
/
denom
;
}
}
}
}
namespace
detail
{
// Constructs some constants needed to partition the work across threads at compile time.
template
<
int
EXPERTS
,
int
BYTES_PER_LDG
,
int
WARP_SIZE_PARAM
>
template
<
int
EXPERTS
,
int
BYTES_PER_LDG
,
int
WARP_SIZE_PARAM
,
typename
InputType
>
struct
TopkConstants
{
static
constexpr
int
ELTS_PER_LDG
=
BYTES_PER_LDG
/
sizeof
(
float
);
static
constexpr
int
ELTS_PER_LDG
=
BYTES_PER_LDG
/
sizeof
(
InputType
);
static_assert
(
EXPERTS
/
(
ELTS_PER_LDG
*
WARP_SIZE_PARAM
)
==
0
||
EXPERTS
%
(
ELTS_PER_LDG
*
WARP_SIZE_PARAM
)
==
0
,
""
);
static
constexpr
int
VECs_PER_THREAD
=
MAX
(
1
,
EXPERTS
/
(
ELTS_PER_LDG
*
WARP_SIZE_PARAM
));
static
constexpr
int
VPT
=
VECs_PER_THREAD
*
ELTS_PER_LDG
;
...
...
@@ -397,20 +508,21 @@ struct TopkConstants
};
}
// namespace detail
template
<
int
EXPERTS
,
int
WARPS_PER_TB
,
int
WARP_SIZE_PARAM
,
int
MAX_BYTES_PER_LDG
,
typename
IndType
>
void
topkGatingSoftmaxLauncherHelper
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
IndType
*
indices
,
int
*
source_row
,
const
int
num_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
,
cudaStream_t
stream
)
template
<
int
EXPERTS
,
int
WARPS_PER_TB
,
int
WARP_SIZE_PARAM
,
int
MAX_BYTES_PER_LDG
,
typename
IndType
,
typename
InputType
>
void
topkGatingSoftmaxLauncherHelper
(
const
InputType
*
input
,
const
bool
*
finished
,
float
*
output
,
IndType
*
indices
,
int
*
source_row
,
const
int
num_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
,
const
bool
renormalize
,
cudaStream_t
stream
)
{
static
constexpr
int
BYTES_PER_LDG
=
MIN
(
MAX_BYTES_PER_LDG
,
sizeof
(
float
)
*
EXPERTS
);
using
Constants
=
detail
::
TopkConstants
<
EXPERTS
,
BYTES_PER_LDG
,
WARP_SIZE_PARAM
>
;
static
constexpr
int
BYTES_PER_LDG
=
MIN
(
MAX_BYTES_PER_LDG
,
sizeof
(
InputType
)
*
EXPERTS
);
using
Constants
=
detail
::
TopkConstants
<
EXPERTS
,
BYTES_PER_LDG
,
WARP_SIZE_PARAM
,
InputType
>
;
static
constexpr
int
VPT
=
Constants
::
VPT
;
static
constexpr
int
ROWS_PER_WARP
=
Constants
::
ROWS_PER_WARP
;
const
int
num_warps
=
(
num_rows
+
ROWS_PER_WARP
-
1
)
/
ROWS_PER_WARP
;
const
int
num_blocks
=
(
num_warps
+
WARPS_PER_TB
-
1
)
/
WARPS_PER_TB
;
dim3
block_dim
(
WARP_SIZE_PARAM
,
WARPS_PER_TB
);
topkGatingSoftmax
<
VPT
,
EXPERTS
,
WARPS_PER_TB
,
BYTES_PER_LDG
,
WARP_SIZE_PARAM
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
input
,
finished
,
output
,
num_rows
,
indices
,
source_row
,
k
,
start_expert
,
end_expert
);
topkGatingSoftmax
<
VPT
,
EXPERTS
,
WARPS_PER_TB
,
BYTES_PER_LDG
,
WARP_SIZE_PARAM
,
IndType
,
InputType
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
input
,
finished
,
output
,
num_rows
,
indices
,
source_row
,
k
,
start_expert
,
end_expert
,
renormalize
);
}
#ifndef USE_ROCM
...
...
@@ -418,26 +530,26 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
static_assert(WARP_SIZE == 32, \
"Unsupported warp size. Only 32 is supported for CUDA"); \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices,
\
token_expert_indices,
num_tokens, topk, 0, num_experts, stream);
gating_output, nullptr, topk_weights, topk_indices,
token_expert_indices,
\
num_tokens, topk, 0, num_experts,
renormalize,
stream);
#else
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
if (WARP_SIZE == 64) { \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64, MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices,
\
token_expert_indices,
num_tokens, topk, 0, num_experts,
stream);
\
gating_output, nullptr, topk_weights, topk_indices,
token_expert_indices,
\
num_tokens, topk, 0, num_experts,
renormalize, stream);
\
} else if (WARP_SIZE == 32) { \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32, MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices,
\
token_expert_indices,
num_tokens, topk, 0, num_experts,
stream);
\
gating_output, nullptr, topk_weights, topk_indices,
token_expert_indices,
\
num_tokens, topk, 0, num_experts,
renormalize, stream);
\
} else { \
assert(false && "Unsupported warp size. Only 32 and 64 are supported for ROCm"); \
}
#endif
template
<
typename
IndType
>
template
<
typename
IndType
,
typename
InputType
>
void
topkGatingSoftmaxKernelLauncher
(
const
float
*
gating_output
,
const
InputType
*
gating_output
,
float
*
topk_weights
,
IndType
*
topk_indices
,
int
*
token_expert_indices
,
...
...
@@ -445,11 +557,15 @@ void topkGatingSoftmaxKernelLauncher(
const
int
num_tokens
,
const
int
num_experts
,
const
int
topk
,
const
bool
renormalize
,
cudaStream_t
stream
)
{
static
constexpr
int
WARPS_PER_TB
=
4
;
static
constexpr
int
BYTES_PER_LDG_POWER_OF_2
=
16
;
#ifndef USE_ROCM
static
constexpr
int
BYTES_PER_LDG_MULTIPLE_64
=
8
;
// for bfloat16 dtype, we need 4 bytes loading to make sure num_experts
// elements can be loaded by a warp
static
constexpr
int
BYTES_PER_LDG_MULTIPLE_64
=
(
std
::
is_same_v
<
InputType
,
__nv_bfloat16
>
||
std
::
is_same_v
<
InputType
,
__half
>
)
?
4
:
8
;
#endif
switch
(
num_experts
)
{
case
1
:
...
...
@@ -506,11 +622,11 @@ void topkGatingSoftmaxKernelLauncher(
TORCH_CHECK
(
softmax_workspace
!=
nullptr
,
"softmax_workspace must be provided for num_experts that are not a power of 2 or multiple of 64."
);
static
constexpr
int
TPB
=
256
;
moeSoftmax
<
TPB
><<<
num_tokens
,
TPB
,
0
,
stream
>>>
(
moeSoftmax
<
TPB
,
InputType
><<<
num_tokens
,
TPB
,
0
,
stream
>>>
(
gating_output
,
nullptr
,
softmax_workspace
,
num_experts
);
moeTopK
<
TPB
><<<
num_tokens
,
TPB
,
0
,
stream
>>>
(
softmax_workspace
,
nullptr
,
topk_weights
,
topk_indices
,
token_expert_indices
,
num_experts
,
topk
,
0
,
num_experts
);
num_experts
,
topk
,
0
,
num_experts
,
renormalize
);
}
}
}
...
...
@@ -518,11 +634,50 @@ void topkGatingSoftmaxKernelLauncher(
}
// namespace moe
}
// namespace vllm
template
<
typename
ComputeType
>
void
dispatch_topk_softmax_launch
(
torch
::
Tensor
&
gating_output
,
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_indices
,
torch
::
Tensor
&
token_expert_indices
,
torch
::
Tensor
&
softmax_workspace
,
int
num_tokens
,
int
num_experts
,
int
topk
,
bool
renormalize
,
cudaStream_t
stream
)
{
if
(
topk_indices
.
scalar_type
()
==
at
::
ScalarType
::
Int
)
{
vllm
::
moe
::
topkGatingSoftmaxKernelLauncher
<
int
,
ComputeType
>
(
reinterpret_cast
<
const
ComputeType
*>
(
gating_output
.
data_ptr
()),
topk_weights
.
data_ptr
<
float
>
(),
topk_indices
.
data_ptr
<
int
>
(),
token_expert_indices
.
data_ptr
<
int
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
num_tokens
,
num_experts
,
topk
,
renormalize
,
stream
);
}
else
if
(
topk_indices
.
scalar_type
()
==
at
::
ScalarType
::
UInt32
)
{
vllm
::
moe
::
topkGatingSoftmaxKernelLauncher
<
uint32_t
,
ComputeType
>
(
reinterpret_cast
<
const
ComputeType
*>
(
gating_output
.
data_ptr
()),
topk_weights
.
data_ptr
<
float
>
(),
topk_indices
.
data_ptr
<
uint32_t
>
(),
token_expert_indices
.
data_ptr
<
int
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
num_tokens
,
num_experts
,
topk
,
renormalize
,
stream
);
}
else
{
TORCH_CHECK
(
topk_indices
.
scalar_type
()
==
at
::
ScalarType
::
Long
);
vllm
::
moe
::
topkGatingSoftmaxKernelLauncher
<
int64_t
,
ComputeType
>
(
reinterpret_cast
<
const
ComputeType
*>
(
gating_output
.
data_ptr
()),
topk_weights
.
data_ptr
<
float
>
(),
topk_indices
.
data_ptr
<
int64_t
>
(),
token_expert_indices
.
data_ptr
<
int
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
num_tokens
,
num_experts
,
topk
,
renormalize
,
stream
);
}
}
void
topk_softmax
(
torch
::
Tensor
&
topk_weights
,
// [num_tokens, topk]
torch
::
Tensor
&
topk_indices
,
// [num_tokens, topk]
torch
::
Tensor
&
token_expert_indices
,
// [num_tokens, topk]
torch
::
Tensor
&
gating_output
)
// [num_tokens, num_experts]
torch
::
Tensor
&
gating_output
,
// [num_tokens, num_experts]
bool
renormalize
)
{
const
int
num_experts
=
gating_output
.
size
(
-
1
);
const
auto
num_tokens
=
gating_output
.
numel
()
/
num_experts
;
...
...
@@ -534,45 +689,19 @@ void topk_softmax(
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
gating_output
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
torch
::
Tensor
softmax_workspace
=
torch
::
empty
({
workspace_size
},
gating_output
.
options
());
if
(
topk_indices
.
scalar_type
()
==
at
::
ScalarType
::
Int
)
{
vllm
::
moe
::
topkGatingSoftmaxKernelLauncher
(
gating_output
.
data_ptr
<
float
>
(),
topk_weights
.
data_ptr
<
float
>
(),
topk_indices
.
data_ptr
<
int
>
(),
token_expert_indices
.
data_ptr
<
int
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
num_tokens
,
num_experts
,
topk
,
stream
);
}
else
if
(
topk_indices
.
scalar_type
()
==
at
::
ScalarType
::
UInt32
)
{
vllm
::
moe
::
topkGatingSoftmaxKernelLauncher
(
gating_output
.
data_ptr
<
float
>
(),
topk_weights
.
data_ptr
<
float
>
(),
topk_indices
.
data_ptr
<
uint32_t
>
(),
token_expert_indices
.
data_ptr
<
int
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
num_tokens
,
num_experts
,
topk
,
stream
);
}
else
{
TORCH_CHECK
(
topk_indices
.
scalar_type
()
==
at
::
ScalarType
::
Long
);
vllm
::
moe
::
topkGatingSoftmaxKernelLauncher
(
gating_output
.
data_ptr
<
float
>
(),
topk_weights
.
data_ptr
<
float
>
(),
topk_indices
.
data_ptr
<
int64_t
>
(),
token_expert_indices
.
data_ptr
<
int
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
num_tokens
,
num_experts
,
topk
,
stream
);
const
auto
workspace_options
=
gating_output
.
options
().
dtype
(
at
::
ScalarType
::
Float
);
torch
::
Tensor
softmax_workspace
=
torch
::
empty
({
workspace_size
},
workspace_options
);
if
(
gating_output
.
scalar_type
()
==
at
::
ScalarType
::
Float
)
{
dispatch_topk_softmax_launch
<
float
>
(
gating_output
,
topk_weights
,
topk_indices
,
token_expert_indices
,
softmax_workspace
,
num_tokens
,
num_experts
,
topk
,
renormalize
,
stream
);
}
else
if
(
gating_output
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
dispatch_topk_softmax_launch
<
__half
>
(
gating_output
,
topk_weights
,
topk_indices
,
token_expert_indices
,
softmax_workspace
,
num_tokens
,
num_experts
,
topk
,
renormalize
,
stream
);
}
else
if
(
gating_output
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
dispatch_topk_softmax_launch
<
__nv_bfloat16
>
(
gating_output
,
topk_weights
,
topk_indices
,
token_expert_indices
,
softmax_workspace
,
num_tokens
,
num_experts
,
topk
,
renormalize
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported gating_output data type: "
,
gating_output
.
scalar_type
());
}
}
csrc/moe/torch_bindings.cpp
View file @
006693ed
...
...
@@ -5,7 +5,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply topk softmax to the gating outputs.
m
.
def
(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()"
);
"token_expert_indices, Tensor gating_output
, bool renormalize
) -> ()"
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
// Calculate the result of moe by summing up the partial results
...
...
@@ -22,6 +22,33 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor! num_tokens_post_pad) -> ()"
);
m
.
impl
(
"moe_align_block_size"
,
torch
::
kCUDA
,
&
moe_align_block_size
);
// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size, but for the batched case.
m
.
def
(
"batched_moe_align_block_size(int max_tokens_per_batch,"
" int block_size, Tensor expert_num_tokens,"
" Tensor! sorted_token_ids,"
" Tensor! experts_ids,"
" Tensor! num_tokens_post_pad) -> ()"
);
m
.
impl
(
"batched_moe_align_block_size"
,
torch
::
kCUDA
,
&
batched_moe_align_block_size
);
// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size.
m
.
def
(
"moe_lora_align_block_size(Tensor topk_ids,"
" Tensor token_lora_mapping,"
" int num_experts,"
" int block_size, int max_loras, "
" int max_num_tokens_padded, "
" int max_num_m_blocks, "
" Tensor !sorted_token_ids,"
" Tensor !experts_ids,"
" Tensor !num_tokens_post_pad,"
" Tensor !adapter_enabled,"
" Tensor !lora_ids) -> () "
);
m
.
impl
(
"moe_lora_align_block_size"
,
torch
::
kCUDA
,
&
moe_lora_align_block_size
);
#ifndef USE_ROCM
m
.
def
(
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "
...
...
@@ -80,9 +107,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply grouped topk routing to select experts.
m
.
def
(
"grouped_topk(Tensor scores,
Tensor scores_with_bias,
int n_group, int "
"grouped_topk(Tensor scores, int n_group, int "
"topk_group, int topk, bool renormalize, float "
"routed_scaling_factor) -> (Tensor, Tensor)"
);
"routed_scaling_factor, Tensor bias, int scoring_func) -> (Tensor, "
"Tensor)"
);
m
.
impl
(
"grouped_topk"
,
torch
::
kCUDA
,
&
grouped_topk
);
#endif
}
...
...
csrc/ops.h
View file @
006693ed
...
...
@@ -92,14 +92,25 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
double
epsilon
);
void
poly_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
torch
::
Tensor
&
bias
,
double
epsilon
);
void
fused_qk_norm_rope
(
torch
::
Tensor
&
qkv
,
int64_t
num_heads_q
,
int64_t
num_heads_k
,
int64_t
num_heads_v
,
int64_t
head_dim
,
double
eps
,
torch
::
Tensor
&
q_weight
,
torch
::
Tensor
&
k_weight
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
,
torch
::
Tensor
&
position_ids
);
void
apply_repetition_penalties_
(
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
prompt_mask
,
const
torch
::
Tensor
&
output_mask
,
const
torch
::
Tensor
&
repetition_penalties
);
void
top_k_per_row
(
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
rowStarts
,
const
torch
::
Tensor
&
rowEnds
,
torch
::
Tensor
&
indices
,
int64_t
numRows
,
int64_t
stride0
,
int64_t
stride1
);
void
top_k_per_row_decode
(
const
torch
::
Tensor
&
logits
,
int64_t
next_n
,
const
torch
::
Tensor
&
seq_lens
,
torch
::
Tensor
&
indices
,
int64_t
numRows
,
int64_t
stride0
,
int64_t
stride1
);
// void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& weight, torch::Tensor& scale,
// double epsilon);
...
...
@@ -133,12 +144,12 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
input_global_scale
);
#endif
void
silu_mul_
fp8_
quant
_deep_gemm_cuda
(
const
at
::
Tensor
&
input
,
// (E, T, 2*H)
const
at
::
Tensor
&
counts
,
// (E)
at
::
Tensor
&
y_q
,
// (E, T, H) [OUT]
at
::
Tensor
&
y_s
,
// (E, T, H//group_size) [OUT]
int64_t
group_size
,
bool
use_ue8m0
,
int64_t
num_parallel_tokens
);
//
void
persistent_masked_m_
silu_mul_quant(
//
const at::Tensor& input, // (E, T, 2*H)
//
const at::Tensor& counts, // (E)
//
at::Tensor& y_q, // (E, T, H) [OUT]
//
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
//
bool use_ue8m0
);
void
mul_and_silu
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
...
...
@@ -304,7 +315,7 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch
::
Tensor
gptq_gemm
(
torch
::
Tensor
a
,
torch
::
Tensor
b_q_weight
,
torch
::
Tensor
b_gptq_qzeros
,
torch
::
Tensor
b_gptq_scales
,
torch
::
Tensor
b_g_idx
,
bool
use_exllama
,
int64_t
bit
);
bool
use_exllama
,
bool
use_v2_format
,
int64_t
bit
);
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int64_t
bit
);
...
...
@@ -318,17 +329,19 @@ void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
// torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
// std::optional<torch::Tensor> const& scale_ub);
void
selective_scan_fwd
(
const
torch
::
Tensor
&
u
,
const
torch
::
Tensor
&
delta
,
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
const
torch
::
Tensor
&
C
,
const
std
::
optional
<
torch
::
Tensor
>&
D_
,
const
std
::
optional
<
torch
::
Tensor
>&
z_
,
const
std
::
optional
<
torch
::
Tensor
>&
delta_bias_
,
bool
delta_softplus
,
const
std
::
optional
<
torch
::
Tensor
>&
query_start_loc
,
const
std
::
optional
<
torch
::
Tensor
>&
cache_indices
,
const
std
::
optional
<
torch
::
Tensor
>&
has_initial_state
,
const
torch
::
Tensor
&
ssm_states
,
int64_t
pad_slot_id
);
void
selective_scan_fwd
(
const
torch
::
Tensor
&
u
,
const
torch
::
Tensor
&
delta
,
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
const
torch
::
Tensor
&
C
,
const
std
::
optional
<
torch
::
Tensor
>&
D_
,
const
std
::
optional
<
torch
::
Tensor
>&
z_
,
const
std
::
optional
<
torch
::
Tensor
>&
delta_bias_
,
bool
delta_softplus
,
const
std
::
optional
<
torch
::
Tensor
>&
query_start_loc
,
const
std
::
optional
<
torch
::
Tensor
>&
cache_indices
,
const
std
::
optional
<
torch
::
Tensor
>&
has_initial_state
,
const
torch
::
Tensor
&
ssm_states
,
int64_t
pad_slot_id
,
int64_t
block_size
,
const
std
::
optional
<
torch
::
Tensor
>&
block_idx_first_scheduled_token
,
const
std
::
optional
<
torch
::
Tensor
>&
block_idx_last_scheduled_token
,
const
std
::
optional
<
torch
::
Tensor
>&
initial_state_idx
);
torch
::
Tensor
dynamic_4bit_int_moe_cpu
(
torch
::
Tensor
x
,
torch
::
Tensor
topk_ids
,
torch
::
Tensor
topk_weights
,
...
...
csrc/quantization/activation_kernels.cu
View file @
006693ed
...
...
@@ -7,7 +7,7 @@
#include "../cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/fp8/common.cuh"
#include "quantization/
w8a8/
fp8/common.cuh"
#include <c10/util/Float8_e4m3fn.h>
...
...
@@ -114,13 +114,22 @@ __global__ void act_and_mul_quant_kernel(
}
__device__
__forceinline__
float
silu
(
float
x
)
{
return
(
__fdividef
(
x
,
(
1.
f
+
expf
(
-
x
)))
)
;
return
__fdividef
(
x
,
(
1.
f
+
expf
(
-
x
)));
}
__device__
__forceinline__
float2
silu2
(
float2
x
)
{
return
make_float2
(
silu
(
x
.
x
),
silu
(
x
.
y
));
}
__device__
__forceinline__
__nv_bfloat162
silu2_v2
(
float2
x
)
{
#ifndef USE_ROCM
return
make_bfloat162
(
__float2bfloat16_rn
(
silu
(
x
.
x
)),
__float2bfloat16_rn
(
silu
(
x
.
y
)));
#else
return
__float22bfloat162_rn
(
make_float2
(
silu
(
x
.
x
),
silu
(
x
.
y
)));
#endif
}
#ifndef USE_ROCM
__device__
__forceinline__
float
warp_max
(
float
v
)
{
static
constexpr
unsigned
FULL_MASK
=
0xffffffffu
;
...
...
@@ -223,224 +232,337 @@ constexpr __nv_bfloat16 get_fp8_min() {
return
__nv_bfloat16
(
__nv_bfloat16_raw
{.
x
=
50032
});
}
}
#ifndef USE_ROCM
template
<
typename
fp8_type
,
int32_t
NUM_WARPS
,
typename
Idx_t
,
int
NUM_PARALLEL_TOKENS
,
bool
USE_UE8M0
,
int
GROUP_SIZE
=
128
,
int
NUM_STAGES
=
3
>
template
<
typename
Idx_t
>
__device__
__forceinline__
int
warp_expert_search
(
int
idx
,
int
n
,
const
Idx_t
*
__restrict__
input
,
Idx_t
val
)
{
const
Idx_t
*
input_ptr
=
input
+
idx
;
int
base_offset
=
0
;
for
(;;)
{
bool
move_on
=
(
idx
<
n
&&
*
input_ptr
<=
val
);
unsigned
mask
=
__ballot_sync
(
0xffffffff
,
move_on
);
if
(
mask
!=
0xffffffffu
)
{
int
last_lane
=
31
-
__clz
(
mask
);
return
base_offset
+
last_lane
;
}
input_ptr
+=
32
;
base_offset
+=
32
;
idx
+=
32
;
}
}
template
<
int
num_parallel_tokens
>
__device__
__forceinline__
void
token_bounds
(
int32_t
n_tokens
,
int32_t
worker_id
,
int32_t
&
n_tokens_lower
,
int32_t
&
n_tokens_upper
)
{
if
(
n_tokens
<
num_parallel_tokens
&&
worker_id
<
n_tokens
)
{
if
(
worker_id
>=
num_parallel_tokens
)
return
;
n_tokens_lower
=
worker_id
;
n_tokens_upper
=
worker_id
+
1
;
}
else
{
int32_t
chunk_size
=
n_tokens
/
num_parallel_tokens
;
int32_t
residual
=
n_tokens
-
chunk_size
*
num_parallel_tokens
;
auto
calc_id
=
[
&
](
int32_t
id
)
{
if
(
id
<
residual
)
return
min
(
n_tokens
,
id
*
(
chunk_size
+
1
));
else
return
min
(
n_tokens
,
id
*
chunk_size
+
residual
);
};
n_tokens_lower
=
calc_id
(
worker_id
);
n_tokens_upper
=
calc_id
(
worker_id
+
1
);
}
}
template
<
int
BLOCK_COUNT
,
int
SMEM_SIZE_BYTES_Y
,
typename
fp8_type
,
typename
scale_t
,
int
THREADS
,
typename
Idx_t
,
bool
CEIL_UE8M0
,
int
GROUP_SIZE
=
128
,
int
NUM_STAGES
=
3
>
__global__
void
silu_mul_fp8_quant_deep_gemm_kernel
(
const
__nv_bfloat16
*
__restrict__
_input
,
fp8_type
*
__restrict__
_y_q
,
float
*
__restrict__
_y_s
,
const
int32_t
*
__restrict__
counts
,
scale_t
*
__restrict__
_y_s
,
const
int32_t
*
__restrict__
tokens_per_expert
,
// sizes
int
H
,
int
G
,
Idx_t
E
,
Idx_t
T
,
Idx_t
H
,
// strides (in elements)
Idx_t
stride_i_e
,
Idx_t
stride_i_t
,
Idx_t
stride_i_h
,
Idx_t
stride_yq_e
,
Idx_t
stride_yq_t
,
Idx_t
stride_yq_h
,
Idx_t
stride_ys_e
,
Idx_t
stride_ys_t
,
Idx_t
stride_ys_g
,
Idx_t
stride_counts_e
)
{
static
constexpr
__nv_bfloat16
fp8_min
=
get_fp8_min
<
fp8_type
>
();
static
constexpr
__nv_bfloat16
fp8_max
=
get_fp8_max
<
fp8_type
>
();
// We assign EPS with its 16-bit unsigned counterpart to allow constexpr.
static
constexpr
__nv_bfloat16
EPS
=
(
__nv_bfloat16_raw
{.
x
=
11996
});
Idx_t
stride_ys_g
,
Idx_t
stride_ys_p
,
Idx_t
stride_counts_e
)
{
#ifndef USE_ROCM
static
constexpr
int
NUM_WARPS
=
THREADS
/
WARP_SIZE
;
// We pack 8 16-bit bfloat16 values into a 128-bit __int128_t.
static
constexpr
int
32_t
BFLOAT16_PER_GROUP
=
8
;
static
constexpr
int
LOAD_STAGE_SIZE
=
2
*
GROUP_SIZE
/
8
;
static
constexpr
int
LOAD_STAGE_MOD
=
NUM_STAGES
*
LOAD_STAGE_SIZE
;
// We split the shared memory in half, corresponding to gate and up matrices:
// [...gate_i, ...up_i] where 0 <= i < stages.
static
constexpr
int32_t
S_NUM_128
=
2u
*
(
GROUP_SIZE
/
BFLOAT16_PER_GROUP
)
*
NUM_WARPS
*
NUM_STAGES
;
static
constexpr
auto
THREAD_COUNT
=
NUM_WARPS
*
WARP_SIZE
;
static
constexpr
int
HALF_THREAD_COUNT
=
THREAD_COUNT
/
2
;
static
constexpr
int32_t
S_NUM_64
=
S_NUM_128
*
2
;
__shared__
__int128_t
__align__
(
16
)
s_buff_128
[
S_NUM_128
];
static
constexpr
int
COMPUTE_STAGE_SIZE
=
2
*
GROUP_SIZE
/
4
;
static
constexpr
int
COMPUTE_STAGE_MOD
=
COMPUTE_STAGE_SIZE
*
NUM_STAGES
;
const
int32_t
tid
=
threadIdx
.
x
;
const
int32_t
warp_id
=
tid
/
WARP_SIZE
;
const
int32_t
lane_id
=
tid
%
WARP_SIZE
;
extern
__shared__
__align__
(
16
)
__int128_t
smem_128
[];
auto
s_buff_compute_32
=
reinterpret_cast
<
__nv_bfloat162
*>
(
s_buff_128
);
int
*
s_expert_offsets
=
reinterpret_cast
<
int
*>
(
smem_128
+
(
SMEM_SIZE_BYTES_Y
/
16
));
// block handles one (expert e, group g)
int32_t
pid
=
blockIdx
.
x
;
int32_t
e
=
pid
/
G
;
int32_t
g
=
pid
%
G
;
static
constexpr
__nv_bfloat16
fp8_min
=
get_fp8_min
<
fp8_type
>
();
static
constexpr
__nv_bfloat16
fp8_max
=
get_fp8_max
<
fp8_type
>
();
// We assign EPS with it's 16-bit unsigned counterpart to allow constexpr.
static
constexpr
__nv_bfloat16
EPS
=
(
__nv_bfloat16_raw
{.
x
=
11996
});
int
tid
=
threadIdx
.
x
;
int
warp_id
=
tid
>>
5
;
int
lane_id
=
tid
&
0x1f
;
int
running_sum
{};
if
(
!
warp_id
)
{
for
(
int
i
=
0
;
i
<
E
;
i
+=
WARP_SIZE
)
{
bool
valid
=
(
i
+
threadIdx
.
x
)
<
E
;
int
value
=
(
valid
?
tokens_per_expert
[
i
+
threadIdx
.
x
*
stride_counts_e
]
:
0
)
+
(
!
lane_id
?
running_sum
:
0
);
for
(
int
offset
=
1
;
offset
<
32
;
offset
*=
2
)
{
int
n
=
__shfl_up_sync
(
0xFFFFFFFFu
,
value
,
offset
);
if
(
lane_id
>=
offset
)
value
+=
n
;
}
const
int32_t
n_tokens
=
counts
[
e
*
stride_counts_e
];
if
(
valid
)
{
s_expert_offsets
[
i
+
threadIdx
.
x
+
1
]
=
value
;
}
if
(
!
n_tokens
)
{
return
;
// Exit ASAP.
running_sum
=
__shfl_sync
(
0xFFFFFFFFu
,
value
,
WARP_SIZE
-
1
);
}
if
(
!
lane_id
)
{
s_expert_offsets
[
0
]
=
0
;
}
}
const
Idx_t
stride_i_t_128
=
stride_i_t
/
8u
;
__syncthreads
()
;
int32_t
n_tokens_lower
,
n_tokens_upper
;
int32_t
total_tokens
=
s_expert_offsets
[
E
]
;
const
int
warp_position_yq
=
warp_id
*
(
H
/
NUM_WARPS
);
const
int
warp_position_scales
=
warp_id
*
(
H
/
(
GROUP_SIZE
*
NUM_WARPS
));
// A single block will handle tokens_per_block tokens.
// Each block i iterates over tokens of a slice of n_tokens =
// expert_counts[i], with the size of chunk being
// (n_tokens / NUM_PARALLEL_TOKENS) + residual, instead of
// updiv(n_tokens, NUM_PARALLEL_TOKENS) for better scheduling.
if
(
n_tokens
<
NUM_PARALLEL_TOKENS
&&
blockIdx
.
y
<
n_tokens
)
{
// Specialize this, but can be likely fused.
if
(
blockIdx
.
y
>=
NUM_PARALLEL_TOKENS
)
{
return
;
}
n_tokens_lower
=
blockIdx
.
y
;
n_tokens_upper
=
blockIdx
.
y
+
1
;
}
else
{
auto
chunk_size
=
n_tokens
/
NUM_PARALLEL_TOKENS
;
auto
residual
=
n_tokens
-
chunk_size
*
NUM_PARALLEL_TOKENS
;
auto
calc_id
=
[
&
](
int32_t
id
)
{
if
(
id
<
residual
)
{
return
min
(
n_tokens
,
id
*
(
chunk_size
+
1
));
}
else
{
return
min
(
n_tokens
,
id
*
chunk_size
+
residual
);
}
};
n_tokens_lower
=
calc_id
(
blockIdx
.
y
);
n_tokens_upper
=
calc_id
(
blockIdx
.
y
+
1
);
}
if
(
n_tokens_lower
>=
n_tokens_upper
)
{
// Each warp will get space to store its hidden dim for gate and up.
__int128_t
*
s_hidden_load
=
smem_128
+
warp_id
*
((
2
*
128
/
8
)
*
NUM_STAGES
);
__int128_t
*
smem_load_ptr
=
s_hidden_load
+
lane_id
;
const
__nv_bfloat16
fp8_inv
=
__hdiv
(
__float2bfloat16
(
1.
f
),
fp8_max
);
int32_t
compute_pipeline_offset_64
=
0
;
int32_t
load_stage_offset
{};
const
__nv_bfloat16
one_bf16
=
__float2bfloat16_rn
(
1.
f
);
__int64_t
*
smem_compute_ptr
=
reinterpret_cast
<
__int64_t
*>
(
smem_128
)
+
warp_id
*
(
2
*
(
GROUP_SIZE
/
4
)
*
NUM_STAGES
)
+
lane_id
;
__int64_t
*
s_gate64_ptr
=
smem_compute_ptr
;
__int64_t
*
s_up64_ptr
=
smem_compute_ptr
+
GROUP_SIZE
/
4
;
int
tokens_lower
,
tokens_upper
;
token_bounds
<
BLOCK_COUNT
>
(
total_tokens
,
blockIdx
.
x
,
tokens_lower
,
tokens_upper
);
Idx_t
expert_id
{},
expert_offset
{},
next_expert_offset
{};
int
token_id
=
tokens_lower
;
int32_t
t_load
{};
if
(
token_id
<
tokens_upper
)
{
expert_id
=
warp_expert_search
<
int
>
(
lane_id
,
E
,
s_expert_offsets
,
token_id
);
expert_offset
=
s_expert_offsets
[
expert_id
];
next_expert_offset
=
s_expert_offsets
[
expert_id
+
1
];
}
else
{
// This thread block has no work to do.
return
;
}
// We do calculations here, using constexpr wherever possible.
const
Idx_t
base_i
=
e
*
stride_i_e
+
NUM_WARPS
*
g
*
GROUP_SIZE
*
stride_i_h
;
const
Idx_t
base_ys
=
e
*
stride_ys_e
+
NUM_WARPS
*
g
*
stride_ys_g
;
const
Idx_t
base_yq
=
e
*
stride_yq_e
+
NUM_WARPS
*
g
*
GROUP_SIZE
*
stride_yq_h
;
Idx_t
gate_off_128
=
(
base_i
/
static_cast
<
Idx_t
>
(
8u
));
auto
input_128_ptr
=
reinterpret_cast
<
const
__int128_t
*>
(
_input
);
auto
gate_128_ptr
=
input_128_ptr
+
gate_off_128
+
(
tid
%
HALF_THREAD_COUNT
)
+
stride_i_t_128
*
n_tokens_lower
;
auto
up_128_ptr
=
gate_128_ptr
+
(
H
*
stride_i_h
)
/
8u
;
auto
y_s_ptr
=
_y_s
+
base_ys
+
warp_id
*
stride_ys_g
+
n_tokens_lower
*
stride_ys_t
;
auto
y_q_ptr
=
_y_q
+
base_yq
+
warp_id
*
GROUP_SIZE
+
stride_yq_t
*
n_tokens_lower
+
4
*
lane_id
;
int32_t
t_load
=
n_tokens_lower
,
load_stage_id
=
0
;
auto
s_buff_gate_load_128
=
s_buff_128
+
(
tid
%
HALF_THREAD_COUNT
);
auto
s_buff_up_load_128
=
s_buff_gate_load_128
+
S_NUM_128
/
2u
;
int32_t
stage_offset
{};
static
constexpr
int32_t
LOAD_STAGE_SIZE
=
(
NUM_WARPS
*
WARP_SIZE
/
2
);
static
constexpr
int32_t
LOAD_STAGE_MOD
=
NUM_STAGES
*
(
NUM_WARPS
*
WARP_SIZE
/
2
);
// Two halves of all threads in a block conduct global loads for gate and up,
// repsectively.
int
t_load_bound
=
H
/
(
GROUP_SIZE
*
NUM_WARPS
);
Idx_t
base_i
=
((
expert_id
*
stride_i_e
)
/
8
)
+
(
token_id
-
expert_offset
)
*
stride_i_t
/
8
;
const
Idx_t
gate_warp_offset
=
warp_id
*
((
stride_i_h
*
H
)
/
(
8
*
NUM_WARPS
))
+
(
lane_id
&
0b1111
);
const
__int128_t
*
input_128_ptr
=
reinterpret_cast
<
const
__int128_t
*>
(
_input
)
+
gate_warp_offset
+
((
lane_id
<
16
)
?
0
:
((
H
*
stride_i_h
)
/
8
));
__int128_t
*
load_ptr
=
const_cast
<
__int128_t
*>
(
input_128_ptr
+
base_i
);
auto
token_offset
=
token_id
-
expert_offset
;
auto
load_and_advance_y_pred
=
[
&
]
{
if
(
t_load
<
n_tokens_upper
)
{
auto
s_gate_stage_128_staged_ptr
=
s_buff_gate_load_128
+
stage_offset
;
auto
s_up_stage_128_staged_ptr
=
s_buff_up_load_128
+
stage_offset
;
if
(
t_load
<
t_load_bound
)
{
// Here we are simply continuing to load data
// from the current token.
auto
smem_load_ptr_staged
=
smem_load_ptr
+
load_stage_offset
;
// It is very important that LOAD_STAGE_SIZE is constexpr to avoid
// unnecessary ALU ops.
stage_offset
+=
LOAD_STAGE_SIZE
;
stage_offset
%=
LOAD_STAGE_MOD
;
load_
stage_offset
+=
LOAD_STAGE_SIZE
;
load_
stage_offset
%=
LOAD_STAGE_MOD
;
if
(
tid
<
HALF_THREAD_COUNT
)
{
cp_async4
(
s_gate_stage_128_staged_ptr
,
gate_128_ptr
);
gate_128_ptr
+=
stride_i_t_128
;
cp_async4
(
smem_load_ptr_staged
,
load_ptr
);
load_ptr
+=
GROUP_SIZE
/
8
;
++
t_load
;
}
else
if
(
token_id
+
1
<
tokens_upper
)
{
// We loaded everything from the current token, let's move on
// to the next one, and we checked that we have more tokens to load.
++
token_id
;
t_load
=
0
;
if
(
token_id
>=
next_expert_offset
)
{
// We need to find the next expert.
do
{
// This is a loop because it's possible
// that some experts are assigned 0 tokens.
// NOTE: We are guaranteed that there's at least
// one more token left so we don't have to check for
// expert_id bounds.
++
expert_id
;
// This skips 1 memory read.
expert_offset
=
next_expert_offset
;
next_expert_offset
=
s_expert_offsets
[
expert_id
+
1
];
}
while
(
next_expert_offset
==
expert_offset
);
base_i
=
expert_id
*
(
stride_i_e
/
8
);
token_offset
=
0
;
load_ptr
=
const_cast
<
__int128_t
*>
(
input_128_ptr
+
base_i
);
}
else
{
cp_async4
(
s_up_stage_128_staged_ptr
,
up_128_ptr
);
up_128_ptr
+=
stride_i_t_128
;
// We remain within the same expert, so just
// move by H/4 __int128_t (2 * H/8).
base_i
+=
stride_yq_t
/
4
;
token_offset
++
;
}
load_ptr
=
const_cast
<
__int128_t
*>
(
input_128_ptr
+
base_i
);
auto
smem_load_ptr_staged
=
smem_load_ptr
+
load_stage_offset
;
// It is very important that LOAD_STAGE_SIZE is constexpr to avoid
// unnecessary ALU ops.
load_stage_offset
+=
LOAD_STAGE_SIZE
;
load_stage_offset
%=
LOAD_STAGE_MOD
;
cp_async4
(
smem_load_ptr_staged
,
load_ptr
);
load_ptr
+=
GROUP_SIZE
/
8
;
++
t_load
;
++
load_stage_id
;
}
// We fence even if there is nothing to load to simplify pipelining.
cp_async_fence
();
};
// We need to warm-up the pipeline.
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_STAGES
-
1
;
i
++
)
{
load_and_advance_y_pred
();
}
__int64_t
*
s_gate_ptr
=
reinterpret_cast
<
__int64_t
*>
(
s_buff_compute_32
+
warp_id
*
(
GROUP_SIZE
/
2
))
+
lane_id
;
__int64_t
*
s_up_ptr
=
s_gate_ptr
+
S_NUM_64
/
2
;
__nv_fp8x4_e4m3
*
y_q_base_ptr
=
reinterpret_cast
<
__nv_fp8x4_e4m3
*>
(
_y_q
)
+
lane_id
;
static
constexpr
int32_t
STAGE_SIZE
=
(
GROUP_SIZE
*
NUM_WARPS
)
/
4u
;
static
constexpr
int32_t
STAGE_MOD
=
STAGE_SIZE
*
NUM_STAGES
;
Idx_t
scale_group_offset
=
0
;
if
constexpr
(
std
::
is_same
<
scale_t
,
uint8_t
>::
value
)
{
// packed int32_t format
int
pack_id
=
warp_position_scales
/
4
;
int
scale_in_pack
=
warp_position_scales
%
4
;
scale_group_offset
=
pack_id
*
stride_ys_p
+
scale_in_pack
*
stride_ys_g
;
}
else
{
scale_group_offset
=
warp_position_scales
*
stride_ys_g
;
}
int32
_t
co
mpute_pipeline_offset_64
=
0
;
scale
_t
*
co
nst
y_scale_base_ptr
=
_y_s
+
scale_group_offset
;
for
(
int32_t
t
=
n_tokens_lower
;
t
<
n_tokens_upper
;
++
t
)
{
__nv_bfloat162
results_bf162
[
2
];
for
(
auto
j
=
tokens_lower
;
j
<
tokens_upper
;
j
++
)
{
int
current_group_id
=
warp_position_scales
;
// Running count of which
// group is being processed
const
Idx_t
base_ys
=
expert_id
*
stride_ys_e
;
auto
y_s_ptr
=
y_scale_base_ptr
+
base_ys
+
token_offset
*
stride_ys_t
;
__nv_fp8x4_e4m3
*
y_q_ptr
=
y_q_base_ptr
+
(
expert_id
*
stride_yq_e
+
token_offset
*
stride_yq_t
+
warp_position_yq
*
stride_yq_h
)
/
4
;
const
int
COMPUTE_LIMIT
=
H
/
(
GROUP_SIZE
*
NUM_WARPS
);
cp_async_wait
<
NUM_STAGES
-
2
>
();
__syncthreads
();
for
(
int
i
=
0
;
i
<
COMPUTE_LIMIT
;
i
++
)
{
cp_async_wait
<
NUM_STAGES
-
2
>
();
__syncthreads
();
load_and_advance_y_pred
();
// We double-buffer pipelined loads so that the next load will
// concurrently run with compute without overwrites.
load_and_advance_y_pred
();
__int64_t
*
gate64_ptr
=
s_gate64_ptr
+
compute_pipeline_offset_64
;
__int64_t
*
up64_ptr
=
s_up64_ptr
+
compute_pipeline_offset_64
;
auto
s_gate_compute_64
=
s_gate_ptr
+
compute_pipeline_offset_64
;
auto
s_up_compute_64
=
s_up_ptr
+
compute_pipeline_offset_64
;
// COMPUTE_STAGE_SIZE/MOD must also be constexpr!
compute_pipeline_offset_64
+=
COMPUTE_STAGE_SIZE
;
compute_pipeline_offset_64
%=
COMPUTE_STAGE_MOD
;
// STAGE_SIZE must also be constexpr!
compute_pipeline_offset_64
+=
STAGE_SIZE
;
compute_pipeline_offset_64
%=
STAGE_MOD
;
__int64_t
gate64
=
*
gate64_ptr
;
__int64_t
up64
=
*
up64_ptr
;
// Each thread loads (gate/up) 2X 4X bfloat16 values into registers.
__int64_t
gate64
=
*
s_gate_compute_64
;
__nv_bfloat162
*
s_gate_compute_32
=
reinterpret_cast
<
__nv_bfloat162
*>
(
&
gate64
);
__int64_t
up64
=
*
s_up_compute_64
;
__nv_bfloat162
*
s_up_compute_32
=
reinterpret_cast
<
__nv_bfloat162
*>
(
&
up64
);
// Compute
__nv_bfloat162
res
[
2
];
__nv_bfloat162
*
s_up_comp
=
reinterpret_cast
<
__nv_bfloat162
*>
(
&
up64
);
__nv_bfloat162
*
s_gate_comp
=
reinterpret_cast
<
__nv_bfloat162
*>
(
&
gate64
);
#pragma unroll
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
// For silu, we make sure that div is emitted.
float2
gate
=
silu2
(
__bfloat1622float2
(
s_gate_compute_32
[
i
]));
results_bf162
[
i
]
=
__float22bfloat162_rn
(
gate
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
results_bf162
[
i
]
=
__hmul2
(
results_bf162
[
i
],
s_up_compute_32
[
i
]);
}
for
(
int32_t
k
=
0
;
k
<
2
;
++
k
)
{
__nv_bfloat162
gate
=
silu2_v2
(
__bfloat1622float2
(
s_gate_comp
[
k
]));
res
[
k
]
=
__hmul2
(
gate
,
s_up_comp
[
k
]);
}
auto
_y_max2
=
__hmax2
(
__habs2
(
results_bf162
[
0
]),
__habs2
(
results_bf162
[
1
]));
auto
_y_max2
=
__hmax2
(
__habs2
(
res
[
0
]),
__habs2
(
res
[
1
]));
__nv_bfloat16
y_max_bf16
=
__hmax
(
EPS
,
__hmax
(
_y_max2
.
x
,
_y_max2
.
y
));
_y_max2
.
x
=
__hmax
(
__hmax
(
_y_max2
.
x
,
_y_max2
.
y
)
,
EPS
);
// An entire group is assigned to a single warp, so a simple warp reduce
// is used.
__nv_bfloat16
y_s
=
warp_max
(
y_max_bf16
)
/
fp8_max
;
__nv_bfloat16
y_s
=
__hmul
(
warp_max
(
_y_max2
.
x
),
fp8_inv
);
if
constexpr
(
USE
_UE8M0
)
{
y_s
=
hexp2
(
hceil
(
hlog2
(
y_s
)));
}
if
constexpr
(
CEIL
_UE8M0
)
{
y_s
=
hexp2
(
hceil
(
hlog2
(
y_s
)));
}
auto
inv_y
=
__float2bfloat16_rn
(
1.
f
)
/
y_s
;
__nv_bfloat16
inv_y
=
__hdiv
(
one_bf16
,
y_s
)
;
auto
y_s2
=
make_bfloat162
(
inv_y
,
inv_y
);
auto
y_s2
=
make_bfloat162
(
inv_y
,
inv_y
);
#pragma unroll
for
(
int32_t
i
=
0
;
i
<
2
;
++
i
)
{
results_bf162
[
i
]
=
clip
(
__hmul2
(
results_bf162
[
i
],
y_s2
),
__bfloat162bfloat162
(
fp8_min
),
__bfloat162bfloat162
(
fp8_max
));
}
for
(
int32_t
k
=
0
;
k
<
2
;
++
k
)
{
res
[
k
]
=
clip
(
__hmul2
(
res
[
k
],
y_s2
),
__bfloat162bfloat162
(
fp8_min
),
__bfloat162bfloat162
(
fp8_max
));
}
*
y_q_ptr
=
__nv_fp8x4_e4m3
(
res
[
0
],
res
[
1
]);
y_q_ptr
+=
WARP_SIZE
*
stride_yq_h
;
if
(
!
lane_id
)
{
// Store scales.
if
constexpr
(
std
::
is_same
<
scale_t
,
uint8_t
>::
value
)
{
// Packed UE8MO format. Remove Mantissa.
*
y_s_ptr
=
reinterpret_cast
<
int16_t
&>
(
y_s
)
>>
7
;
bool
const
jump_pack
=
(
current_group_id
+
1
)
%
4
==
0
;
// Minus 3 because we need to get to the first group in the
// next pack.
y_s_ptr
+=
jump_pack
?
(
stride_ys_p
-
3
)
:
stride_ys_g
;
auto
fp8x4
=
__nv_fp8x4_e4m3
(
results_bf162
[
0
],
results_bf162
[
1
]);
*
reinterpret_cast
<
__nv_fp8x4_e4m3
*>
(
y_q_ptr
)
=
fp8x4
;
y_q_ptr
+=
stride_yq_t
;
}
else
{
// float32 format
static_assert
(
std
::
is_same
<
scale_t
,
float
>::
value
);
*
y_s_ptr
=
y_s
;
y_s_ptr
+=
stride_ys_g
;
}
if
(
lane_id
==
0
)
{
*
y_s_ptr
=
y_s
;
y_s_ptr
+=
stride_ys_t
;
current_group_id
+=
1
;
}
}
}
}
#endif
}
}
// namespace vllm
...
...
@@ -475,25 +597,26 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d]
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
silu_kernel
);
}
void
silu_mul_
fp8_
quant
_deep_gemm_cuda
(
const
at
::
Tensor
&
input
,
// (E, T, 2*H)
const
at
::
Tensor
&
counts
,
// (E)
at
::
Tensor
&
y_q
,
// (E, T, H) [OUT]
at
::
Tensor
&
y_s
,
// (E, T, H//group_size) [OUT]
int64_t
group_size
,
bool
use_ue8m0
,
int64_t
num_parallel_tokens
)
{
void
persistent_masked_m_
silu_mul_quant
(
const
at
::
Tensor
&
input
,
// (E, T, 2*H)
const
at
::
Tensor
&
tokens_per_expert
,
// (E)
at
::
Tensor
&
y_q
,
// (E, T, H) [OUT]
at
::
Tensor
&
y_s
,
// (E, T, H//group_size) [OUT]
bool
cast_scale_ue8m0
)
{
#ifndef USE_ROCM
// This kernel relies heavily on cp.async and fp8 support.
// This kernel currently only supports H % 128 == 0 and assumes a
// fixed GROUP_SIZE of 128.
static
constexpr
int
GROUP_SIZE
=
128
;
TORCH_CHECK
(
input
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
y_q
.
dtype
()
==
torch
::
kFloat8_e4m3fn
||
y_q
.
dtype
()
==
torch
::
kFloat8_e4m3fnuz
);
TORCH_CHECK
(
y_s
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
input
.
size
(
-
1
)
%
256
==
0
);
TORCH_CHECK
(
input
.
size
(
-
1
)
%
(
GROUP_SIZE
*
2
)
==
0
);
// Check that num_parallel_toke
ns is
of power of 2 and between 1 and 64.
TORCH_CHECK
(
1
<=
num_parallel_tokens
&&
num_parallel_tokens
<=
64
);
TORCH_CHECK
(
!
(
num_parallel_tokens
&
(
num_parallel_tokens
-
1
))
);
bool
co
ns
t
is
_packed_ue8m0
=
(
y_s
.
dtype
()
==
torch
::
kInt32
&&
cast_scale_ue8m0
);
TORCH_CHECK
(
y_s
.
dtype
()
==
torch
::
kFloat32
||
is_packed_ue8m0
);
using
Idx_t
=
int64_t
;
...
...
@@ -506,85 +629,107 @@ void silu_mul_fp8_quant_deep_gemm_cuda(
Idx_t
stride_yq_e
=
y_q
.
stride
(
0
);
Idx_t
stride_yq_t
=
y_q
.
stride
(
1
);
Idx_t
stride_yq_h
=
y_q
.
stride
(
2
);
Idx_t
stride_ys_e
=
y_s
.
stride
(
0
);
Idx_t
stride_ys_t
=
y_s
.
stride
(
1
);
Idx_t
stride_ys_g
=
y_s
.
stride
(
2
);
Idx_t
stride_counts_e
=
counts
.
stride
(
0
);
Idx_t
stride_counts_e
=
tokens_per_expert
.
stride
(
0
);
static
constexpr
int
GROUP_SIZE
=
128
;
int
const
NUM_GROUPS
=
H
/
GROUP_SIZE
;
#define KERNEL_FN \
if (use_ue8m0) { \
vllm::silu_mul_fp8_quant_deep_gemm_kernel<fp8_t, NUM_WARPS, Idx_t, \
NUM_PARALLEL_TOKENS, true> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
reinterpret_cast<int32_t*>(counts.data_ptr<int>()), H, G, \
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \
stride_counts_e); \
} else { \
vllm::silu_mul_fp8_quant_deep_gemm_kernel<fp8_t, NUM_WARPS, Idx_t, \
NUM_PARALLEL_TOKENS, false> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
reinterpret_cast<int32_t*>(counts.data_ptr<int>()), H, G, \
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \
stride_counts_e); \
}
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
#define KERNEL_CALL_H \
if (H % (4 * GROUP_SIZE) == 0) { \
static constexpr int NUM_WARPS = 4; \
populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \
KERNEL_FN \
} else { \
static constexpr int NUM_WARPS = 1; \
populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \
KERNEL_FN \
// TODO: Get this from cuda_arch ?
static
constexpr
int
SILU_V2_BLOCK_COUNT
=
132
*
32
;
#define KERNEL(BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \
STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, STAGES) \
static constexpr int NUM_WARPS = THREAD_COUNT / WARP_SIZE; \
int sms = SILU_V2_BLOCK_COUNT; \
static constexpr int max_shared_mem_bytes = \
GROUP_SIZE * 2 * STAGES * NUM_WARPS * 2; \
dim3 grid(sms), block(THREAD_COUNT); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
VLLM_DISPATCH_FP8_TYPES( \
y_q.scalar_type(), "silu_mul_fp8_quant_deep_gemm_kernel", [&] { \
vllm::silu_mul_fp8_quant_deep_gemm_kernel< \
BLOCK_COUNT, max_shared_mem_bytes, fp8_t, scale_t, THREAD_COUNT, \
Idx_t, CEIL_UE8M0, GROUP_SIZE, STAGES> \
<<<grid, block, max_shared_mem_bytes + (E + 1) * 16, stream>>>( \
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
(fp8_t*)y_q.data_ptr(), \
reinterpret_cast<scale_t*>(y_s.data_ptr()), \
reinterpret_cast<int32_t*>(tokens_per_expert.data_ptr()), E, \
T, H, stride_i_e, stride_i_t, stride_i_h, stride_yq_e, \
stride_yq_t, stride_yq_h, STRIDE_YS_E, STRIDE_YS_T, \
STRIDE_YS_G, STRIDE_YS_P, stride_counts_e); \
});
#define LAUNCH_ON_H(scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \
STRIDE_YS_P, CEIL_UE8M0) \
if (H >= 4096 && (NUM_GROUPS % 8) == 0) { \
/* 8 warp config */
\
static constexpr int NUM_STAGES = 4; \
static constexpr int THREAD_COUNT = 256; \
KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \
STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, NUM_STAGES); \
} else { \
/* 1 warp config */
\
static constexpr int THREAD_COUNT = 32; \
KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \
STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, 2); \
}
#define KERNEL_CALL_TOP_LEVEL \
if (num_parallel_tokens == 1) { \
static constexpr int NUM_PARALLEL_TOKENS = 1; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 2) { \
static constexpr int NUM_PARALLEL_TOKENS = 2; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 4) { \
static constexpr int NUM_PARALLEL_TOKENS = 4; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 8) { \
static constexpr int NUM_PARALLEL_TOKENS = 8; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 16) { \
static constexpr int NUM_PARALLEL_TOKENS = 16; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 32) { \
static constexpr int NUM_PARALLEL_TOKENS = 32; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 64) { \
static constexpr int NUM_PARALLEL_TOKENS = 64; \
KERNEL_CALL_H \
}
Idx_t
stride_ys_e
=
y_s
.
stride
(
0
);
Idx_t
stride_ys_t
=
y_s
.
stride
(
1
);
Idx_t
stride_ys_g
=
y_s
.
stride
(
2
);
Idx_t
stride_ys_p
=
0
;
if
(
!
cast_scale_ue8m0
)
{
TORCH_CHECK
(
!
is_packed_ue8m0
);
LAUNCH_ON_H
(
float
,
stride_ys_e
,
stride_ys_t
,
stride_ys_g
,
stride_ys_p
,
false
);
return
;
}
Idx_t
G
;
dim3
block
,
grid
;
auto
populate_launch_params
=
[
&
](
int
num_warps
,
int
_num_parallel_tokens
)
{
G
=
H
/
Idx_t
(
group_size
*
num_warps
);
grid
=
dim3
(
E
*
G
,
_num_parallel_tokens
);
block
=
dim3
(
num_warps
*
WARP_SIZE
);
};
if
(
!
is_packed_ue8m0
)
{
// UE8M0 but not packed
LAUNCH_ON_H
(
float
,
stride_ys_e
,
stride_ys_t
,
stride_ys_g
,
stride_ys_p
,
true
);
return
;
}
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
VLLM_DISPATCH_FP8_TYPES
(
y_q
.
scalar_type
(),
"silu_mul_fp8_quant_deep_gemm_kernel"
,
[
&
]
{
KERNEL_CALL_TOP_LEVEL
});
TORCH_CHECK
(
cast_scale_ue8m0
&&
is_packed_ue8m0
);
TORCH_CHECK
(
y_s
.
dtype
()
==
torch
::
kInt32
);
// Int32 packed ue8m0 scales tensor.
// Let E, T, G be the number to experts, number of tokens and number of groups
// respectively. Let, E = 2, T = 4, G = 6, in this case the int32 scales
// tensor are of shape [1, 4, 2] and stride [8, 1, 4]. The scales are expected
// to be arranged as follows,
// [[T0G0-T0G1-T0G2-T0G3, T0G4-T0G5-X-X,],
// [T1G0-T1G1-T1G2-T1G3, T1G4-T1G5-X-X,]
// [T2G0-T2G1-T2G2-T2G3, T2G4-T2G5-X-X,]
// [T3G0-T3G1-T3G2-T3G3, T3G4-T3G5-X-X,]]
// where, TxGy is the scale ue8m0 scale value of Token x, Group y.
//
// In memory (in bytes) the scale values are arranged as,
// [T0G0, T0G1, T0G2, T0G3, T1G0, T1G2, T1G3, T1G4, T2G0, T2G1, T2G3, T2G4,
// T3G0, T3G1, T3G2, T3G3, T0G4, T0G5, X, X, T1G4, T1G5, X, X, T2G4, T2G5,
// X, X, T3G4, T3G5, X, X]
//
// An Int32 tensor of size [1, 4, 2] and stride [8, 1, 4] can be represented
// as an uint8 tensor of shape [1, 2, 4, 4] and stride [32, 16, 4, 1]. In
// english, ignoring the Experts dimension, the original int32 tensor is
// simply treated as two packed [4, 4] uint8 tensor (or two [4, 1] int32
// tensor). The following strides setting reflects this change. Caveat: This
// means that the G dimension is no longer contiguous. i.e. Note that to move
// from G3 to G4, we need to jump along the packing dimension. The kernel
// handles this case.
stride_ys_e
*=
sizeof
(
int32_t
);
stride_ys_p
=
T
*
sizeof
(
int32_t
);
// Packing dimension
stride_ys_t
=
sizeof
(
int32_t
);
stride_ys_g
=
1
;
LAUNCH_ON_H
(
uint8_t
,
stride_ys_e
,
stride_ys_t
,
stride_ys_g
,
stride_ys_p
,
true
);
#endif
}
Prev
1
…
5
6
7
8
9
10
11
12
13
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment