"vllm/vscode:/vscode.git/clone" did not exist on "fbd8595c5c6f969dfa6cf33e5a371d93d55025fb"
Unverified Commit d5b6f3ba authored by Douglas Lehr's avatar Douglas Lehr Committed by GitHub
Browse files

[ROCm][Quantization] Add Composable Kernel (CK) backend support for M… (#34301)


Signed-off-by: default avatarDoug Lehr <douglehr@amd.com>
Signed-off-by: default avatarDouglas Lehr <91553416+dllehr-amd@users.noreply.github.com>
Signed-off-by: default avatarDouglas Lehr <Doug.Lehr@amd.com>
Co-authored-by: default avatarDoug Lehr <douglehr@amd.com>
Co-authored-by: default avatarCursor <cursoragent@cursor.com>
Co-authored-by: default avatarRohan Potdar <66227218+Rohan138@users.noreply.github.com>
parent 1a014a0a
...@@ -87,6 +87,10 @@ def _rocm_aiter_fused_moe_impl( ...@@ -87,6 +87,10 @@ def _rocm_aiter_fused_moe_impl(
a2_scale: torch.Tensor | None = None, a2_scale: torch.Tensor | None = None,
num_local_tokens: torch.Tensor | None = None, num_local_tokens: torch.Tensor | None = None,
output_dtype: torch.dtype | None = None, output_dtype: torch.dtype | None = None,
hidden_pad: int = 0,
intermediate_pad: int = 0,
bias1: torch.Tensor | None = None,
bias2: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
from aiter import ActivationType, QuantType from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe from aiter.fused_moe import fused_moe
...@@ -110,6 +114,10 @@ def _rocm_aiter_fused_moe_impl( ...@@ -110,6 +114,10 @@ def _rocm_aiter_fused_moe_impl(
a2_scale, a2_scale,
num_local_tokens=num_local_tokens, num_local_tokens=num_local_tokens,
dtype=output_dtype, dtype=output_dtype,
hidden_pad=hidden_pad,
intermediate_pad=intermediate_pad,
bias1=bias1,
bias2=bias2,
) )
...@@ -307,6 +315,28 @@ def _rocm_aiter_grouped_topk_fake( ...@@ -307,6 +315,28 @@ def _rocm_aiter_grouped_topk_fake(
pass pass
def _rocm_aiter_fused_topk_impl(
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
gate_up: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.fused_moe import fused_topk
# fused_topk returns (topk_weights, topk_indices)
return fused_topk(x, router_logits, top_k, gate_up)
def _rocm_aiter_fused_topk_fake(
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
gate_up: bool,
) -> None:
# tuple[torch.Tensor, torch.Tensor]:
pass
# Cache whether aiter supports FP8 MLA parameters # Cache whether aiter supports FP8 MLA parameters
_AITER_MLA_SUPPORTS_FP8: bool | None = None _AITER_MLA_SUPPORTS_FP8: bool | None = None
...@@ -994,6 +1024,70 @@ class rocm_aiter_ops: ...@@ -994,6 +1024,70 @@ class rocm_aiter_ops:
cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
cls._TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM cls._TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
@staticmethod
def get_aiter_activation_type(activation_str: str):
"""
Given an activation type as a string, returns the corresponding aiter ActivationType enum.
Supported activation types: "no", "none", "silu", "gelu", "swiglu".
Returns None if the mapping fails.
Args:
activation_str (str): Activation type as string.
Returns:
Aiter ActivationType enum value, or None if not found.
"""
# Import only locally, since aiter may not always be available.
try:
from aiter import ActivationType
except ImportError:
return None
if not isinstance(activation_str, str):
return None
name = activation_str.strip().lower()
mapping = {
"none": ActivationType.No,
"no": ActivationType.No,
"silu": ActivationType.Silu,
"gelu": ActivationType.Gelu,
"swiglu": ActivationType.Swiglu,
}
return mapping.get(name)
@staticmethod
def get_aiter_quant_type(quant_type_str: str):
"""
Given a quantization type as a string, returns the corresponding aiter QuantType enum.
Supported quantization types: "no", "per_tensor", "per_token", "per_1x32", "per_1x128", "per_128x128".
Returns None if the mapping fails.
Args:
quant_type_str (str): Quantization type as string.
Returns:
Aiter QuantType enum value, or None if not found.
"""
try:
from aiter import QuantType
except ImportError:
return None
if not isinstance(quant_type_str, str):
return None
name = quant_type_str.strip().lower()
mapping = {
"no": QuantType.No,
"per_tensor": QuantType.per_Tensor,
"per_token": QuantType.per_Token,
"per_1x32": QuantType.per_1x32,
"per_1x128": QuantType.per_1x128,
"per_128x128": QuantType.per_128x128,
}
return mapping.get(name)
@classmethod @classmethod
@if_aiter_supported @if_aiter_supported
def is_enabled(cls) -> bool: def is_enabled(cls) -> bool:
...@@ -1127,6 +1221,14 @@ class rocm_aiter_ops: ...@@ -1127,6 +1221,14 @@ class rocm_aiter_ops:
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op(
op_name="rocm_aiter_fused_topk",
op_func=_rocm_aiter_fused_topk_impl,
mutates_args=[],
fake_impl=_rocm_aiter_fused_topk_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_mla_decode_fwd", op_name="rocm_aiter_mla_decode_fwd",
op_func=_rocm_aiter_mla_decode_fwd_impl, op_func=_rocm_aiter_mla_decode_fwd_impl,
...@@ -1360,6 +1462,10 @@ class rocm_aiter_ops: ...@@ -1360,6 +1462,10 @@ class rocm_aiter_ops:
a2_scale: torch.Tensor | None = None, a2_scale: torch.Tensor | None = None,
num_local_tokens: torch.Tensor | None = None, num_local_tokens: torch.Tensor | None = None,
output_dtype: torch.dtype | None = None, output_dtype: torch.dtype | None = None,
hidden_pad: int = 0,
intermediate_pad: int = 0,
bias1: torch.Tensor | None = None,
bias2: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_fused_moe( return torch.ops.vllm.rocm_aiter_fused_moe(
hidden_states, hidden_states,
...@@ -1377,6 +1483,10 @@ class rocm_aiter_ops: ...@@ -1377,6 +1483,10 @@ class rocm_aiter_ops:
a2_scale, a2_scale,
num_local_tokens, num_local_tokens,
output_dtype, output_dtype,
hidden_pad,
intermediate_pad,
bias1,
bias2,
) )
@staticmethod @staticmethod
...@@ -1481,6 +1591,15 @@ class rocm_aiter_ops: ...@@ -1481,6 +1591,15 @@ class rocm_aiter_ops:
routed_scaling_factor, routed_scaling_factor,
) )
@staticmethod
def fused_topk(
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
gate_up: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops.vllm.rocm_aiter_fused_topk(x, router_logits, top_k, gate_up)
@staticmethod @staticmethod
def mla_decode_fwd( def mla_decode_fwd(
q: torch.Tensor, q: torch.Tensor,
...@@ -1701,6 +1820,47 @@ class rocm_aiter_ops: ...@@ -1701,6 +1820,47 @@ class rocm_aiter_ops:
return shuffle_weight(tensor, layout=layout) return shuffle_weight(tensor, layout=layout)
@staticmethod
def shuffle_weight_a16w4(
tensor: "torch.Tensor",
nLane: int,
gate_up: bool,
) -> "torch.Tensor":
"""
Shuffles the weight tensor into (A16W4) layout for AITER kernels.
Args:
tensor: The input weight tensor to be shuffled.
layout: The block layout to use, defaults to (16, 4).
Returns:
torch.Tensor: The shuffled tensor.
"""
from aiter.ops.shuffle import shuffle_weight_a16w4
return shuffle_weight_a16w4(tensor, nLane, gate_up)
@staticmethod
def shuffle_scale_a16w4(
tensor: "torch.Tensor",
num_experts: int,
gate_up: bool,
) -> "torch.Tensor":
"""
Shuffles the scale tensor into (A16W4) layout for AITER kernels.
Args:
tensor: The input scale tensor to be shuffled.
num_experts: Number of experts, needed for reshaping logic.
gate_up: Whether the scale is for w13 (True) or w2 (False).
Returns:
torch.Tensor: The shuffled scale tensor.
"""
from aiter.ops.shuffle import shuffle_scale_a16w4
return shuffle_scale_a16w4(tensor, num_experts, gate_up)
@staticmethod @staticmethod
def shuffle_weights( def shuffle_weights(
*tensors: torch.Tensor, layout: tuple[int, int] = (16, 16) *tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
......
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import envs from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
...@@ -77,6 +78,8 @@ class Mxfp4Backend(Enum): ...@@ -77,6 +78,8 @@ class Mxfp4Backend(Enum):
# Triton Backend # Triton Backend
TRITON = 6 TRITON = 6
CK = 7
def get_mxfp4_backend_with_lora() -> Mxfp4Backend: def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
""" """
...@@ -167,7 +170,13 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: ...@@ -167,7 +170,13 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
elif current_platform.is_xpu(): elif current_platform.is_xpu():
logger.info_once("Using xpu backend on XPU") logger.info_once("Using xpu backend on XPU")
return Mxfp4Backend.MARLIN return Mxfp4Backend.MARLIN
elif current_platform.is_rocm() and has_triton_kernels(): elif current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx950
if rocm_aiter_ops.is_enabled() and on_gfx950():
logger.info_once("Using CK MXFP4 MoE backend (Aiter ROCm)")
return Mxfp4Backend.CK
elif has_triton_kernels():
logger.info_once("Using Triton backend") logger.info_once("Using Triton backend")
return Mxfp4Backend.TRITON return Mxfp4Backend.TRITON
...@@ -338,6 +347,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -338,6 +347,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.intermediate_size = intermediate_size_per_partition_after_pad self.intermediate_size = intermediate_size_per_partition_after_pad
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.hidden_pad = extra_weight_attrs.get("hidden_pad", 0)
self.intermediate_pad = (
intermediate_size_per_partition_after_pad - intermediate_size_per_partition
)
# Fused gate_up_proj (column parallel) # Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.zeros( torch.zeros(
...@@ -784,6 +797,66 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -784,6 +797,66 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
), ),
shared_experts=None, shared_experts=None,
) )
elif self.mxfp4_backend == Mxfp4Backend.CK:
if layer.w13_bias is not None:
layer.w13_bias.data = layer.w13_bias.data.to(torch.float32)
if layer.w2_bias.data is not None:
layer.w2_bias.data = layer.w2_bias.data.to(torch.float32)
e, n, k = layer.w13_weight.shape
layer.w13_weight.view(torch.uint8).copy_(
layer.w13_weight.data.view(torch.uint8)
.view(e, n // 2, 2, k)
.permute(0, 2, 1, 3)
.contiguous()
.view(e, n, k)
)
layer.w13_weight_scale.data = (
layer.w13_weight_scale.data.view(e, n // 2, 2, -1)
.permute(0, 2, 1, 3)
.contiguous()
.view(e, n, -1)
)
layer.w13_weight.data = layer.w13_weight.data.view(torch.float4_e2m1fn_x2)
layer.w2_weight.data = layer.w2_weight.data.view(torch.float4_e2m1fn_x2)
layer.w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(
layer.w13_weight, 16, True
)
shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4(
layer.w13_weight_scale.view(-1, layer.w13_weight_scale.shape[-1]),
self.num_experts,
True,
)
layer.w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(
layer.w2_weight, 16, False
)
shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4(
layer.w2_weight_scale.view(-1, layer.w2_weight_scale.shape[-1]),
self.num_experts,
False,
)
layer.w13_bias.data = (
layer.w13_bias.data.view(-1, n // 2, 2)
.permute(0, 2, 1)
.contiguous()
.view(-1, n)
)
layer.w13_weight_scale = torch.nn.Parameter(
shuffled_w13_scale, requires_grad=False
)
layer.w2_weight_scale = torch.nn.Parameter(
shuffled_w2_scale, requires_grad=False
)
# replace_parameter(layer, "w13_bias", w13_bias)
# replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
# replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
# replace_parameter(layer, "w13_weight", w13_weight)
# replace_parameter(layer, "w2_weight", w2_weight)
elif self.mxfp4_backend == Mxfp4Backend.TRITON: elif self.mxfp4_backend == Mxfp4Backend.TRITON:
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
...@@ -792,7 +865,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -792,7 +865,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.w13_bias = Parameter(w13_bias, requires_grad=False) layer.w13_bias = Parameter(w13_bias, requires_grad=False)
layer.w2_bias = Parameter(w2_bias, requires_grad=False) layer.w2_bias = Parameter(w2_bias, requires_grad=False)
# Ideally we'd use FusedMoEModularKernel.prepare_finalize object # Ideally we'd use FusedMoEModularKernel.prepare_finalize object
# (stored in self.fused_experts) to determine if the MoE has a # (stored in self.fused_experts) to determine if the MoE has a
# batched activation format. As self.fused_experts is not # batched activation format. As self.fused_experts is not
...@@ -803,7 +875,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -803,7 +875,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
else: else:
num_warps = 8 num_warps = 8
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
layer.w13_weight, layer.w13_weight_scale, num_warps layer.w13_weight, layer.w13_weight_scale, num_warps
) )
...@@ -817,13 +888,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -817,13 +888,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.w2_precision_config = PrecisionConfig( self.w2_precision_config = PrecisionConfig(
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex) weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
) )
self.w13_weight = w13_weight self.w13_weight = w13_weight
self.w2_weight = w2_weight self.w2_weight = w2_weight
del layer.w13_weight del layer.w13_weight
del layer.w2_weight del layer.w2_weight
layer.w13_weight = w13_weight layer.w13_weight = w13_weight
layer.w2_weight = w2_weight layer.w2_weight = w2_weight
else: else:
raise ValueError( raise ValueError(
f"Unsupported mxfp4_backend: {self.mxfp4_backend}: " f"Unsupported mxfp4_backend: {self.mxfp4_backend}: "
...@@ -862,6 +933,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -862,6 +933,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
elif self.mxfp4_backend in [ elif self.mxfp4_backend in [
Mxfp4Backend.SM100_FI_MXFP4_BF16, Mxfp4Backend.SM100_FI_MXFP4_BF16,
Mxfp4Backend.SM90_FI_MXFP4_BF16, Mxfp4Backend.SM90_FI_MXFP4_BF16,
Mxfp4Backend.CK,
]: ]:
return mxfp4_w4a16_moe_quant_config( return mxfp4_w4a16_moe_quant_config(
w1_bias=layer.w13_bias, w1_bias=layer.w13_bias,
...@@ -933,6 +1005,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -933,6 +1005,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
or self.mxfp4_backend == Mxfp4Backend.TRITON or self.mxfp4_backend == Mxfp4Backend.TRITON
or self.mxfp4_backend == Mxfp4Backend.CK
) )
def apply( def apply(
...@@ -1054,6 +1127,27 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -1054,6 +1127,27 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
tune_max_num_tokens=max(self.max_capture_size, 1), tune_max_num_tokens=max(self.max_capture_size, 1),
)[0] )[0]
return trtllm_gen_output return trtllm_gen_output
elif self.mxfp4_backend == Mxfp4Backend.CK:
topk_weights, topk_ids = rocm_aiter_ops.fused_topk(
x, router_logits, layer.top_k, True
)
output = rocm_aiter_ops.fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation_method=rocm_aiter_ops.get_aiter_activation_type("swiglu"),
quant_method=rocm_aiter_ops.get_aiter_quant_type("per_1x32"),
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
doweight_stage1=False,
hidden_pad=self.hidden_pad // 128 * 128,
intermediate_pad=self.intermediate_pad // 64 * 64 * 2,
bias1=layer.w13_bias,
bias2=layer.w2_bias,
)
return output
elif self.mxfp4_backend == Mxfp4Backend.TRITON: elif self.mxfp4_backend == Mxfp4Backend.TRITON:
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
triton_kernel_moe_forward, triton_kernel_moe_forward,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment