Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b7361cc4
Unverified
Commit
b7361cc4
authored
Sep 02, 2025
by
Guoyuan Lin
Committed by
GitHub
Sep 02, 2025
Browse files
[Fix] fix the issue encountered when inference LongCat-Flash/MTP EP MoE on b200 (#9916)
parent
a96c5b5c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
30 deletions
+49
-30
python/sglang/srt/models/longcat_flash.py
python/sglang/srt/models/longcat_flash.py
+26
-15
python/sglang/srt/models/longcat_flash_nextn.py
python/sglang/srt/models/longcat_flash_nextn.py
+23
-15
No files found.
python/sglang/srt/models/longcat_flash.py
View file @
b7361cc4
...
@@ -651,9 +651,6 @@ class LongcatFlashForCausalLM(nn.Module):
...
@@ -651,9 +651,6 @@ class LongcatFlashForCausalLM(nn.Module):
).
T
).
T
else
:
else
:
w
=
self_attn
.
kv_b_proj
.
weight
w
=
self_attn
.
kv_b_proj
.
weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model.
# Fix deepseek v3 blockwise bmm by using deep_gemm
use_deep_gemm_bmm
=
False
use_deep_gemm_bmm
=
False
if
w
.
dtype
in
(
if
w
.
dtype
in
(
...
@@ -790,6 +787,9 @@ class LongcatFlashForCausalLM(nn.Module):
...
@@ -790,6 +787,9 @@ class LongcatFlashForCausalLM(nn.Module):
self
.
config
.
hidden_size
/
self
.
config
.
kv_lora_rank
self
.
config
.
hidden_size
/
self
.
config
.
kv_lora_rank
)
**
0.5
)
**
0.5
# TODO(linguoyuan) EPMoE not support DEEPGEMM_BLACKWELL, DeepEP needs to be supported in the future
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
=
False
if
(
if
(
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
and
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
...
@@ -804,24 +804,35 @@ class LongcatFlashForCausalLM(nn.Module):
...
@@ -804,24 +804,35 @@ class LongcatFlashForCausalLM(nn.Module):
for
layer_id
in
range
(
self
.
config
.
num_hidden_layers
):
for
layer_id
in
range
(
self
.
config
.
num_hidden_layers
):
layer
=
self
.
model
.
layers
[
layer_id
]
layer
=
self
.
model
.
layers
[
layer_id
]
for
i
in
range
(
2
):
for
i
in
range
(
2
):
for
module
in
[
self_attn
=
layer
.
self_attn
[
i
]
layer
.
self_attn
[
i
].
fused_qkv_a_proj_with_mqa
,
module_list
=
[
layer
.
self_attn
[
i
].
q_b_proj
,
self_attn
.
kv_b_proj
,
layer
.
self_attn
[
i
].
kv_b_proj
,
self_attn
.
o_proj
,
layer
.
self_attn
[
i
].
o_proj
,
]
]:
requant_weight_ue8m0_inplace
(
if
self
.
config
.
q_lora_rank
is
not
None
:
module
.
weight
,
module
.
weight_scale_inv
,
weight_block_size
module_list
.
append
(
self_attn
.
fused_qkv_a_proj_with_mqa
)
)
module_list
.
append
(
self_attn
.
q_b_proj
)
else
:
module_list
.
append
(
self_attn
.
kv_a_proj_with_mqa
)
module_list
.
append
(
self_attn
.
q_proj
)
for
module
in
module_list
:
if
hasattr
(
module
,
"weight_scale_inv"
):
requant_weight_ue8m0_inplace
(
module
.
weight
,
module
.
weight_scale_inv
,
weight_block_size
)
mlp
=
layer
.
mlps
[
i
]
mlp
=
layer
.
mlps
[
i
]
assert
isinstance
(
mlp
,
LongcatFlashMLP
)
assert
isinstance
(
mlp
,
LongcatFlashMLP
)
for
module
in
[
for
module
in
[
mlp
.
gate_up_proj
,
mlp
.
gate_up_proj
,
mlp
.
down_proj
,
mlp
.
down_proj
,
]:
]:
requant_weight_ue8m0_inplace
(
if
hasattr
(
module
,
"weight_scale_inv"
):
module
.
weight
,
module
.
weight_scale_inv
,
weight_block_size
requant_weight_ue8m0_inplace
(
)
module
.
weight
,
module
.
weight_scale_inv
,
weight_block_size
)
for
layer_id
in
range
(
self
.
config
.
num_hidden_layers
):
for
layer_id
in
range
(
self
.
config
.
num_hidden_layers
):
experts
=
layer
.
mlp
.
experts
experts
=
layer
.
mlp
.
experts
...
...
python/sglang/srt/models/longcat_flash_nextn.py
View file @
b7361cc4
...
@@ -344,9 +344,6 @@ class LongcatFlashForCausalLMNextN(LongcatFlashForCausalLM):
...
@@ -344,9 +344,6 @@ class LongcatFlashForCausalLMNextN(LongcatFlashForCausalLM):
).
T
).
T
else
:
else
:
w
=
self_attn
.
kv_b_proj
.
weight
w
=
self_attn
.
kv_b_proj
.
weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model.
# Fix deepseek v3 blockwise bmm by using deep_gemm
use_deep_gemm_bmm
=
False
use_deep_gemm_bmm
=
False
if
w
.
dtype
in
(
if
w
.
dtype
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fn
,
...
@@ -480,24 +477,35 @@ class LongcatFlashForCausalLMNextN(LongcatFlashForCausalLM):
...
@@ -480,24 +477,35 @@ class LongcatFlashForCausalLMNextN(LongcatFlashForCausalLM):
def
_weight_requant_ue8m0
(
self
):
def
_weight_requant_ue8m0
(
self
):
weight_block_size
=
self
.
quant_config
.
weight_block_size
weight_block_size
=
self
.
quant_config
.
weight_block_size
layer
=
self
.
model
.
decoder
layer
=
self
.
model
.
decoder
for
module
in
[
self_attn
=
layer
.
self_attn
layer
.
self_attn
.
fused_qkv_a_proj_with_mqa
,
module_list
=
[
layer
.
self_attn
.
q_b_proj
,
self_attn
.
kv_b_proj
,
layer
.
self_attn
.
kv_b_proj
,
self_attn
.
o_proj
,
layer
.
self_attn
.
o_proj
,
]
]:
requant_weight_ue8m0_inplace
(
if
self
.
config
.
q_lora_rank
is
not
None
:
module
.
weight
,
module
.
weight_scale_inv
,
weight_block_size
module_list
.
append
(
self_attn
.
fused_qkv_a_proj_with_mqa
)
)
module_list
.
append
(
self_attn
.
q_b_proj
)
else
:
module_list
.
append
(
self_attn
.
kv_a_proj_with_mqa
)
module_list
.
append
(
self_attn
.
q_proj
)
for
module
in
module_list
:
if
hasattr
(
module
,
"weight_scale_inv"
):
requant_weight_ue8m0_inplace
(
module
.
weight
,
module
.
weight_scale_inv
,
weight_block_size
)
mlp
=
layer
.
mlps
mlp
=
layer
.
mlps
assert
isinstance
(
mlp
,
LongcatFlashMLP
)
assert
isinstance
(
mlp
,
LongcatFlashMLP
)
for
module
in
[
for
module
in
[
mlp
.
gate_up_proj
,
mlp
.
gate_up_proj
,
mlp
.
down_proj
,
mlp
.
down_proj
,
]:
]:
requant_weight_ue8m0_inplace
(
if
hasattr
(
module
,
"weight_scale_inv"
):
module
.
weight
,
module
.
weight_scale_inv
,
weight_block_size
requant_weight_ue8m0_inplace
(
)
module
.
weight
,
module
.
weight_scale_inv
,
weight_block_size
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment