Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
db7343c9
Unverified
Commit
db7343c9
authored
Aug 01, 2025
by
Stefan He
Committed by
GitHub
Aug 01, 2025
Browse files
fix per token cuda kernel hidden dim cannot divide by 16 (#8543)
parent
533cb5b2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
167 additions
and
47 deletions
+167
-47
sgl-kernel/benchmark/bench_per_token_quant_fp8.py
sgl-kernel/benchmark/bench_per_token_quant_fp8.py
+107
-20
sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
+59
-26
sgl-kernel/tests/test_per_token_quant_fp8.py
sgl-kernel/tests/test_per_token_quant_fp8.py
+1
-1
No files found.
sgl-kernel/benchmark/bench_per_token_quant_fp8.py
View file @
db7343c9
...
...
@@ -12,6 +12,39 @@ from sglang.srt.utils import is_hip
_is_hip
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
# Get correct FP8 E4M3 maximum value
if
_is_hip
:
FP8_E4M3_MAX
=
224.0
# ROCM uses 224.0
else
:
# For CUDA, get the actual max value from the type
FP8_E4M3_MAX
=
float
(
torch
.
finfo
(
fp8_type_
).
max
)
def
torch_per_token_quant_fp8
(
input
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Pure PyTorch reference implementation for per-token FP8 quantization."""
device
=
input
.
device
dtype
=
input
.
dtype
# Find max absolute value per token (row) - exactly like CUDA kernel
max_vals
=
torch
.
abs
(
input
).
max
(
dim
=
1
)[
0
]
# [num_tokens]
# Calculate scale per token - exactly like CUDA kernel: scale = max_value / FP8_E4M3_MAX
scales
=
max_vals
/
FP8_E4M3_MAX
# [num_tokens]
# No special zero handling - directly compute 1.0 / scale like CUDA kernel
scale_inv
=
1.0
/
scales
# [num_tokens]
# Quantize: input * scale_inv, then clamp to FP8 range
quantized_float
=
input
*
scale_inv
.
unsqueeze
(
1
)
# Broadcast scale_inv
quantized_float
=
torch
.
clamp
(
quantized_float
,
-
FP8_E4M3_MAX
,
FP8_E4M3_MAX
)
# Convert to FP8 - use more explicit conversion
quantized_fp8
=
quantized_float
.
to
(
fp8_type_
)
return
quantized_fp8
,
scales
def
vllm_per_token_quant_fp8
(
input
:
torch
.
Tensor
,
...
...
@@ -29,53 +62,100 @@ def sglang_per_token_quant_fp8(
return
output
,
scale
def
calculate_diff
(
batch_size
:
int
,
seq_len
:
int
):
"""C
alculate difference between
VLLM and SGLang implementations."""
def
calculate_diff
(
batch_size
:
int
,
seq_len
:
int
,
hidden_dim
:
int
):
"""C
ompare Torch reference,
VLLM
,
and SGLang implementations."""
device
=
torch
.
device
(
"cuda"
)
x
=
torch
.
rand
((
batch_size
,
seq_len
),
dtype
=
torch
.
float16
,
device
=
device
)
x
=
torch
.
rand
(
(
batch_size
*
seq_len
,
hidden_dim
),
dtype
=
torch
.
float16
,
device
=
device
)
# Get all three implementations
torch_out
,
torch_scale
=
torch_per_token_quant_fp8
(
x
)
vllm_out
,
vllm_scale
=
vllm_per_token_quant_fp8
(
x
)
sglang_out
,
sglang_scale
=
sglang_per_token_quant_fp8
(
x
)
scale_diff
=
torch
.
abs
(
vllm_scale
-
sglang_scale
).
mean
().
item
()
output_diff
=
torch
.
abs
(
vllm_out
.
float
()
-
sglang_out
.
float
()).
mean
().
item
()
print
(
f
"
\n
=== Comparison for hidden_dim=
{
hidden_dim
}
==="
)
if
torch
.
allclose
(
vllm_out
.
to
(
torch
.
float32
),
sglang_out
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
)
and
torch
.
allclose
(
vllm_scale
,
sglang_scale
,
rtol
=
1e-3
,
atol
=
1e-5
):
print
(
"✅ All implementations match"
)
else
:
print
(
"❌ Implementations differ"
)
# Compare scales
torch_vllm_scale_diff
=
torch
.
abs
(
torch_scale
-
vllm_scale
).
mean
().
item
()
torch_sglang_scale_diff
=
torch
.
abs
(
torch_scale
-
sglang_scale
).
mean
().
item
()
vllm_sglang_scale_diff
=
torch
.
abs
(
vllm_scale
-
sglang_scale
).
mean
().
item
()
print
(
f
"Scale differences:"
)
print
(
f
" Torch vs VLLM:
{
torch_vllm_scale_diff
:.
8
f
}
"
)
print
(
f
" Torch vs SGLang:
{
torch_sglang_scale_diff
:.
8
f
}
"
)
print
(
f
" VLLM vs SGLang:
{
vllm_sglang_scale_diff
:.
8
f
}
"
)
# Compare outputs
torch_vllm_out_diff
=
torch
.
abs
(
torch_out
.
float
()
-
vllm_out
.
float
()).
mean
().
item
()
torch_sglang_out_diff
=
(
torch
.
abs
(
torch_out
.
float
()
-
sglang_out
.
float
()).
mean
().
item
()
)
vllm_sglang_out_diff
=
(
torch
.
abs
(
vllm_out
.
float
()
-
sglang_out
.
float
()).
mean
().
item
()
)
print
(
f
"Output differences:"
)
print
(
f
" Torch vs VLLM:
{
torch_vllm_out_diff
:.
8
f
}
"
)
print
(
f
" Torch vs SGLang:
{
torch_sglang_out_diff
:.
8
f
}
"
)
print
(
f
" VLLM vs SGLang:
{
vllm_sglang_out_diff
:.
8
f
}
"
)
# Check tolerances
rtol
,
atol
=
1e-3
,
1e-5
torch_vllm_match
=
torch
.
allclose
(
torch_out
.
float
(),
vllm_out
.
float
(),
rtol
=
rtol
,
atol
=
atol
)
and
torch
.
allclose
(
torch_scale
,
vllm_scale
,
rtol
=
rtol
,
atol
=
atol
)
torch_sglang_match
=
torch
.
allclose
(
torch_out
.
float
(),
sglang_out
.
float
(),
rtol
=
rtol
,
atol
=
atol
)
and
torch
.
allclose
(
torch_scale
,
sglang_scale
,
rtol
=
rtol
,
atol
=
atol
)
if
hidden_dim
==
1368
:
rtol
=
1e-2
# we found vllm sglang has diff when hidden dim is not dividable by 16
# and we believe SGLang is closer to Torch implementation
vllm_sglang_match
=
torch
.
allclose
(
vllm_out
.
float
(),
sglang_out
.
float
(),
rtol
=
rtol
,
atol
=
atol
)
and
torch
.
allclose
(
vllm_scale
,
sglang_scale
,
rtol
=
rtol
,
atol
=
atol
)
print
(
f
"Matches (rtol=
{
rtol
}
, atol=
{
atol
}
):"
)
print
(
f
" Torch vs VLLM:
{
'✅'
if
torch_vllm_match
else
'❌'
}
"
)
print
(
f
" Torch vs SGLang:
{
'✅'
if
torch_sglang_match
else
'❌'
}
"
)
print
(
f
" VLLM vs SGLang:
{
'✅'
if
vllm_sglang_match
else
'❌'
}
"
)
batch_size_range
=
[
16
,
32
,
64
,
128
]
seq_len_range
=
[
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
]
hidden_dim_range
=
[
1368
,
2048
,
4096
]
configs
=
list
(
itertools
.
product
(
batch_size_range
,
seq_len_range
))
configs
=
list
(
itertools
.
product
(
batch_size_range
,
seq_len_range
,
hidden_dim_range
))
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
,
"seq_len"
],
x_names
=
[
"batch_size"
,
"seq_len"
,
"hidden_dim"
],
x_vals
=
configs
,
line_arg
=
"provider"
,
line_vals
=
[
"vllm"
,
"sglang"
],
line_names
=
[
"VLLM"
,
"SGL Kernel"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
line_vals
=
[
"torch"
,
"vllm"
,
"sglang"
],
line_names
=
[
"Torch Reference"
,
"VLLM"
,
"SGL Kernel"
],
styles
=
[
(
"red"
,
"-"
),
(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"per-token-dynamic-quant-fp8-performance"
,
args
=
{},
)
)
def
benchmark_quantization
(
batch_size
,
seq_len
,
provider
):
def
benchmark_quantization
(
batch_size
,
seq_len
,
hidden_dim
,
provider
):
dtype
=
torch
.
float16
device
=
torch
.
device
(
"cuda"
)
x
=
torch
.
randn
(
batch_size
*
seq_len
,
4096
,
device
=
device
,
dtype
=
dtype
)
x
=
torch
.
randn
(
batch_size
*
seq_len
,
hidden_dim
,
device
=
device
,
dtype
=
dtype
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"vllm"
:
if
provider
==
"torch"
:
fn
=
lambda
:
torch_per_token_quant_fp8
(
x
.
clone
())
elif
provider
==
"vllm"
:
fn
=
lambda
:
vllm_per_token_quant_fp8
(
x
.
clone
())
elif
provider
==
"sglang"
:
fn
=
lambda
:
sglang_per_token_quant_fp8
(
x
.
clone
())
...
...
@@ -86,5 +166,12 @@ def benchmark_quantization(batch_size, seq_len, provider):
if
__name__
==
"__main__"
:
calculate_diff
(
batch_size
=
4
,
seq_len
=
4096
)
# Test various hidden dimensions for correctness
test_dims
=
[
1368
,
2048
,
4096
]
for
dim
in
test_dims
:
calculate_diff
(
batch_size
=
4
,
seq_len
=
4096
,
hidden_dim
=
dim
)
print
(
"
\n
"
+
"="
*
60
)
print
(
"Starting performance benchmark..."
)
benchmark_quantization
.
run
(
print_data
=
True
)
sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
View file @
db7343c9
...
...
@@ -75,14 +75,21 @@ __global__ void per_token_quant_fp8_kernel(
c10
::
Float8_e4m3fnuz
::
from_bits
());
#endif
}
*
(
uint4
*
)(
token_output
+
i
*
kVecSize
)
=
*
(
uint4
*
)
output_arr
;
if
constexpr
(
kVecSize
==
16
)
{
*
(
uint4
*
)(
token_output
+
i
*
kVecSize
)
=
*
(
uint4
*
)
output_arr
;
}
else
{
// Use element-wise copy for vector size 8 to ensure correctness
for
(
int
k
=
0
;
k
<
kVecSize
;
++
k
)
{
token_output
[
i
*
kVecSize
+
k
]
=
output_arr
[
k
];
}
}
}
}
// ---------------------------------------------------------------------------
// 2. Baseline kernel (1 token / CTA, CUB block reduce)
// ---------------------------------------------------------------------------
template
<
typename
T
,
typename
DST_DTYPE
>
template
<
typename
T
,
typename
DST_DTYPE
,
int
kVecSize
=
16
>
__global__
void
per_token_quant_fp8_small_batch_kernel
(
const
T
*
__restrict__
input
,
DST_DTYPE
*
__restrict__
output_q
,
...
...
@@ -100,19 +107,17 @@ __global__ void per_token_quant_fp8_small_batch_kernel(
float
max_value
=
0.0
f
;
// We want to store 128 bits of data at a time. 16 = 128 / 8 bits
// Load is already vectorized, so 16 elements work for T.
const
uint32_t
VEC_SIZE
=
16
;
using
vec_t
=
flashinfer
::
vec_t
<
T
,
VEC_SIZE
>
;
const
int32_t
num_vec_elems
=
hidden_dim
/
VEC_SIZE
;
// Use template parameter for vector size
using
vec_t
=
flashinfer
::
vec_t
<
T
,
kVecSize
>
;
const
int32_t
num_vec_elems
=
hidden_dim
/
kVecSize
;
// Find max using vectorized loads
for
(
int32_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
block_dim
)
{
vec_t
input_vec
;
input_vec
.
cast_load
(
token_input
+
i
*
VEC_SIZE
);
input_vec
.
cast_load
(
token_input
+
i
*
kVecSize
);
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
VEC_SIZE
;
++
j
)
{
for
(
uint32_t
j
=
0
;
j
<
kVecSize
;
++
j
)
{
float
val
=
static_cast
<
float
>
(
input_vec
[
j
]);
max_value
=
fmaxf
(
max_value
,
fabsf
(
val
));
}
...
...
@@ -132,11 +137,11 @@ __global__ void per_token_quant_fp8_small_batch_kernel(
// Quantize using vectorized loads
for
(
int32_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
block_dim
)
{
vec_t
input_vec
;
input_vec
.
cast_load
(
token_input
+
i
*
VEC_SIZE
);
input_vec
.
cast_load
(
token_input
+
i
*
kVecSize
);
DST_DTYPE
output_arr
[
VEC_SIZE
];
DST_DTYPE
output_arr
[
kVecSize
];
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
VEC_SIZE
;
++
j
)
{
for
(
uint32_t
j
=
0
;
j
<
kVecSize
;
++
j
)
{
float
val
=
fmaxf
(
fminf
(
static_cast
<
float
>
(
input_vec
[
j
])
*
scale_inv
,
FP8_E4M3_MAX
),
-
FP8_E4M3_MAX
);
#ifndef USE_ROCM
output_arr
[
j
]
=
static_cast
<
DST_DTYPE
>
(
val
);
...
...
@@ -147,7 +152,14 @@ __global__ void per_token_quant_fp8_small_batch_kernel(
#endif
}
*
(
uint4
*
)(
token_output
+
i
*
VEC_SIZE
)
=
*
(
uint4
*
)
output_arr
;
if
constexpr
(
kVecSize
==
16
)
{
*
(
uint4
*
)(
token_output
+
i
*
kVecSize
)
=
*
(
uint4
*
)
output_arr
;
}
else
{
// Use element-wise copy for vector size 8 to ensure correctness
for
(
int
k
=
0
;
k
<
kVecSize
;
++
k
)
{
token_output
[
i
*
kVecSize
+
k
]
=
output_arr
[
k
];
}
}
}
}
...
...
@@ -158,13 +170,14 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
const
auto
input_sizes
=
input
.
sizes
();
const
int64_t
num_tokens
=
input_sizes
[
0
];
const
int64_t
hidden_dim
=
input_sizes
[
1
];
TORCH_CHECK
(
hidden_dim
%
16
==
0
,
"Hidden dimension must be divisible by
16
, but got "
,
hidden_dim
);
TORCH_CHECK
(
hidden_dim
%
8
==
0
,
"Hidden dimension must be divisible by
8
, but got "
,
hidden_dim
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
// Hard-code sm_count
int
sm_count
=
132
;
constexpr
int
TOKENS_PER_CTA
=
8
;
const
bool
use_warp_kernel
=
(
num_tokens
>=
sm_count
*
2
*
TOKENS_PER_CTA
);
const
bool
use_vec16
=
(
hidden_dim
%
16
==
0
);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
if
(
use_warp_kernel
)
{
...
...
@@ -172,23 +185,43 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
constexpr
int
THREADS
=
TOKENS_PER_CTA
*
kWarpSize
;
// 256
dim3
grid
((
num_tokens
+
TOKENS_PER_CTA
-
1
)
/
TOKENS_PER_CTA
);
dim3
block
(
THREADS
);
per_token_quant_fp8_kernel
<
scalar_t
,
__nv_fp8_e4m3
,
TOKENS_PER_CTA
,
16
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
const
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
__nv_fp8_e4m3
*>
(
output_q
.
data_ptr
()),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
hidden_dim
,
num_tokens
);
if
(
use_vec16
)
{
per_token_quant_fp8_kernel
<
scalar_t
,
__nv_fp8_e4m3
,
TOKENS_PER_CTA
,
16
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
const
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
__nv_fp8_e4m3
*>
(
output_q
.
data_ptr
()),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
hidden_dim
,
num_tokens
);
}
else
{
per_token_quant_fp8_kernel
<
scalar_t
,
__nv_fp8_e4m3
,
TOKENS_PER_CTA
,
8
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
const
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
__nv_fp8_e4m3
*>
(
output_q
.
data_ptr
()),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
hidden_dim
,
num_tokens
);
}
}
else
{
// -------- baseline -----------------------------------------------------
constexpr
int
THREADS
=
256
;
dim3
grid
(
num_tokens
);
dim3
block
(
THREADS
);
per_token_quant_fp8_small_batch_kernel
<
scalar_t
,
__nv_fp8_e4m3
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
const
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
__nv_fp8_e4m3
*>
(
output_q
.
data_ptr
()),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
hidden_dim
,
num_tokens
);
if
(
use_vec16
)
{
per_token_quant_fp8_small_batch_kernel
<
scalar_t
,
__nv_fp8_e4m3
,
16
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
const
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
__nv_fp8_e4m3
*>
(
output_q
.
data_ptr
()),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
hidden_dim
,
num_tokens
);
}
else
{
per_token_quant_fp8_small_batch_kernel
<
scalar_t
,
__nv_fp8_e4m3
,
8
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
const
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
__nv_fp8_e4m3
*>
(
output_q
.
data_ptr
()),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
hidden_dim
,
num_tokens
);
}
}
return
true
;
});
...
...
sgl-kernel/tests/test_per_token_quant_fp8.py
View file @
db7343c9
...
...
@@ -36,7 +36,7 @@ def sglang_per_token_quant_fp8(
@
pytest
.
mark
.
parametrize
(
"num_tokens,hidden_dim"
,
list
(
itertools
.
product
([
128
,
256
,
512
],
[
512
,
2048
,
4096
])),
list
(
itertools
.
product
([
128
,
256
,
512
],
[
512
,
1368
,
2048
,
4096
])),
)
def
test_per_token_quant_compare_implementations
(
num_tokens
:
int
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment