Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
15fabcc0
Unverified
Commit
15fabcc0
authored
Apr 23, 2025
by
Yineng Zhang
Committed by
GitHub
Apr 23, 2025
Browse files
fix sgl-kernel unit tests (#5666)
parent
e62c4955
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
313 additions
and
0 deletions
+313
-0
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+1
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+6
-0
sgl-kernel/csrc/grammar/apply_token_bitmask_inplace_cuda.cu
sgl-kernel/csrc/grammar/apply_token_bitmask_inplace_cuda.cu
+251
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+5
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-0
sgl-kernel/python/sgl_kernel/grammar.py
sgl-kernel/python/sgl_kernel/grammar.py
+15
-0
sgl-kernel/tests/test_apply_token_bitmask_inplace.py
sgl-kernel/tests/test_apply_token_bitmask_inplace.py
+23
-0
sgl-kernel/tests/test_fp8_blockwise_moe.py
sgl-kernel/tests/test_fp8_blockwise_moe.py
+10
-0
sgl-kernel/tests/test_moe_fused_gate.py
sgl-kernel/tests/test_moe_fused_gate.py
+1
-0
No files found.
sgl-kernel/CMakeLists.txt
View file @
15fabcc0
...
...
@@ -199,6 +199,7 @@ set(SOURCES
"csrc/speculative/eagle_utils.cu"
"csrc/speculative/speculative_sampling.cu"
"csrc/speculative/packbit.cu"
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
"csrc/common_extension.cc"
"
${
repo-flashinfer_SOURCE_DIR
}
/csrc/norm.cu"
"
${
repo-flashinfer_SOURCE_DIR
}
/csrc/renorm.cu"
...
...
sgl-kernel/csrc/common_extension.cc
100755 → 100644
View file @
15fabcc0
...
...
@@ -233,6 +233,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"bool is_causal, float softcap, bool return_softmax, "
"Generator? gen) -> Tensor[]"
);
m
.
impl
(
"varlen_fwd_sparse"
,
torch
::
kCUDA
,
&
flash
::
mha_varlen_fwd_sparse
);
/*
* From XGrammar
*/
m
.
def
(
"apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()"
);
m
.
impl
(
"apply_token_bitmask_inplace_cuda"
,
&
ApplyTokenBitmaskInplace
);
}
REGISTER_EXTENSION
(
common_ops
)
sgl-kernel/csrc/grammar/apply_token_bitmask_inplace_cuda.cu
0 → 100644
View file @
15fabcc0
// Adapted from
// https://github.com/mlc-ai/xgrammar/blob/v0.1.18/python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.cu
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// clang-format off
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
// clang-format on
#ifndef CUDART_INF_FP16
#define CUDART_INF_FP16 __ushort_as_half((unsigned short)0x7C00U)
#endif
#ifndef CUDART_INF_BF16
#define CUDART_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U)
#endif
constexpr
int32_t
BITS_PER_BLOCK
=
32
;
constexpr
int32_t
THREADS_PER_THREAD_BLOCK
=
256
;
template
<
typename
T
>
__device__
T
NegativeInfinity
()
{
return
-
INFINITY
;
}
template
<
>
__device__
__half
NegativeInfinity
<
__half
>
()
{
return
-
CUDART_INF_FP16
;
}
template
<
>
__device__
__nv_bfloat16
NegativeInfinity
<
__nv_bfloat16
>
()
{
return
-
CUDART_INF_BF16
;
}
template
<
typename
T
,
typename
PackedT
>
__device__
PackedT
PackedNegativeInfinity
()
{
constexpr
int
kAlignment
=
sizeof
(
PackedT
)
/
sizeof
(
T
);
T
packed
[
kAlignment
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kAlignment
;
i
++
)
{
packed
[
i
]
=
NegativeInfinity
<
T
>
();
}
return
*
reinterpret_cast
<
PackedT
*>
(
packed
);
}
template
<
typename
T
,
typename
PackedT
,
int32_t
kBitsPerThread
>
__global__
void
__launch_bounds__
(
THREADS_PER_THREAD_BLOCK
)
LogitsBitmaskKernel
(
T
*
__restrict__
logits
,
const
int32_t
*
__restrict__
bitmask
,
const
int32_t
*
__restrict__
indices
,
int32_t
vocab_size
,
int32_t
logits_stride
,
int32_t
bitmask_stride
)
{
constexpr
int
kAlignment
=
sizeof
(
PackedT
)
/
sizeof
(
T
);
constexpr
uint32_t
kPackedMask
=
(
1
<<
kAlignment
)
-
1
;
const
int
batch_idx
=
(
indices
==
nullptr
)
?
blockIdx
.
y
:
indices
[
blockIdx
.
y
];
const
int
block_offset
=
blockIdx
.
x
*
THREADS_PER_THREAD_BLOCK
*
kBitsPerThread
;
T
*
logits_gmem_ptr
=
logits
+
batch_idx
*
logits_stride
+
block_offset
;
const
int32_t
*
bitmask_gmem_ptr
=
bitmask
+
batch_idx
*
bitmask_stride
+
block_offset
/
BITS_PER_BLOCK
;
const
int
bitmask_inner_idx
=
threadIdx
.
x
%
(
BITS_PER_BLOCK
/
kAlignment
);
T
logits_reg
[
kAlignment
];
#pragma unroll
for
(
int
offset
=
threadIdx
.
x
*
kAlignment
;
offset
<
THREADS_PER_THREAD_BLOCK
*
kBitsPerThread
;
offset
+=
THREADS_PER_THREAD_BLOCK
*
kAlignment
)
{
if
(
block_offset
+
offset
>=
vocab_size
)
{
break
;
}
const
uint32_t
bitmask_val
=
(
~
bitmask_gmem_ptr
[
offset
/
BITS_PER_BLOCK
]
>>
(
bitmask_inner_idx
*
kAlignment
))
&
kPackedMask
;
if
(
bitmask_val
==
0
)
{
continue
;
}
if
(
bitmask_val
==
kPackedMask
)
{
*
reinterpret_cast
<
PackedT
*>
(
logits_gmem_ptr
+
offset
)
=
PackedNegativeInfinity
<
T
,
PackedT
>
();
continue
;
}
*
reinterpret_cast
<
PackedT
*>
(
logits_reg
)
=
*
reinterpret_cast
<
PackedT
*>
(
logits_gmem_ptr
+
offset
);
#pragma unroll
for
(
int
i
=
0
;
i
<
kAlignment
;
i
++
)
{
if
(((
bitmask_val
>>
i
)
&
1
))
{
logits_reg
[
i
]
=
NegativeInfinity
<
T
>
();
}
}
*
reinterpret_cast
<
PackedT
*>
(
logits_gmem_ptr
+
offset
)
=
*
reinterpret_cast
<
PackedT
*>
(
logits_reg
);
}
}
template
<
typename
T
,
typename
=
std
::
enable_if_t
<
std
::
is_integral
<
T
>
::
value
>>
constexpr
auto
CeilDiv
(
T
numerator
,
T
denominator
)
{
return
(
numerator
+
denominator
-
1
)
/
denominator
;
}
template
<
typename
T
,
typename
PackedT
>
void
ApplyTokenBitmaskInplaceDispatchToBitsPerThread
(
T
*
__restrict__
logits
,
const
int32_t
*
__restrict__
bitmask
,
const
int32_t
*
__restrict__
indices
,
int32_t
vocab_size
,
int32_t
logits_stride
,
int32_t
bitmask_stride
,
int32_t
num_rows
)
{
constexpr
int
kAlignment
=
sizeof
(
PackedT
)
/
sizeof
(
T
);
const
int32_t
num_blocks_per_row
=
CeilDiv
(
2048
/
THREADS_PER_THREAD_BLOCK
*
128
,
num_rows
);
const
int32_t
num_bits_per_thread
=
CeilDiv
(
vocab_size
,
THREADS_PER_THREAD_BLOCK
*
num_blocks_per_row
);
const
dim3
block
(
THREADS_PER_THREAD_BLOCK
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
if
(
num_bits_per_thread
<=
4
&&
kAlignment
<=
4
)
{
const
dim3
grid
(
CeilDiv
(
vocab_size
,
THREADS_PER_THREAD_BLOCK
*
4
),
num_rows
);
LogitsBitmaskKernel
<
T
,
PackedT
,
4
>
<<<
grid
,
block
,
0
,
stream
>>>
(
logits
,
bitmask
,
indices
,
vocab_size
,
logits_stride
,
bitmask_stride
);
}
else
if
(
num_bits_per_thread
<=
8
&&
kAlignment
<=
8
)
{
const
dim3
grid
(
CeilDiv
(
vocab_size
,
THREADS_PER_THREAD_BLOCK
*
8
),
num_rows
);
LogitsBitmaskKernel
<
T
,
PackedT
,
8
>
<<<
grid
,
block
,
0
,
stream
>>>
(
logits
,
bitmask
,
indices
,
vocab_size
,
logits_stride
,
bitmask_stride
);
}
else
if
(
num_bits_per_thread
<=
16
&&
kAlignment
<=
16
)
{
const
dim3
grid
(
CeilDiv
(
vocab_size
,
THREADS_PER_THREAD_BLOCK
*
16
),
num_rows
);
LogitsBitmaskKernel
<
T
,
PackedT
,
16
>
<<<
grid
,
block
,
0
,
stream
>>>
(
logits
,
bitmask
,
indices
,
vocab_size
,
logits_stride
,
bitmask_stride
);
}
else
{
const
dim3
grid
(
CeilDiv
(
vocab_size
,
THREADS_PER_THREAD_BLOCK
*
32
),
num_rows
);
LogitsBitmaskKernel
<
T
,
PackedT
,
32
>
<<<
grid
,
block
,
0
,
stream
>>>
(
logits
,
bitmask
,
indices
,
vocab_size
,
logits_stride
,
bitmask_stride
);
}
}
template
<
typename
T
>
void
ApplyTokenBitmaskInplaceDispatchToPackedT
(
T
*
__restrict__
logits
,
const
int32_t
*
__restrict__
bitmask
,
const
int32_t
*
__restrict__
indices
,
int32_t
vocab_size
,
int32_t
logits_stride
,
int32_t
bitmask_stride
,
int32_t
num_rows
)
{
if
(
logits_stride
%
(
sizeof
(
float4
)
/
sizeof
(
T
))
==
0
)
{
ApplyTokenBitmaskInplaceDispatchToBitsPerThread
<
T
,
float4
>
(
logits
,
bitmask
,
indices
,
vocab_size
,
logits_stride
,
bitmask_stride
,
num_rows
);
}
else
{
ApplyTokenBitmaskInplaceDispatchToBitsPerThread
<
T
,
T
>
(
logits
,
bitmask
,
indices
,
vocab_size
,
logits_stride
,
bitmask_stride
,
num_rows
);
}
}
void
ApplyTokenBitmaskInplace
(
at
::
Tensor
logits
,
at
::
Tensor
bitmask
,
at
::
optional
<
at
::
Tensor
>
indices
=
at
::
nullopt
)
{
TORCH_CHECK
(
logits
.
is_cuda
(),
"logits must be a CUDA tensor."
);
TORCH_CHECK
(
logits
.
is_contiguous
(),
"logits must be contiguous."
);
TORCH_CHECK
(
logits
.
dim
()
==
1
||
logits
.
dim
()
==
2
,
"logits must be a 1D or 2D tensor."
);
std
::
pair
<
int32_t
,
int32_t
>
logits_shape
=
logits
.
dim
()
==
2
?
std
::
make_pair
(
static_cast
<
int32_t
>
(
logits
.
size
(
0
)),
static_cast
<
int32_t
>
(
logits
.
size
(
1
)))
:
std
::
make_pair
(
1
,
static_cast
<
int32_t
>
(
logits
.
size
(
0
)));
TORCH_CHECK
(
bitmask
.
is_cuda
(),
"bitmask must be a CUDA tensor."
);
TORCH_CHECK
(
bitmask
.
is_contiguous
(),
"bitmask must be contiguous."
);
TORCH_CHECK
(
bitmask
.
dim
()
==
1
||
bitmask
.
dim
()
==
2
,
"bitmask must be a 1D or 2D tensor."
);
std
::
pair
<
int32_t
,
int32_t
>
bitmask_shape
=
bitmask
.
dim
()
==
2
?
std
::
make_pair
(
static_cast
<
int32_t
>
(
bitmask
.
size
(
0
)),
static_cast
<
int32_t
>
(
bitmask
.
size
(
1
)))
:
std
::
make_pair
(
1
,
static_cast
<
int32_t
>
(
bitmask
.
size
(
0
)));
TORCH_CHECK
(
bitmask
.
dtype
()
==
torch
::
kInt32
,
"bitmask must be of type int32."
);
TORCH_CHECK
(
(
logits_shape
.
second
+
BITS_PER_BLOCK
-
1
)
/
BITS_PER_BLOCK
>=
bitmask_shape
.
second
,
"The provided logits's vocab size should be no less than the bitmask's vocab size "
"(converted from bitmask size). But got vocab size "
,
logits_shape
.
second
,
" vs bitmask size "
,
bitmask_shape
.
second
);
int
vocab_size
=
std
::
min
(
logits_shape
.
second
,
bitmask_shape
.
second
*
BITS_PER_BLOCK
);
int32_t
num_rows
=
logits_shape
.
first
;
int32_t
*
indices_ptr
=
nullptr
;
if
(
indices
)
{
TORCH_CHECK
(
indices
->
is_cuda
(),
"indices must be a CUDA tensor."
);
TORCH_CHECK
(
indices
->
is_contiguous
(),
"indices must be contiguous."
);
TORCH_CHECK
(
indices
->
dim
()
==
1
,
"indices must be a 1D tensor."
);
TORCH_CHECK
(
indices
->
dtype
()
==
torch
::
kInt32
,
"indices must be of type int32."
);
num_rows
=
indices
->
size
(
0
);
indices_ptr
=
indices
->
data_ptr
<
int32_t
>
();
}
else
{
TORCH_CHECK
(
logits_shape
.
first
==
bitmask_shape
.
first
,
"logits and bitmask must have the same batch size."
);
}
switch
(
logits
.
scalar_type
())
{
case
torch
::
kFloat32
:
{
ApplyTokenBitmaskInplaceDispatchToPackedT
(
logits
.
data_ptr
<
float
>
(),
bitmask
.
data_ptr
<
int32_t
>
(),
indices_ptr
,
vocab_size
,
logits_shape
.
second
,
bitmask_shape
.
second
,
num_rows
);
break
;
}
case
torch
::
kFloat16
:
{
ApplyTokenBitmaskInplaceDispatchToPackedT
(
reinterpret_cast
<
__half
*>
(
logits
.
data_ptr
<
torch
::
Half
>
()),
bitmask
.
data_ptr
<
int32_t
>
(),
indices_ptr
,
vocab_size
,
logits_shape
.
second
,
bitmask_shape
.
second
,
num_rows
);
break
;
}
case
torch
::
kBFloat16
:
{
ApplyTokenBitmaskInplaceDispatchToPackedT
(
reinterpret_cast
<
__nv_bfloat16
*>
(
logits
.
data_ptr
<
torch
::
BFloat16
>
()),
bitmask
.
data_ptr
<
int32_t
>
(),
indices_ptr
,
vocab_size
,
logits_shape
.
second
,
bitmask_shape
.
second
,
num_rows
);
break
;
}
default:
TORCH_CHECK
(
false
,
"logits dtype must be float, half or bfloat16."
);
break
;
}
}
sgl-kernel/include/sgl_kernel_ops.h
100755 → 100644
View file @
15fabcc0
...
...
@@ -352,3 +352,8 @@ std::vector<at::Tensor> mha_varlen_fwd_sparse(
const
bool
return_softmax
,
c10
::
optional
<
at
::
Generator
>
gen_
);
}
// namespace flash
/*
* From XGrammar
*/
void
ApplyTokenBitmaskInplace
(
at
::
Tensor
logits
,
at
::
Tensor
bitmask
,
at
::
optional
<
at
::
Tensor
>
indices
=
at
::
nullopt
);
sgl-kernel/python/sgl_kernel/__init__.py
View file @
15fabcc0
...
...
@@ -41,6 +41,7 @@ from sgl_kernel.gemm import (
sgl_per_token_group_quant_int8
,
sgl_per_token_quant_fp8
,
)
from
sgl_kernel.grammar
import
apply_token_bitmask_inplace_cuda
from
sgl_kernel.moe
import
(
fp8_blockwise_scaled_grouped_mm
,
moe_align_block_size
,
...
...
sgl-kernel/python/sgl_kernel/grammar.py
0 → 100644
View file @
15fabcc0
from
typing
import
List
,
Optional
,
Union
import
torch
def
apply_token_bitmask_inplace_cuda
(
logits
:
torch
.
Tensor
,
bitmask
:
torch
.
Tensor
,
indices
:
Optional
[
Union
[
List
[
int
],
torch
.
Tensor
]]
=
None
,
)
->
None
:
if
isinstance
(
indices
,
list
):
indices
=
torch
.
tensor
(
indices
,
dtype
=
torch
.
int32
,
device
=
logits
.
device
)
if
indices
is
not
None
:
indices
=
indices
.
to
(
logits
.
device
)
torch
.
ops
.
sgl_kernel
.
apply_token_bitmask_inplace_cuda
(
logits
,
bitmask
,
indices
)
sgl-kernel/tests/test_apply_token_bitmask_inplace.py
0 → 100644
View file @
15fabcc0
import
pytest
import
torch
from
sgl_kernel
import
apply_token_bitmask_inplace_cuda
def
test_apply_token_bitmask_inplace_kernel
():
neginf
=
float
(
"-inf"
)
bool_mask
=
torch
.
tensor
([
0
,
1
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
1
],
dtype
=
torch
.
bool
)
logits
=
torch
.
tensor
(
[
1.0
,
2.0
,
3.0
,
4.0
,
5.0
,
6.0
,
7.0
,
8.0
,
9.0
,
10.0
],
dtype
=
torch
.
float32
)
expected
=
torch
.
where
(
bool_mask
,
logits
,
neginf
)
logits_gpu
=
logits
.
to
(
"cuda"
)
bitmask
=
torch
.
tensor
([
0b1010101010
],
dtype
=
torch
.
int32
).
to
(
"cuda"
)
apply_token_bitmask_inplace_cuda
(
logits_gpu
,
bitmask
)
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
logits_gpu
,
expected
.
to
(
"cuda"
))
if
__name__
==
"__main__"
:
test_apply_token_bitmask_inplace_kernel
()
pytest
.
main
([
__file__
])
sgl-kernel/tests/test_fp8_blockwise_moe.py
View file @
15fabcc0
...
...
@@ -47,6 +47,16 @@ def baseline_scaled_mm(
).
to
(
out_dtype
)
def
is_sm100_supported
(
device
=
None
)
->
bool
:
return
(
torch
.
cuda
.
get_device_capability
(
device
)[
0
]
==
10
)
and
(
torch
.
version
.
cuda
>=
"12.8"
)
@
pytest
.
mark
.
skipif
(
not
is_sm100_supported
(),
reason
=
"fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100"
,
)
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
8
,
16
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
half
,
torch
.
bfloat16
])
def
test_fp8_blockwise_scaled_grouped_mm
(
num_experts
,
out_dtype
):
...
...
sgl-kernel/tests/test_moe_fused_gate.py
View file @
15fabcc0
...
...
@@ -48,6 +48,7 @@ def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusi
topk_group
=
topk_group
,
compiled
=
False
,
n_share_experts_fusion
=
n_share_experts_fusion
,
routed_scaling_factor
=
2.5
,
)
# When n_share_experts_fusion > 0, ignore the comparison of the last topk dimension
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment