Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6b39f9cf
"lmdeploy/lite/quantization/__init__.py" did not exist on "edb6eb86437d8f1c8df3d509d6e507e466742978"
Unverified
Commit
6b39f9cf
authored
Aug 28, 2025
by
Rain Jiang
Committed by
GitHub
Aug 28, 2025
Browse files
Support compile sgl-kernel on cuda 13.0 (#9721)
parent
07c9d8fb
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
78 additions
and
14 deletions
+78
-14
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+20
-9
sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py
sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py
+25
-2
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h
+1
-0
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cuh
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cuh
+1
-0
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cuh
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cuh
+1
-0
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cuh
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cuh
+1
-0
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cuh
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cuh
+1
-0
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cuh
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cuh
+1
-0
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cuh
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cuh
+1
-0
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_marlin.cuh
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_marlin.cuh
+10
-0
sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h
sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h
+2
-0
sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu
sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu
+1
-0
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
+13
-3
No files found.
sgl-kernel/CMakeLists.txt
View file @
6b39f9cf
...
...
@@ -78,7 +78,7 @@ FetchContent_Populate(repo-triton)
FetchContent_Declare
(
repo-flashinfer
GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git
GIT_TAG
9220fb3443b5a5d274f00ca5552f798e225239b7
GIT_TAG
018b551825c8e5579206e6eb9d3229fa679202b3
GIT_SHALLOW OFF
)
FetchContent_Populate
(
repo-flashinfer
)
...
...
@@ -174,11 +174,28 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
list
(
APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_100,code=sm_100"
"-gencode=arch=compute_100a,code=sm_100a"
"-gencode=arch=compute_101,code=sm_101"
"-gencode=arch=compute_101a,code=sm_101a"
"-gencode=arch=compute_120,code=sm_120"
"-gencode=arch=compute_120a,code=sm_120a"
)
# refer sm_121, sm_110 and sm_101 description https://github.com/pytorch/pytorch/pull/156176
if
(
"
${
CUDA_VERSION
}
"
VERSION_GREATER_EQUAL
"13.0"
)
list
(
APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_103,code=sm_103"
"-gencode=arch=compute_103a,code=sm_103a"
"-gencode=arch=compute_110,code=sm_110"
"-gencode=arch=compute_110a,code=sm_110a"
"-gencode=arch=compute_121,code=sm_121"
"-gencode=arch=compute_121a,code=sm_121a"
"--compress-mode=size"
)
else
()
list
(
APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_101,code=sm_101"
"-gencode=arch=compute_101a,code=sm_101a"
)
endif
()
else
()
list
(
APPEND SGL_KERNEL_CUDA_FLAGS
"-use_fast_math"
...
...
@@ -261,12 +278,6 @@ set(SOURCES
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
"csrc/moe/marlin_moe_wna16/ops.cu"
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu"
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu"
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu"
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu"
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu"
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu"
"csrc/moe/moe_align_kernel.cu"
"csrc/moe/moe_fused_gate.cu"
"csrc/moe/moe_topk_softmax_kernels.cu"
...
...
sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py
View file @
6b39f9cf
...
...
@@ -9,6 +9,7 @@ import jinja2
FILE_HEAD
=
"""
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel.h"
#include "marlin_template.h"
...
...
@@ -33,6 +34,17 @@ TEMPLATE = (
"( MARLIN_KERNEL_PARAMS );"
)
KERNEL_FILE_TEMPLATE
=
(
"// auto generated by generate.py
\n
"
"// clang-format off
\n
"
"#pragma once
\n\n
"
"{% for kernel_file in kernel_files %}"
'#include "{{ kernel_file }}"
\n
'
"{% endfor %}"
)
KERNEL_FILE_NAME
=
"kernel_marlin.cuh"
# int8 with zero point case (sglang::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES
=
[
"sglang::kU4"
,
"sglang::kU4B8"
,
"sglang::kU8B128"
]
...
...
@@ -48,11 +60,12 @@ DTYPES = ["fp16", "bf16"]
def
remove_old_kernels
():
for
filename
in
glob
.
glob
(
os
.
path
.
dirname
(
__file__
)
+
"/kernel_*.cu"
):
for
filename
in
glob
.
glob
(
os
.
path
.
dirname
(
__file__
)
+
"/kernel_*.cu
h
"
):
subprocess
.
call
([
"rm"
,
"-f"
,
filename
])
def
generate_new_kernels
():
kernel_files
=
set
()
for
scalar_type
,
dtype
in
itertools
.
product
(
SCALAR_TYPES
,
DTYPES
):
has_zp
=
"B"
not
in
scalar_type
all_template_str_list
=
[]
...
...
@@ -95,10 +108,20 @@ def generate_new_kernels():
file_content
=
FILE_HEAD
+
"
\n\n
"
file_content
+=
"
\n\n
"
.
join
(
all_template_str_list
)
+
"
\n\n
}
\n
"
filename
=
f
"kernel_
{
dtype
}
_
{
scalar_type
[
8
:].
lower
()
}
.cu"
filename
=
f
"kernel_
{
dtype
}
_
{
scalar_type
[
8
:].
lower
()
}
.cu
h
"
with
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
filename
),
"w"
)
as
f
:
f
.
write
(
file_content
)
kernel_files
.
add
(
filename
)
kernel_files
=
list
(
kernel_files
)
kernel_files
.
sort
()
file_content
=
jinja2
.
Template
(
KERNEL_FILE_TEMPLATE
).
render
(
kernel_files
=
kernel_files
)
with
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
KERNEL_FILE_NAME
),
"w"
)
as
f
:
f
.
write
(
file_content
)
if
__name__
==
"__main__"
:
...
...
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h
View file @
6b39f9cf
#pragma once
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
...
...
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu
→
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu
h
View file @
6b39f9cf
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel.h"
#include "marlin_template.h"
...
...
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu
→
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu
h
View file @
6b39f9cf
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel.h"
#include "marlin_template.h"
...
...
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu
→
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu
h
View file @
6b39f9cf
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel.h"
#include "marlin_template.h"
...
...
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu
→
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu
h
View file @
6b39f9cf
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel.h"
#include "marlin_template.h"
...
...
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu
→
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu
h
View file @
6b39f9cf
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel.h"
#include "marlin_template.h"
...
...
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu
→
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu
h
View file @
6b39f9cf
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel.h"
#include "marlin_template.h"
...
...
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_marlin.cuh
0 → 100644
View file @
6b39f9cf
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel_bf16_ku4.cuh"
#include "kernel_bf16_ku4b8.cuh"
#include "kernel_bf16_ku8b128.cuh"
#include "kernel_fp16_ku4.cuh"
#include "kernel_fp16_ku4b8.cuh"
#include "kernel_fp16_ku8b128.cuh"
sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h
View file @
6b39f9cf
...
...
@@ -18,6 +18,8 @@
/*
* Adapted from https://github.com/IST-DASLab/marlin
*/
#pragma once
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
...
...
sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu
View file @
6b39f9cf
...
...
@@ -24,6 +24,7 @@
#endif
#include "kernel.h"
#include "kernel_marlin.cuh"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert( \
...
...
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
View file @
6b39f9cf
...
...
@@ -23,6 +23,7 @@ limitations under the License.
#ifndef USE_ROCM
#include <cub/cub.cuh>
#include <cub/util_type.cuh>
#include <cuda/functional>
#else
#include <hipcub/hipcub.hpp>
#include <hipcub/util_type.hpp>
...
...
@@ -33,6 +34,16 @@ limitations under the License.
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
// Define reduction operators based on CUDA version
// CUDA 13 (12.9+) deprecated cub::Max/Min in favor of cuda::maximum/minimum
#if CUDA_VERSION >= 12090
using
MaxReduceOp
=
cuda
::
maximum
<>
;
using
MinReduceOp
=
cuda
::
minimum
<>
;
#else
using
MaxReduceOp
=
cub
::
Max
;
using
MinReduceOp
=
cub
::
Min
;
#endif
/// Aligned array type
template
<
typename
T
,
...
...
@@ -72,7 +83,6 @@ __launch_bounds__(TPB) __global__
const
int
thread_row_offset
=
blockIdx
.
x
*
num_cols
;
cub
::
Sum
sum
;
float
threadData
(
-
FLT_MAX
);
// Don't touch finished rows.
...
...
@@ -85,7 +95,7 @@ __launch_bounds__(TPB) __global__
threadData
=
max
(
convert_to_float
<
T
>
(
input
[
idx
]),
threadData
);
}
const
float
maxElem
=
BlockReduce
(
tmpStorage
).
Reduce
(
threadData
,
cub
::
Max
());
const
float
maxElem
=
BlockReduce
(
tmpStorage
).
Reduce
(
threadData
,
MaxReduceOp
());
if
(
threadIdx
.
x
==
0
)
{
float_max
=
maxElem
;
...
...
@@ -99,7 +109,7 @@ __launch_bounds__(TPB) __global__
threadData
+=
exp
((
convert_to_float
<
T
>
(
input
[
idx
])
-
float_max
));
}
const
auto
Z
=
BlockReduce
(
tmpStorage
).
Reduce
(
threadData
,
sum
);
const
auto
Z
=
BlockReduce
(
tmpStorage
).
Sum
(
threadData
);
if
(
threadIdx
.
x
==
0
)
{
normalizing_factor
=
1.
f
/
Z
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment