Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
481f608b
Unverified
Commit
481f608b
authored
Mar 12, 2025
by
lambert0312
Committed by
GitHub
Mar 12, 2025
Browse files
Add INT8 support MTP NextN function (#3911)
parent
ed91561f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
0 deletions
+20
-0
python/sglang/srt/models/deepseek_nextn.py
python/sglang/srt/models/deepseek_nextn.py
+20
-0
No files found.
python/sglang/srt/models/deepseek_nextn.py
View file @
481f608b
...
@@ -30,6 +30,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
...
@@ -30,6 +30,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
block_quant_to_tensor_quant
,
block_quant_to_tensor_quant
,
normalize_e4m3fn_to_e4m3fnuz
,
normalize_e4m3fn_to_e4m3fnuz
,
)
)
from
sglang.srt.layers.quantization.int8_utils
import
(
block_dequant
as
int8_block_dequant
,
)
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
...
@@ -291,6 +294,23 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
...
@@ -291,6 +294,23 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
weight
,
weight_scale
,
weight_block_size
weight
,
weight_scale
,
weight_block_size
)
)
self_attn
.
w_scale
=
scale
self_attn
.
w_scale
=
scale
if
w
.
dtype
==
torch
.
int8
:
if
hasattr
(
self
.
quant_config
,
"weight_block_size"
):
# block-wise int8 need it
weight_block_size
=
self
.
quant_config
.
weight_block_size
if
weight_block_size
is
not
None
:
assert
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale_inv"
)
weight
=
w
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale_inv
w
=
int8_block_dequant
(
weight
,
weight_scale
,
weight_block_size
).
to
(
torch
.
bfloat16
)
else
:
# channel-wise int8 need it
assert
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale"
)
w
=
w
.
to
(
torch
.
bfloat16
)
*
self_attn
.
kv_b_proj
.
weight_scale
.
to
(
torch
.
bfloat16
)
w_kc
,
w_vc
=
w
.
unflatten
(
w_kc
,
w_vc
=
w
.
unflatten
(
0
,
(
-
1
,
self_attn
.
qk_nope_head_dim
+
self_attn
.
v_head_dim
)
0
,
(
-
1
,
self_attn
.
qk_nope_head_dim
+
self_attn
.
v_head_dim
)
).
split
([
self_attn
.
qk_nope_head_dim
,
self_attn
.
v_head_dim
],
dim
=
1
)
).
split
([
self_attn
.
qk_nope_head_dim
,
self_attn
.
v_head_dim
],
dim
=
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment