Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2937387a
"src/graph/vscode:/vscode.git/clone" did not exist on "99bc3ab83b3e791e49e33aa2be62438d16ccb1e6"
Unverified
Commit
2937387a
authored
Mar 13, 2025
by
Yineng Zhang
Committed by
GitHub
Mar 13, 2025
Browse files
fix accuracy issue (#4376)
parent
cf721fde
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
17 additions
and
5 deletions
+17
-5
sgl-kernel/benchmark/bench_per_token_quant_fp8.py
sgl-kernel/benchmark/bench_per_token_quant_fp8.py
+5
-1
sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
+3
-1
sgl-kernel/setup.py
sgl-kernel/setup.py
+4
-0
sgl-kernel/tests/test_per_token_quant_fp8.py
sgl-kernel/tests/test_per_token_quant_fp8.py
+5
-3
No files found.
sgl-kernel/benchmark/bench_per_token_quant_fp8.py
View file @
2937387a
...
...
@@ -22,9 +22,10 @@ def vllm_per_token_quant_fp8(
def
sglang_per_token_quant_fp8
(
input
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
scale
=
torch
.
zeros
(
(
input
.
size
(
0
),
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
scale
=
torch
.
zeros
(
input
.
size
(
0
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
output
=
torch
.
empty_like
(
input
,
device
=
input
.
device
,
dtype
=
fp8_type_
)
sgl_per_token_quant_fp8
(
input
,
output
,
scale
)
return
output
,
scale
...
...
@@ -36,6 +37,9 @@ def calculate_diff(batch_size: int, seq_len: int):
vllm_out
,
vllm_scale
=
vllm_per_token_quant_fp8
(
x
)
sglang_out
,
sglang_scale
=
sglang_per_token_quant_fp8
(
x
)
scale_diff
=
torch
.
abs
(
vllm_scale
-
sglang_scale
).
mean
().
item
()
output_diff
=
torch
.
abs
(
vllm_out
.
float
()
-
sglang_out
.
float
()).
mean
().
item
()
if
torch
.
allclose
(
vllm_out
.
to
(
torch
.
float32
),
sglang_out
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
)
and
torch
.
allclose
(
vllm_scale
,
sglang_scale
,
rtol
=
1e-3
,
atol
=
1e-5
):
...
...
sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
View file @
2937387a
...
...
@@ -49,6 +49,8 @@ __global__ void per_token_quant_fp8_kernel(
}
__syncthreads
();
const
float
scale_val
=
1.0
f
/
block_max
;
// Quantize using vectorized loads
for
(
int32_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
block_dim
)
{
vec_t
input_vec
;
...
...
@@ -57,7 +59,7 @@ __global__ void per_token_quant_fp8_kernel(
FP8_TYPE
output_arr
[
vec_size
];
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
float
val
=
fmaxf
(
fminf
(
static_cast
<
float
>
(
input_vec
[
j
])
/
block_max
,
FP8_E4M3_MAX
),
-
FP8_E4M3_MAX
);
float
val
=
fmaxf
(
fminf
(
static_cast
<
float
>
(
input_vec
[
j
])
*
scale_val
,
FP8_E4M3_MAX
),
-
FP8_E4M3_MAX
);
#ifndef USE_ROCM
output_arr
[
j
]
=
static_cast
<
FP8_TYPE
>
(
val
);
#else
...
...
sgl-kernel/setup.py
View file @
2937387a
...
...
@@ -178,6 +178,8 @@ if torch.cuda.is_available():
if
cuda_version
>=
(
12
,
8
)
and
sm_version
>=
100
:
nvcc_flags
.
append
(
"-gencode=arch=compute_100,code=sm_100"
)
nvcc_flags
.
append
(
"-gencode=arch=compute_100a,code=sm_100a"
)
else
:
nvcc_flags
.
append
(
"-use_fast_math"
)
if
sm_version
>=
90
:
nvcc_flags
.
extend
(
nvcc_flags_fp8
)
if
sm_version
>=
80
:
...
...
@@ -188,6 +190,8 @@ else:
nvcc_flags
.
append
(
"-gencode=arch=compute_90a,code=sm_90a"
)
if
enable_sm100a
:
nvcc_flags
.
append
(
"-gencode=arch=compute_100a,code=sm_100a"
)
else
:
nvcc_flags
.
append
(
"-use_fast_math"
)
if
enable_fp8
:
nvcc_flags
.
extend
(
nvcc_flags_fp8
)
if
enable_bf16
:
...
...
sgl-kernel/tests/test_per_token_quant_fp8.py
View file @
2937387a
...
...
@@ -21,16 +21,18 @@ def vllm_per_token_quant_fp8(
def
sglang_per_token_quant_fp8
(
input
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
scale
=
torch
.
zeros
(
(
input
.
size
(
0
),
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
scale
=
torch
.
zeros
(
input
.
size
(
0
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
output
=
torch
.
empty_like
(
input
,
device
=
input
.
device
,
dtype
=
fp8_type_
)
sgl_per_token_quant_fp8
(
input
,
output
,
scale
)
scale
=
scale
.
reshape
(
-
1
,
1
)
return
output
,
scale
@
pytest
.
mark
.
parametrize
(
"num_tokens,hidden_dim"
,
list
(
itertools
.
product
([
32
,
64
,
128
,
256
,
512
],
[
128
,
256
,
512
,
2048
,
4096
])),
list
(
itertools
.
product
([
128
,
256
,
512
],
[
512
,
2048
,
4096
])),
)
def
test_per_token_quant_compare_implementations
(
num_tokens
:
int
,
...
...
@@ -42,7 +44,7 @@ def test_per_token_quant_compare_implementations(
vllm_out
,
vllm_scale
=
vllm_per_token_quant_fp8
(
x
)
sglang_out
,
sglang_scale
=
sglang_per_token_quant_fp8
(
x
)
torch
.
testing
.
assert_close
(
vllm_scale
,
sglang_scale
,
rtol
=
1e-3
,
atol
=
1e-
5
)
torch
.
testing
.
assert_close
(
vllm_scale
,
sglang_scale
,
rtol
=
1e-3
,
atol
=
1e-
3
)
torch
.
testing
.
assert_close
(
vllm_out
.
float
(),
sglang_out
.
float
(),
rtol
=
1e-3
,
atol
=
1e-3
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment