Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5ca07eed
"tests/vscode:/vscode.git/clone" did not exist on "9fee20b91dba3804a7a19f327b2b1d2407d93874"
Unverified
Commit
5ca07eed
authored
Jun 17, 2025
by
JieXin Liang
Committed by
GitHub
Jun 16, 2025
Browse files
[fix] fix DeepGEMM blackwell input quant & ut & fix style and log (#7247)
parent
e30ef368
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
285 additions
and
31 deletions
+285
-31
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+2
-2
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
+2
-2
python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py
...rt/layers/quantization/deep_gemm_wrapper/compile_utils.py
+6
-9
python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py
...g/srt/layers/quantization/deep_gemm_wrapper/configurer.py
+7
-7
python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py
...g/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py
+3
-3
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+4
-3
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+4
-3
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+4
-2
python/sglang/test/test_block_fp8.py
python/sglang/test/test_block_fp8.py
+1
-0
python/sglang/test/test_block_fp8_deep_gemm_blackwell.py
python/sglang/test/test_block_fp8_deep_gemm_blackwell.py
+252
-0
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
5ca07eed
...
...
@@ -1201,7 +1201,7 @@ class DeepEPMoE(EPMoE):
gateup_output
,
masked_m
,
expected_m
,
recipe
=
(
1
,
128
,
128
)
if
deep_gemm_wrapper
.
DEEPGEMM_
V202506
else
None
,
recipe
=
(
1
,
128
,
128
)
if
deep_gemm_wrapper
.
DEEPGEMM_
BLACKWELL
else
None
,
)
dispose_tensor
(
hidden_states_fp8
[
0
])
...
...
@@ -1256,7 +1256,7 @@ class DeepEPMoE(EPMoE):
down_output
,
masked_m
,
expected_m
,
recipe
=
(
1
,
128
,
128
)
if
deep_gemm_wrapper
.
DEEPGEMM_
V202506
else
None
,
recipe
=
(
1
,
128
,
128
)
if
deep_gemm_wrapper
.
DEEPGEMM_
BLACKWELL
else
None
,
)
return
down_output
...
...
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
View file @
5ca07eed
...
...
@@ -553,9 +553,9 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
async_finish
=
not
self
.
return_recv_hook
,
return_recv_hook
=
self
.
return_recv_hook
,
round_scale
=
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
deep_gemm_wrapper
.
DEEPGEMM_
V202506
,
and
deep_gemm_wrapper
.
DEEPGEMM_
BLACKWELL
,
use_ue8m0
=
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
deep_gemm_wrapper
.
DEEPGEMM_
V202506
,
and
deep_gemm_wrapper
.
DEEPGEMM_
BLACKWELL
,
)
)
return
packed_recv_hidden
,
packed_recv_count
,
event
,
hook
...
...
python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py
View file @
5ca07eed
...
...
@@ -8,7 +8,7 @@ from typing import Callable, Dict, List, Optional, Tuple
from
tqdm.contrib.concurrent
import
thread_map
from
sglang.srt.layers.quantization.deep_gemm_wrapper.configurer
import
(
DEEPGEMM_
V202506
,
DEEPGEMM_
BLACKWELL
,
ENABLE_JIT_DEEPGEMM
,
)
from
sglang.srt.server_args
import
ServerArgs
...
...
@@ -16,13 +16,11 @@ from sglang.srt.utils import get_bool_env_var, get_int_env_var
logger
=
logging
.
getLogger
(
__name__
)
try
:
if
ENABLE_JIT_DEEPGEMM
and
not
DEEPGEMM_BLACKWELL
:
from
deep_gemm
import
get_num_sms
from
deep_gemm.jit
import
build
from
deep_gemm.jit_kernels.gemm
import
get_best_configs
from
deep_gemm.jit_kernels.runtime
import
FP8GemmRuntime
,
GemmType
except
ImportError
:
pass
_BUILTIN_M_LIST
=
list
(
range
(
1
,
1024
*
16
+
1
))
...
...
@@ -313,7 +311,8 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
ret
=
origin_func
(
self
,
*
args
,
**
kwargs
)
if
ret
is
None
:
kernel_helper
=
_KERNEL_HELPER_DICT
[
kernel_type
]
_compile_warning_2
()
if
not
DEEPGEMM_BLACKWELL
:
_compile_warning_2
()
logger
.
warning
(
f
"DeepGEMM JIT Compiling for <
{
kernel_helper
.
name
}
> M=
{
M
}
, N=
{
N
}
, K=
{
K
}
. Please wait."
)
...
...
@@ -329,10 +328,8 @@ def deep_gemm_execution_hook(
m
:
int
,
n
:
int
,
k
:
int
,
num_groups
:
int
,
kernel_type
:
DeepGemmKernelType
):
# not supported yet
if
DEEPGEMM_V202506
:
yield
return
if
not
DEEPGEMM_BLACKWELL
:
_maybe_compile_deep_gemm_one_type_all
(
kernel_type
,
n
,
k
,
num_groups
)
_maybe_compile_deep_gemm_one_type_all
(
kernel_type
,
n
,
k
,
num_groups
)
with
_log_jit_build
(
m
,
n
,
k
,
kernel_type
):
yield
python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py
View file @
5ca07eed
...
...
@@ -6,16 +6,16 @@ logger = logging.getLogger(__name__)
def
_compute_enable_deep_gemm
():
sm_version
=
get_device_sm
()
if
sm_version
<
90
:
return
False
try
:
import
deep_gemm
except
ImportError
:
logger
.
warning
(
"Failed to import deep_gemm, disable ENABLE_JIT_DEEPGEMM."
)
return
False
sm_version
=
get_device_sm
()
if
sm_version
<
90
:
return
False
return
get_bool_env_var
(
"SGL_ENABLE_JIT_DEEPGEMM"
,
default
=
"true"
)
...
...
@@ -25,8 +25,8 @@ try:
from
deep_gemm
import
fp8_gemm_nt
# They have not given a name to this breaking change
DEEPGEMM_
V202506
=
True
DEEPGEMM_
BLACKWELL
=
True
except
ImportError
:
DEEPGEMM_
V202506
=
False
DEEPGEMM_
BLACKWELL
=
False
DEEPGEMM_SCALE_UE8M0
=
DEEPGEMM_
V202506
DEEPGEMM_SCALE_UE8M0
=
DEEPGEMM_
BLACKWELL
python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py
View file @
5ca07eed
...
...
@@ -6,8 +6,8 @@ import torch
from
sglang.srt.layers.quantization.deep_gemm_wrapper
import
compile_utils
from
sglang.srt.layers.quantization.deep_gemm_wrapper.configurer
import
(
DEEPGEMM_BLACKWELL
,
DEEPGEMM_SCALE_UE8M0
,
DEEPGEMM_V202506
,
ENABLE_JIT_DEEPGEMM
,
)
from
sglang.srt.server_args
import
ServerArgs
...
...
@@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
if
ENABLE_JIT_DEEPGEMM
:
import
deep_gemm
if
DEEPGEMM_
V202506
:
if
DEEPGEMM_
BLACKWELL
:
from
deep_gemm
import
fp8_gemm_nt
as
_gemm_nt_f8f8bf16_raw
from
deep_gemm
import
(
fp8_m_grouped_gemm_nt_masked
as
_grouped_gemm_nt_f8f8bf16_masked_raw
,
...
...
@@ -57,7 +57,7 @@ def grouped_gemm_nt_f8f8bf16_masked(
out
,
masked_m
,
expected_m
,
**
({
"recipe"
:
recipe
}
if
DEEPGEMM_
V202506
else
{})
**
({
"recipe"
:
recipe
}
if
DEEPGEMM_
BLACKWELL
else
{})
)
...
...
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
5ca07eed
...
...
@@ -290,11 +290,12 @@ def sglang_per_token_group_quant_fp8(
x_s_mn
,
x_s_k
=
x_q_mn
,
x_q_k
//
128
aligned_mn
=
align
(
x_s_mn
,
4
)
aligned_k
=
align
(
x_s_k
,
4
)
x_s
=
torch
.
empty
(
# TODO(FIXME): Fix cuda kernel and recover here to empty.
x_s
=
torch
.
zeros
(
(
aligned_k
//
4
,
aligned_mn
),
device
=
x
.
device
,
dtype
=
torch
.
int
,
).
permute
(
-
1
,
-
2
)[:
x_s_mn
,
:]
).
transpose
(
0
,
1
)[:
x_s_mn
,
:]
elif
column_major_scales
:
if
scale_tma_aligned
:
# TODO extract "align" function
...
...
@@ -768,7 +769,7 @@ def prepare_block_fp8_matmul_inputs(
if
As
.
dtype
==
torch
.
float
:
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
As
.
shape
[
-
1
]
elif
B
s
.
dtype
==
torch
.
int
:
elif
A
s
.
dtype
==
torch
.
int
:
assert
(
triton
.
cdiv
(
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
),
4
)
==
As
.
shape
[
-
1
]
),
f
"
{
A
.
shape
=
}
{
As
.
shape
=
}
{
block_size
=
}
"
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
5ca07eed
...
...
@@ -241,9 +241,10 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
scale_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
)
if
get_bool_env_var
(
"SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"
):
_check_ue8m0
(
"x_scale"
,
x_scale
)
_check_ue8m0
(
"weight_scale"
,
weight_scale
)
# NOTE(alcanderian): Useless when scale is packed to int32
# if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"):
# _check_ue8m0("x_scale", x_scale)
# _check_ue8m0("weight_scale", ws)
output
=
w8a8_block_fp8_matmul_deepgemm
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
output_dtype
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
5ca07eed
...
...
@@ -1829,8 +1829,10 @@ class DeepseekV2ForCausalLM(nn.Module):
and
weight_block_size
[
1
]
==
128
and
model_dtype
==
torch
.
bfloat16
):
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
get_bool_env_var
(
"SGL_USE_DEEPGEMM_BMM"
,
"false"
if
(
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
not
deep_gemm_wrapper
.
DEEPGEMM_BLACKWELL
and
get_bool_env_var
(
"SGL_USE_DEEPGEMM_BMM"
,
"false"
)
):
block_scale
=
weight_scale
use_deep_gemm_bmm
=
True
...
...
python/sglang/test/test_block_fp8.py
View file @
5ca07eed
...
...
@@ -343,6 +343,7 @@ class TestW8A8BlockFP8Matmul(CustomTestCase):
OUT_DTYPES
=
[
torch
.
bfloat16
]
M
=
[
64
,
128
,
512
,
1024
,
4096
]
NKs
=
[
(
2112
,
7168
),
(
1536
,
7168
),
(
3072
,
1536
),
(
24576
,
7168
),
...
...
python/sglang/test/test_block_fp8_deep_gemm_blackwell.py
0 → 100644
View file @
5ca07eed
import
itertools
import
os
import
unittest
from
typing
import
List
,
Tuple
import
torch
from
deep_gemm
import
fp8_gemm_nt
from
sglang.test.test_utils
import
CustomTestCase
_is_cuda
=
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
# Modify form DeepGEMM Blackwell
def
ceil_div
(
x
:
int
,
y
:
int
)
->
int
:
return
(
x
+
y
-
1
)
//
y
def
align
(
x
:
int
,
y
:
int
)
->
int
:
return
ceil_div
(
x
,
y
)
*
y
def
per_token_group_quant_fp8
(
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
and
x
.
size
(
1
)
%
128
==
0
m
,
n
=
x
.
shape
x_view
=
x
.
view
(
m
,
-
1
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
2
).
view
(
m
,
-
1
).
clamp
(
1e-4
)
sf
=
x_amax
/
448.0
return
(
x_view
*
(
1.0
/
sf
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
).
view
(
m
,
n
),
sf
def
per_block_quant_fp8
(
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
x_padded
=
torch
.
zeros
(
(
align
(
m
,
128
),
align
(
n
,
128
)),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x_padded
[:
m
,
:
n
]
=
x
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
sf
=
x_amax
/
448.0
x_scaled
=
(
x_view
*
(
1.0
/
sf
)).
to
(
torch
.
float8_e4m3fn
)
return
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
(),
sf
.
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
)
)
def
ceil_to_ue8m0
(
x
:
torch
.
Tensor
):
assert
x
.
view
(
-
1
).
amax
().
item
()
>
0
return
torch
.
pow
(
2.0
,
torch
.
ceil
(
torch
.
log2
(
x
.
abs
())))
def
per_token_group_quant_mxfp8
(
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
and
x
.
size
(
1
)
%
128
==
0
m
,
n
=
x
.
shape
x_view
=
x
.
view
(
m
,
-
1
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
2
).
view
(
m
,
-
1
).
clamp
(
1e-4
)
sf
=
ceil_to_ue8m0
(
x_amax
/
448.0
)
return
(
x_view
*
(
1.0
/
sf
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
).
view
(
m
,
n
),
sf
def
per_block_quant_mxfp8
(
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
x_padded
=
torch
.
zeros
(
(
align
(
m
,
128
),
align
(
n
,
128
)),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x_padded
[:
m
,
:
n
]
=
x
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
sf
=
ceil_to_ue8m0
(
x_amax
/
448.0
)
x_scaled
=
(
x_view
*
(
1.0
/
sf
)).
to
(
torch
.
float8_e4m3fn
)
return
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
(),
sf
.
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
)
)
# For test
def
native_w8a8_block_fp8_matmul
(
A
,
B
,
As
,
Bs
,
block_size
,
output_dtype
=
torch
.
float16
):
"""This function performs matrix multiplication with block-wise quantization using native torch.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
"""
A
=
A
.
to
(
torch
.
float32
)
B
=
B
.
to
(
torch
.
float32
)
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
assert
B
.
ndim
==
2
and
B
.
is_contiguous
()
and
Bs
.
ndim
==
2
assert
len
(
block_size
)
==
2
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
assert
(
A
.
shape
[
-
1
]
+
block_k
-
1
)
//
block_k
==
As
.
shape
[
-
1
]
assert
A
.
shape
[:
-
1
]
==
As
.
shape
[:
-
1
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
N
,
K
=
B
.
shape
origin_C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
A
=
A
.
reshape
(
M
,
A
.
shape
[
-
1
])
As
=
As
.
reshape
(
M
,
As
.
shape
[
-
1
])
n_tiles
=
(
N
+
block_n
-
1
)
//
block_n
k_tiles
=
(
K
+
block_k
-
1
)
//
block_k
assert
n_tiles
==
Bs
.
shape
[
0
]
assert
k_tiles
==
Bs
.
shape
[
1
]
C_shape
=
(
M
,
N
)
C
=
torch
.
zeros
(
C_shape
,
dtype
=
torch
.
float32
,
device
=
A
.
device
)
A_tiles
=
[
A
[:,
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
K
)]
for
i
in
range
(
k_tiles
)]
B_tiles
=
[
[
B
[
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
N
),
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
K
),
]
for
i
in
range
(
k_tiles
)
]
for
j
in
range
(
n_tiles
)
]
C_tiles
=
[
C
[:,
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
N
)]
for
j
in
range
(
n_tiles
)]
As_tiles
=
[
As
[:,
i
:
i
+
1
]
for
i
in
range
(
k_tiles
)]
for
i
in
range
(
k_tiles
):
for
j
in
range
(
n_tiles
):
a
=
A_tiles
[
i
]
b
=
B_tiles
[
j
][
i
]
c
=
C_tiles
[
j
]
s
=
As_tiles
[
i
]
*
Bs
[
j
][
i
]
c
[:,
:]
+=
torch
.
matmul
(
a
,
b
.
t
())
*
s
C
=
C
.
reshape
(
origin_C_shape
).
to
(
output_dtype
)
return
C
def
block_quant_dequant
(
x_q_block
:
torch
.
Tensor
,
x_s
:
torch
.
Tensor
,
block_size
:
List
[
int
],
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
"""This function converts block-wise quantization to unquantized.
The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
and the block size.
The output is an unquantized tensor with dtype.
"""
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
n
,
k
=
x_q_block
.
shape
n_tiles
=
(
n
+
block_n
-
1
)
//
block_n
k_tiles
=
(
k
+
block_k
-
1
)
//
block_k
assert
n_tiles
==
x_s
.
shape
[
0
]
assert
k_tiles
==
x_s
.
shape
[
1
]
x_dq_block
=
torch
.
empty_like
(
x_q_block
,
dtype
=
dtype
)
for
j
in
range
(
n_tiles
):
for
i
in
range
(
k_tiles
):
x_q_block_tile
=
x_q_block
[
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
n
),
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
k
),
]
x_dq_block_tile
=
x_dq_block
[
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
n
),
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
k
),
]
x_dq_block_tile
[:,
:]
=
x_q_block_tile
.
to
(
torch
.
float32
)
*
x_s
[
j
][
i
]
return
x_dq_block
class
TestDeepGemmBlackwell
(
CustomTestCase
):
if
not
_is_cuda
:
OUT_DTYPES
=
[
torch
.
float32
,
torch
.
half
,
torch
.
bfloat16
]
M
=
[
1
,
7
,
83
,
512
,
2048
]
NKs
=
[
(
N
,
K
)
for
N
in
[
128
,
512
,
1024
,
4096
,
7748
,
13824
]
for
K
in
[
256
,
4096
,
5120
,
3884
,
13824
]
]
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
BLOCK_SIZE
=
[[
128
,
128
]]
SEEDS
=
[
0
]
else
:
# use practical shape in DeepSeek V3 for test
OUT_DTYPES
=
[
torch
.
bfloat16
]
M
=
[
64
,
128
,
512
,
1024
,
4096
]
NKs
=
[
(
2112
,
7168
),
(
1536
,
7168
),
# (3072, 1536),
# (24576, 7168),
# (4096, 512),
# (7168, 2048),
# (4608, 7168),
# (512, 7168),
# (7168, 2304),
# (7168, 512),
]
BLOCK_SIZE
=
[[
128
,
128
]]
SEEDS
=
[
0
]
@
classmethod
def
setUpClass
(
cls
):
if
not
torch
.
cuda
.
is_available
():
raise
unittest
.
SkipTest
(
"CUDA is not available"
)
torch
.
set_default_device
(
"cuda"
)
def
_test_deep_gemm_blackwell
(
self
,
M
,
NK
,
block_size
,
out_dtype
,
seed
):
N
,
K
=
NK
torch
.
manual_seed
(
seed
)
A
=
torch
.
empty
((
M
,
K
),
dtype
=
torch
.
bfloat16
).
normal_
(
0
,
0.2
)
B
=
torch
.
empty
((
N
,
K
),
dtype
=
torch
.
bfloat16
).
normal_
(
0
,
0.2
)
A_q
,
A_s
=
per_token_group_quant_fp8
(
A
)
B_q
,
B_s
=
per_block_quant_fp8
(
B
)
A_dq
=
block_quant_dequant
(
A_q
,
A_s
,
[
1
,
block_size
[
1
]],
out_dtype
)
B_dq
=
block_quant_dequant
(
B_q
,
B_s
,
block_size
,
out_dtype
)
A_qu
=
per_token_group_quant_mxfp8
(
A_dq
)
B_qu
=
per_block_quant_mxfp8
(
B_dq
)
out
=
None
with
torch
.
inference_mode
():
ref_out
=
native_w8a8_block_fp8_matmul
(
A_q
,
B_q
,
A_s
,
B_s
,
block_size
,
out_dtype
)
out
=
torch
.
empty_like
(
ref_out
)
fp8_gemm_nt
(
A_qu
,
B_qu
,
out
)
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
1e-1
,
rtol
=
1e-2
)
def
test_deep_gemm_blackwell
(
self
):
for
params
in
itertools
.
product
(
self
.
M
,
self
.
NKs
,
self
.
BLOCK_SIZE
,
self
.
OUT_DTYPES
,
self
.
SEEDS
,
):
with
self
.
subTest
(
M
=
params
[
0
],
NKs
=
params
[
1
],
block_size
=
params
[
2
],
out_dtype
=
params
[
3
],
seed
=
params
[
4
],
):
self
.
_test_deep_gemm_blackwell
(
*
params
)
if
__name__
==
"__main__"
:
unittest
.
main
(
verbosity
=
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment