Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e99fb988
Unverified
Commit
e99fb988
authored
Mar 23, 2026
by
Chuan (Richard) Li
Committed by
GitHub
Mar 23, 2026
Browse files
[ROCm] Fix fused_moe_fake signature mismatch and other AITER bugs (#36100)
Signed-off-by:
Li
<
chuali@amd.com
>
parent
a16133a0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
16 additions
and
26 deletions
+16
-26
vllm/_aiter_ops.py
vllm/_aiter_ops.py
+5
-1
vllm/model_executor/layers/quantization/quark/quark_moe.py
vllm/model_executor/layers/quantization/quark/quark_moe.py
+1
-1
vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py
...xecutor/layers/quantization/quark/schemes/quark_ocp_mx.py
+5
-20
vllm/v1/attention/backends/rocm_aiter_fa.py
vllm/v1/attention/backends/rocm_aiter_fa.py
+5
-4
No files found.
vllm/_aiter_ops.py
View file @
e99fb988
...
...
@@ -137,6 +137,10 @@ def _rocm_aiter_fused_moe_fake(
a2_scale
:
torch
.
Tensor
|
None
=
None
,
num_local_tokens
:
torch
.
Tensor
|
None
=
None
,
output_dtype
:
torch
.
dtype
|
None
=
None
,
hidden_pad
:
int
=
0
,
intermediate_pad
:
int
=
0
,
bias1
:
torch
.
Tensor
|
None
=
None
,
bias2
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
if
output_dtype
is
not
None
:
return
torch
.
empty_like
(
hidden_states
,
dtype
=
output_dtype
)
...
...
@@ -1700,7 +1704,7 @@ class rocm_aiter_ops:
)
@
staticmethod
def
triton_fp4_gemm_dynamic_q
a
unt
(
def
triton_fp4_gemm_dynamic_qu
a
nt
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/quark/quark_moe.py
View file @
e99fb988
...
...
@@ -765,7 +765,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
if
self
.
emulate
:
logger
.
warning_once
(
f
"The current mode (supports_mx=
{
current_platform
.
supports_mx
()
}
, "
f
"use_
mxfp4
_aiter_moe=
{
self
.
use_rocm_aiter_moe
}
, "
f
"use_
rocm
_aiter_moe=
{
self
.
use_rocm_aiter_moe
}
, "
f
"ocp_mx_scheme=
{
self
.
ocp_mx_scheme
}
) "
"does not support native MXFP4/MXFP6 "
"computation. Simulated weight dequantization and activation "
...
...
vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py
View file @
e99fb988
...
...
@@ -3,13 +3,12 @@
from
collections.abc
import
Callable
from
fractions
import
Fraction
from
functools
import
cache
,
partial
from
functools
import
partial
from
typing
import
Any
import
torch
import
torch.nn.functional
as
F
from
vllm
import
envs
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.mxfp4_utils
import
(
...
...
@@ -37,22 +36,6 @@ from .quark_scheme import QuarkScheme
logger
=
init_logger
(
__name__
)
# TODO: move registration of custom op to aiter_ops.py
# `from vllm._aiter_ops import rocm_aiter_ops`
# use `rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()`
# for envs checks which does not require @cache anymore.
# triton kernel is torch compile compatible.
# does not require direct registration.
# use `rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt`.
@
cache
def
is_rocm_aiter_fp4_asm_gemm_enabled
()
->
bool
:
return
(
current_platform
.
is_rocm
()
and
envs
.
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
and
envs
.
VLLM_ROCM_USE_AITER
)
try
:
from
aiter.ops.shuffle
import
shuffle_weight
from
aiter.ops.triton.gemm_afp4wfp4
import
(
...
...
@@ -63,7 +46,7 @@ try:
from
vllm.utils.torch_utils
import
direct_register_custom_op
if
is_
rocm_aiter_
fp4_asm_gemm
_enabled
():
if
rocm_aiter_
ops
.
is_asm_fp4_gemm_dynamic_quant
_enabled
():
from
aiter
import
gemm_a4w4
,
per_1x32_f4_quant_hip
def
gemm_with_dynamic_quant
(
...
...
@@ -233,7 +216,9 @@ class QuarkOCP_MX(QuarkScheme):
self
.
input_dtype
!=
"mxfp4"
or
self
.
weight_dtype
!=
"mxfp4"
)
self
.
rocm_use_aiter_fp4_asm_gemm
=
is_rocm_aiter_fp4_asm_gemm_enabled
()
self
.
rocm_use_aiter_fp4_asm_gemm
=
(
rocm_aiter_ops
.
is_asm_fp4_gemm_dynamic_quant_enabled
()
)
if
not
self
.
emulate
and
(
dynamic_mxfp4_quant
is
None
or
gemm_afp4wfp4
is
None
):
# Currently need these kernels if not emulating
...
...
vllm/v1/attention/backends/rocm_aiter_fa.py
View file @
e99fb988
...
...
@@ -157,13 +157,13 @@ if current_platform.is_rocm():
total_tokens
:
int
,
):
assert
kv_cache_layout
in
[
"NHD"
,
"SHUFFLE"
],
(
"kv_cache_layout only support NHD, SHUFFLE"
"kv_cache_layout only support
s
NHD, SHUFFLE"
)
head_dim
=
key
.
shape
[
2
]
x
=
16
//
key_cache
.
element_size
()
# assert dequant is True, "Currently, we only support "\
# "gather cache with dequant"
# For k cache layout: [num_blocks,
num_heads, page_size
, head_dim]
# For k cache layout: [num_blocks,
page_size, num_heads
, head_dim]
assert
head_dim
==
key_cache
.
shape
[
3
],
(
"We assume your kv cache layout is [num_blocks, "
"page_size, num_heads, head_dim], but got otherwise"
...
...
@@ -832,7 +832,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
if
attn_type
not
in
[
AttentionType
.
DECODER
,
AttentionType
.
ENCODER_DECODER
]:
raise
NotImplementedError
(
"Encoder self-attention is not implemented for FlashAttentionImpl"
"Encoder self-attention is not implemented for
Aiter
FlashAttentionImpl"
)
def
extend_for_sliding_window
(
...
...
@@ -1047,7 +1047,8 @@ class AiterFlashAttentionImpl(AttentionImpl):
if
output_scale
is
not
None
or
output_block_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported for FlashAttentionImpl"
"fused output quantization is not yet supported "
"for AiterFlashAttentionImpl"
)
if
attn_metadata
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment