Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f445a1d9
"tests/vscode:/vscode.git/clone" did not exist on "6459a688ae15d797dd4d0586f2f8ad2e46d58145"
Unverified
Commit
f445a1d9
authored
Aug 22, 2025
by
Hubert Lu
Committed by
GitHub
Aug 22, 2025
Browse files
[AMD] Fix Llama 4 FP8 accuracy issues on MI300X (#7699)
parent
e5638573
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
212 additions
and
17 deletions
+212
-17
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+0
-1
python/sglang/srt/layers/moe/rocm_moe_utils.py
python/sglang/srt/layers/moe/rocm_moe_utils.py
+141
-0
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+66
-15
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+1
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+4
-1
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
f445a1d9
...
@@ -52,7 +52,6 @@ if not (_is_npu or _is_hip):
...
@@ -52,7 +52,6 @@ if not (_is_npu or _is_hip):
if
_use_aiter
:
if
_use_aiter
:
from
aiter
import
ActivationType
,
QuantType
from
aiter
import
ActivationType
,
QuantType
from
aiter.fused_moe
import
fused_moe
from
aiter.fused_moe
import
fused_moe
from
aiter.ops.shuffle
import
shuffle_weight
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/layers/moe/rocm_moe_utils.py
0 → 100644
View file @
f445a1d9
# Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1rc2/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
enum
import
IntEnum
from
functools
import
cache
from
typing
import
Optional
import
torch
from
sglang.srt.utils
import
direct_register_custom_op
,
get_bool_env_var
,
is_hip
_is_hip
=
is_hip
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
class
ActivationMethod
(
IntEnum
):
# This allows interfacing with AITER ActivationType enum
# without importing the ActivationType enum from AITER globally.
SILU
=
0
GELU
=
1
def
rocm_aiter_asm_moe_tkw1_impl
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
fc1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc1_smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc2_smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a16
:
bool
=
False
,
per_tensor_quant_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
activation_method
:
int
=
ActivationMethod
.
SILU
.
value
,
)
->
torch
.
Tensor
:
from
aiter
import
ActivationType
from
aiter.fused_moe_bf16_asm
import
asm_moe_tkw1
activation
=
ActivationType
(
activation_method
)
return
asm_moe_tkw1
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
fc1_scale
=
fc1_scale
,
fc2_scale
=
fc2_scale
,
fc1_smooth_scale
=
fc1_smooth_scale
,
fc2_smooth_scale
=
fc2_smooth_scale
,
a16
=
a16
,
per_tensor_quant_scale
=
per_tensor_quant_scale
,
expert_mask
=
expert_mask
,
activation
=
activation
,
)
def
rocm_aiter_asm_moe_tkw1_fake
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
fc1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc1_smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc2_smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a16
:
bool
=
False
,
per_tensor_quant_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
activation_method
:
int
=
ActivationMethod
.
SILU
.
value
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
if
_use_aiter
:
direct_register_custom_op
(
op_name
=
"rocm_aiter_asm_moe_tkw1"
,
op_func
=
rocm_aiter_asm_moe_tkw1_impl
,
mutates_args
=
[],
fake_impl
=
rocm_aiter_asm_moe_tkw1_fake
,
)
def
rocm_fused_experts_tkw1
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
activation_method
=
(
ActivationMethod
.
SILU
if
activation
==
"silu"
else
ActivationMethod
.
GELU
)
# All AITER Fused MoE kernels are expecting the following datatypes
topk_weights
=
topk_weights
.
to
(
torch
.
float32
)
topk_ids
=
topk_ids
.
to
(
torch
.
int32
)
# w8a8 per-channel quantization
if
per_channel_quant
and
apply_router_weight_on_input
and
use_fp8_w8a8
:
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
# This applies topk_weights on the GEMM output of the first FC layer
# rather than the second FC.
assert
(
topk_weights
.
dim
()
==
2
),
"`topk_weights` should be in shape (num_tokens, topk)"
assert
topk_weights
.
shape
[
-
1
]
==
1
,
(
"Only support topk=1 when"
" `apply_router_weight_on_input` is True"
)
return
torch
.
ops
.
sglang
.
rocm_aiter_asm_moe_tkw1
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
fc1_scale
=
w1_scale
,
fc2_scale
=
w2_scale
,
fc1_smooth_scale
=
None
,
fc2_smooth_scale
=
None
,
a16
=
False
,
per_tensor_quant_scale
=
None
,
expert_mask
=
None
,
activation_method
=
activation_method
,
)
else
:
assert
False
,
"This should not be called."
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
f445a1d9
...
@@ -19,7 +19,14 @@ from sglang.srt.layers.quantization.utils import (
...
@@ -19,7 +19,14 @@ from sglang.srt.layers.quantization.utils import (
per_tensor_dequantize
,
per_tensor_dequantize
,
replace_parameter
,
replace_parameter
,
)
)
from
sglang.srt.utils
import
is_cpu
,
is_cuda
,
is_hip
,
is_npu
,
set_weight_attrs
from
sglang.srt.utils
import
(
get_bool_env_var
,
is_cpu
,
is_cuda
,
is_hip
,
is_npu
,
set_weight_attrs
,
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
...
@@ -29,6 +36,13 @@ if TYPE_CHECKING:
...
@@ -29,6 +36,13 @@ if TYPE_CHECKING:
CompressedTensorsConfig
,
CompressedTensorsConfig
,
)
)
_is_hip
=
is_hip
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
if
_use_aiter
:
from
aiter.ops.shuffle
import
shuffle_weight
from
sglang.srt.layers.moe.rocm_moe_utils
import
rocm_fused_experts_tkw1
try
:
try
:
import
vllm
import
vllm
...
@@ -265,6 +279,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -265,6 +279,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
max_w13_scales
,
requires_grad
=
False
max_w13_scales
,
requires_grad
=
False
)
)
if
_use_aiter
:
with
torch
.
no_grad
():
# Pre-shuffle weights
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffle_weight
(
layer
.
w13_weight
.
data
,
(
16
,
16
)),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
shuffle_weight
(
layer
.
w2_weight
.
data
,
(
16
,
16
)),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
@@ -274,20 +302,43 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -274,20 +302,43 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton
import
fused_experts
from
sglang.srt.layers.moe.fused_moe_triton
import
fused_experts
return
fused_experts
(
if
(
x
,
_use_aiter
layer
.
w13_weight
,
and
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
layer
.
w2_weight
,
and
moe_runner_config
.
apply_router_weight_on_input
topk_output
=
topk_output
,
):
moe_runner_config
=
moe_runner_config
,
topk_weights
,
topk_ids
,
_
=
topk_output
use_fp8_w8a8
=
True
,
return
rocm_fused_experts_tkw1
(
per_channel_quant
=
self
.
weight_quant
.
strategy
hidden_states
=
x
,
==
QuantizationStrategy
.
CHANNEL
,
w1
=
layer
.
w13_weight
,
w1_scale
=
layer
.
w13_weight_scale
,
w2
=
layer
.
w2_weight
,
w2_scale
=
layer
.
w2_weight_scale
,
topk_weights
=
topk_weights
,
a1_scale
=
layer
.
w13_input_scale
,
topk_ids
=
topk_ids
,
a2_scale
=
layer
.
w2_input_scale
,
activation
=
moe_runner_config
.
activation
,
)
apply_router_weight_on_input
=
moe_runner_config
.
apply_router_weight_on_input
,
use_fp8_w8a8
=
True
,
per_channel_quant
=
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
)
else
:
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_output
=
topk_output
,
moe_runner_config
=
moe_runner_config
,
use_fp8_w8a8
=
True
,
per_channel_quant
=
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
)
class
CompressedTensorsWNA16MoEMethod
(
CompressedTensorsMoEMethod
):
class
CompressedTensorsWNA16MoEMethod
(
CompressedTensorsMoEMethod
):
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
f445a1d9
...
@@ -966,6 +966,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -966,6 +966,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
requires_grad
=
False
,
requires_grad
=
False
,
)
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
# ROCm (_use_aiter): using column-wise scaling
# ROCm (_use_aiter): using column-wise scaling
layer
.
w13_weight_scale1
*=
layer
.
w13_weight_scale
.
unsqueeze
(
-
1
)
layer
.
w13_weight_scale1
*=
layer
.
w13_weight_scale
.
unsqueeze
(
-
1
)
layer
.
w2_weight_scale1
*=
layer
.
w2_weight_scale
.
unsqueeze
(
-
1
)
layer
.
w2_weight_scale1
*=
layer
.
w2_weight_scale
.
unsqueeze
(
-
1
)
...
...
python/sglang/srt/server_args.py
View file @
f445a1d9
...
@@ -2228,7 +2228,10 @@ class ServerArgs:
...
@@ -2228,7 +2228,10 @@ class ServerArgs:
# use bf16 for mxfp4 triton kernels
# use bf16 for mxfp4 triton kernels
self
.
dtype
=
"bfloat16"
self
.
dtype
=
"bfloat16"
elif
"Llama4"
in
model_arch
:
elif
"Llama4"
in
model_arch
:
assert
self
.
attention_backend
==
"fa3"
,
"fa3 is required for Llama4 model"
assert
self
.
attention_backend
in
{
"fa3"
,
"aiter"
,
},
"fa3 or aiter is required for Llama4 model"
elif
model_arch
in
[
elif
model_arch
in
[
"Gemma2ForCausalLM"
,
"Gemma2ForCausalLM"
,
"Gemma3ForCausalLM"
,
"Gemma3ForCausalLM"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment