Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
79e6a8a6
"tools/vscode:/vscode.git/clone" did not exist on "91c63155e9159221f0a8378452462fdde91c09a6"
Unverified
Commit
79e6a8a6
authored
Aug 26, 2025
by
Rain Jiang
Committed by
GitHub
Aug 26, 2025
Browse files
support cuda 13.0 and trtllm kernel by Aug 25 2025 (#9495)
parent
8f7b1c31
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
81 additions
and
14 deletions
+81
-14
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+23
-9
sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py
sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py
+25
-2
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h
+1
-0
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cuh
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cuh
+1
-0
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cuh
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cuh
+1
-0
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cuh
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cuh
+1
-0
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cuh
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cuh
+1
-0
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cuh
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cuh
+1
-0
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cuh
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cuh
+1
-0
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_marlin.cuh
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_marlin.cuh
+10
-0
sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h
sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h
+2
-0
sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu
sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu
+1
-0
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
+13
-3
No files found.
sgl-kernel/CMakeLists.txt
View file @
79e6a8a6
...
...
@@ -57,6 +57,9 @@ if("${CUDA_VERSION}" VERSION_EQUAL "12.8")
elseif
(
"
${
CUDA_VERSION
}
"
VERSION_EQUAL
"12.9"
)
set
(
DeepGEMM_REPO
"https://github.com/sgl-project/DeepGEMM"
)
set
(
DeepGEMM_TAG
"blackwell"
)
elseif
(
"
${
CUDA_VERSION
}
"
VERSION_EQUAL
"13.0"
)
set
(
DeepGEMM_REPO
"https://github.com/sgl-project/DeepGEMM"
)
set
(
DeepGEMM_TAG
"blackwell"
)
else
()
set
(
DeepGEMM_REPO
"https://github.com/deepseek-ai/DeepGEMM"
)
set
(
DeepGEMM_TAG
"391755ada0ffefa9a6a52b6f14dcaf22d1a463e0"
)
...
...
@@ -83,7 +86,7 @@ FetchContent_Populate(repo-triton)
FetchContent_Declare
(
repo-flashinfer
GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git
GIT_TAG
9220fb3443b5a5d274f00ca5552f798e225239b7
GIT_TAG
018b551825c8e5579206e6eb9d3229fa679202b3
GIT_SHALLOW OFF
)
FetchContent_Populate
(
repo-flashinfer
)
...
...
@@ -179,11 +182,28 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
list
(
APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_100,code=sm_100"
"-gencode=arch=compute_100a,code=sm_100a"
"-gencode=arch=compute_10
1
,code=sm_10
1
"
"-gencode=arch=compute_10
1
a,code=sm_10
1
a"
"-gencode=arch=compute_10
3
,code=sm_10
3
"
"-gencode=arch=compute_10
3
a,code=sm_10
3
a"
"-gencode=arch=compute_120,code=sm_120"
"-gencode=arch=compute_120a,code=sm_120a"
)
# refer sm_121, sm_110 and sm_101 description https://github.com/pytorch/pytorch/pull/156176
if
(
"
${
CUDA_VERSION
}
"
VERSION_GREATER_EQUAL
"13.0"
)
list
(
APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_110,code=sm_110"
"-gencode=arch=compute_110a,code=sm_110a"
"-gencode=arch=compute_121,code=sm_121"
"-gencode=arch=compute_121a,code=sm_121a"
"--compress-mode=size"
)
else
()
list
(
APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_101,code=sm_101"
"-gencode=arch=compute_101a,code=sm_101a"
)
endif
()
else
()
list
(
APPEND SGL_KERNEL_CUDA_FLAGS
"-use_fast_math"
...
...
@@ -266,12 +286,6 @@ set(SOURCES
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
"csrc/moe/marlin_moe_wna16/ops.cu"
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu"
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu"
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu"
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu"
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu"
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu"
"csrc/moe/moe_align_kernel.cu"
"csrc/moe/moe_fused_gate.cu"
"csrc/moe/moe_topk_softmax_kernels.cu"
...
...
sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py
View file @
79e6a8a6
...
...
@@ -9,6 +9,7 @@ import jinja2
FILE_HEAD
=
"""
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel.h"
#include "marlin_template.h"
...
...
@@ -33,6 +34,17 @@ TEMPLATE = (
"( MARLIN_KERNEL_PARAMS );"
)
KERNEL_FILE_TEMPLATE
=
(
"// auto generated by generate.py
\n
"
"// clang-format off
\n
"
"#pragma once
\n\n
"
"{% for kernel_file in kernel_files %}"
'#include "{{ kernel_file }}"
\n
'
"{% endfor %}"
)
KERNEL_FILE_NAME
=
"kernel_marlin.cuh"
# int8 with zero point case (sglang::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES
=
[
"sglang::kU4"
,
"sglang::kU4B8"
,
"sglang::kU8B128"
]
...
...
@@ -48,11 +60,12 @@ DTYPES = ["fp16", "bf16"]
def
remove_old_kernels
():
for
filename
in
glob
.
glob
(
os
.
path
.
dirname
(
__file__
)
+
"/kernel_*.cu"
):
for
filename
in
glob
.
glob
(
os
.
path
.
dirname
(
__file__
)
+
"/kernel_*.cu
h
"
):
subprocess
.
call
([
"rm"
,
"-f"
,
filename
])
def
generate_new_kernels
():
kernel_files
=
set
()
for
scalar_type
,
dtype
in
itertools
.
product
(
SCALAR_TYPES
,
DTYPES
):
has_zp
=
"B"
not
in
scalar_type
all_template_str_list
=
[]
...
...
@@ -95,10 +108,20 @@ def generate_new_kernels():
file_content
=
FILE_HEAD
+
"
\n\n
"
file_content
+=
"
\n\n
"
.
join
(
all_template_str_list
)
+
"
\n\n
}
\n
"
filename
=
f
"kernel_
{
dtype
}
_
{
scalar_type
[
8
:].
lower
()
}
.cu"
filename
=
f
"kernel_
{
dtype
}
_
{
scalar_type
[
8
:].
lower
()
}
.cu
h
"
with
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
filename
),
"w"
)
as
f
:
f
.
write
(
file_content
)
kernel_files
.
add
(
filename
)
kernel_files
=
list
(
kernel_files
)
kernel_files
.
sort
()
file_content
=
jinja2
.
Template
(
KERNEL_FILE_TEMPLATE
).
render
(
kernel_files
=
kernel_files
)
with
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
KERNEL_FILE_NAME
),
"w"
)
as
f
:
f
.
write
(
file_content
)
if
__name__
==
"__main__"
:
...
...
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h
View file @
79e6a8a6
#pragma once
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
...
...
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu
→
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu
h
View file @
79e6a8a6
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel.h"
#include "marlin_template.h"
...
...
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu
→
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu
h
View file @
79e6a8a6
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel.h"
#include "marlin_template.h"
...
...
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu
→
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu
h
View file @
79e6a8a6
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel.h"
#include "marlin_template.h"
...
...
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu
→
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu
h
View file @
79e6a8a6
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel.h"
#include "marlin_template.h"
...
...
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu
→
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu
h
View file @
79e6a8a6
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel.h"
#include "marlin_template.h"
...
...
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu
→
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu
h
View file @
79e6a8a6
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel.h"
#include "marlin_template.h"
...
...
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_marlin.cuh
0 → 100644
View file @
79e6a8a6
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel_bf16_ku4.cuh"
#include "kernel_bf16_ku4b8.cuh"
#include "kernel_bf16_ku8b128.cuh"
#include "kernel_fp16_ku4.cuh"
#include "kernel_fp16_ku4b8.cuh"
#include "kernel_fp16_ku8b128.cuh"
sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h
View file @
79e6a8a6
...
...
@@ -18,6 +18,8 @@
/*
* Adapted from https://github.com/IST-DASLab/marlin
*/
#pragma once
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
...
...
sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu
View file @
79e6a8a6
...
...
@@ -24,6 +24,7 @@
#endif
#include "kernel.h"
#include "kernel_marlin.cuh"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert( \
...
...
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
View file @
79e6a8a6
...
...
@@ -23,6 +23,7 @@ limitations under the License.
#ifndef USE_ROCM
#include <cub/cub.cuh>
#include <cub/util_type.cuh>
#include <cuda/functional>
#else
#include <hipcub/hipcub.hpp>
#include <hipcub/util_type.hpp>
...
...
@@ -33,6 +34,16 @@ limitations under the License.
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
// Define reduction operators based on CUDA version
// CUDA 13 (12.9+) deprecated cub::Max/Min in favor of cuda::maximum/minimum
#if CUDA_VERSION >= 12090
using
MaxReduceOp
=
cuda
::
maximum
<>
;
using
MinReduceOp
=
cuda
::
minimum
<>
;
#else
using
MaxReduceOp
=
cub
::
Max
;
using
MinReduceOp
=
cub
::
Min
;
#endif
/// Aligned array type
template
<
typename
T
,
...
...
@@ -72,7 +83,6 @@ __launch_bounds__(TPB) __global__
const
int
thread_row_offset
=
blockIdx
.
x
*
num_cols
;
cub
::
Sum
sum
;
float
threadData
(
-
FLT_MAX
);
// Don't touch finished rows.
...
...
@@ -85,7 +95,7 @@ __launch_bounds__(TPB) __global__
threadData
=
max
(
convert_to_float
<
T
>
(
input
[
idx
]),
threadData
);
}
const
float
maxElem
=
BlockReduce
(
tmpStorage
).
Reduce
(
threadData
,
cub
::
Max
());
const
float
maxElem
=
BlockReduce
(
tmpStorage
).
Reduce
(
threadData
,
MaxReduceOp
());
if
(
threadIdx
.
x
==
0
)
{
float_max
=
maxElem
;
...
...
@@ -99,7 +109,7 @@ __launch_bounds__(TPB) __global__
threadData
+=
exp
((
convert_to_float
<
T
>
(
input
[
idx
])
-
float_max
));
}
const
auto
Z
=
BlockReduce
(
tmpStorage
).
Reduce
(
threadData
,
sum
);
const
auto
Z
=
BlockReduce
(
tmpStorage
).
Sum
(
threadData
);
if
(
threadIdx
.
x
==
0
)
{
normalizing_factor
=
1.
f
/
Z
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment