Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
42c87045
Unverified
Commit
42c87045
authored
Aug 20, 2025
by
fzyzcjy
Committed by
GitHub
Aug 20, 2025
Browse files
Add PDL support for quant kernel and rope kernel (#9106)
parent
c9bf3877
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
80 additions
and
33 deletions
+80
-33
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+2
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+0
-1
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+1
-1
sgl-kernel/csrc/elementwise/pos_enc.cuh
sgl-kernel/csrc/elementwise/pos_enc.cuh
+67
-31
sgl-kernel/csrc/elementwise/rope.cu
sgl-kernel/csrc/elementwise/rope.cu
+3
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+1
-0
sgl-kernel/python/sgl_kernel/elementwise.py
sgl-kernel/python/sgl_kernel/elementwise.py
+6
-0
No files found.
python/sglang/srt/entrypoints/engine.py
View file @
42c87045
...
...
@@ -635,6 +635,8 @@ def _set_envs_and_config(server_args: ServerArgs):
os
.
environ
[
"NCCL_NVLS_ENABLE"
]
=
str
(
int
(
server_args
.
enable_nccl_nvls
))
os
.
environ
[
"CUDA_DEVICE_MAX_CONNECTIONS"
]
=
"4"
os
.
environ
[
"CUDA_MODULE_LOADING"
]
=
"AUTO"
# flashinfer uses this environment variable for various kernels from MoE to quant kernels
os
.
environ
[
"TRTLLM_ENABLE_PDL"
]
=
"1"
# Set prometheus env vars
if
server_args
.
enable_metrics
:
...
...
python/sglang/srt/server_args.py
View file @
42c87045
...
...
@@ -550,7 +550,6 @@ class ServerArgs:
assert
(
self
.
quantization
==
"modelopt_fp4"
),
"modelopt_fp4 quantization is required for Flashinfer MOE"
os
.
environ
[
"TRTLLM_ENABLE_PDL"
]
=
"1"
assert
self
.
ep_size
in
[
1
,
self
.
tp_size
,
...
...
sgl-kernel/csrc/common_extension.cc
View file @
42c87045
...
...
@@ -90,7 +90,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
def
(
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
"Tensor pos_ids, bool interleave, int cuda_stream, "
"Tensor pos_ids, bool interleave,
bool enable_pdl,
int cuda_stream, "
"Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()"
);
m
.
impl
(
"apply_rope_pos_ids_cos_sin_cache"
,
torch
::
kCUDA
,
&
apply_rope_pos_ids_cos_sin_cache
);
...
...
sgl-kernel/csrc/elementwise/pos_enc.cuh
View file @
42c87045
...
...
@@ -104,6 +104,10 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel
uint32_t
by
=
blockIdx
.
y
;
const
uint32_t
bdy
=
blockDim
.
y
;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm
volatile
(
"griddepcontrol.wait;"
);
#endif
vec_t
<
float
,
vec_size
>
cos
,
sin
;
if
(
bx
*
bdy
+
ty
<
nnz
)
{
const
uint32_t
idx
=
bx
*
bdy
+
ty
;
...
...
@@ -178,6 +182,10 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm
volatile
(
"griddepcontrol.launch_dependents;"
);
#endif
}
template
<
...
...
@@ -220,6 +228,10 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel(
uint32_t
bx
=
blockIdx
.
x
,
tx
=
threadIdx
.
x
,
ty
=
threadIdx
.
y
;
const
uint32_t
bdy
=
blockDim
.
y
;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm
volatile
(
"griddepcontrol.wait;"
);
#endif
vec_t
<
float
,
vec_size
>
cos
,
sin
;
if
(
bx
*
bdy
+
ty
<
nnz
)
{
const
uint32_t
idx
=
bx
*
bdy
+
ty
;
...
...
@@ -296,6 +308,10 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel(
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm
volatile
(
"griddepcontrol.launch_dependents;"
);
#endif
}
#define DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, ...) \
...
...
@@ -340,12 +356,59 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
IdType
*
kv_cache_loc
,
bool
interleave
,
bool
save_kv_cache
,
bool
enable_pdl
,
cudaStream_t
stream
=
nullptr
)
{
int
dev_id
=
0
;
int
num_sms
=
0
;
FLASHINFER_CUDA_CALL
(
cudaGetDevice
(
&
dev_id
));
FLASHINFER_CUDA_CALL
(
cudaDeviceGetAttribute
(
&
num_sms
,
cudaDevAttrMultiProcessorCount
,
dev_id
));
#define LAUNCH_KERNEL_RAW(kernel_name) \
do { \
cudaLaunchConfig_t config = {}; \
config.gridDim = nblks; \
config.blockDim = nthrs; \
config.dynamicSmemBytes = 0; \
config.stream = stream; \
cudaLaunchAttribute attrs[1] = {}; \
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; \
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; \
config.numAttrs = 1; \
config.attrs = attrs; \
\
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( \
&config, \
kernel_name, \
q, \
k, \
v, \
q_rope, \
k_rope, \
k_buffer, \
v_buffer, \
cos_sin_cache, \
pos_ids, \
nnz, \
num_qo_heads, \
num_kv_heads, \
rotary_dim, \
q_stride_n, \
q_stride_h, \
k_stride_n, \
k_stride_h, \
v_stride_n, \
v_stride_h, \
q_rope_stride_n, \
q_rope_stride_h, \
k_rope_stride_n, \
k_rope_stride_h, \
k_buffer_stride_n, \
k_buffer_stride_h, \
v_buffer_stride_n, \
v_buffer_stride_h, \
kv_cache_loc)); \
} while (0)
DISPATCH_SAVE_KV_CACHE
(
save_kv_cache
,
SAVE_KV_CACHE
,
{
DISPATCH_INTERLEAVE
(
interleave
,
INTERLEAVE
,
{
DISPATCH_HEAD_DIM
(
head_dim
,
HEAD_DIM
,
{
...
...
@@ -359,35 +422,7 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
uint32_t
bdy
=
num_threads
/
bdx
;
// how many blocks needed to process all tokens
uint32_t
nblks_x
=
(
nnz
+
bdy
-
1
)
/
bdy
;
void
*
args
[]
=
{
(
void
*
)
&
q
,
(
void
*
)
&
k
,
(
void
*
)
&
v
,
(
void
*
)
&
q_rope
,
(
void
*
)
&
k_rope
,
(
void
*
)
&
k_buffer
,
(
void
*
)
&
v_buffer
,
(
void
*
)
&
cos_sin_cache
,
(
void
*
)
&
pos_ids
,
(
void
*
)
&
nnz
,
(
void
*
)
&
num_qo_heads
,
(
void
*
)
&
num_kv_heads
,
(
void
*
)
&
rotary_dim
,
(
void
*
)
&
q_stride_n
,
(
void
*
)
&
q_stride_h
,
(
void
*
)
&
k_stride_n
,
(
void
*
)
&
k_stride_h
,
(
void
*
)
&
v_stride_n
,
(
void
*
)
&
v_stride_h
,
(
void
*
)
&
q_rope_stride_n
,
(
void
*
)
&
q_rope_stride_h
,
(
void
*
)
&
k_rope_stride_n
,
(
void
*
)
&
k_rope_stride_h
,
(
void
*
)
&
k_buffer_stride_n
,
(
void
*
)
&
k_buffer_stride_h
,
(
void
*
)
&
v_buffer_stride_n
,
(
void
*
)
&
v_buffer_stride_h
,
(
void
*
)
&
kv_cache_loc
};
auto
kernel_0
=
BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel
<
SAVE_KV_CACHE
,
INTERLEAVE
,
...
...
@@ -405,7 +440,7 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
if
((
nnz
+
bdy
-
1
)
/
bdy
>=
num_ctas_0
)
{
dim3
nblks
(
nblks_x
);
dim3
nthrs
(
bdx
,
bdy
);
F
LA
SHINFER_CUDA_CALL
(
cudaLaunchKernel
((
void
*
)
kernel_0
,
nblks
,
nthrs
,
args
,
0
,
stream
)
);
LA
UNCH_KERNEL_RAW
(
kernel_0
);
}
else
{
dim3
nblks
(
nblks_x
,
num_qo_heads
+
num_kv_heads
);
dim3
nthrs
(
bdx
,
bdy
);
...
...
@@ -417,11 +452,12 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
bdx
,
DType
,
IdType
>
;
F
LA
SHINFER_CUDA_CALL
(
cudaLaunchKernel
((
void
*
)
kernel_1
,
nblks
,
nthrs
,
args
,
0
,
stream
)
);
LA
UNCH_KERNEL_RAW
(
kernel_1
);
}
});
});
});
#undef LAUNCH_KERNEL_RAW
return
cudaSuccess
;
}
...
...
sgl-kernel/csrc/elementwise/rope.cu
View file @
42c87045
...
...
@@ -27,6 +27,7 @@ void apply_rope_pos_ids_cos_sin_cache(
at
::
Tensor
cos_sin_cache
,
at
::
Tensor
pos_ids
,
bool
interleave
,
bool
enable_pdl
,
int64_t
cuda_stream
,
const
std
::
optional
<
at
::
Tensor
>&
v
,
const
std
::
optional
<
at
::
Tensor
>&
k_buffer
,
...
...
@@ -124,12 +125,14 @@ void apply_rope_pos_ids_cos_sin_cache(
kv_cache_loc_ptr
,
interleave
,
save_kv_cache
,
enable_pdl
,
stream
);
TORCH_CHECK
(
status
==
cudaSuccess
,
"BatchQKApplyRotaryPosIdsCosSinCacheEnhanced failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
}
else
{
TORCH_CHECK
(
!
enable_pdl
);
cudaError_t
status
=
BatchQKApplyRotaryPosIdsCosSinCache
(
static_cast
<
c_type
*>
(
q
.
data_ptr
()),
static_cast
<
c_type
*>
(
k
.
data_ptr
()),
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
42c87045
...
...
@@ -151,6 +151,7 @@ void apply_rope_pos_ids_cos_sin_cache(
at
::
Tensor
cos_sin_cache
,
at
::
Tensor
pos_ids
,
bool
interleave
,
bool
enable_pdl
,
int64_t
cuda_stream
,
const
std
::
optional
<
at
::
Tensor
>&
v
,
const
std
::
optional
<
at
::
Tensor
>&
k_buffer
,
...
...
sgl-kernel/python/sgl_kernel/elementwise.py
View file @
42c87045
...
...
@@ -271,6 +271,7 @@ def apply_rope_with_cos_sin_cache_inplace(
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
=
True
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
enable_pdl
:
Optional
[
bool
]
=
None
,
)
->
None
:
r
"""
Apply rotary embedding to keys and queries with precomputed cos/sin values.
...
...
@@ -307,6 +308,10 @@ def apply_rope_with_cos_sin_cache_inplace(
if
cos_sin_cache
.
dtype
!=
torch
.
float32
:
raise
ValueError
(
"cos_sin_cache should be float32"
)
if
enable_pdl
is
None
:
# the non-fused branch does not yet support PDL, but after we switch to our impl for that branch it will
enable_pdl
=
is_arch_support_pdl
()
and
(
fused_set_kv_buffer_arg
is
not
None
)
if
(
a
:
=
fused_set_kv_buffer_arg
)
is
not
None
:
assert
a
.
k_scale
is
None
,
"k_scale is not yet supported"
assert
a
.
v_scale
is
None
,
"v_scale is not yet supported"
...
...
@@ -323,6 +328,7 @@ def apply_rope_with_cos_sin_cache_inplace(
cos_sin_cache
,
positions
.
long
(),
(
not
is_neox
),
enable_pdl
,
get_cuda_stream
(),
(
_view_3d
(
fused_set_kv_buffer_arg
.
value
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment