Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5dc35387
Unverified
Commit
5dc35387
authored
Mar 04, 2026
by
Chuan (Richard) Li
Committed by
GitHub
Mar 04, 2026
Browse files
[ROCm][Bugfix] Fall back from CK MXFP4 MoE when GEMM dimensions are unsupported (#35893)
Signed-off-by:
Li
<
chuali@amd.com
>
parent
36bf2131
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
63 additions
and
1 deletion
+63
-1
vllm/model_executor/layers/quantization/mxfp4.py
vllm/model_executor/layers/quantization/mxfp4.py
+26
-0
vllm/model_executor/layers/quantization/quark/quark_moe.py
vllm/model_executor/layers/quantization/quark/quark_moe.py
+30
-1
vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
+7
-0
No files found.
vllm/model_executor/layers/quantization/mxfp4.py
View file @
5dc35387
...
@@ -48,6 +48,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
...
@@ -48,6 +48,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_fp4_layer_for_marlin
,
prepare_moe_fp4_layer_for_marlin
,
)
)
from
vllm.model_executor.layers.quantization.utils.mxfp4_utils
import
(
from
vllm.model_executor.layers.quantization.utils.mxfp4_utils
import
(
CK_MXFP4_MOE_DIM_ALIGNMENT
,
_can_support_mxfp4
,
_can_support_mxfp4
,
_swizzle_mxfp4
,
_swizzle_mxfp4
,
get_padding_alignment
,
get_padding_alignment
,
...
@@ -259,6 +260,31 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
...
@@ -259,6 +260,31 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
get_current_vllm_config
().
compilation_config
.
max_cudagraph_capture_size
get_current_vllm_config
().
compilation_config
.
max_cudagraph_capture_size
)
)
# CK's pre-compiled MXFP4 MoE GEMM kernel instances have dimension
# alignment requirements. Fall back to Triton when not met.
if
(
self
.
mxfp4_backend
==
Mxfp4Backend
.
CK
and
moe
.
intermediate_size_per_partition
%
CK_MXFP4_MOE_DIM_ALIGNMENT
!=
0
):
if
has_triton_kernels
():
logger
.
warning_once
(
"CK MXFP4 MoE GEMM does not support "
"intermediate_size_per_partition=%d (not a multiple of "
"%d). Falling back to Triton backend."
,
moe
.
intermediate_size_per_partition
,
CK_MXFP4_MOE_DIM_ALIGNMENT
,
)
self
.
mxfp4_backend
=
Mxfp4Backend
.
TRITON
else
:
raise
ValueError
(
f
"CK MXFP4 MoE GEMM does not support "
f
"intermediate_size_per_partition="
f
"
{
moe
.
intermediate_size_per_partition
}
(not a multiple "
f
"of
{
CK_MXFP4_MOE_DIM_ALIGNMENT
}
) and no Triton "
f
"fallback is available. Use a compatible "
f
"tensor_parallel_size."
)
assert
self
.
mxfp4_backend
!=
Mxfp4Backend
.
NONE
,
(
assert
self
.
mxfp4_backend
!=
Mxfp4Backend
.
NONE
,
(
f
"get_mxfp4_backend(with_lora_support=
{
moe
.
is_lora_enabled
}
) found"
f
"get_mxfp4_backend(with_lora_support=
{
moe
.
is_lora_enabled
}
) found"
"no compatible MXFP4 MoE backend (FlashInfer/Marlin/Triton)."
"no compatible MXFP4 MoE backend (FlashInfer/Marlin/Triton)."
...
...
vllm/model_executor/layers/quantization/quark/quark_moe.py
View file @
5dc35387
...
@@ -32,7 +32,10 @@ from vllm.model_executor.layers.quantization.mxfp4 import (
...
@@ -32,7 +32,10 @@ from vllm.model_executor.layers.quantization.mxfp4 import (
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
prepare_fp8_moe_layer_for_marlin
,
prepare_fp8_moe_layer_for_marlin
,
)
)
from
vllm.model_executor.layers.quantization.utils.mxfp4_utils
import
_swizzle_mxfp4
from
vllm.model_executor.layers.quantization.utils.mxfp4_utils
import
(
CK_MXFP4_MOE_DIM_ALIGNMENT
,
_swizzle_mxfp4
,
)
from
vllm.model_executor.layers.quantization.utils.ocp_mx_utils
import
(
from
vllm.model_executor.layers.quantization.utils.ocp_mx_utils
import
(
OCP_MX_BLOCK_SIZE
,
OCP_MX_BLOCK_SIZE
,
OCP_MX_Scheme
,
OCP_MX_Scheme
,
...
@@ -732,6 +735,32 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
...
@@ -732,6 +735,32 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
or
not
self
.
ocp_mx_scheme
.
startswith
(
"w_mxfp4"
)
or
not
self
.
ocp_mx_scheme
.
startswith
(
"w_mxfp4"
)
)
and
(
self
.
mxfp4_backend
is
None
or
not
self
.
use_rocm_aiter_moe
)
)
and
(
self
.
mxfp4_backend
is
None
or
not
self
.
use_rocm_aiter_moe
)
# CK's pre-compiled MXFP4 MoE GEMM kernel instances have dimension
# alignment requirements. When violated (e.g. MiniMax-M2.1 with
# TP=4 yields intermediate_size_per_partition=384), AITER raises:
# "device_gemm ... does not support this GEMM problem".
# Fall back to emulation in that case.
if
(
not
self
.
emulate
and
self
.
use_rocm_aiter_moe
and
self
.
ocp_mx_scheme
is
not
None
and
self
.
ocp_mx_scheme
.
startswith
(
"w_mxfp4"
)
and
moe
.
intermediate_size_per_partition
%
CK_MXFP4_MOE_DIM_ALIGNMENT
!=
0
):
logger
.
warning_once
(
"AITER CK MXFP4 MoE GEMM does not support "
"intermediate_size_per_partition=%d (not a multiple of %d). "
"This typically happens when intermediate_size / "
"tensor_parallel_size produces an incompatible dimension. "
"Falling back to emulation mode. To avoid this overhead, "
"use a compatible tensor_parallel_size or set "
"VLLM_ROCM_USE_AITER_MOE=0."
,
moe
.
intermediate_size_per_partition
,
CK_MXFP4_MOE_DIM_ALIGNMENT
,
)
self
.
use_rocm_aiter_moe
=
False
self
.
emulate
=
True
if
self
.
emulate
:
if
self
.
emulate
:
logger
.
warning_once
(
logger
.
warning_once
(
f
"The current mode (supports_mx=
{
current_platform
.
supports_mx
()
}
, "
f
"The current mode (supports_mx=
{
current_platform
.
supports_mx
()
}
, "
...
...
vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
View file @
5dc35387
...
@@ -14,6 +14,13 @@ from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_
...
@@ -14,6 +14,13 @@ from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# CK's pre-compiled MXFP4 MoE GEMM kernel instances require the
# intermediate_size (after TP split) to be a multiple of this value.
# This arises from FP4 packing (2 values per byte) combined with CK
# tile size constraints. When violated, AITER raises:
# "device_gemm ... does not support this GEMM problem".
CK_MXFP4_MOE_DIM_ALIGNMENT
=
256
def
_swizzle_mxfp4
(
quant_tensor
,
scale
,
num_warps
):
def
_swizzle_mxfp4
(
quant_tensor
,
scale
,
num_warps
):
"""weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""
"""weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment