Unverified Commit f080a835 authored by vllmellm's avatar vllmellm Committed by GitHub
Browse files

[RFC][ROCm][AITER] Keep all AITER kernels in `_aiter_ops` class like...


[RFC][ROCm][AITER] Keep all AITER kernels in `_aiter_ops` class like `_custom_ops` and `_ipex_ops` (#24490)
Signed-off-by: default avatarvllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent 40e2eeeb
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
def is_rocm_triton_rotary_embedding_enabled() -> bool:
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_TRITON_ROPE
)
def rocm_aiter_rotary_emb_with_key_forward_triton_impl(
positions: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
rotate_style: int = 0,
is_nope_first: bool = False,
) -> None:
import aiter.ops.triton.rope as ops
ops.rope_cached_thd_positions_2c_fwd_inplace(
query,
key,
cos,
sin,
positions,
rotate_style,
reuse_freqs_front_part=True,
nope_first=is_nope_first,
)
def rocm_aiter_rotary_emb_with_key_forward_triton_fake(
positions: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
rotate_style: int = 0,
is_nope_first: bool = False,
) -> None:
pass
if is_rocm_triton_rotary_embedding_enabled():
direct_register_custom_op(
op_name="rocm_aiter_rotary_emb_with_key_forward_triton",
op_func=rocm_aiter_rotary_emb_with_key_forward_triton_impl,
mutates_args=["key", "query"],
fake_impl=rocm_aiter_rotary_emb_with_key_forward_triton_fake,
dispatch_key=current_platform.dispatch_key,
)
def rocm_aiter_rotary_emb(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
cos_sin_cache: torch.Tensor,
head_size: int,
rotary_dim: int,
is_neox_style: bool,
):
num_tokens = positions.numel()
cos, sin = cos_sin_cache.chunk(2, dim=-1)
query_shape = query.shape
key_shape = key.shape
rotate_style = 0 if is_neox_style else 1
query = query.view(num_tokens, -1, head_size)
key = key.view(num_tokens, -1, head_size)
query_ = query[..., :rotary_dim]
key_ = key[..., :rotary_dim]
positions = positions.view(*query.shape[:1])
torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_triton(
positions,
sin,
cos,
query_,
key_,
rotate_style,
False,
)
query = query.view(query_shape)
key = key.view(key_shape)
......@@ -33,6 +33,7 @@ import torch
from torch import nn
from transformers import DeepseekV2Config, DeepseekV3Config
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention import Attention
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
......@@ -50,10 +51,6 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_fusion_shared_expert_enabled,
is_rocm_aiter_moe_enabled,
)
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
......@@ -294,10 +291,8 @@ class DeepseekV2MoE(nn.Module):
self.physical_expert_start + self.n_local_physical_experts
)
if (
config.n_shared_experts is None
or is_rocm_aiter_fusion_shared_expert_enabled()
):
self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
if config.n_shared_experts is None or self.is_rocm_aiter_moe_enabled:
self.shared_experts = None
else:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
......@@ -330,14 +325,14 @@ class DeepseekV2MoE(nn.Module):
# we do scaling outside, set factor to 1.0 to avoid double mul
# aiter applies routed_scaling_factor internally
routed_scaling_factor=1.0
if not is_rocm_aiter_moe_enabled()
if not self.is_rocm_aiter_moe_enabled
else self.routed_scaling_factor,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
n_shared_experts=config.n_shared_experts
if is_rocm_aiter_fusion_shared_expert_enabled()
if rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
else None,
)
......@@ -371,7 +366,7 @@ class DeepseekV2MoE(nn.Module):
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16:
if not is_rocm_aiter_moe_enabled():
if not self.is_rocm_aiter_moe_enabled:
final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None:
assert shared_output is not None
......@@ -1428,6 +1423,9 @@ class DeepseekV2ForCausalLM(
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
rocm_aiter_moe_shared_expert_enabled = (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
)
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
......@@ -1456,7 +1454,7 @@ class DeepseekV2ForCausalLM(
num_experts=self.config.n_routed_experts
+ (
self.config.n_shared_experts
if is_rocm_aiter_fusion_shared_expert_enabled()
if rocm_aiter_moe_shared_expert_enabled
else 0
),
num_redundant_experts=self.num_redundant_experts,
......@@ -1472,9 +1470,8 @@ class DeepseekV2ForCausalLM(
if spec_layer is not None:
continue # skip spec decode layers for main model
is_fuse_shared_experts_layer = (
is_rocm_aiter_fusion_shared_expert_enabled()
and ("mlp.shared_experts" in name)
is_fuse_shared_experts_layer = rocm_aiter_moe_shared_expert_enabled and (
"mlp.shared_experts" in name
)
for param_name, weight_name, shard_id in stacked_params_mapping:
......
......@@ -142,6 +142,8 @@ def use_rocm_custom_paged_attention(
alibi_slopes: torch.Tensor | None = None,
sinks: torch.Tensor | None = None,
) -> bool:
from vllm._aiter_ops import rocm_aiter_ops
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
......@@ -157,7 +159,7 @@ def use_rocm_custom_paged_attention(
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN and envs.VLLM_ROCM_USE_AITER)
and not (rocm_aiter_ops.is_pa_attn_enabled())
and sinks is None
)
......@@ -202,12 +204,15 @@ class RocmPlatform(Platform):
]
@classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
from importlib.util import find_spec
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import _Backend
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
if rocm_aiter_ops.is_mha_enabled():
# Note: AITER FA is only supported for Qwen-VL models.
# TODO: Add support for other VL models in their model class.
return _Backend.ROCM_AITER_FA
if on_gfx9() and find_spec("flash_attn") is not None:
......@@ -228,19 +233,23 @@ class RocmPlatform(Platform):
has_sink,
use_sparse,
) -> str:
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import _Backend
if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on ROCm.")
if use_mla:
from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
is_aiter_mla_enabled,
if not use_v1:
raise RuntimeError(
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
"to select a supported backend."
)
if use_mla:
if selected_backend is None:
selected_backend = (
_Backend.ROCM_AITER_MLA
if is_aiter_mla_enabled() or block_size == 1
if rocm_aiter_ops.is_mla_enabled() or block_size == 1
else _Backend.TRITON_MLA
)
......@@ -265,12 +274,12 @@ class RocmPlatform(Platform):
logger.info("Using FlexAttention backend.")
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
if (
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9()
rocm_aiter_ops.is_mha_enabled()
) or selected_backend == _Backend.ROCM_AITER_FA:
logger.info("Using Aiter Flash Attention backend.")
return "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
if (
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
rocm_aiter_ops.is_triton_unified_attn_enabled()
) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
logger.info("Using Aiter Unified Attention backend.")
return (
......
......@@ -198,6 +198,7 @@ from tqdm import tqdm
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionLayer,
......@@ -270,22 +271,9 @@ except ImportError:
flashinfer_available = False
def is_rocm_aiter_fp8bmm_enabled() -> bool:
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER_FP8BMM
and envs.VLLM_ROCM_USE_AITER
)
if is_rocm_aiter_fp8bmm_enabled():
from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm, # noqa: E501
)
def dynamic_per_batched_tensor_quant(
def dynamic_per_batched_tensor_quant(
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
):
):
DTYPE_MAX = torch.finfo(dtype).max
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
......@@ -1109,6 +1097,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
self.kv_b_proj = kv_b_proj
self.indexer = indexer
self.q_pad_num_heads = q_pad_num_heads
self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
......@@ -1158,7 +1147,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
)
if is_rocm_aiter_fp8bmm_enabled():
if self.is_aiter_triton_fp8_bmm_enabled:
W_K = W_UK.transpose(0, 1) # 16 512 128
W_V = W_UV.permute(1, 2, 0) # 16 128 512
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
......@@ -1187,7 +1176,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
dtype=torch.bfloat16,
device=self.W_K.device,
)
aiter_triton_fp8_bmm(
rocm_aiter_ops.triton_fp8_bmm(
x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
)
......@@ -1196,7 +1185,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
dtype=torch.bfloat16,
device=self.W_V.device,
)
aiter_triton_fp8_bmm(
rocm_aiter_ops.triton_fp8_bmm(
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
)
else:
......@@ -1208,10 +1197,9 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
if is_rocm_aiter_fp8bmm_enabled():
if self.is_aiter_triton_fp8_bmm_enabled:
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
x = aiter_triton_fp8_bmm(
x = rocm_aiter_ops.triton_fp8_bmm(
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
)
# Convert from (B, N, V) to (B, N * V)
......@@ -1571,7 +1559,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
)
if is_rocm_aiter_fp8bmm_enabled():
if self.is_aiter_triton_fp8_bmm_enabled:
W_K = W_UK.transpose(0, 1) # 16 512 128
W_V = W_UV.permute(1, 2, 0) # 16 128 512
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
......@@ -1600,7 +1588,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
dtype=torch.bfloat16,
device=self.W_K.device,
)
aiter_triton_fp8_bmm(
rocm_aiter_ops.triton_fp8_bmm(
x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
)
......@@ -1609,7 +1597,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
dtype=torch.bfloat16,
device=self.W_V.device,
)
aiter_triton_fp8_bmm(
rocm_aiter_ops.triton_fp8_bmm(
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
)
else:
......@@ -1958,7 +1946,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)
# Pads the head_dim if necessary (for the underlying kernel)
if self.q_pad_num_heads is not None:
B, N, L = decode_q_pe.shape
decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L))
......@@ -1966,9 +1953,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
decode_pe_padded.copy_(decode_q_pe)
decode_q_pe = decode_pe_padded
if is_rocm_aiter_fp8bmm_enabled():
if self.is_aiter_triton_fp8_bmm_enabled:
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
decode_ql_nope = aiter_triton_fp8_bmm(
decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm(
decode_q_nope,
self.W_K,
self.W_K_scale,
......
......@@ -6,9 +6,8 @@ from typing import ClassVar
import torch
import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.abstract import AttentionLayer
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mla.common import (
......@@ -22,10 +21,6 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec
def is_aiter_mla_enabled() -> bool:
return envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MLA
class AiterMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
......@@ -284,7 +279,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
# max_seqlen_qo must be 1 except for MTP
# TODO: Find the best value for MTP
max_seqlen_qo = 1
aiter_mla_decode_fwd(
rocm_aiter_ops.mla_decode_fwd(
q,
kv_buffer,
o,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment