Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8e2363dc
"vscode:/vscode.git/clone" did not exist on "5aa809ff144ffc192cd66a2bbb194400e4aee1f4"
Unverified
Commit
8e2363dc
authored
Jun 17, 2025
by
Alex Sun
Committed by
GitHub
Jun 16, 2025
Browse files
fix amd EP MoE FP8 issue (#7125)
parent
f9dc9dd2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
30 additions
and
0 deletions
+30
-0
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+30
-0
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
8e2363dc
...
@@ -33,10 +33,12 @@ from sglang.srt.layers.quantization.base_config import (
...
@@ -33,10 +33,12 @@ from sglang.srt.layers.quantization.base_config import (
)
)
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
,
Fp8MoEMethod
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
,
Fp8MoEMethod
from
sglang.srt.layers.quantization.fp8_kernel
import
(
from
sglang.srt.layers.quantization.fp8_kernel
import
(
is_fp8_fnuz
,
scaled_fp8_quant
,
scaled_fp8_quant
,
sglang_per_token_group_quant_fp8
,
sglang_per_token_group_quant_fp8
,
sglang_per_token_quant_fp8
,
sglang_per_token_quant_fp8
,
)
)
from
sglang.srt.layers.quantization.fp8_utils
import
normalize_e4m3fn_to_e4m3fnuz
from
sglang.srt.managers.expert_location
import
get_global_expert_location_metadata
from
sglang.srt.managers.expert_location
import
get_global_expert_location_metadata
from
sglang.srt.managers.expert_location_dispatch
import
ExpertLocationDispatchInfo
from
sglang.srt.managers.expert_location_dispatch
import
ExpertLocationDispatchInfo
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
...
@@ -50,6 +52,7 @@ from sglang.srt.utils import (
...
@@ -50,6 +52,7 @@ from sglang.srt.utils import (
)
)
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
_is_fp8_fnuz
=
is_fp8_fnuz
()
if
_is_hip
:
if
_is_hip
:
from
vllm._custom_ops
import
scaled_fp8_quant
from
vllm._custom_ops
import
scaled_fp8_quant
...
@@ -843,6 +846,33 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
...
@@ -843,6 +846,33 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
torch
.
max
(
layer
.
w13_weight_scale
,
dim
=
1
).
values
,
torch
.
max
(
layer
.
w13_weight_scale
,
dim
=
1
).
values
,
requires_grad
=
False
,
requires_grad
=
False
,
)
)
if
self
.
block_quant
:
# If ROCm, normalize the weights and scales to e4m3fnuz
if
_is_fp8_fnuz
:
# activation_scheme: dynamic
w13_weight
,
w13_weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
layer
.
w13_weight
,
weight_scale
=
layer
.
w13_weight_scale_inv
,
input_scale
=
None
,
)
w2_weight
,
w2_weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
layer
.
w2_weight
,
weight_scale
=
layer
.
w2_weight_scale_inv
,
input_scale
=
None
,
)
# Reset the parameter
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale_inv
=
torch
.
nn
.
Parameter
(
w13_weight_scale
,
requires_grad
=
False
)
layer
.
w13_input_scale
=
None
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight_scale_inv
=
torch
.
nn
.
Parameter
(
w2_weight_scale
,
requires_grad
=
False
)
layer
.
w2_input_scale
=
None
return
return
def
apply
(
def
apply
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment