Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
183d9f96
Unverified
Commit
183d9f96
authored
May 27, 2025
by
HAI
Committed by
GitHub
May 27, 2025
Browse files
DeepSeek: enable none block-quant FP8 quantizations (#6638)
parent
63195028
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
55 additions
and
43 deletions
+55
-43
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+55
-43
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
183d9f96
...
@@ -57,6 +57,7 @@ from sglang.srt.layers.moe.topk import select_experts
...
@@ -57,6 +57,7 @@ from sglang.srt.layers.moe.topk import select_experts
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.deep_gemm
import
_ENABLE_JIT_DEEPGEMM
from
sglang.srt.layers.quantization.deep_gemm
import
_ENABLE_JIT_DEEPGEMM
from
sglang.srt.layers.quantization.fp8_kernel
import
(
from
sglang.srt.layers.quantization.fp8_kernel
import
(
is_fp8_fnuz
,
per_tensor_quant_mla_fp8
,
per_tensor_quant_mla_fp8
,
per_token_group_quant_mla_deep_gemm_masked_fp8
,
per_token_group_quant_mla_deep_gemm_masked_fp8
,
)
)
...
@@ -101,6 +102,7 @@ from sglang.srt.utils import (
...
@@ -101,6 +102,7 @@ from sglang.srt.utils import (
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
_is_fp8_fnuz
=
is_fp8_fnuz
()
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
awq_dequantize
,
bmm_fp8
,
merge_state_v2
from
sgl_kernel
import
awq_dequantize
,
bmm_fp8
,
merge_state_v2
...
@@ -684,7 +686,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -684,7 +686,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self
.
w_kc
=
None
self
.
w_kc
=
None
self
.
w_vc
=
None
self
.
w_vc
=
None
self
.
w_scale
=
None
self
.
w_scale
=
1.0
self
.
w_scale_k
=
None
self
.
w_scale_k
=
None
self
.
w_scale_v
=
None
self
.
w_scale_v
=
None
...
@@ -948,8 +950,8 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -948,8 +950,8 @@ class DeepseekV2AttentionMLA(nn.Module):
expected_m
,
expected_m
,
)
)
q_nope_out
=
q_nope_out
[:,
:
expected_m
,
:]
q_nope_out
=
q_nope_out
[:,
:
expected_m
,
:]
elif
self
.
w_kc
.
dtype
==
torch
.
float8_e4m3fnuz
:
elif
_is_hip
:
# TODO(
kernel
): add bmm_fp8
for torch.float8_e4m3fnuz
# TODO(
haishaw
): add bmm_fp8
to ROCm
q_nope_out
=
torch
.
bmm
(
q_nope_out
=
torch
.
bmm
(
q_nope
.
to
(
torch
.
bfloat16
).
transpose
(
0
,
1
),
q_nope
.
to
(
torch
.
bfloat16
).
transpose
(
0
,
1
),
self
.
w_kc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
self
.
w_kc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
...
@@ -1000,8 +1002,8 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1000,8 +1002,8 @@ class DeepseekV2AttentionMLA(nn.Module):
expected_m
,
expected_m
,
)
)
attn_bmm_output
=
attn_bmm_output
[:,
:
expected_m
,
:]
attn_bmm_output
=
attn_bmm_output
[:,
:
expected_m
,
:]
elif
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fnuz
:
elif
_is_hip
:
# TODO(
kernel
): add bmm_fp8
for torch.float8_e4m3fnuz
# TODO(
haishaw
): add bmm_fp8
to ROCm
attn_bmm_output
=
torch
.
bmm
(
attn_bmm_output
=
torch
.
bmm
(
attn_output
.
to
(
torch
.
bfloat16
).
transpose
(
0
,
1
),
attn_output
.
to
(
torch
.
bfloat16
).
transpose
(
0
,
1
),
self
.
w_vc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
self
.
w_vc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
...
@@ -1052,8 +1054,8 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1052,8 +1054,8 @@ class DeepseekV2AttentionMLA(nn.Module):
latent_cache
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
]
latent_cache
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
]
q_nope
,
q_pe
=
q
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
q_nope
,
q_pe
=
q
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
if
self
.
w_kc
.
dtype
==
torch
.
float8_e4m3fnuz
:
if
_is_hip
:
# TODO(
kernel
): add bmm_fp8
for torch.float8_e4m3fnuz
# TODO(
haishaw
): add bmm_fp8
to ROCm
q_nope_out
=
torch
.
bmm
(
q_nope_out
=
torch
.
bmm
(
q_nope
.
to
(
torch
.
bfloat16
).
transpose
(
0
,
1
),
q_nope
.
to
(
torch
.
bfloat16
).
transpose
(
0
,
1
),
self
.
w_kc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
self
.
w_kc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
...
@@ -1186,8 +1188,8 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1186,8 +1188,8 @@ class DeepseekV2AttentionMLA(nn.Module):
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
kv_lora_rank
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
kv_lora_rank
)
if
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fnuz
:
if
_is_hip
:
# TODO(
kernel
): add bmm_fp8
for torch.float8_e4m3fnuz
# TODO(
haishaw
): add bmm_fp8
to ROCm
attn_bmm_output
=
torch
.
bmm
(
attn_bmm_output
=
torch
.
bmm
(
attn_output
.
to
(
torch
.
bfloat16
).
transpose
(
0
,
1
),
attn_output
.
to
(
torch
.
bfloat16
).
transpose
(
0
,
1
),
self
.
w_vc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
self
.
w_vc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
...
@@ -1749,46 +1751,56 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -1749,46 +1751,56 @@ class DeepseekV2ForCausalLM(nn.Module):
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
torch
.
float8_e4m3fnuz
,
):
):
if
hasattr
(
self
.
quant_config
,
"weight_block_size"
):
if
(
hasattr
(
self
.
quant_config
,
"weight_block_size"
)
and
self
.
quant_config
.
weight_block_size
is
not
None
):
weight_block_size
=
self
.
quant_config
.
weight_block_size
weight_block_size
=
self
.
quant_config
.
weight_block_size
if
weight_block_size
is
not
None
:
assert
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale_inv"
)
assert
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale_inv"
)
if
_is_fp8_fnuz
:
if
_is_hip
:
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
w
,
weight
=
w
,
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale_inv
,
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale_inv
,
input_scale
=
None
,
input_scale
=
None
,
)
)
else
:
else
:
weight
=
w
weight
=
w
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale_inv
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale_inv
if
(
if
(
_is_cuda
_is_cuda
and
weight_block_size
[
0
]
==
128
and
weight_block_size
[
0
]
==
128
and
weight_block_size
[
1
]
==
128
and
weight_block_size
[
1
]
==
128
and
model_dtype
==
torch
.
bfloat16
and
model_dtype
==
torch
.
bfloat16
):
if
_ENABLE_JIT_DEEPGEMM
and
get_bool_env_var
(
"SGL_USE_DEEPGEMM_BMM"
,
"false"
):
):
if
_ENABLE_JIT_DEEPGEMM
and
get_bool_env_var
(
block_scale
=
weight_scale
"SGL_USE_DEEPGEMM_BMM"
,
"false"
use_deep_gemm_bmm
=
True
):
block_scale
=
weight_scale
use_deep_gemm_bmm
=
True
else
:
w
=
block_quant_dequant
(
weight
,
weight_scale
,
weight_block_size
,
model_dtype
,
)
else
:
else
:
w
,
scale
=
block_quant_to_tensor_quant
(
w
=
block_quant_dequant
(
weight
,
weight_scale
,
weight_block_size
weight
,
weight_scale
,
weight_block_size
,
model_dtype
,
)
)
self_attn
.
w_scale
=
scale
else
:
w
,
scale
=
block_quant_to_tensor_quant
(
weight
,
weight_scale
,
weight_block_size
)
self_attn
.
w_scale
=
scale
else
:
else
:
weight
=
w
if
_is_fp8_fnuz
:
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
w
,
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale
,
input_scale
=
None
,
)
else
:
weight
=
w
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale
w
,
scale
=
channel_quant_to_tensor_quant
(
weight
,
weight_scale
)
w
,
scale
=
channel_quant_to_tensor_quant
(
weight
,
weight_scale
)
self_attn
.
w_scale
=
scale
self_attn
.
w_scale
=
scale
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment