Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d2e507df
Unverified
Commit
d2e507df
authored
Apr 09, 2025
by
yinfan98
Committed by
GitHub
Apr 09, 2025
Browse files
[Misc] clean up vllm in sgl-kernel test (#5189)
parent
61970b08
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
25 additions
and
40 deletions
+25
-40
sgl-kernel/tests/test_awq_dequant.py
sgl-kernel/tests/test_awq_dequant.py
+0
-15
sgl-kernel/tests/test_int8_gemm.py
sgl-kernel/tests/test_int8_gemm.py
+0
-3
sgl-kernel/tests/test_per_tensor_quant_fp8.py
sgl-kernel/tests/test_per_tensor_quant_fp8.py
+14
-14
sgl-kernel/tests/test_per_token_quant_fp8.py
sgl-kernel/tests/test_per_token_quant_fp8.py
+11
-8
No files found.
sgl-kernel/tests/test_awq_dequant.py
View file @
d2e507df
...
@@ -4,7 +4,6 @@ from typing import Optional, Tuple
...
@@ -4,7 +4,6 @@ from typing import Optional, Tuple
import
pytest
import
pytest
import
torch
import
torch
from
sgl_kernel
import
awq_dequantize
from
sgl_kernel
import
awq_dequantize
from
vllm
import
_custom_ops
as
ops
def
reverse_awq_order
(
t
:
torch
.
Tensor
):
def
reverse_awq_order
(
t
:
torch
.
Tensor
):
...
@@ -58,12 +57,6 @@ def awq_dequantize_torch(
...
@@ -58,12 +57,6 @@ def awq_dequantize_torch(
return
(
iweights
-
zeros
)
*
scales
return
(
iweights
-
zeros
)
*
scales
def
vllm_awq_dequantize
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
ops
.
awq_dequantize
(
qweight
,
scales
,
qzeros
,
0
,
0
,
0
)
def
sglang_awq_dequantize
(
def
sglang_awq_dequantize
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -110,7 +103,6 @@ def test_awq_dequant_compare_implementations(
...
@@ -110,7 +103,6 @@ def test_awq_dequant_compare_implementations(
)
)
# Run both implementations
# Run both implementations
vllm_out
=
vllm_awq_dequantize
(
qweight
,
scales
.
to
(
torch
.
float16
),
qzeros
)
torch_out
=
awq_dequantize_torch
(
qweight
,
scales
,
qzeros
,
group_size
)
torch_out
=
awq_dequantize_torch
(
qweight
,
scales
,
qzeros
,
group_size
)
sglang_out
=
sglang_awq_dequantize
(
qweight
,
scales
,
qzeros
)
sglang_out
=
sglang_awq_dequantize
(
qweight
,
scales
,
qzeros
)
...
@@ -118,13 +110,6 @@ def test_awq_dequant_compare_implementations(
...
@@ -118,13 +110,6 @@ def test_awq_dequant_compare_implementations(
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
torch_out
.
to
(
torch
.
float32
),
sglang_out
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
torch_out
.
to
(
torch
.
float32
),
sglang_out
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
)
)
if
not
is_bf16_act
:
torch
.
testing
.
assert_close
(
vllm_out
.
to
(
torch
.
float32
),
sglang_out
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
,
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
sgl-kernel/tests/test_int8_gemm.py
View file @
d2e507df
import
pytest
import
pytest
import
torch
import
torch
from
sgl_kernel
import
int8_scaled_mm
from
sgl_kernel
import
int8_scaled_mm
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
def
to_int8
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
to_int8
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -28,9 +27,7 @@ def _test_accuracy_once(M, N, K, with_bias, out_dtype, device):
...
@@ -28,9 +27,7 @@ def _test_accuracy_once(M, N, K, with_bias, out_dtype, device):
bias
=
None
bias
=
None
o
=
int8_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
o
=
int8_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
o1
=
torch_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
o1
=
torch_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
o2
=
vllm_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
torch
.
testing
.
assert_close
(
o
,
o1
)
torch
.
testing
.
assert_close
(
o
,
o1
)
torch
.
testing
.
assert_close
(
o
,
o2
)
print
(
f
"M=
{
M
}
, N=
{
N
}
, K=
{
K
}
, with_bias=
{
with_bias
}
, out_dtype=
{
out_dtype
}
: OK"
)
print
(
f
"M=
{
M
}
, N=
{
N
}
, K=
{
K
}
, with_bias=
{
with_bias
}
, out_dtype=
{
out_dtype
}
: OK"
)
...
...
sgl-kernel/tests/test_per_tensor_quant_fp8.py
View file @
d2e507df
...
@@ -4,7 +4,6 @@ from typing import Optional, Tuple
...
@@ -4,7 +4,6 @@ from typing import Optional, Tuple
import
pytest
import
pytest
import
torch
import
torch
from
sgl_kernel
import
sgl_per_tensor_quant_fp8
from
sgl_kernel
import
sgl_per_tensor_quant_fp8
from
vllm
import
_custom_ops
as
ops
from
sglang.srt.utils
import
is_hip
from
sglang.srt.utils
import
is_hip
...
@@ -12,13 +11,6 @@ is_hip_ = is_hip()
...
@@ -12,13 +11,6 @@ is_hip_ = is_hip()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
is_hip_
else
torch
.
float8_e4m3fn
fp8_type_
=
torch
.
float8_e4m3fnuz
if
is_hip_
else
torch
.
float8_e4m3fn
def
vllm_scaled_fp8_quant
(
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
ops
.
scaled_fp8_quant
(
input
,
scale
)
def
sglang_scaled_fp8_quant
(
def
sglang_scaled_fp8_quant
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -34,6 +26,16 @@ def sglang_scaled_fp8_quant(
...
@@ -34,6 +26,16 @@ def sglang_scaled_fp8_quant(
return
output
,
scale
return
output
,
scale
def
torch_scaled_fp8_quant
(
tensor
,
inv_scale
):
# The reference implementation that fully aligns to
# the kernel being tested.
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
scale
=
inv_scale
.
reciprocal
()
qweight
=
(
tensor
.
to
(
torch
.
float32
)
*
scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
qweight
=
qweight
.
to
(
torch
.
float8_e4m3fn
)
return
qweight
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"num_tokens,hidden_dim"
,
"num_tokens,hidden_dim"
,
list
(
itertools
.
product
([
128
,
256
,
512
],
[
512
,
2048
,
4096
])),
list
(
itertools
.
product
([
128
,
256
,
512
],
[
512
,
2048
,
4096
])),
...
@@ -45,21 +47,19 @@ def test_per_tensor_quant_compare_implementations(
...
@@ -45,21 +47,19 @@ def test_per_tensor_quant_compare_implementations(
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
x
=
torch
.
rand
((
num_tokens
,
hidden_dim
),
dtype
=
torch
.
float16
,
device
=
device
)
x
=
torch
.
rand
((
num_tokens
,
hidden_dim
),
dtype
=
torch
.
float16
,
device
=
device
)
vllm_out
,
vllm_scale
=
vllm_scaled_fp8_quant
(
x
)
sglang_out
,
sglang_scale
=
sglang_scaled_fp8_quant
(
x
)
sglang_out
,
sglang_scale
=
sglang_scaled_fp8_quant
(
x
)
torch_out
=
torch_scaled_fp8_quant
(
x
,
sglang_scale
)
torch
.
testing
.
assert_close
(
vllm_scale
,
sglang_scale
,
rtol
=
1e-3
,
atol
=
1e-3
)
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
vllm
_out
.
float
(),
sglang
_out
.
float
(),
rtol
=
1e-3
,
atol
=
1e-3
sglang
_out
.
float
(),
torch
_out
.
float
(),
rtol
=
1e-3
,
atol
=
1e-3
)
)
scale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
scale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
vllm_out
,
vllm_scale
=
vllm_scaled_fp8_quant
(
x
,
scale
)
sglang_out
,
sglang_scale
=
sglang_scaled_fp8_quant
(
x
,
scale
)
sglang_out
,
sglang_scale
=
sglang_scaled_fp8_quant
(
x
,
scale
)
torch_out
=
torch_scaled_fp8_quant
(
x
,
scale
)
torch
.
testing
.
assert_close
(
vllm_scale
,
sglang_scale
,
rtol
=
1e-3
,
atol
=
1e-3
)
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
vllm
_out
.
float
(),
sglang
_out
.
float
(),
rtol
=
1e-3
,
atol
=
1e-3
sglang
_out
.
float
(),
torch
_out
.
float
(),
rtol
=
1e-3
,
atol
=
1e-3
)
)
...
...
sgl-kernel/tests/test_per_token_quant_fp8.py
View file @
d2e507df
...
@@ -4,7 +4,6 @@ from typing import Optional, Tuple
...
@@ -4,7 +4,6 @@ from typing import Optional, Tuple
import
pytest
import
pytest
import
torch
import
torch
from
sgl_kernel
import
sgl_per_token_quant_fp8
from
sgl_kernel
import
sgl_per_token_quant_fp8
from
vllm
import
_custom_ops
as
ops
from
sglang.srt.utils
import
is_hip
from
sglang.srt.utils
import
is_hip
...
@@ -12,10 +11,15 @@ is_hip_ = is_hip()
...
@@ -12,10 +11,15 @@ is_hip_ = is_hip()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
is_hip_
else
torch
.
float8_e4m3fn
fp8_type_
=
torch
.
float8_e4m3fnuz
if
is_hip_
else
torch
.
float8_e4m3fn
def
vllm_per_token_quant_fp8
(
def
torch_per_token_quant_fp8
(
tensor
,
inv_scale
):
input
:
torch
.
Tensor
,
# The reference implementation that fully aligns to
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# the kernel being tested.
return
ops
.
scaled_fp8_quant
(
input
,
use_per_token_if_dynamic
=
True
)
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
inv_scale
=
inv_scale
.
view
(
-
1
,
1
)
scale
=
inv_scale
.
reciprocal
()
qweight
=
(
tensor
.
to
(
torch
.
float32
)
*
scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
qweight
=
qweight
.
to
(
torch
.
float8_e4m3fn
)
return
qweight
def
sglang_per_token_quant_fp8
(
def
sglang_per_token_quant_fp8
(
...
@@ -41,12 +45,11 @@ def test_per_token_quant_compare_implementations(
...
@@ -41,12 +45,11 @@ def test_per_token_quant_compare_implementations(
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
x
=
torch
.
rand
((
num_tokens
,
hidden_dim
),
dtype
=
torch
.
float16
,
device
=
device
)
x
=
torch
.
rand
((
num_tokens
,
hidden_dim
),
dtype
=
torch
.
float16
,
device
=
device
)
vllm_out
,
vllm_scale
=
vllm_per_token_quant_fp8
(
x
)
sglang_out
,
sglang_scale
=
sglang_per_token_quant_fp8
(
x
)
sglang_out
,
sglang_scale
=
sglang_per_token_quant_fp8
(
x
)
torch_out
=
torch_per_token_quant_fp8
(
x
,
sglang_scale
)
torch
.
testing
.
assert_close
(
vllm_scale
,
sglang_scale
,
rtol
=
1e-3
,
atol
=
1e-3
)
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
vllm
_out
.
float
(),
sglang
_out
.
float
(),
rtol
=
1e-3
,
atol
=
1e-3
sglang
_out
.
float
(),
torch
_out
.
float
(),
rtol
=
1e-3
,
atol
=
1e-3
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment