Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
679c6a3e
Unverified
Commit
679c6a3e
authored
Mar 24, 2026
by
Andreas Karatzas
Committed by
GitHub
Mar 25, 2026
Browse files
[Bugfix][ROCm][MoE] Fix mxfp4 oracle regressions from #37128 (#37787)
Signed-off-by:
Andreas Karatzas
<
akaratza@amd.com
>
parent
8bbb7c7f
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
69 additions
and
15 deletions
+69
-15
.buildkite/test-amd.yaml
.buildkite/test-amd.yaml
+1
-0
tests/lora/test_gptoss_tp.py
tests/lora/test_gptoss_tp.py
+11
-0
vllm/lora/layers/fused_moe.py
vllm/lora/layers/fused_moe.py
+2
-2
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
+5
-1
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+3
-0
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
...l_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
+5
-5
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+8
-1
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+3
-3
vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
+12
-2
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
+11
-1
vllm/model_executor/layers/quantization/mxfp4.py
vllm/model_executor/layers/quantization/mxfp4.py
+8
-0
No files found.
.buildkite/test-amd.yaml
View file @
679c6a3e
...
@@ -2526,6 +2526,7 @@ steps:
...
@@ -2526,6 +2526,7 @@ steps:
-
pytest -v -s -x lora/test_llm_with_multi_loras.py
-
pytest -v -s -x lora/test_llm_with_multi_loras.py
-
pytest -v -s -x lora/test_olmoe_tp.py
-
pytest -v -s -x lora/test_olmoe_tp.py
-
pytest -v -s -x lora/test_gptoss_tp.py
-
pytest -v -s -x lora/test_gptoss_tp.py
-
pytest -v -s -x lora/test_qwen35_densemoel_lora.py
-
label
:
Weight Loading Multiple GPU
# 7.5m
-
label
:
Weight Loading Multiple GPU
# 7.5m
...
...
tests/lora/test_gptoss_tp.py
View file @
679c6a3e
...
@@ -5,6 +5,7 @@ import pytest
...
@@ -5,6 +5,7 @@ import pytest
import
vllm
import
vllm
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.platforms
import
current_platform
from
..utils
import
multi_gpu_test
from
..utils
import
multi_gpu_test
...
@@ -69,6 +70,16 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
...
@@ -69,6 +70,16 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
assert
generated_texts
[
i
].
startswith
(
EXPECTED_LORA_OUTPUT
[
i
])
assert
generated_texts
[
i
].
startswith
(
EXPECTED_LORA_OUTPUT
[
i
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
(
"Mxfp4 LoRA on ROCm is blocked by a spawn compatibility issue. "
"The fused_moe_lora Triton kernel crashes in spawned subprocesses, "
"and vLLM forces spawn mode when HIP is initialized before "
"multiprocessing. Fixing this requires either making the LoRA "
"Triton kernel spawn-safe or pre-warming the kernel cache."
),
)
@
pytest
.
mark
.
parametrize
(
"mxfp4_use_marlin"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"mxfp4_use_marlin"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"specialize_active_lora"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"specialize_active_lora"
,
[
True
,
False
])
def
test_gpt_oss_lora
(
def
test_gpt_oss_lora
(
...
...
vllm/lora/layers/fused_moe.py
View file @
679c6a3e
...
@@ -109,8 +109,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
...
@@ -109,8 +109,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
else
:
# fall back to the default config
else
:
# fall back to the default config
get_config_func
=
functools
.
partial
(
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_lora_config
,
try_get_optimal_moe_lora_config
,
w1_shape
=
layer
.
w13_weight
.
s
ize
()
,
w1_shape
=
layer
.
w13_weight
.
s
hape
,
w2_shape
=
layer
.
w2_weight
.
s
ize
()
,
w2_shape
=
layer
.
w2_weight
.
s
hape
,
rank
=
rank
,
rank
=
rank
,
top_k
=
top_k
,
top_k
=
top_k
,
dtype
=
config_dtype
,
dtype
=
config_dtype
,
...
...
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
View file @
679c6a3e
...
@@ -379,7 +379,11 @@ def _fused_moe_lora_kernel(
...
@@ -379,7 +379,11 @@ def _fused_moe_lora_kernel(
)
)
a_ptrs
+=
BLOCK_SIZE_K
*
SPLIT_K
*
stride_ak
a_ptrs
+=
BLOCK_SIZE_K
*
SPLIT_K
*
stride_ak
accumulator
+=
tl
.
dot
(
a
,
b
)
# Cast operands to matching dtype for tl.dot. On ROCm, Triton's
# compiler may infer different types for a and b when merging
# if/else branches (TMA desc path returns fp32, tl.load returns
# the pointer's element type).
accumulator
+=
tl
.
dot
(
a
.
to
(
tl
.
bfloat16
),
b
.
to
(
tl
.
bfloat16
))
if
MUL_ROUTED_WEIGHT
:
if
MUL_ROUTED_WEIGHT
:
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0.0
)
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0.0
)
...
...
vllm/model_executor/layers/fused_moe/config.py
View file @
679c6a3e
...
@@ -229,6 +229,9 @@ class FusedMoEQuantConfig:
...
@@ -229,6 +229,9 @@ class FusedMoEQuantConfig:
_w1
:
FusedMoEQuantDesc
_w1
:
FusedMoEQuantDesc
_w2
:
FusedMoEQuantDesc
_w2
:
FusedMoEQuantDesc
is_nvfp4_scale_swizzled
:
bool
=
True
is_nvfp4_scale_swizzled
:
bool
=
True
# CK MXFP4 (gfx950) padding info for rocm_aiter_ops.fused_moe()
hidden_pad
:
int
=
0
intermediate_pad
:
int
=
0
def
__post_init__
(
self
):
def
__post_init__
(
self
):
assert
not
self
.
per_act_token_quant
or
self
.
block_shape
is
None
,
(
assert
not
self
.
per_act_token_quant
or
self
.
block_shape
is
None
,
(
...
...
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
View file @
679c6a3e
...
@@ -257,7 +257,7 @@ def triton_kernel_moe_forward(
...
@@ -257,7 +257,7 @@ def triton_kernel_moe_forward(
# sparse_logits.indx contains global expert IDs – remap to local.
# sparse_logits.indx contains global expert IDs – remap to local.
topk_ids
=
expert_map
[
sparse_logits
.
indx
.
to
(
torch
.
long
)]
topk_ids
=
expert_map
[
sparse_logits
.
indx
.
to
(
torch
.
long
)]
topk_weights
=
sparse_logits
.
vals
topk_weights
=
sparse_logits
.
vals
local_num_experts
=
w1
.
s
ize
(
0
)
local_num_experts
=
w1
.
s
hape
[
0
]
routing_data
,
gather_idx
,
scatter_idx
=
make_routing_data
(
routing_data
,
gather_idx
,
scatter_idx
=
make_routing_data
(
topk_ids
,
topk_weights
,
local_num_experts
topk_ids
,
topk_weights
,
local_num_experts
)
)
...
@@ -604,8 +604,8 @@ class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
...
@@ -604,8 +604,8 @@ class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
require a specialized implementation, like MarlinExperts, they are free
require a specialized implementation, like MarlinExperts, they are free
to override this function.
to override this function.
"""
"""
assert
w1
.
dim
()
==
3
and
w2
.
dim
(
)
==
3
assert
len
(
w1
.
shape
)
==
3
and
len
(
w2
.
shape
)
==
3
E
,
_
,
N
=
w1
.
s
ize
()
E
,
_
,
N
=
w1
.
s
hape
K
=
a1
.
size
(
-
1
)
K
=
a1
.
size
(
-
1
)
assert
a1
.
dim
()
==
2
assert
a1
.
dim
()
==
2
...
@@ -683,7 +683,7 @@ class OAITritonExperts(BaseOAITritonExperts):
...
@@ -683,7 +683,7 @@ class OAITritonExperts(BaseOAITritonExperts):
if
expert_map
is
not
None
:
if
expert_map
is
not
None
:
topk_ids
=
expert_map
[
topk_ids
]
topk_ids
=
expert_map
[
topk_ids
]
local_num_experts
=
w1
.
s
ize
(
0
)
local_num_experts
=
w1
.
s
hape
[
0
]
if
global_num_experts
==
-
1
:
if
global_num_experts
==
-
1
:
global_num_experts
=
local_num_experts
global_num_experts
=
local_num_experts
...
@@ -781,7 +781,7 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
...
@@ -781,7 +781,7 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
if
expert_map
is
not
None
:
if
expert_map
is
not
None
:
topk_ids
=
expert_map
[
topk_ids
]
topk_ids
=
expert_map
[
topk_ids
]
local_num_experts
=
w1
.
s
ize
(
0
)
local_num_experts
=
w1
.
s
hape
[
0
]
if
global_num_experts
==
-
1
:
if
global_num_experts
==
-
1
:
global_num_experts
=
local_num_experts
global_num_experts
=
local_num_experts
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
679c6a3e
...
@@ -567,6 +567,13 @@ class FusedMoE(CustomOp):
...
@@ -567,6 +567,13 @@ class FusedMoE(CustomOp):
# for heuristic purposes, so it must be initialized first.
# for heuristic purposes, so it must be initialized first.
self
.
quant_method
:
FusedMoEMethodBase
=
_get_quant_method
()
self
.
quant_method
:
FusedMoEMethodBase
=
_get_quant_method
()
# Quant methods (e.g. Mxfp4MoEMethod) may round up hidden_dim
# and intermediate_size in moe_config during __init__. Sync
# self.hidden_size so downstream consumers (e.g. LoRA) see the
# padded value.
if
self
.
moe_config
.
hidden_dim
!=
self
.
hidden_size
:
self
.
hidden_size
=
self
.
moe_config
.
hidden_dim
if
not
self
.
moe_config
.
is_act_and_mul
and
not
current_platform
.
is_cuda_alike
():
if
not
self
.
moe_config
.
is_act_and_mul
and
not
current_platform
.
is_cuda_alike
():
raise
NotImplementedError
(
raise
NotImplementedError
(
"is_act_and_mul=False is supported only for CUDA and ROCm for now"
"is_act_and_mul=False is supported only for CUDA and ROCm for now"
...
@@ -586,7 +593,7 @@ class FusedMoE(CustomOp):
...
@@ -586,7 +593,7 @@ class FusedMoE(CustomOp):
moe_quant_params
=
{
moe_quant_params
=
{
"num_experts"
:
self
.
local_num_experts
,
"num_experts"
:
self
.
local_num_experts
,
"hidden_size"
:
hidden_size
,
"hidden_size"
:
self
.
hidden_size
,
"unpadded_hidden_size"
:
unpadded_hidden_size
,
"unpadded_hidden_size"
:
unpadded_hidden_size
,
"intermediate_size_per_partition"
:
self
.
intermediate_size_per_partition
,
"intermediate_size_per_partition"
:
self
.
intermediate_size_per_partition
,
"params_dtype"
:
params_dtype
,
"params_dtype"
:
params_dtype
,
...
...
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
679c6a3e
...
@@ -768,8 +768,8 @@ class FusedMoEExpertsModular(FusedMoEExperts):
...
@@ -768,8 +768,8 @@ class FusedMoEExpertsModular(FusedMoEExperts):
require a specialized implementation, like MarlinExperts, they are free
require a specialized implementation, like MarlinExperts, they are free
to override this function.
to override this function.
"""
"""
assert
w1
.
dim
()
==
3
and
w2
.
dim
(
)
==
3
assert
len
(
w1
.
shape
)
==
3
and
len
(
w2
.
shape
)
==
3
E
,
N
,
_
=
w1
.
s
ize
()
E
,
N
,
_
=
w1
.
s
hape
K
=
a1
.
size
(
-
1
)
K
=
a1
.
size
(
-
1
)
if
a1
.
dim
()
==
2
:
if
a1
.
dim
()
==
2
:
...
@@ -1349,7 +1349,7 @@ class FusedMoEKernelModularImpl:
...
@@ -1349,7 +1349,7 @@ class FusedMoEKernelModularImpl:
else
:
else
:
output
=
torch
.
empty_like
(
hidden_states
)
output
=
torch
.
empty_like
(
hidden_states
)
local_num_experts
=
w1
.
s
ize
(
0
)
local_num_experts
=
w1
.
s
hape
[
0
]
if
global_num_experts
==
-
1
:
if
global_num_experts
==
-
1
:
global_num_experts
=
local_num_experts
global_num_experts
=
local_num_experts
...
...
vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
View file @
679c6a3e
...
@@ -212,7 +212,11 @@ def select_mxfp4_moe_backend(
...
@@ -212,7 +212,11 @@ def select_mxfp4_moe_backend(
# LoRA: separate experts backend path
# LoRA: separate experts backend path
if
config
.
is_lora_enabled
:
if
config
.
is_lora_enabled
:
if
not
current_platform
.
is_cuda
():
if
not
current_platform
.
is_cuda
():
raise
NotImplementedError
(
"Mxfp4 LoRA only supported on CUDA Platform."
)
# ROCm: Triton mxfp4 LoRA hits GPU memory faults due to
# triton_kernels.tensor.Tensor / HIP read-only page issues
# during weight swizzle and LoRA forward. Needs work from
# the triton_kernels/aiter side.
raise
NotImplementedError
(
"Mxfp4 LoRA is currently only supported on CUDA."
)
if
envs
.
VLLM_MXFP4_USE_MARLIN
is
False
and
triton_kernels_supported
:
if
envs
.
VLLM_MXFP4_USE_MARLIN
is
False
and
triton_kernels_supported
:
logger
.
info_once
(
"Using Triton backend for mxfp4 lora"
)
logger
.
info_once
(
"Using Triton backend for mxfp4 lora"
)
return
Mxfp4MoeBackend
.
TRITON_UNFUSED
,
backend_to_kernel_cls
(
return
Mxfp4MoeBackend
.
TRITON_UNFUSED
,
backend_to_kernel_cls
(
...
@@ -775,6 +779,8 @@ def make_mxfp4_moe_quant_config(
...
@@ -775,6 +779,8 @@ def make_mxfp4_moe_quant_config(
w2_scale
:
Union
[
torch
.
Tensor
,
"PrecisionConfig"
],
w2_scale
:
Union
[
torch
.
Tensor
,
"PrecisionConfig"
],
w1_bias
:
torch
.
Tensor
|
None
=
None
,
w1_bias
:
torch
.
Tensor
|
None
=
None
,
w2_bias
:
torch
.
Tensor
|
None
=
None
,
w2_bias
:
torch
.
Tensor
|
None
=
None
,
hidden_pad
:
int
=
0
,
intermediate_pad
:
int
=
0
,
)
->
FusedMoEQuantConfig
|
None
:
)
->
FusedMoEQuantConfig
|
None
:
"""Create a FusedMoEQuantConfig for the given MXFP4 backend."""
"""Create a FusedMoEQuantConfig for the given MXFP4 backend."""
if
mxfp4_backend
in
(
if
mxfp4_backend
in
(
...
@@ -796,12 +802,16 @@ def make_mxfp4_moe_quant_config(
...
@@ -796,12 +802,16 @@ def make_mxfp4_moe_quant_config(
Mxfp4MoeBackend
.
FLASHINFER_CUTLASS_MXFP4_BF16
,
Mxfp4MoeBackend
.
FLASHINFER_CUTLASS_MXFP4_BF16
,
Mxfp4MoeBackend
.
CK
,
Mxfp4MoeBackend
.
CK
,
):
):
return
mxfp4_w4a16_moe_quant_config
(
config
=
mxfp4_w4a16_moe_quant_config
(
w1_bias
=
w1_bias
,
w1_bias
=
w1_bias
,
w2_bias
=
w2_bias
,
w2_bias
=
w2_bias
,
w1_scale
=
w1_scale
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w2_scale
=
w2_scale
,
)
)
if
mxfp4_backend
==
Mxfp4MoeBackend
.
CK
:
config
.
hidden_pad
=
hidden_pad
config
.
intermediate_pad
=
intermediate_pad
return
config
else
:
else
:
return
ocp_mx_moe_quant_config
(
return
ocp_mx_moe_quant_config
(
quant_dtype
=
"mxfp4"
,
quant_dtype
=
"mxfp4"
,
...
...
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
View file @
679c6a3e
...
@@ -292,6 +292,8 @@ def rocm_aiter_fused_experts(
...
@@ -292,6 +292,8 @@ def rocm_aiter_fused_experts(
doweight_stage1
=
apply_router_weight_on_input
,
doweight_stage1
=
apply_router_weight_on_input
,
num_local_tokens
=
num_local_tokens
,
num_local_tokens
=
num_local_tokens
,
output_dtype
=
output_dtype
,
output_dtype
=
output_dtype
,
hidden_pad
=
quant_config
.
hidden_pad
,
intermediate_pad
=
quant_config
.
intermediate_pad
,
bias1
=
quant_config
.
w1_bias
if
quant_config
.
use_mxfp4_w4a16
else
None
,
bias1
=
quant_config
.
w1_bias
if
quant_config
.
use_mxfp4_w4a16
else
None
,
bias2
=
quant_config
.
w2_bias
if
quant_config
.
use_mxfp4_w4a16
else
None
,
bias2
=
quant_config
.
w2_bias
if
quant_config
.
use_mxfp4_w4a16
else
None
,
)
)
...
@@ -332,7 +334,15 @@ class AiterExperts(mk.FusedMoEExpertsModular):
...
@@ -332,7 +334,15 @@ class AiterExperts(mk.FusedMoEExpertsModular):
(
kFp8StaticChannelSym
,
kFp8DynamicTokenSym
),
(
kFp8StaticChannelSym
,
kFp8DynamicTokenSym
),
(
kMxfp4Static
,
None
),
(
kMxfp4Static
,
None
),
]
]
return
(
weight_key
,
activation_key
)
in
SUPPORTED_W_A
if
(
weight_key
,
activation_key
)
not
in
SUPPORTED_W_A
:
return
False
# CK MXFP4 MoE kernels are only supported on gfx950.
if
weight_key
==
kMxfp4Static
:
from
vllm.platforms.rocm
import
on_gfx950
if
not
on_gfx950
():
return
False
return
True
@
staticmethod
@
staticmethod
def
_supports_activation
(
activation
:
MoEActivation
)
->
bool
:
def
_supports_activation
(
activation
:
MoEActivation
)
->
bool
:
...
...
vllm/model_executor/layers/quantization/mxfp4.py
View file @
679c6a3e
...
@@ -158,6 +158,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -158,6 +158,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
intermediate_size_per_partition_after_pad
intermediate_size_per_partition_after_pad
)
)
# CK (gfx950) padding info for rocm_aiter_ops.fused_moe()
self
.
hidden_pad
=
extra_weight_attrs
.
get
(
"hidden_pad"
,
0
)
self
.
intermediate_pad
=
(
intermediate_size_per_partition_after_pad
-
intermediate_size_per_partition
)
# Fused gate_up_proj (column parallel)
# Fused gate_up_proj (column parallel)
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
torch
.
zeros
(
...
@@ -362,6 +368,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -362,6 +368,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w2_scale
=
w2_scale
,
w2_scale
=
w2_scale
,
w1_bias
=
w1_bias
,
w1_bias
=
w1_bias
,
w2_bias
=
w2_bias
,
w2_bias
=
w2_bias
,
hidden_pad
=
self
.
hidden_pad
,
intermediate_pad
=
self
.
intermediate_pad
,
)
)
def
select_gemm_impl
(
def
select_gemm_impl
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment