Unverified Commit f080a835 authored by vllmellm's avatar vllmellm Committed by GitHub
Browse files

[RFC][ROCm][AITER] Keep all AITER kernels in `_aiter_ops` class like...


[RFC][ROCm][AITER] Keep all AITER kernels in `_aiter_ops` class like `_custom_ops` and `_ipex_ops` (#24490)
Signed-off-by: default avatarvllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent 40e2eeeb
...@@ -97,7 +97,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels ...@@ -97,7 +97,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] | | trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] | | pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] | | iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_moe_impl] | | rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] |
| cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] | | cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] |
| naive batched<sup>4</sup> | batched | int8,</br>fp8 | G,A,T | silu, gelu | <sup>6</sup> | Y | [`NaiveBatchedExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.NaiveBatchedExperts] | | naive batched<sup>4</sup> | batched | int8,</br>fp8 | G,A,T | silu, gelu | <sup>6</sup> | Y | [`NaiveBatchedExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.NaiveBatchedExperts] |
......
...@@ -6,6 +6,8 @@ Run `pytest tests/kernels/test_moe.py`. ...@@ -6,6 +6,8 @@ Run `pytest tests/kernels/test_moe.py`.
""" """
import functools import functools
import importlib
import sys
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
...@@ -20,6 +22,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock ...@@ -20,6 +22,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.moe.utils import fused_moe from tests.kernels.moe.utils import fused_moe
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import init_distributed_environment from vllm.distributed.parallel_state import init_distributed_environment
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
...@@ -412,14 +415,12 @@ def test_mixtral_moe( ...@@ -412,14 +415,12 @@ def test_mixtral_moe(
huggingface.""" huggingface."""
# clear the cache before every test # clear the cache before every test
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # Force reload aiter_ops to pick up the new environment variables.
is_rocm_aiter_moe_enabled, if "rocm_aiter_ops" in sys.modules:
) importlib.reload(rocm_aiter_ops)
is_rocm_aiter_moe_enabled.cache_clear()
if use_rocm_aiter: if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
if dtype == torch.float32: if dtype == torch.float32:
pytest.skip("AITER ROCm test skip for float32") pytest.skip("AITER ROCm test skip for float32")
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import pytest import pytest
import torch import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import ( from vllm.model_executor.layers.activation import (
...@@ -15,9 +16,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( ...@@ -15,9 +16,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
dispatch_topk_func, dispatch_topk_func,
vllm_topk_softmax, vllm_topk_softmax,
) )
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
)
from vllm.model_executor.layers.layernorm import ( from vllm.model_executor.layers.layernorm import (
RMSNorm, RMSNorm,
dispatch_rocm_rmsnorm_func, dispatch_rocm_rmsnorm_func,
...@@ -126,50 +124,39 @@ def test_enabled_ops_invalid(env: str): ...@@ -126,50 +124,39 @@ def test_enabled_ops_invalid(env: str):
RMSNorm(1024).enabled() RMSNorm(1024).enabled()
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) @pytest.mark.parametrize(
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) )
topk_func = dispatch_topk_func() def test_topk_dispatch(use_rocm_aiter: bool):
is_rocm_aiter_moe_enabled.cache_clear() topk_func = dispatch_topk_func(use_rocm_aiter)
if current_platform.is_rocm() and int(use_rocm_aiter):
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_topk_softmax,
)
assert topk_func == rocm_aiter_topk_softmax if current_platform.is_rocm() and use_rocm_aiter:
assert topk_func == rocm_aiter_ops.topk_softmax
else: else:
assert topk_func == vllm_topk_softmax assert topk_func == vllm_topk_softmax
@pytest.mark.parametrize("add_residual", [True, False]) @pytest.mark.parametrize("add_residual", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) @pytest.mark.parametrize("use_rocm_aiter", [True, False])
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
@pytest.mark.skipif( @pytest.mark.skipif(
not current_platform.is_rocm(), reason="AITER is a feature exclusive for ROCm" not current_platform.is_rocm(), reason="AITER is a feature exclusive for ROCm"
) )
def test_rms_norm_dispatch( def test_rms_norm_dispatch(
add_residual: bool, add_residual: bool, dtype: torch.dtype, use_rocm_aiter: bool
dtype: torch.dtype,
use_rocm_aiter: str,
use_rocm_aiter_norm: str,
monkeypatch,
): ):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype, use_rocm_aiter)
monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm)
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype)
should_use_rocm_aiter = ( should_use_rocm_aiter = (
current_platform.is_rocm() current_platform.is_rocm()
and int(use_rocm_aiter) and use_rocm_aiter
and int(use_rocm_aiter_norm)
and dtype in RMS_NORM_SUPPORTED_DTYPES and dtype in RMS_NORM_SUPPORTED_DTYPES
) )
if add_residual and should_use_rocm_aiter: if add_residual and should_use_rocm_aiter:
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add assert rms_norm_func == rocm_aiter_ops.rms_norm2d_with_add
elif should_use_rocm_aiter: elif should_use_rocm_aiter:
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rms_norm assert rms_norm_func == rocm_aiter_ops.rms_norm
elif add_residual: elif add_residual:
assert rms_norm_func == fused_add_rms_norm assert rms_norm_func == fused_add_rms_norm
else: else:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from collections.abc import Callable
import torch
import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
def is_aiter_found() -> bool:
from importlib.util import find_spec
return find_spec("aiter") is not None
# `find_spec` is not torch.compile compatible.
# In cases where aiter availability might have
# been checked in forward passes that are torch compiled.
# we keep this global outside to not cause torch compile breaks.
IS_AITER_FOUND = is_aiter_found()
def if_aiter_supported(func: Callable) -> Callable:
"""Decorator that only executes the function if
ROCm AITER package is supported on gfx9 archs.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
# checks the platform, device arch and aiter library existance.
from vllm.platforms.rocm import on_gfx9
if current_platform.is_rocm() and on_gfx9() and IS_AITER_FOUND:
return func(*args, **kwargs)
else:
# Return None or do nothing if not supported
return None
return wrapper
def _rocm_aiter_fused_moe_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
expert_mask: torch.Tensor | None = None,
activation_method: int = 0,
quant_method: int = 0,
doweight_stage1: bool = False,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
) -> torch.Tensor:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
activation = ActivationType(activation_method)
quant_type = QuantType(quant_method)
return fused_moe(
hidden_states,
w1,
w2,
topk_weight,
topk_ids,
expert_mask,
activation,
quant_type,
doweight_stage1,
w1_scale,
w2_scale,
a1_scale,
a2_scale,
)
def _rocm_aiter_fused_moe_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
expert_mask: torch.Tensor | None = None,
activation_method: int = 0,
quant_method: int = 0,
doweight_stage1: bool = False,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
def _rocm_aiter_asm_moe_tkw1_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: torch.Tensor | None = None,
fc2_scale: torch.Tensor | None = None,
fc1_smooth_scale: torch.Tensor | None = None,
fc2_smooth_scale: torch.Tensor | None = None,
a16: bool = False,
per_tensor_quant_scale: torch.Tensor | None = None,
expert_mask: torch.Tensor | None = None,
activation_method: int = 0,
) -> torch.Tensor:
from aiter import ActivationType
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
activation = ActivationType(activation_method)
return asm_moe_tkw1(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
fc1_scale=fc1_scale,
fc2_scale=fc2_scale,
fc1_smooth_scale=fc1_smooth_scale,
fc2_smooth_scale=fc2_smooth_scale,
a16=a16,
per_tensor_quant_scale=per_tensor_quant_scale,
expert_mask=expert_mask,
activation=activation,
)
def _rocm_aiter_asm_moe_tkw1_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: torch.Tensor | None = None,
fc2_scale: torch.Tensor | None = None,
fc1_smooth_scale: torch.Tensor | None = None,
fc2_smooth_scale: torch.Tensor | None = None,
a16: bool = False,
per_tensor_quant_scale: torch.Tensor | None = None,
expert_mask: torch.Tensor | None = None,
activation_method: int = 0,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
def _rocm_aiter_topk_softmax_impl(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> None:
from aiter import topk_softmax
topk_softmax(
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
)
def _rocm_aiter_topk_softmax_fake(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> None:
pass
def _rocm_aiter_biased_grouped_topk_impl(
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
from aiter import biased_grouped_topk
biased_grouped_topk(
gating_output,
correction_bias,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
need_renorm,
routed_scaling_factor,
)
def _rocm_aiter_biased_grouped_topk_fake(
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
pass
def _rocm_aiter_grouped_topk_impl(
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
is_softmax = scoring_func == "softmax"
from aiter import grouped_topk
grouped_topk(
gating_output,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
need_renorm,
is_softmax,
routed_scaling_factor,
)
def _rocm_aiter_grouped_topk_fake(
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
pass
def _rocm_aiter_mla_decode_fwd_impl(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: torch.Tensor | None = None,
kv_indices: torch.Tensor | None = None,
kv_last_page_lens: torch.Tensor | None = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
) -> None:
from aiter.mla import mla_decode_fwd
mla_decode_fwd(
q,
kv_buffer.view(-1, 1, 1, q.shape[-1]),
o,
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
max_seqlen_qo,
sm_scale=sm_scale,
logit_cap=logit_cap,
)
def _rocm_aiter_mla_decode_fwd_fake(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: torch.Tensor | None = None,
kv_indices: torch.Tensor | None = None,
kv_last_page_lens: torch.Tensor | None = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
) -> None:
pass
def _rocm_aiter_gemm_w8a8_impl(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
from aiter import gemm_a8w8_CK
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
def _rocm_aiter_gemm_w8a8_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m = A.shape[0]
n = B.shape[0]
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
return Y
def _rocm_aiter_gemm_w8a8_blockscale_impl(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
from aiter import gemm_a8w8_blockscale
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
def _rocm_aiter_gemm_w8a8_blockscale_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m = A.shape[0]
n = B.shape[0]
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
return Y
def _rocm_aiter_rms_norm_impl(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
from aiter import rms_norm
if x.dim() > 2:
x_original_shape = x.shape
x = x.reshape(-1, x_original_shape[-1])
x = rms_norm(x, weight, variance_epsilon)
return x.reshape(x_original_shape)
return rms_norm(x, weight, variance_epsilon)
def _rocm_aiter_rms_norm_fake(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
return torch.empty_like(x)
def _rocm_aiter_rmsnorm2d_fwd_with_add_impl(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter import rmsnorm2d_fwd_with_add
residual_out = torch.empty_like(residual)
output = torch.empty_like(x)
rmsnorm2d_fwd_with_add(
output, # output
x, # input
residual, # residual input
residual_out, # residual output
weight,
variance_epsilon,
)
return output, residual_out
def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(x), torch.empty_like(residual)
# Global flag to ensure ops are registered only once
_OPS_REGISTERED = False
class rocm_aiter_ops:
_AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
_LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
_RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
_PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
_MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
_TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
_TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
_MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
@classmethod
@if_aiter_supported
def is_enabled(cls) -> bool:
"""Verifies device specs and availability of aiter main env variable."""
return cls._AITER_ENABLED
@classmethod
@if_aiter_supported
def is_linear_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._LINEAR_ENABLED
@classmethod
@if_aiter_supported
def is_linear_fp8_enaled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls.is_linear_enabled() and current_platform.is_fp8_fnuz()
@classmethod
@if_aiter_supported
def is_rmsnorm_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._RMSNORM_ENABLED
@classmethod
@if_aiter_supported
def is_fused_moe_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._FMOE_ENABLED
@classmethod
@if_aiter_supported
def is_fusion_moe_shared_experts_enabled(cls) -> bool:
return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED
@classmethod
@if_aiter_supported
def is_mla_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._MLA_ENABLED
@classmethod
@if_aiter_supported
def is_mha_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._MHA_ENABLED
@classmethod
@if_aiter_supported
def is_pa_attn_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._PG_ATTN_ENABLED
@classmethod
@if_aiter_supported
def is_triton_unified_attn_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED
@classmethod
@if_aiter_supported
def is_fp8bmm_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._FP8BMM_ENABLED
@classmethod
@if_aiter_supported
def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._FP4_GEMM_DYNAMIC_QUANT_ASM
@classmethod
@if_aiter_supported
def is_triton_rotary_embed_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED
@staticmethod
@if_aiter_supported
def register_ops_once() -> None:
global _OPS_REGISTERED
if not _OPS_REGISTERED:
tags = (
tuple()
if is_torch_equal_or_newer("2.7.0")
else (torch.Tag.needs_fixed_stride_order,)
)
# register all the custom ops here
direct_register_custom_op(
op_name="rocm_aiter_asm_moe_tkw1",
op_func=_rocm_aiter_asm_moe_tkw1_impl,
mutates_args=[],
fake_impl=_rocm_aiter_asm_moe_tkw1_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_fused_moe",
op_func=_rocm_aiter_fused_moe_impl,
mutates_args=[],
fake_impl=_rocm_aiter_fused_moe_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_topk_softmax",
op_func=_rocm_aiter_topk_softmax_impl,
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
fake_impl=_rocm_aiter_topk_softmax_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_biased_grouped_topk",
op_func=_rocm_aiter_biased_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"],
fake_impl=_rocm_aiter_biased_grouped_topk_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_grouped_topk",
op_func=_rocm_aiter_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"],
fake_impl=_rocm_aiter_grouped_topk_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_mla_decode_fwd",
op_func=_rocm_aiter_mla_decode_fwd_impl,
mutates_args=["o"],
fake_impl=_rocm_aiter_mla_decode_fwd_fake,
tags=tags,
)
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8",
op_func=_rocm_aiter_gemm_w8a8_impl,
mutates_args=[],
fake_impl=_rocm_aiter_gemm_w8a8_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8_blockscale",
op_func=_rocm_aiter_gemm_w8a8_blockscale_impl,
mutates_args=[],
fake_impl=_rocm_aiter_gemm_w8a8_blockscale_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rms_norm",
op_func=_rocm_aiter_rms_norm_impl,
mutates_args=[],
fake_impl=_rocm_aiter_rms_norm_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl,
mutates_args=[],
fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake,
dispatch_key=current_platform.dispatch_key,
)
_OPS_REGISTERED = True
@staticmethod
def rms_norm2d_with_add(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add(
x, residual, weight, variance_epsilon
)
@staticmethod
def rms_norm(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon)
@staticmethod
def gemm_w8a8(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_gemm_w8a8(A, B, As, Bs, bias, output_dtype)
@staticmethod
def gemm_w8a8_blockscale(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
A, B, As, Bs, output_dtype
)
@staticmethod
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
expert_mask: torch.Tensor | None = None,
activation_method: int = 0,
quant_method: int = 0,
doweight_stage1: bool = False,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_fused_moe(
hidden_states,
w1,
w2,
topk_weight,
topk_ids,
expert_mask,
activation_method,
quant_method,
doweight_stage1,
w1_scale,
w2_scale,
a1_scale,
a2_scale,
)
@staticmethod
def asm_moe_tkw1(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: torch.Tensor | None = None,
fc2_scale: torch.Tensor | None = None,
fc1_smooth_scale: torch.Tensor | None = None,
fc2_smooth_scale: torch.Tensor | None = None,
a16: bool = False,
per_tensor_quant_scale: torch.Tensor | None = None,
expert_mask: torch.Tensor | None = None,
activation_method: int = 0,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_asm_moe_tkw1(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
fc1_scale,
fc2_scale,
fc1_smooth_scale,
fc2_smooth_scale,
a16,
per_tensor_quant_scale,
expert_mask,
activation_method,
)
@staticmethod
def topk_softmax(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> tuple[torch.Tensor, ...]:
torch.ops.vllm.rocm_aiter_topk_softmax(
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
)
return topk_weights, topk_indices
@staticmethod
def biased_grouped_topk(
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
routed_scaling_factor: float = 1.0,
) -> None:
torch.ops.vllm.rocm_aiter_biased_grouped_topk(
gating_output,
correction_bias,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
need_renorm,
routed_scaling_factor,
)
@staticmethod
def grouped_topk(
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
) -> None:
torch.ops.vllm.rocm_aiter_grouped_topk(
gating_output,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
need_renorm,
scoring_func,
routed_scaling_factor,
)
@staticmethod
def mla_decode_fwd(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
sm_scale: float,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: torch.Tensor | None = None,
kv_indices: torch.Tensor | None = None,
kv_last_page_lens: torch.Tensor | None = None,
logit_cap: float = 0.0,
):
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
q,
kv_buffer.view(-1, 1, 1, q.shape[-1]),
o,
qo_indptr,
max_seqlen_qo,
kv_indptr,
kv_indices,
kv_last_page_lens,
sm_scale=sm_scale,
logit_cap=logit_cap,
)
@staticmethod
def triton_fp4_gemm_dynamic_qaunt(
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out_dtype: torch.dtype | None = torch.bfloat16,
x_scales: torch.Tensor | None = None,
) -> torch.Tensor:
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
from aiter.ops.triton.quant import dynamic_mxfp4_quant
if x_scales is None:
x_q, x_s = dynamic_mxfp4_quant(x)
else:
x_q = x
x_s = x_scales
y = torch.empty(
x_q.shape[0], weight.shape[0], device=x_q.device, dtype=out_dtype
)
gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
return y
@staticmethod
def triton_rotary_embed(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
cos_sin_cache: torch.Tensor,
head_size: int,
rotary_dim: int,
is_neox_style: bool,
):
from aiter.ops.triton.rope import rope_cached_thd_positions_2c_fwd_inplace
num_tokens = positions.numel()
cos, sin = cos_sin_cache.chunk(2, dim=-1)
query_shape = query.shape
key_shape = key.shape
rotate_style = 0 if is_neox_style else 1
query = query.view(num_tokens, -1, head_size)
key = key.view(num_tokens, -1, head_size)
query_ = query[..., :rotary_dim]
key_ = key[..., :rotary_dim]
positions = positions.view(*query.shape[:1])
rope_cached_thd_positions_2c_fwd_inplace(
positions,
sin,
cos,
query_,
key_,
rotate_style,
reuse_freqs_front_part=True,
is_nope_first=False,
)
query = query.view(query_shape)
key = key.view(key_shape)
@staticmethod
def triton_fp8_bmm(
X: torch.Tensor,
WQ: torch.Tensor,
w_scale: torch.Tensor,
group_size: int = 128,
bias: torch.Tensor | None = None,
dtype: torch.dtype | None = torch.bfloat16,
splitK: int | None = None,
YQ: torch.Tensor | None = None,
transpose_bm: bool | None = False,
config: dict | None = None,
) -> torch.Tensor:
# ruff: noqa: E501 # isort: skip
from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import (
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm,
)
return aiter_triton_fp8_bmm(
X,
WQ,
w_scale,
group_size=group_size,
bias=bias,
dtype=dtype,
splitK=splitK,
YQ=YQ,
transpose_bm=transpose_bm,
config=config,
)
@staticmethod
def triton_gemm_a8w8_blockscale(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
@staticmethod
def per_1x128_fp8_quant(
input_2d: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
"""Only applies quantization method for fp8 data type only."""
from aiter import QuantType, dtypes, get_hip_quant
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
return aiter_per1x128_quant(input_2d.contiguous(), quant_dtype=dtypes.fp8)
@staticmethod
def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool:
return (n, k) in [
(1024, 8192),
(2112, 7168),
(3072, 1536),
(32768, 8192),
(4096, 7168),
(4608, 7168),
(512, 7168),
(7168, 2048),
(7168, 256),
(8192, 1024),
(8192, 32768),
]
@staticmethod
def shuffle_weight(
self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)
) -> torch.Tensor:
from aiter.ops.shuffle import shuffle_weight
return shuffle_weight(tensor, layout=layout)
@staticmethod
def shuffle_weights(
*tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
) -> tuple[torch.Tensor, ...]:
"""
Applies shuffle_weight function from AITER to each
input tensor and returns them.
Rearranges (shuffles) the input tensor/s
into a specified block layout for optimized computation.
Args:
*tensors: Variable number of torch.Tensor objects.
layout: A pair of integers specifying the block sizes used to divide
the tensors during shuffling. Default is (16, 16).
Returns:
A Tuple of shuffled tensors.
"""
from aiter.ops.shuffle import shuffle_weight
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
rocm_aiter_ops.register_ops_once()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
def get_aiter_mla_metadata(
max_batch_size: int, block_size: int, max_block_per_batch: int, device: torch.device
) -> tuple[torch.Tensor, ...]:
paged_kv_indices = torch.zeros(
max_batch_size * max_block_per_batch, dtype=torch.int32, device=device
)
paged_kv_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int32, device=device)
paged_kv_last_page_lens = torch.full(
(max_batch_size,), block_size, dtype=torch.int32
)
qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device)
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr
def aiter_mla_decode_fwd(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
sm_scale: float,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: torch.Tensor | None = None,
kv_indices: torch.Tensor | None = None,
kv_last_page_lens: torch.Tensor | None = None,
logit_cap: float = 0.0,
):
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
q,
kv_buffer.view(-1, 1, 1, q.shape[-1]),
o,
qo_indptr,
max_seqlen_qo,
kv_indptr,
kv_indices,
kv_last_page_lens,
sm_scale=sm_scale,
logit_cap=logit_cap,
)
def mla_decode_fwd_impl(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: torch.Tensor | None = None,
kv_indices: torch.Tensor | None = None,
kv_last_page_lens: torch.Tensor | None = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
) -> None:
from aiter.mla import mla_decode_fwd
mla_decode_fwd(
q,
kv_buffer.view(-1, 1, 1, q.shape[-1]),
o,
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
max_seqlen_qo,
sm_scale=sm_scale,
logit_cap=logit_cap,
)
def mla_decode_fwd_fake(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: torch.Tensor | None = None,
kv_indices: torch.Tensor | None = None,
kv_last_page_lens: torch.Tensor | None = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
) -> None:
pass
if current_platform.is_rocm():
if is_torch_equal_or_newer("2.7.0"):
tags = ()
else:
tags = ((torch.Tag.needs_fixed_stride_order,),)
direct_register_custom_op(
op_name="rocm_aiter_mla_decode_fwd",
op_func=mla_decode_fwd_impl,
mutates_args=["o"],
fake_impl=mla_decode_fwd_fake,
tags=tags,
)
...@@ -109,7 +109,7 @@ if TYPE_CHECKING: ...@@ -109,7 +109,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = True VLLM_ROCM_USE_AITER_MHA: bool = True
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
VLLM_ROCM_USE_TRITON_ROPE: bool = False VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False
VLLM_ROCM_USE_AITER_FP8BMM: bool = True VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True
...@@ -926,8 +926,8 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -926,8 +926,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
), ),
# Whether to use aiter rope. # Whether to use aiter rope.
# By default is disabled. # By default is disabled.
"VLLM_ROCM_USE_TRITON_ROPE": lambda: ( "VLLM_ROCM_USE_AITER_TRITON_ROPE": lambda: (
os.getenv("VLLM_ROCM_USE_TRITON_ROPE", "False").lower() in ("true", "1") os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "False").lower() in ("true", "1")
), ),
# Whether to use aiter triton fp8 bmm kernel # Whether to use aiter triton fp8 bmm kernel
# By default is enabled. # By default is enabled.
...@@ -1589,7 +1589,7 @@ def compute_hash() -> str: ...@@ -1589,7 +1589,7 @@ def compute_hash() -> str:
"VLLM_ROCM_USE_AITER_MLA", "VLLM_ROCM_USE_AITER_MLA",
"VLLM_ROCM_USE_AITER_MHA", "VLLM_ROCM_USE_AITER_MHA",
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", "VLLM_ROCM_USE_AITER_FP4_ASM_GEMM",
"VLLM_ROCM_USE_TRITON_ROPE", "VLLM_ROCM_USE_AITER_TRITON_ROPE",
"VLLM_ROCM_USE_AITER_FP8BMM", "VLLM_ROCM_USE_AITER_FP8BMM",
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION",
"VLLM_ROCM_USE_AITER_TRITON_GEMM", "VLLM_ROCM_USE_AITER_TRITON_GEMM",
......
...@@ -14,6 +14,7 @@ import torch.nn.functional as F ...@@ -14,6 +14,7 @@ import torch.nn.functional as F
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
...@@ -55,8 +56,6 @@ from vllm.triton_utils import tl, triton ...@@ -55,8 +56,6 @@ from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -1089,11 +1088,11 @@ def vllm_topk_softmax( ...@@ -1089,11 +1088,11 @@ def vllm_topk_softmax(
return topk_weights, topk_indices return topk_weights, topk_indices
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]: def dispatch_topk_func(
if is_rocm_aiter_moe_enabled(): use_rocm_aiter: bool = False,
from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax ) -> Callable[..., tuple[torch.Tensor, ...]]:
if use_rocm_aiter:
return rocm_aiter_topk_softmax return rocm_aiter_ops.topk_softmax
return vllm_topk_softmax return vllm_topk_softmax
...@@ -1121,7 +1120,7 @@ def fused_topk( ...@@ -1121,7 +1120,7 @@ def fused_topk(
M, topk, dtype=torch.int32, device=hidden_states.device M, topk, dtype=torch.int32, device=hidden_states.device
) )
topk_func = dispatch_topk_func() topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled())
topk_weights, topk_ids = topk_func( topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
) )
......
...@@ -13,6 +13,7 @@ import torch.nn.functional as F ...@@ -13,6 +13,7 @@ import torch.nn.functional as F
from torch.nn.parameter import UninitializedParameter from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.parallel import ExpertPlacementStrategy from vllm.config.parallel import ExpertPlacementStrategy
from vllm.distributed import ( from vllm.distributed import (
...@@ -41,8 +42,6 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( ...@@ -41,8 +42,6 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
) )
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data, init_aiter_topK_meta_data,
is_rocm_aiter_fusion_shared_expert_enabled,
is_rocm_aiter_moe_enabled,
) )
from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
...@@ -92,13 +91,11 @@ else: ...@@ -92,13 +91,11 @@ else:
return topk_ids return topk_ids
eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk,
)
if is_rocm_aiter_moe_enabled():
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk as grouped_topk_aiter,
)
else:
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
if current_platform.is_tpu(): if current_platform.is_tpu():
from .moe_pallas import fused_moe as fused_moe_pallas from .moe_pallas import fused_moe as fused_moe_pallas
else: else:
...@@ -463,7 +460,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -463,7 +460,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def __init__(self, moe: FusedMoEConfig): def __init__(self, moe: FusedMoEConfig):
super().__init__(moe) super().__init__(moe)
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
...@@ -620,13 +618,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -620,13 +618,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
# Padding the weight for better performance on ROCm # Padding the weight for better performance on ROCm
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
# Lazy import to avoid importing triton.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
shuffle_weights,
)
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
shuffled_w13, shuffled_w2 = shuffle_weights( shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data layer.w13_weight.data, layer.w2_weight.data
) )
...@@ -1002,6 +996,7 @@ def determine_expert_map( ...@@ -1002,6 +996,7 @@ def determine_expert_map(
global_num_experts: int, global_num_experts: int,
expert_placement_strategy: ExpertPlacementStrategy = "linear", expert_placement_strategy: ExpertPlacementStrategy = "linear",
num_fused_shared_experts: int = 0, num_fused_shared_experts: int = 0,
return_expert_mask: bool = False,
) -> tuple[int, torch.Tensor | None, torch.Tensor | None]: ) -> tuple[int, torch.Tensor | None, torch.Tensor | None]:
""" """
Calculates how many experts should be assigned to each rank for EP and Calculates how many experts should be assigned to each rank for EP and
...@@ -1064,7 +1059,7 @@ def determine_expert_map( ...@@ -1064,7 +1059,7 @@ def determine_expert_map(
) )
expert_mask = None expert_mask = None
if is_rocm_aiter_moe_enabled(): if return_expert_mask:
expert_mask = torch.ones( expert_mask = torch.ones(
(global_num_experts + num_fused_shared_experts + 1,), dtype=torch.int32 (global_num_experts + num_fused_shared_experts + 1,), dtype=torch.int32
) )
...@@ -1292,14 +1287,18 @@ class FusedMoE(CustomOp): ...@@ -1292,14 +1287,18 @@ class FusedMoE(CustomOp):
self.logical_replica_count: torch.Tensor | None = None self.logical_replica_count: torch.Tensor | None = None
# ROCm aiter shared experts fusion # ROCm aiter shared experts fusion
self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
self.aiter_fmoe_shared_expert_enabled = (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
)
self.num_fused_shared_experts = ( self.num_fused_shared_experts = (
n_shared_experts n_shared_experts
if n_shared_experts is not None if n_shared_experts is not None and self.aiter_fmoe_shared_expert_enabled
and is_rocm_aiter_fusion_shared_expert_enabled()
else 0 else 0
) )
if ( if (
not is_rocm_aiter_fusion_shared_expert_enabled() not self.aiter_fmoe_shared_expert_enabled
and self.num_fused_shared_experts != 0 and self.num_fused_shared_experts != 0
): ):
raise ValueError( raise ValueError(
...@@ -1346,6 +1345,7 @@ class FusedMoE(CustomOp): ...@@ -1346,6 +1345,7 @@ class FusedMoE(CustomOp):
global_num_experts=self.global_num_experts, global_num_experts=self.global_num_experts,
expert_placement_strategy=expert_placement_strategy, expert_placement_strategy=expert_placement_strategy,
num_fused_shared_experts=self.num_fused_shared_experts, num_fused_shared_experts=self.num_fused_shared_experts,
return_expert_mask=self.rocm_aiter_fmoe_enabled,
) )
self.local_num_experts = local_num_experts self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map) self.register_buffer("expert_map", expert_map)
...@@ -1570,13 +1570,16 @@ class FusedMoE(CustomOp): ...@@ -1570,13 +1570,16 @@ class FusedMoE(CustomOp):
ep_rank=self.ep_rank, ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts, global_num_experts=self.global_num_experts,
num_fused_shared_experts=self.num_fused_shared_experts, num_fused_shared_experts=self.num_fused_shared_experts,
return_expert_mask=self.rocm_aiter_fmoe_enabled,
) )
self.local_num_experts = local_num_experts self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map) self.register_buffer("expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask) self.register_buffer("expert_mask", expert_mask)
self._init_aiter_shared_experts_topK_buffer( if self.aiter_fmoe_shared_expert_enabled:
vllm_config=get_current_vllm_config(), dp_size=get_dp_group().world_size self._init_aiter_shared_experts_topK_buffer(
) vllm_config=get_current_vllm_config(),
dp_size=get_dp_group().world_size,
)
def _load_per_tensor_weight_scale( def _load_per_tensor_weight_scale(
self, self,
...@@ -1753,20 +1756,19 @@ class FusedMoE(CustomOp): ...@@ -1753,20 +1756,19 @@ class FusedMoE(CustomOp):
def _init_aiter_shared_experts_topK_buffer( def _init_aiter_shared_experts_topK_buffer(
self, vllm_config: VllmConfig, dp_size: int self, vllm_config: VllmConfig, dp_size: int
): ):
if is_rocm_aiter_fusion_shared_expert_enabled(): if self.num_fused_shared_experts > 0:
if self.num_fused_shared_experts > 0: init_aiter_topK_meta_data(
init_aiter_topK_meta_data( n_routed_experts=self.global_num_experts,
n_routed_experts=self.global_num_experts, n_shared_experts=self.num_fused_shared_experts,
n_shared_experts=self.num_fused_shared_experts, top_k=self.top_k,
top_k=self.top_k, tp_rank=self.ep_rank if self.use_ep else self.tp_rank,
tp_rank=self.ep_rank if self.use_ep else self.tp_rank, tp_size=self.ep_size if self.use_ep else self.tp_size,
tp_size=self.ep_size if self.use_ep else self.tp_size, shared_experts_score=1.0,
shared_experts_score=1.0, max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens
max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens * dp_size,
* dp_size, is_EP=self.use_ep,
is_EP=self.use_ep, )
) self.local_num_experts += self.num_fused_shared_experts
self.local_num_experts += self.num_fused_shared_experts
@overload @overload
def weight_loader( def weight_loader(
...@@ -2208,15 +2210,16 @@ class FusedMoE(CustomOp): ...@@ -2208,15 +2210,16 @@ class FusedMoE(CustomOp):
elif use_grouped_topk: elif use_grouped_topk:
assert topk_group is not None assert topk_group is not None
assert num_expert_group is not None assert num_expert_group is not None
if is_rocm_aiter_moe_enabled(): if rocm_aiter_ops.is_fused_moe_enabled():
if not is_rocm_aiter_fusion_shared_expert_enabled(): if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
assert num_fused_shared_experts == 0 assert num_fused_shared_experts == 0
grouped_topk_impl = partial( grouped_topk_impl = partial(
grouped_topk_aiter, rocm_aiter_grouped_topk,
num_fused_shared_experts=num_fused_shared_experts, num_fused_shared_experts=num_fused_shared_experts,
) )
else: else:
grouped_topk_impl = grouped_topk grouped_topk_impl = grouped_topk
topk_weights, topk_ids = grouped_topk_impl( topk_weights, topk_ids = grouped_topk_impl(
hidden_states=hidden_states, hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
...@@ -2448,7 +2451,7 @@ class FusedMoE(CustomOp): ...@@ -2448,7 +2451,7 @@ class FusedMoE(CustomOp):
use_grouped_topk=self.use_grouped_topk, use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts, global_num_experts=self.global_num_experts,
expert_map=self.expert_map expert_map=self.expert_map
if not is_rocm_aiter_moe_enabled() if not self.rocm_aiter_fmoe_enabled
else self.expert_mask, else self.expert_mask,
topk_group=self.topk_group, topk_group=self.topk_group,
num_expert_group=self.num_expert_group, num_expert_group=self.num_expert_group,
...@@ -2612,7 +2615,7 @@ class FusedMoE(CustomOp): ...@@ -2612,7 +2615,7 @@ class FusedMoE(CustomOp):
use_grouped_topk=self.use_grouped_topk, use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts, global_num_experts=self.global_num_experts,
expert_map=self.expert_map expert_map=self.expert_map
if not is_rocm_aiter_moe_enabled() if not self.rocm_aiter_fmoe_enabled
else self.expert_mask, else self.expert_mask,
topk_group=self.topk_group, topk_group=self.topk_group,
num_expert_group=self.num_expert_group, num_expert_group=self.num_expert_group,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import IntEnum from enum import IntEnum
from functools import cache, lru_cache from functools import lru_cache
import torch import torch
from vllm import envs from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
class QuantMethod(IntEnum): class QuantMethod(IntEnum):
...@@ -37,27 +35,6 @@ class ActivationMethod(IntEnum): ...@@ -37,27 +35,6 @@ class ActivationMethod(IntEnum):
GELU = 1 GELU = 1
@cache
def is_rocm_aiter_moe_enabled() -> bool:
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER_MOE
and envs.VLLM_ROCM_USE_AITER
)
@cache
def use_mxfp4_aiter_moe() -> bool:
return current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
@cache
def is_rocm_aiter_fusion_shared_expert_enabled() -> bool:
return (
envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS and is_rocm_aiter_moe_enabled()
)
aiter_topK_meta_data = None aiter_topK_meta_data = None
...@@ -114,250 +91,6 @@ def init_aiter_topK_meta_data( ...@@ -114,250 +91,6 @@ def init_aiter_topK_meta_data(
aiter_topK_meta_data = (total_topk_weights, total_topk_ids) aiter_topK_meta_data = (total_topk_weights, total_topk_ids)
def rocm_aiter_asm_moe_tkw1_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: torch.Tensor | None = None,
fc2_scale: torch.Tensor | None = None,
fc1_smooth_scale: torch.Tensor | None = None,
fc2_smooth_scale: torch.Tensor | None = None,
a16: bool = False,
per_tensor_quant_scale: torch.Tensor | None = None,
expert_mask: torch.Tensor | None = None,
activation_method: int = ActivationMethod.SILU.value,
) -> torch.Tensor:
from aiter import ActivationType
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
activation = ActivationType(activation_method)
return asm_moe_tkw1(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
fc1_scale=fc1_scale,
fc2_scale=fc2_scale,
fc1_smooth_scale=fc1_smooth_scale,
fc2_smooth_scale=fc2_smooth_scale,
a16=a16,
per_tensor_quant_scale=per_tensor_quant_scale,
expert_mask=expert_mask,
activation=activation,
)
def rocm_aiter_asm_moe_tkw1_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: torch.Tensor | None = None,
fc2_scale: torch.Tensor | None = None,
fc1_smooth_scale: torch.Tensor | None = None,
fc2_smooth_scale: torch.Tensor | None = None,
a16: bool = False,
per_tensor_quant_scale: torch.Tensor | None = None,
expert_mask: torch.Tensor | None = None,
activation_method: int = ActivationMethod.SILU.value,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
def rocm_aiter_topk_softmax_impl(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> None:
from aiter import topk_softmax
topk_softmax(
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
)
def rocm_aiter_topk_softmax_fake(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> None:
pass
def rocm_aiter_biased_grouped_topk_impl(
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
from aiter import biased_grouped_topk
biased_grouped_topk(
gating_output,
correction_bias,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
need_renorm,
routed_scaling_factor,
)
def rocm_aiter_biased_grouped_topk_fake(
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
pass
def rocm_aiter_grouped_topk_impl(
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
from aiter import grouped_topk
grouped_topk(
gating_output,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
need_renorm,
scoring_func,
routed_scaling_factor,
)
def rocm_aiter_grouped_topk_fake(
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
pass
def rocm_aiter_fused_moe_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
expert_mask: torch.Tensor | None = None,
activation_method: int = ActivationMethod.SILU.value,
quant_method: int = QuantMethod.NO.value,
doweight_stage1: bool = False,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
) -> torch.Tensor:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
activation = ActivationType(activation_method)
quant_type = QuantType(quant_method)
return fused_moe(
hidden_states,
w1,
w2,
topk_weight,
topk_ids,
expert_mask,
activation,
quant_type,
doweight_stage1,
w1_scale,
w2_scale,
a1_scale,
a2_scale,
)
def rocm_aiter_fused_moe_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
expert_mask: torch.Tensor | None = None,
activation_method: int = ActivationMethod.SILU.value,
quant_method: int = QuantMethod.NO.value,
doweight_stage1: bool = False,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_asm_moe_tkw1",
op_func=rocm_aiter_asm_moe_tkw1_impl,
fake_impl=rocm_aiter_asm_moe_tkw1_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_fused_moe",
op_func=rocm_aiter_fused_moe_impl,
fake_impl=rocm_aiter_fused_moe_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_topk_softmax",
op_func=rocm_aiter_topk_softmax_impl,
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
fake_impl=rocm_aiter_topk_softmax_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_biased_grouped_topk",
op_func=rocm_aiter_biased_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"],
fake_impl=rocm_aiter_biased_grouped_topk_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_grouped_topk",
op_func=rocm_aiter_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"],
fake_impl=rocm_aiter_grouped_topk_fake,
)
def rocm_aiter_grouped_topk( def rocm_aiter_grouped_topk(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
...@@ -372,7 +105,10 @@ def rocm_aiter_grouped_topk( ...@@ -372,7 +105,10 @@ def rocm_aiter_grouped_topk(
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
token = hidden_states.shape[0] token = hidden_states.shape[0]
device = hidden_states.device device = hidden_states.device
if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0: if (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
and num_fused_shared_experts > 0
):
assert aiter_topK_meta_data is not None, ( assert aiter_topK_meta_data is not None, (
"AITER topK meta data is not initialized. " "AITER topK meta data is not initialized. "
"Please ensure that init_aiter_topK_meta_data " "Please ensure that init_aiter_topK_meta_data "
...@@ -397,7 +133,7 @@ def rocm_aiter_grouped_topk( ...@@ -397,7 +133,7 @@ def rocm_aiter_grouped_topk(
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
if e_score_correction_bias is not None: if e_score_correction_bias is not None:
torch.ops.vllm.rocm_aiter_biased_grouped_topk( rocm_aiter_ops.biased_grouped_topk(
gating_output, gating_output,
e_score_correction_bias.to(gating_output.dtype), e_score_correction_bias.to(gating_output.dtype),
topk_weights, topk_weights,
...@@ -409,7 +145,7 @@ def rocm_aiter_grouped_topk( ...@@ -409,7 +145,7 @@ def rocm_aiter_grouped_topk(
) )
else: else:
assert scoring_func == "softmax" or scoring_func == "sigmoid" assert scoring_func == "softmax" or scoring_func == "sigmoid"
torch.ops.vllm.rocm_aiter_grouped_topk( rocm_aiter_ops.grouped_topk(
gating_output, gating_output,
topk_weights, topk_weights,
topk_ids, topk_ids,
...@@ -420,7 +156,10 @@ def rocm_aiter_grouped_topk( ...@@ -420,7 +156,10 @@ def rocm_aiter_grouped_topk(
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
) )
if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0: if (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
and num_fused_shared_experts > 0
):
return total_topk_weights, total_topk_ids return total_topk_weights, total_topk_ids
return topk_weights, topk_ids return topk_weights, topk_ids
...@@ -464,7 +203,7 @@ def rocm_aiter_fused_experts( ...@@ -464,7 +203,7 @@ def rocm_aiter_fused_experts(
"Only support topk=1 when `apply_router_weight_on_input` is True" "Only support topk=1 when `apply_router_weight_on_input` is True"
) )
return torch.ops.vllm.rocm_aiter_asm_moe_tkw1( return rocm_aiter_ops.asm_moe_tkw1(
hidden_states, hidden_states,
w1, w1,
w2, w2,
...@@ -482,7 +221,9 @@ def rocm_aiter_fused_experts( ...@@ -482,7 +221,9 @@ def rocm_aiter_fused_experts(
else: else:
quant_method = QuantMethod.NO.value quant_method = QuantMethod.NO.value
# quark moe for mxfp4 w_dtype
if quant_config.use_mxfp4_w4a16:
quant_method = QuantMethod.BLOCK_1X32.value
# w8a8 block-scaled # w8a8 block-scaled
if quant_config.block_shape is not None and quant_config.use_fp8_w8a8: if quant_config.block_shape is not None and quant_config.use_fp8_w8a8:
assert not apply_router_weight_on_input, ( assert not apply_router_weight_on_input, (
...@@ -507,7 +248,7 @@ def rocm_aiter_fused_experts( ...@@ -507,7 +248,7 @@ def rocm_aiter_fused_experts(
"Only support topk=1 when `apply_router_weight_on_input` is True" "Only support topk=1 when `apply_router_weight_on_input` is True"
) )
return torch.ops.vllm.rocm_aiter_fused_moe( return rocm_aiter_ops.fused_moe(
hidden_states, hidden_states,
w1, w1,
w2, w2,
...@@ -522,39 +263,3 @@ def rocm_aiter_fused_experts( ...@@ -522,39 +263,3 @@ def rocm_aiter_fused_experts(
a2_scale=quant_config.a2_scale, a2_scale=quant_config.a2_scale,
doweight_stage1=apply_router_weight_on_input, doweight_stage1=apply_router_weight_on_input,
) )
def rocm_aiter_topk_softmax(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> tuple[torch.Tensor, ...]:
torch.ops.vllm.rocm_aiter_topk_softmax(
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
)
return topk_weights, topk_indices
def shuffle_weights(
*tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
) -> tuple[torch.Tensor, ...]:
"""
Applies shuffle_weight function from AITER to each
input tensor and returns them.
Rearranges (shuffles) the input tensor/s
into a specified block layout for optimized computation.
Args:
*tensors: Variable number of torch.Tensor objects.
layout: A pair of integers specifying the block sizes used to divide
the tensors during shuffling. Default is (16, 16).
Returns:
A Tuple of shuffled tensors.
"""
from aiter.ops.shuffle import shuffle_weight
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
...@@ -6,18 +6,13 @@ import torch ...@@ -6,18 +6,13 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import vllm.envs as envs from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
rms_norm_batch_invariant, rms_norm_batch_invariant,
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
def is_rocm_aiter_rmsnorm_enabled() -> bool:
return envs.VLLM_ROCM_USE_AITER_RMSNORM and envs.VLLM_ROCM_USE_AITER
def rms_norm( def rms_norm(
...@@ -58,80 +53,34 @@ def fused_add_rms_norm( ...@@ -58,80 +53,34 @@ def fused_add_rms_norm(
return x, residual return x, residual
def rocm_aiter_rms_norm_impl( def poly_norm(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
) -> torch.Tensor: ) -> torch.Tensor:
import aiter as rocm_aiter from vllm import _custom_ops as ops
if x.dim() > 2:
x_original_shape = x.shape
x = x.reshape(-1, x_original_shape[-1])
x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
return x.reshape(x_original_shape)
return rocm_aiter.rms_norm(x, weight, variance_epsilon)
def rocm_aiter_rmsnorm2d_fwd_with_add_impl( out = torch.empty_like(x)
x: torch.Tensor, ops.poly_norm(
residual: torch.Tensor, out,
weight: torch.Tensor, x,
variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
import aiter as rocm_aiter
residual_out = torch.empty_like(residual)
output = torch.empty_like(x)
rocm_aiter.rmsnorm2d_fwd_with_add(
output, # output
x, # input
residual, # residual input
residual_out, # residual output
weight, weight,
bias,
variance_epsilon, variance_epsilon,
) )
return output, residual_out return out
def rocm_aiter_rms_norm_fake(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
return torch.empty_like(x)
def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(x), torch.empty_like(residual)
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_rms_norm",
op_func=rocm_aiter_rms_norm_impl,
fake_impl=rocm_aiter_rms_norm_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
)
def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype): def dispatch_rocm_rmsnorm_func(
use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [ with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
):
use_aiter = use_aiter and dtype in [
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
] ]
if use_aiter and with_fused_add: if use_aiter and with_fused_add:
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add return rocm_aiter_ops.rms_norm2d_with_add
if use_aiter: if use_aiter:
return torch.ops.vllm.rocm_aiter_rms_norm return rocm_aiter_ops.rms_norm
# fall back to CUDA implementation # fall back to CUDA implementation
if with_fused_add: if with_fused_add:
...@@ -169,11 +118,14 @@ class RMSNorm(CustomOp): ...@@ -169,11 +118,14 @@ class RMSNorm(CustomOp):
self.weight = nn.Parameter(self.weight) self.weight = nn.Parameter(self.weight)
if current_platform.is_rocm(): if current_platform.is_rocm():
aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
self.rocm_norm_func = dispatch_rocm_rmsnorm_func( self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
with_fused_add=False, dtype=weight_dtype with_fused_add=False,
dtype=weight_dtype,
use_aiter=aiter_rmsnorm_enabled,
) )
self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func( self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
with_fused_add=True, dtype=weight_dtype with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
) )
@staticmethod @staticmethod
......
...@@ -12,6 +12,7 @@ from compressed_tensors.quantization import ActivationOrdering, QuantizationStra ...@@ -12,6 +12,7 @@ from compressed_tensors.quantization import ActivationOrdering, QuantizationStra
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
...@@ -582,11 +583,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -582,11 +583,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# Disable marlin for rocm # Disable marlin for rocm
if current_platform.is_rocm(): if current_platform.is_rocm():
self.use_marlin = False self.use_marlin = False
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
)
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# cutlass path # cutlass path
self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100( self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100(
...@@ -829,12 +827,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -829,12 +827,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# Property to determine if AITER is used # Property to determine if AITER is used
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
shuffle_weights,
)
# reshaping weights is required for aiter moe kernel. # reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights( shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data layer.w13_weight.data, layer.w2_weight.data
) )
......
...@@ -7,12 +7,12 @@ import torch ...@@ -7,12 +7,12 @@ import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
from torch.nn import Parameter from torch.nn import Parameter
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsScheme,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, W8A8BlockFp8LinearOp,
check_aiter_fp8_linear_support,
create_fp8_input_scale, create_fp8_input_scale,
create_fp8_scale_parameter, create_fp8_scale_parameter,
create_fp8_weight_parameter, create_fp8_weight_parameter,
...@@ -61,7 +61,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -61,7 +61,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
) )
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
if self.weight_block_size is not None: if self.weight_block_size is not None:
assert not self.is_static_input_scheme assert not self.is_static_input_scheme
......
...@@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter ...@@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
...@@ -56,7 +57,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( ...@@ -56,7 +57,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, W8A8BlockFp8LinearOp,
check_aiter_fp8_linear_support,
create_fp8_input_scale, create_fp8_input_scale,
create_fp8_scale_parameter, create_fp8_scale_parameter,
create_fp8_weight_parameter, create_fp8_weight_parameter,
...@@ -369,7 +369,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -369,7 +369,7 @@ class Fp8LinearMethod(LinearMethodBase):
if vllm_is_batch_invariant(): if vllm_is_batch_invariant():
self.use_marlin = False self.use_marlin = False
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
self.use_deep_gemm = is_deep_gemm_supported() self.use_deep_gemm = is_deep_gemm_supported()
self.weight_block_size = self.quant_config.weight_block_size self.weight_block_size = self.quant_config.weight_block_size
...@@ -869,12 +869,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -869,12 +869,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
# Lazy import to avoid importing triton too early. # Lazy import to avoid importing triton too early.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
shuffle_weights,
)
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# TODO (rob): refactor block quant into separate class. # TODO (rob): refactor block quant into separate class.
if self.block_quant: if self.block_quant:
...@@ -916,7 +912,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -916,7 +912,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
) )
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel. # reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights( shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data layer.w13_weight.data, layer.w2_weight.data
) )
...@@ -962,7 +958,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -962,7 +958,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel. # reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights( shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight layer.w13_weight, layer.w2_weight
) )
...@@ -1042,7 +1038,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1042,7 +1038,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
start += shard_size start += shard_size
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
shuffled_w13, shuffled_w2 = shuffle_weights( shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight layer.w13_weight, layer.w2_weight
) )
......
...@@ -4,54 +4,14 @@ ...@@ -4,54 +4,14 @@
import torch import torch
import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from .cutlass import CutlassScaledMMLinearKernel from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
def rocm_aiter_gemm_w8a8_impl(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
from aiter import gemm_a8w8_CK
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
def rocm_aiter_gemm_w8a8_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m = A.shape[0]
n = B.shape[0]
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
return Y
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8",
op_func=rocm_aiter_gemm_w8a8_impl,
fake_impl=rocm_aiter_gemm_w8a8_fake,
)
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
...@@ -75,7 +35,7 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): ...@@ -75,7 +35,7 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
+ "installed on ROCm.", + "installed on ROCm.",
) )
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled # Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
if not (envs.VLLM_ROCM_USE_AITER_LINEAR and envs.VLLM_ROCM_USE_AITER): if not (rocm_aiter_ops.is_linear_enabled()):
return ( return (
False, False,
"AiterScaledMMLinearKernel is disabled. " "AiterScaledMMLinearKernel is disabled. "
...@@ -157,6 +117,4 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): ...@@ -157,6 +117,4 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
# a to be [M, K] # a to be [M, K]
# b to be [N, K] # b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return torch.ops.vllm.rocm_aiter_gemm_w8a8( return rocm_aiter_ops.gemm_w8a8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)
x_q, w_q.t(), x_s, w_s, bias, out_dtype
)
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoE,
...@@ -21,10 +22,6 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -21,10 +22,6 @@ from vllm.model_executor.layers.fused_moe.config import (
ocp_mx_moe_quant_config, ocp_mx_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
use_mxfp4_aiter_moe,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_moe_fp8_layer_for_marlin, prepare_moe_fp8_layer_for_marlin,
) )
...@@ -122,7 +119,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -122,7 +119,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
if current_platform.is_rocm(): if current_platform.is_rocm():
self.use_marlin = False self.use_marlin = False
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
def create_weights( def create_weights(
self, self,
...@@ -309,12 +306,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -309,12 +306,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
) )
# Property to determine if AITER is used # Property to determine if AITER is used
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
shuffle_weights,
)
# reshaping weights is required for aiter moe kernel. # reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights( shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data layer.w13_weight.data, layer.w2_weight.data
) )
...@@ -470,13 +463,15 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -470,13 +463,15 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
"not implemented. Please open an issue." "not implemented. Please open an issue."
) )
self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled()
self.emulate = not current_platform.supports_mx() or not ( self.emulate = not current_platform.supports_mx() or not (
use_mxfp4_aiter_moe() and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4" self.use_rocm_aiter_moe and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4"
) )
if self.emulate: if self.emulate:
logger.warning_once( logger.warning_once(
f"The current mode (supports_mx={current_platform.supports_mx()}, " f"The current mode (supports_mx={current_platform.supports_mx()}, "
f"use_mxfp4_aiter_moe={use_mxfp4_aiter_moe()}, " f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}, "
f"ocp_mx_scheme={self.ocp_mx_scheme}) " f"ocp_mx_scheme={self.ocp_mx_scheme}) "
"does not support native MXFP4/MXFP6 " "does not support native MXFP4/MXFP6 "
"computation. Simulated weight dequantization and activation " "computation. Simulated weight dequantization and activation "
...@@ -656,28 +651,18 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -656,28 +651,18 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
) )
if not self.emulate: if not self.emulate:
from aiter import ActivationType, QuantType from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
from aiter.fused_moe import fused_moe rocm_aiter_fused_experts,
)
aiter_acts = {
ActivationType.No.name.lower(): ActivationType.No, out = rocm_aiter_fused_experts(
ActivationType.Silu.name.lower(): ActivationType.Silu,
ActivationType.Gelu.name.lower(): ActivationType.Gelu,
}
assert activation in aiter_acts, (
f"Aiter CK fp4 MoE doesn't support activation {activation}"
)
out = fused_moe(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_weights, topk_weights=topk_weights,
topk_ids, topk_ids=topk_ids,
quant_type=QuantType.per_1x32, activation=activation,
w1_scale=layer.w13_weight_scale, quant_config=self.moe_quant_config,
w2_scale=layer.w2_weight_scale,
activation=aiter_acts[activation],
doweight_stage1=False,
) )
else: else:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
......
...@@ -31,6 +31,13 @@ from .quark_scheme import QuarkScheme ...@@ -31,6 +31,13 @@ from .quark_scheme import QuarkScheme
logger = init_logger(__name__) logger = init_logger(__name__)
# TODO: move registration of custom op to aiter_ops.py
# `from vllm._aiter_ops import rocm_aiter_ops`
# use `rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()`
# for envs checks which does not require @cache anymore.
# triton kernel is torch compile compatible.
# does not require direct registeration.
# use `rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt`.
@cache @cache
def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool: def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
return ( return (
......
...@@ -12,6 +12,7 @@ import torch ...@@ -12,6 +12,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
...@@ -68,78 +69,6 @@ def cutlass_scaled_mm( ...@@ -68,78 +69,6 @@ def cutlass_scaled_mm(
) )
def rocm_aiter_gemm_w8a8_blockscale_impl(
input_2d: torch.Tensor,
weight: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
group_size: int,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
def is_aiter_triton_kernel_tuned(n, k):
return (n, k) in [
(1024, 8192),
(2112, 7168),
(3072, 1536),
(32768, 8192),
(4096, 7168),
(4608, 7168),
(512, 7168),
(7168, 2048),
(7168, 256),
(8192, 1024),
(8192, 32768),
]
n, k = weight.shape
if input_scale is not None:
q_input = input_2d
elif not current_platform.is_fp8_fnuz() and is_aiter_triton_kernel_tuned(n, k):
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
# MI350 case uses triton kernel
q_input, input_scale = per_token_group_quant_fp8(
input_2d,
group_size,
column_major_scales=False,
use_ue8m0=False,
)
else:
# MI300 uses tuned AITER ASM/C++ kernel
import aiter as rocm_aiter
from aiter import gemm_a8w8_blockscale, get_hip_quant
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
q_input, input_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8
)
return gemm_a8w8_blockscale(
q_input, weight, input_scale, weight_scale, dtype=output_dtype
)
def rocm_aiter_gemm_w8a8_blockscale_fake(
input_2d: torch.Tensor,
weight: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
group_size: int,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m = input_2d.shape[0]
n = weight.shape[0]
return torch.empty(m, n, dtype=output_dtype, device=input_2d.device)
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8_blockscale",
op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
)
# TODO we should be able to change the type of block_size to GroupShape # TODO we should be able to change the type of block_size to GroupShape
# after we resolve GroupShape compilation issue # after we resolve GroupShape compilation issue
# https://github.com/vllm-project/vllm/issues/25270 # https://github.com/vllm-project/vllm/issues/25270
...@@ -385,14 +314,40 @@ class W8A8BlockFp8LinearOp: ...@@ -385,14 +314,40 @@ class W8A8BlockFp8LinearOp:
input_scale: torch.Tensor | None = None, input_scale: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.act_quant_group_shape == GroupShape(1, 128) assert self.act_quant_group_shape == GroupShape(1, 128)
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
input_2d, n, k = weight.shape
weight, if input_scale is not None:
input_scale, q_input = input_2d
weight_scale,
self.act_quant_group_shape.col, # MI350 case uses triton kernel
input_2d.dtype, if (
) not current_platform.is_fp8_fnuz()
and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k)
):
q_input, input_scale = per_token_group_quant_fp8(
input_2d,
self.act_quant_group_shape.col,
column_major_scales=False,
use_ue8m0=False,
)
return rocm_aiter_ops.triton_gemm_a8w8_blockscale(
q_input,
weight,
input_scale,
weight_scale,
input_2d.dtype,
)
# MI300 uses tuned AITER ASM/C++ kernel
else:
q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d)
return rocm_aiter_ops.gemm_w8a8_blockscale(
q_input,
weight,
input_scale,
weight_scale,
input_2d.dtype,
)
def _run_triton( def _run_triton(
self, self,
...@@ -971,15 +926,6 @@ def requant_weight_ue8m0_inplace( ...@@ -971,15 +926,6 @@ def requant_weight_ue8m0_inplace(
s_old.copy_(s_requant) s_old.copy_(s_requant)
def check_aiter_fp8_linear_support() -> bool:
"""AITER is only supported on ROCm for MI3XX"""
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_LINEAR
)
def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor: def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
"""Pad the weight tensor. This is an optimization on ROCm platform, which """Pad the weight tensor. This is an optimization on ROCm platform, which
can benefit from tensors located far enough from one another in memory""" can benefit from tensors located far enough from one another in memory"""
......
...@@ -472,7 +472,7 @@ class Fp8LinearOp: ...@@ -472,7 +472,7 @@ class Fp8LinearOp:
# Example: # Example:
# When the number of token is 1, per-token scale is [[1]] # When the number of token is 1, per-token scale is [[1]]
# When per-tensor scale is [1] or (). # When per-tensor scale is [1] or ().
per_tensor_weights = (weight_scale.numel() == 1) and weight_scale.dim() < 2 per_tensor_weights = weight_scale.numel() == 1
per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2 per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
# TODO(luka) do this dispatch during init (after ScaledMM refactor) # TODO(luka) do this dispatch during init (after ScaledMM refactor)
......
...@@ -4,13 +4,10 @@ ...@@ -4,13 +4,10 @@
import torch import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from .common import apply_rotary_emb_torch from .common import apply_rotary_emb_torch
from .rocm_aiter_rope_ops import (
is_rocm_triton_rotary_embedding_enabled,
rocm_aiter_rotary_emb,
)
@CustomOp.register("rotary_embedding") @CustomOp.register("rotary_embedding")
...@@ -48,8 +45,8 @@ class RotaryEmbeddingBase(CustomOp): ...@@ -48,8 +45,8 @@ class RotaryEmbeddingBase(CustomOp):
cache = cache.to(dtype) cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False) self.register_buffer("cos_sin_cache", cache, persistent=False)
self.is_rocm_triton_rotary_embedding_enabled = ( self.is_rocm_triton_rotary_embed_enabled = (
is_rocm_triton_rotary_embedding_enabled() rocm_aiter_ops.is_triton_rotary_embed_enabled()
) )
def _compute_inv_freq(self, base: float) -> torch.Tensor: def _compute_inv_freq(self, base: float) -> torch.Tensor:
...@@ -169,9 +166,9 @@ class RotaryEmbedding(RotaryEmbeddingBase): ...@@ -169,9 +166,9 @@ class RotaryEmbedding(RotaryEmbeddingBase):
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor | None = None, key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]: ) -> tuple[torch.Tensor, torch.Tensor | None]:
if self.is_rocm_triton_rotary_embedding_enabled: if self.is_rocm_triton_rotary_embed_enabled:
self._match_cos_sin_cache_dtype(query) self._match_cos_sin_cache_dtype(query)
rocm_aiter_rotary_emb( rocm_aiter_ops.triton_rotary_embed(
positions, positions,
query, query,
key, key,
......
...@@ -146,6 +146,15 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase): ...@@ -146,6 +146,15 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
key = key_rot key = key_rot
return query, key return query, key
def forward_hip(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(positions, query, key, offsets)
def forward_cuda( def forward_cuda(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment