Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c5210dfa
"graphbolt/vscode:/vscode.git/clone" did not exist on "15695ed0ecc9eb39f862a8c8b2c23e453ee7b8f2"
Unverified
Commit
c5210dfa
authored
Dec 30, 2024
by
HAI
Committed by
GitHub
Dec 30, 2024
Browse files
AMD DeepSeek_V3 FP8 Numerical fix (#2667)
parent
a29dd950
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
7 deletions
+34
-7
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+34
-7
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
c5210dfa
...
...
@@ -46,6 +46,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from
sglang.srt.layers.quantization.fp8_utils
import
(
block_quant_to_tensor_quant
,
input_to_float8
,
normalize_e4m3fn_to_e4m3fnuz
,
)
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
(
...
...
@@ -55,7 +56,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
is_flashinfer_available
from
sglang.srt.utils
import
is_flashinfer_available
,
is_hip
is_hip_
=
is_hip
()
if
is_flashinfer_available
():
from
flashinfer
import
bmm_fp8
...
...
@@ -573,7 +576,13 @@ class DeepseekV2AttentionMLA(nn.Module):
)
q_nope
,
q_pe
=
q
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
if
self
.
w_kc
.
dtype
==
torch
.
float8_e4m3fn
:
if
self
.
w_kc
.
dtype
==
torch
.
float8_e4m3fnuz
:
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
q_nope_out
=
torch
.
bmm
(
q_nope
.
to
(
torch
.
bfloat16
).
transpose
(
0
,
1
),
self
.
w_kc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
)
elif
self
.
w_kc
.
dtype
==
torch
.
float8_e4m3fn
:
q_nope_val
,
q_nope_scale
=
input_to_float8
(
q_nope
.
transpose
(
0
,
1
),
torch
.
float8_e4m3fn
)
...
...
@@ -598,7 +607,13 @@ class DeepseekV2AttentionMLA(nn.Module):
attn_output
=
self
.
attn_mqa
(
q_input
,
k_input
,
v_input
,
forward_batch
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
kv_lora_rank
)
if
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
if
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fnuz
:
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
attn_bmm_output
=
torch
.
bmm
(
attn_output
.
to
(
torch
.
bfloat16
).
transpose
(
0
,
1
),
self
.
w_vc
.
to
(
torch
.
bfloat16
)
*
self
.
w_scale
,
)
elif
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
attn_output_val
,
attn_output_scale
=
input_to_float8
(
attn_output
.
transpose
(
0
,
1
),
torch
.
float8_e4m3fn
)
...
...
@@ -940,15 +955,25 @@ class DeepseekV2ForCausalLM(nn.Module):
w
=
self_attn
.
kv_b_proj
.
weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model.
if
(
hasattr
(
self
.
quant_config
,
"weight_block_size"
)
and
w
.
dtype
==
torch
.
float8_e4m3fn
if
hasattr
(
self
.
quant_config
,
"weight_block_size"
)
and
w
.
dtype
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fn
uz
,
):
weight_block_size
=
self
.
quant_config
.
weight_block_size
if
weight_block_size
is
not
None
:
assert
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale_inv"
)
if
is_hip_
:
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
w
,
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale_inv
,
input_scale
=
None
,
)
else
:
weight
=
w
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale_inv
w
,
scale
=
block_quant_to_tensor_quant
(
w
,
self_attn
.
kv_b_proj
.
weight_scale
_inv
,
weight_block_size
w
eight
,
weight_scale
,
weight_block_size
)
self_attn
.
w_scale
=
scale
w_kc
,
w_vc
=
w
.
unflatten
(
...
...
@@ -961,6 +986,8 @@ class DeepseekV2ForCausalLM(nn.Module):
and
self_attn
.
w_scale
is
None
):
self_attn
.
w_scale
=
self_attn
.
kv_b_proj
.
weight_scale
if
is_hip_
:
self_attn
.
w_scale
*=
2.0
class
DeepseekV3ForCausalLM
(
DeepseekV2ForCausalLM
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment