Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
a6cc86df
"magic_pdf/vscode:/vscode.git/clone" did not exist on "59b0b0c3da1c3a6f183207098c6c7fb8cde5e09d"
Unverified
Commit
a6cc86df
authored
Sep 30, 2025
by
Trevor Morris
Committed by
GitHub
Sep 30, 2025
Browse files
Fix DSR1 accuracy for flashinfer_trtllm MoE with FP8 quantization (#11081)
parent
229d2b95
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
4 deletions
+4
-4
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+3
-3
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-1
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
a6cc86df
...
@@ -575,9 +575,9 @@ class FusedMoE(torch.nn.Module):
...
@@ -575,9 +575,9 @@ class FusedMoE(torch.nn.Module):
)
)
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
if
(
if
should_use_flashinfer_trtllm_moe
()
and
(
should_use_flashinfer_trtllm_moe
(
)
isinstance
(
self
.
quant_method
,
ModelOptNvFp4FusedMoEMethod
)
and
self
.
quant_method
.
__class__
.
__name__
==
"ModelOptNvFp4Fused
MoEMethod
"
or
isinstance
(
self
.
quant_method
,
Fp8
MoEMethod
)
):
):
shard_id
=
{
"w1"
:
"w3"
,
"w3"
:
"w1"
,
"w2"
:
"w2"
}[
shard_id
]
shard_id
=
{
"w1"
:
"w3"
,
"w3"
:
"w1"
,
"w2"
:
"w2"
}[
shard_id
]
...
...
python/sglang/srt/server_args.py
View file @
a6cc86df
...
@@ -916,7 +916,7 @@ class ServerArgs:
...
@@ -916,7 +916,7 @@ class ServerArgs:
if
self
.
moe_runner_backend
==
"flashinfer_trtllm"
:
if
self
.
moe_runner_backend
==
"flashinfer_trtllm"
:
assert
(
assert
(
self
.
quantization
==
"modelopt_fp4"
or
self
.
quantization
==
"fp8"
self
.
quantization
==
"modelopt_fp4"
or
self
.
quantization
==
"fp8"
),
"modelopt_fp4 quantization is required for Flashinfer TRTLLM MoE"
),
"modelopt_fp4
or fp8
quantization is required for Flashinfer TRTLLM MoE"
self
.
disable_shared_experts_fusion
=
True
self
.
disable_shared_experts_fusion
=
True
logger
.
warning
(
logger
.
warning
(
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment