Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
99456bca
Unverified
Commit
99456bca
authored
Apr 20, 2025
by
JieXin Liang
Committed by
GitHub
Apr 20, 2025
Browse files
[perf] introduce deep gemm group_gemm_masked as bmm (#5432)
parent
d07e797a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
361 additions
and
20 deletions
+361
-20
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+108
-4
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+86
-16
python/sglang/test/test_block_fp8.py
python/sglang/test/test_block_fp8.py
+167
-0
No files found.
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
99456bca
...
...
@@ -44,6 +44,7 @@ else:
fp8_min
=
-
fp8_max
_enable_jit_deepgemm
=
False
_enable_jit_deepgemm_bmm
=
False
if
_is_cuda
:
import
deep_gemm
from
sgl_kernel
import
(
...
...
@@ -53,10 +54,11 @@ if _is_cuda:
)
sm_version
=
get_device_sm
()
if
sm_version
==
90
and
get_bool_env_var
(
"SGL_ENABLE_JIT_DEEPGEMM"
,
default
=
"false"
):
_enable_jit_deepgemm
=
True
if
sm_version
==
90
:
if
get_bool_env_var
(
"SGL_ENABLE_JIT_DEEPGEMM"
,
default
=
"false"
):
_enable_jit_deepgemm
=
True
if
get_bool_env_var
(
"SGL_ENABLE_JIT_DEEPGEMM_BMM"
,
default
=
"false"
):
_enable_jit_deepgemm_bmm
=
True
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -940,6 +942,108 @@ def per_tensor_quant_mla_fp8(
return
x_q
,
x_s_out
@
triton
.
jit
def
_per_token_group_quant_mla_deep_gemm_masked_fp8
(
y_ptr
,
y_q_ptr
,
y_s_ptr
,
masked_m_ptr
,
group_size
,
y_stride_b
,
y_stride_t
,
y_q_stride_b
,
y_q_stride_t
,
y_s_stride_b
,
y_s_stride_g
,
eps
,
fp8_min
,
fp8_max
,
NUM_GROUP
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor for deep_gemm grouped_gemm_masked.
This function converts the tensor values into float8 values.
y and y_q: (b, t, k)
y_s: (b, k//group_size, t)
"""
t_id
=
tl
.
program_id
(
0
)
b_id
=
tl
.
program_id
(
1
)
y_ptr
+=
b_id
*
y_stride_b
+
t_id
*
y_stride_t
y_q_ptr
+=
b_id
*
y_q_stride_b
+
t_id
*
y_q_stride_t
y_s_ptr
+=
b_id
*
y_s_stride_b
+
t_id
if
t_id
==
0
:
tl
.
store
(
masked_m_ptr
+
b_id
,
tl
.
num_programs
(
0
))
cols
=
tl
.
arange
(
0
,
BLOCK
)
# group_size <= BLOCK
mask
=
cols
<
group_size
for
gid
in
range
(
NUM_GROUP
):
y
=
tl
.
load
(
y_ptr
+
gid
*
group_size
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp8_max
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
gid
*
group_size
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
+
gid
*
y_s_stride_g
,
y_s
)
def
per_tensor_quant_mla_deep_gemm_masked_fp8
(
x
:
torch
.
Tensor
,
group_size
:
int
=
128
,
eps
:
float
=
1e-12
,
dtype
:
torch
.
dtype
=
torch
.
float8_e4m3fn
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
This function quantizes input values to float8 values with per-token-group-quantization
for deep_gemm grouped_gemm_masked and specialized for mla absorbed case.
"""
assert
x
.
dim
()
==
3
,
"`x` is not a 3d-tensor"
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
if
_is_hip
:
dtype
=
torch
.
float8_e4m3fnuz
fp8_max
=
224.0
b
,
m
,
k
=
x
.
shape
aligned_m
=
(
m
+
255
)
//
256
*
256
# 256 is the max block_m of the gemm kernel
num_tiles_k
=
k
//
group_size
assert
num_tiles_k
*
group_size
==
k
,
f
"k %
{
group_size
}
must be zero"
x_q
=
x
.
new_empty
((
b
,
aligned_m
,
k
),
dtype
=
dtype
)
x_s
=
x
.
new_empty
((
b
,
num_tiles_k
,
aligned_m
),
dtype
=
torch
.
float32
)
masked_m
=
x
.
new_empty
((
b
,),
dtype
=
torch
.
int32
)
BLOCK_SIZE
=
triton
.
next_power_of_2
(
group_size
)
grid
=
(
m
,
b
)
_per_token_group_quant_mla_deep_gemm_masked_fp8
[
grid
](
x
,
x_q
,
x_s
,
masked_m
,
group_size
,
x
.
stride
(
0
),
x
.
stride
(
1
),
x_q
.
stride
(
0
),
x_q
.
stride
(
1
),
x_s
.
stride
(
0
),
x_s
.
stride
(
1
),
eps
,
-
fp8_max
,
fp8_max
,
num_tiles_k
,
BLOCK_SIZE
,
)
return
x_q
,
x_s
.
transpose
(
1
,
2
),
masked_m
,
m
,
aligned_m
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
99456bca
...
...
@@ -57,7 +57,11 @@ from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8_kernel
import
per_tensor_quant_mla_fp8
from
sglang.srt.layers.quantization.fp8_kernel
import
(
_enable_jit_deepgemm_bmm
,
per_tensor_quant_mla_deep_gemm_masked_fp8
,
per_tensor_quant_mla_fp8
,
)
from
sglang.srt.layers.quantization.fp8_utils
import
(
block_quant_to_tensor_quant
,
channel_quant_to_tensor_quant
,
...
...
@@ -82,6 +86,7 @@ _is_hip = is_hip()
_is_cuda
=
is_cuda
()
if
_is_cuda
:
from
deep_gemm
import
m_grouped_gemm_fp8_fp8_bf16_nt_masked
from
sgl_kernel
import
awq_dequantize
,
bmm_fp8
,
merge_state_v2
else
:
from
vllm._custom_ops
import
awq_dequantize
...
...
@@ -530,6 +535,10 @@ class DeepseekV2AttentionMLA(nn.Module):
self
.
w_vc
=
None
self
.
w_scale
=
None
self
.
w_scale_k
=
None
self
.
w_scale_v
=
None
self
.
use_deep_gemm_bmm
=
False
self
.
flashinfer_mla_disable_ragged
=
global_server_args_dict
[
"flashinfer_mla_disable_ragged"
]
...
...
@@ -684,7 +693,24 @@ class DeepseekV2AttentionMLA(nn.Module):
)
q_nope
,
q_pe
=
q
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
if
self
.
w_kc
.
dtype
==
torch
.
float8_e4m3fnuz
:
if
self
.
use_deep_gemm_bmm
:
q_nope_val
,
q_nope_scale
,
masked_m
,
expected_m
,
aligned_m
=
(
per_tensor_quant_mla_deep_gemm_masked_fp8
(
q_nope
.
transpose
(
0
,
1
),
dtype
=
torch
.
float8_e4m3fn
)
)
q_nope_out
=
q_nope
.
new_empty
(
(
self
.
num_local_heads
,
aligned_m
,
self
.
kv_lora_rank
)
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked
(
(
q_nope_val
,
q_nope_scale
),
(
self
.
w_kc
,
self
.
w_scale_k
),
q_nope_out
,
masked_m
,
expected_m
,
)
q_nope_out
=
q_nope_out
[:,
:
expected_m
,
:]
elif
self
.
w_kc
.
dtype
==
torch
.
float8_e4m3fnuz
:
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
q_nope_out
=
torch
.
bmm
(
q_nope
.
to
(
torch
.
bfloat16
).
transpose
(
0
,
1
),
...
...
@@ -716,7 +742,24 @@ class DeepseekV2AttentionMLA(nn.Module):
attn_output
=
self
.
attn_mqa
(
q_input
,
k_input
,
v_input
,
forward_batch
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
kv_lora_rank
)
if
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fnuz
:
if
self
.
use_deep_gemm_bmm
:
attn_output_val
,
attn_output_scale
,
masked_m
,
expected_m
,
aligned_m
=
(
per_tensor_quant_mla_deep_gemm_masked_fp8
(
attn_output
.
transpose
(
0
,
1
),
dtype
=
torch
.
float8_e4m3fn
)
)
attn_bmm_output
=
attn_output
.
new_empty
(
(
self
.
num_local_heads
,
aligned_m
,
self
.
v_head_dim
)
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked
(
(
attn_output_val
,
attn_output_scale
),
(
self
.
w_vc
,
self
.
w_scale_v
),
attn_bmm_output
,
masked_m
,
expected_m
,
)
attn_bmm_output
=
attn_bmm_output
[:,
:
expected_m
,
:]
elif
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fnuz
:
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
attn_bmm_output
=
torch
.
bmm
(
attn_output
.
to
(
torch
.
bfloat16
).
transpose
(
0
,
1
),
...
...
@@ -1439,6 +1482,10 @@ class DeepseekV2ForCausalLM(nn.Module):
w
=
self_attn
.
kv_b_proj
.
weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model.
# Fix deepseek v3 blockwise bmm by using deep_gemm
use_deep_gemm_bmm
=
False
model_dtype
=
torch
.
get_default_dtype
()
if
w
.
dtype
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
...
...
@@ -1457,10 +1504,20 @@ class DeepseekV2ForCausalLM(nn.Module):
weight
=
w
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale_inv
w
,
scale
=
block_quant_to_tensor_quant
(
weight
,
weight_scale
,
weight_block_size
)
self_attn
.
w_scale
=
scale
if
(
_is_cuda
and
_enable_jit_deepgemm_bmm
and
weight_block_size
[
0
]
==
128
and
weight_block_size
[
1
]
==
128
and
model_dtype
==
torch
.
bfloat16
):
block_scale
=
weight_scale
use_deep_gemm_bmm
=
True
else
:
w
,
scale
=
block_quant_to_tensor_quant
(
weight
,
weight_scale
,
weight_block_size
)
self_attn
.
w_scale
=
scale
else
:
weight
=
w
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale
...
...
@@ -1483,18 +1540,31 @@ class DeepseekV2ForCausalLM(nn.Module):
w
=
w
.
to
(
torch
.
bfloat16
)
*
self_attn
.
kv_b_proj
.
weight_scale
.
to
(
torch
.
bfloat16
)
w_kc
,
w_vc
=
w
.
unflatten
(
0
,
(
-
1
,
self_attn
.
qk_nope_head_dim
+
self_attn
.
v_head_dim
)
).
split
([
self_attn
.
qk_nope_head_dim
,
self_attn
.
v_head_dim
],
dim
=
1
)
self_attn
.
w_kc
=
w_kc
.
transpose
(
1
,
2
).
contiguous
().
transpose
(
1
,
2
)
self_attn
.
w_vc
=
w_vc
.
contiguous
().
transpose
(
1
,
2
)
if
(
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale"
)
and
self_attn
.
w_scale
is
None
):
self_attn
.
w_scale
=
self_attn
.
kv_b_proj
.
weight_scale
if
_is_hip
:
self_attn
.
w_scale
*=
2.0
if
not
use_deep_gemm_bmm
:
self_attn
.
w_kc
=
w_kc
.
transpose
(
1
,
2
).
contiguous
().
transpose
(
1
,
2
)
self_attn
.
w_vc
=
w_vc
.
contiguous
().
transpose
(
1
,
2
)
if
(
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale"
)
and
self_attn
.
w_scale
is
None
):
self_attn
.
w_scale
=
self_attn
.
kv_b_proj
.
weight_scale
if
_is_hip
:
self_attn
.
w_scale
*=
2.0
else
:
num_tiles_k
=
self_attn
.
qk_nope_head_dim
//
weight_block_size
[
1
]
num_tiles_n
=
self_attn
.
v_head_dim
//
weight_block_size
[
0
]
ws_kc
,
ws_vc
=
block_scale
.
unflatten
(
0
,
(
-
1
,
(
num_tiles_k
+
num_tiles_n
))
).
split
([
num_tiles_k
,
num_tiles_n
],
dim
=
1
)
self_attn
.
w_scale_k
=
ws_kc
.
transpose
(
1
,
2
).
contiguous
()
self_attn
.
w_scale_v
=
ws_vc
.
contiguous
()
self_attn
.
w_kc
=
w_kc
.
transpose
(
1
,
2
).
contiguous
()
self_attn
.
w_vc
=
w_vc
.
contiguous
()
self_attn
.
use_deep_gemm_bmm
=
True
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/test/test_block_fp8.py
View file @
99456bca
...
...
@@ -7,6 +7,7 @@ import torch
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_tensor_quant_mla_deep_gemm_masked_fp8
,
per_tensor_quant_mla_fp8
,
per_token_group_quant_fp8
,
static_quant_fp8
,
...
...
@@ -212,6 +213,62 @@ class TestPerTensorQuantMlaFP8(CustomTestCase):
self
.
_per_tensor_quant_mla_fp8
(
*
params
)
class
TestPerTokenGroupQuantMlaDeepGemmMaskedFP8
(
CustomTestCase
):
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float32
]
B
=
[
128
]
NUM_TOKENS
=
[
7
,
83
,
2048
,
1024
*
16
]
D
=
[
512
,
128
]
GROUP_SIZE
=
[
128
]
SEEDS
=
[
0
]
@
classmethod
def
setUpClass
(
cls
):
if
not
torch
.
cuda
.
is_available
():
raise
unittest
.
SkipTest
(
"CUDA is not available"
)
torch
.
set_default_device
(
"cuda"
)
def
_per_token_group_quant_mla_deep_gemm_masked_fp8
(
self
,
b
,
num_tokens
,
d
,
dtype
,
group_size
,
seed
):
torch
.
manual_seed
(
seed
)
x
=
torch
.
rand
(
b
,
num_tokens
,
d
,
dtype
=
dtype
)
with
torch
.
inference_mode
():
ref_out
,
ref_scale
=
native_per_token_group_quant_fp8
(
x
,
group_size
,
1e-12
)
out
,
scale
,
_
,
_
,
_
=
per_tensor_quant_mla_deep_gemm_masked_fp8
(
x
,
group_size
)
out
=
out
[:,
:
num_tokens
,
:]
scale
=
scale
[:,
:
num_tokens
,
:]
self
.
assertTrue
(
torch
.
allclose
(
out
.
to
(
torch
.
float32
),
ref_out
.
to
(
torch
.
float32
),
rtol
=
0.20
,
atol
=
1e-2
)
)
self
.
assertTrue
(
torch
.
allclose
(
scale
,
ref_scale
))
def
test_per_token_group_quant_mla_deep_gemm_masked_fp8
(
self
):
for
params
in
itertools
.
product
(
self
.
B
,
self
.
NUM_TOKENS
,
self
.
D
,
self
.
DTYPES
,
self
.
GROUP_SIZE
,
self
.
SEEDS
,
):
with
self
.
subTest
(
b
=
params
[
0
],
num_tokens
=
params
[
1
],
d
=
params
[
2
],
dtype
=
params
[
3
],
group_size
=
params
[
4
],
seed
=
params
[
5
],
):
self
.
_per_token_group_quant_mla_deep_gemm_masked_fp8
(
*
params
)
# For test
def
native_w8a8_block_fp8_matmul
(
A
,
B
,
As
,
Bs
,
block_size
,
output_dtype
=
torch
.
float16
):
"""This function performs matrix multiplication with block-wise quantization using native torch.
...
...
@@ -485,5 +542,115 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
self
.
_w8a8_block_fp8_fused_moe
(
*
params
)
# For test
def
torch_w8a8_block_fp8_bmm
(
a
,
a_s
,
w
,
w_s
,
block_shape
,
out_dtype
):
"""This function performs bmm with block-wise quantization using native torch."""
B
,
N
,
_
=
w
.
shape
_
,
M
,
_
=
a
.
shape
out
=
torch
.
empty
((
B
,
M
,
N
),
dtype
=
out_dtype
,
device
=
a
.
device
)
for
i
in
range
(
B
):
out
[
i
]
=
native_w8a8_block_fp8_matmul
(
a
[
i
],
w
[
i
],
a_s
[
i
],
w_s
[
i
],
block_shape
,
output_dtype
=
out_dtype
)
return
out
class
TestW8A8BlockFP8BatchedDeepGemm
(
CustomTestCase
):
DTYPES
=
[
torch
.
bfloat16
]
M
=
[
1
,
33
,
64
,
222
,
8192
]
N
=
[
128
,
512
]
K
=
[
128
,
512
]
BATCH
=
[
128
]
BLOCK_SIZE
=
[[
128
,
128
]]
SEEDS
=
[
0
]
@
classmethod
def
setUpClass
(
cls
):
if
not
torch
.
cuda
.
is_available
():
raise
unittest
.
SkipTest
(
"CUDA is not available"
)
try
:
import
deep_gemm
except
ImportError
:
raise
unittest
.
SkipTest
(
"DeepGEMM is not available"
)
torch
.
set_default_device
(
"cuda"
)
def
_w8a8_block_fp8_batched_deep_gemm
(
self
,
M
,
N
,
K
,
B
,
block_size
,
dtype
,
seed
):
torch
.
manual_seed
(
seed
)
factor_for_scale
=
1e-2
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
a_fp32
=
torch
.
randn
((
B
,
M
,
K
),
dtype
=
torch
.
float32
)
/
10
a
=
a_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
w_fp32
=
(
torch
.
rand
((
B
,
N
,
K
),
dtype
=
torch
.
float32
)
-
0.5
)
*
2
*
fp8_max
w
=
w_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
n_tiles_w
=
(
N
+
block_n
-
1
)
//
block_n
k_tiles_w
=
(
K
+
block_k
-
1
)
//
block_k
w_s
=
(
torch
.
rand
((
B
,
n_tiles_w
,
k_tiles_w
),
dtype
=
torch
.
float32
)
*
factor_for_scale
)
a_s
=
torch
.
rand
((
B
,
M
,
k_tiles_w
),
dtype
=
torch
.
float32
)
*
factor_for_scale
ae
=
a
.
new_empty
(
B
,
(
M
+
255
)
//
256
*
256
,
K
)
ae_s
=
a_s
.
new_empty
(
B
,
(
M
+
255
)
//
256
*
256
,
k_tiles_w
)
oe
=
torch
.
empty
((
B
,
(
M
+
255
)
//
256
*
256
,
N
),
dtype
=
dtype
)
ae
[:,
:
M
,
:]
=
a
ae_s
[:,
:
M
,
:]
=
a_s
masked_m
=
torch
.
full
((
B
,),
M
,
dtype
=
torch
.
int
)
expected_m
=
M
lhs
=
(
ae
,
ae_s
,
)
rhs
=
(
w
,
w_s
,
)
from
deep_gemm
import
m_grouped_gemm_fp8_fp8_bf16_nt_masked
with
torch
.
inference_mode
():
ref_out
=
torch_w8a8_block_fp8_bmm
(
a
,
a_s
,
w
,
w_s
,
block_size
,
dtype
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked
(
lhs
,
rhs
,
oe
,
masked_m
,
expected_m
)
out
=
oe
[:,
:
M
,
:]
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out
.
to
(
torch
.
float32
)
-
ref_out
.
to
(
torch
.
float32
)))
/
torch
.
mean
(
torch
.
abs
(
ref_out
.
to
(
torch
.
float32
)))
<
0.0001
)
def
test_w8a8_block_fp8_batched_deep_gemm
(
self
):
for
params
in
itertools
.
product
(
self
.
M
,
self
.
N
,
self
.
K
,
self
.
BATCH
,
self
.
BLOCK_SIZE
,
self
.
DTYPES
,
self
.
SEEDS
,
):
with
self
.
subTest
(
M
=
params
[
0
],
N
=
params
[
1
],
K
=
params
[
2
],
B
=
params
[
3
],
block_size
=
params
[
4
],
dtype
=
params
[
5
],
seed
=
params
[
6
],
):
self
.
_w8a8_block_fp8_batched_deep_gemm
(
*
params
)
if
__name__
==
"__main__"
:
unittest
.
main
(
verbosity
=
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment