Unverified Commit b57b9673 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor][7/N] AITER MK (#31102)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
parent 6d518ffb
...@@ -2132,6 +2132,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -2132,6 +2132,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
torch.float8_e4m3fn, torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
] ]
E, num_tokens, N, K, top_k_num = self.moe_problem_size( E, num_tokens, N, K, top_k_num = self.moe_problem_size(
...@@ -2156,7 +2157,10 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -2156,7 +2157,10 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
compute_type = tl.float16 compute_type = tl.float16
elif hidden_states.dtype == torch.float32: elif hidden_states.dtype == torch.float32:
compute_type = tl.float32 compute_type = tl.float32
elif hidden_states.dtype == torch.float8_e4m3fn: elif (
hidden_states.dtype == torch.float8_e4m3fn
or hidden_states.dtype == torch.float8_e4m3fnuz
):
compute_type = tl.bfloat16 compute_type = tl.bfloat16
else: else:
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
......
...@@ -13,6 +13,10 @@ from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input ...@@ -13,6 +13,10 @@ from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def __init__(self, defer_input_quant: bool = False) -> None:
super().__init__()
self.defer_input_quant = defer_input_quant
@property @property
def activation_format(self) -> mk.FusedMoEActivationFormat: def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard return mk.FusedMoEActivationFormat.Standard
...@@ -48,6 +52,11 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -48,6 +52,11 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
# Note: do not use inplace for shared experts overlap # Note: do not use inplace for shared experts overlap
a1 = a1 * topk_weights.to(a1.dtype) a1 = a1 * topk_weights.to(a1.dtype)
# Defer input quant to moe kernel for backends (e.g. AITER, FI)
# which use a single kernel call for quant + experts.
if self.defer_input_quant:
return a1, None, None, None, None
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1,
quant_config.a1_scale, quant_config.a1_scale,
......
...@@ -5,11 +5,15 @@ from functools import lru_cache ...@@ -5,11 +5,15 @@ from functools import lru_cache
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
class QuantMethod(IntEnum): class QuantMethod(IntEnum):
...@@ -263,3 +267,78 @@ def rocm_aiter_fused_experts( ...@@ -263,3 +267,78 @@ def rocm_aiter_fused_experts(
a2_scale=quant_config.a2_scale, a2_scale=quant_config.a2_scale,
doweight_stage1=apply_router_weight_on_input, doweight_stage1=apply_router_weight_on_input,
) )
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, quant_config):
super().__init__(quant_config)
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard,
)
def supports_expert_map(self):
return True
def supports_chunking(self):
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Workspaces are managed internally by AITER.
workspace1 = (0,)
workspace2 = (0,)
output = (M, K)
return (workspace1, workspace2, output)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
assert a1q_scale is None
assert a2_scale is None
assert expert_tokens_meta is None
result = rocm_aiter_fused_experts(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
quant_config=self.quant_config,
)
assert result.shape == output.shape
output.copy_(result)
...@@ -117,6 +117,7 @@ class Fp8MoeBackend(Enum): ...@@ -117,6 +117,7 @@ class Fp8MoeBackend(Enum):
DEEPGEMM = 3 DEEPGEMM = 3
MARLIN = 4 MARLIN = 4
TRITON = 5 TRITON = 5
AITER = 6
def get_fp8_moe_backend( def get_fp8_moe_backend(
...@@ -189,6 +190,10 @@ def get_fp8_moe_backend( ...@@ -189,6 +190,10 @@ def get_fp8_moe_backend(
logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local") logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
return Fp8MoeBackend.DEEPGEMM return Fp8MoeBackend.DEEPGEMM
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE:
logger.info_once("Using ROCm AITER backend for FP8 MoE", scope="local")
return Fp8MoeBackend.AITER
# default to Triton # default to Triton
logger.info_once("Using Triton backend for FP8 MoE") logger.info_once("Using Triton backend for FP8 MoE")
return Fp8MoeBackend.TRITON return Fp8MoeBackend.TRITON
...@@ -888,16 +893,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -888,16 +893,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_input_scale = None layer.w13_input_scale = None
layer.w2_input_scale = None layer.w2_input_scale = None
self.rocm_aiter_moe_enabled = False
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False): if getattr(layer, "_already_called_process_weights_after_loading", False):
return return
# Lazy import to avoid importing triton too early.
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# TODO (rob): refactor block quant into separate class. # TODO (rob): refactor block quant into separate class.
if self.block_quant: if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic" assert self.quant_config.activation_scheme == "dynamic"
...@@ -932,7 +931,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -932,7 +931,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv) replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
replace_parameter(layer, "w2_weight", w2_weight) replace_parameter(layer, "w2_weight", w2_weight)
replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv) replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
if self.rocm_aiter_moe_enabled: if self.fp8_backend == Fp8MoeBackend.AITER:
# reshaping weights is required for aiter moe kernel. # reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data layer.w13_weight.data, layer.w2_weight.data
...@@ -1026,7 +1025,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1026,7 +1025,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
) )
start += shard_size start += shard_size
if self.rocm_aiter_moe_enabled: if self.fp8_backend == Fp8MoeBackend.AITER:
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight layer.w13_weight, layer.w2_weight
) )
...@@ -1072,6 +1071,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1072,6 +1071,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.moe_quant_config = config self.moe_quant_config = config
self.kernel = mk.FusedMoEModularKernel( self.kernel = mk.FusedMoEModularKernel(
# TODO(rob): we can use the generic MoEPrepareAndFinalizeNoEP
# with the changes to defer input quantization
FlashInferAllGatherMoEPrepareAndFinalize( FlashInferAllGatherMoEPrepareAndFinalize(
use_dp=(self.moe.dp_size > 1), use_dp=(self.moe.dp_size > 1),
use_deepseek_fp8_block_scale=self.block_quant, use_deepseek_fp8_block_scale=self.block_quant,
...@@ -1093,6 +1094,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1093,6 +1094,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
Fp8MoeBackend.DEEPGEMM, Fp8MoeBackend.DEEPGEMM,
Fp8MoeBackend.TRITON, Fp8MoeBackend.TRITON,
Fp8MoeBackend.MARLIN, Fp8MoeBackend.MARLIN,
Fp8MoeBackend.AITER,
]: ]:
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
TritonOrDeepGemmExperts, TritonOrDeepGemmExperts,
...@@ -1103,24 +1105,33 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1103,24 +1105,33 @@ class Fp8MoEMethod(FusedMoEMethodBase):
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP, MoEPrepareAndFinalizeNoEP,
) )
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
)
config = self.get_fused_moe_quant_config(layer) config = self.get_fused_moe_quant_config(layer)
assert config is not None assert config is not None
self.moe_quant_config = config self.moe_quant_config = config
use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
moe_kernel = (
MarlinExperts(quant_config=self.moe_quant_config)
if use_marlin
else TritonOrDeepGemmExperts(
quant_config=self.moe_quant_config,
allow_deep_gemm=allow_deep_gemm,
)
)
self.kernel = mk.FusedMoEModularKernel( if self.fp8_backend == Fp8MoeBackend.AITER:
MoEPrepareAndFinalizeNoEP(), moe_kernel self.kernel = mk.FusedMoEModularKernel(
) # TODO: make defer_input_quant an attr of the AiterExperts
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
AiterExperts(quant_config=self.moe_quant_config),
)
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
MarlinExperts(quant_config=self.moe_quant_config),
)
else:
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonOrDeepGemmExperts(
quant_config=self.moe_quant_config,
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
),
)
self.use_inplace = True self.use_inplace = True
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
...@@ -1128,7 +1139,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1128,7 +1139,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
if ( if (
self.rocm_aiter_moe_enabled self.fp8_backend == Fp8MoeBackend.AITER
or self.fp8_backend == Fp8MoeBackend.MARLIN or self.fp8_backend == Fp8MoeBackend.MARLIN
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
): ):
...@@ -1161,11 +1172,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1161,11 +1172,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
TritonOrDeepGemmExperts, TritonOrDeepGemmExperts,
) )
assert ( if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
self.fp8_backend != Fp8MoeBackend.MARLIN raise NotImplementedError(
) and not self.rocm_aiter_moe_enabled, ( "Marlin and ROCm AITER are not supported with all2all yet."
"Marlin and ROCm AITER are not supported with all2all yet." )
)
assert self.moe_quant_config is not None assert self.moe_quant_config is not None
...@@ -1313,37 +1323,18 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1313,37 +1323,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
) )
result = self.kernel(
if self.rocm_aiter_moe_enabled: x,
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 layer.w13_weight,
rocm_aiter_fused_experts, layer.w2_weight,
) topk_weights,
topk_ids,
# TODO(rob): convert this to MK. inplace=self.use_inplace,
result = rocm_aiter_fused_experts( activation=layer.activation,
x, global_num_experts=layer.global_num_experts,
layer.w13_weight, expert_map=layer.expert_map,
layer.w2_weight, apply_router_weight_on_input=layer.apply_router_weight_on_input,
topk_weights=topk_weights, )
topk_ids=topk_ids,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
else:
result = self.kernel(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
inplace=self.use_inplace,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
return result return result
...@@ -1456,15 +1447,10 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod): ...@@ -1456,15 +1447,10 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
layer.w13_input_scale = None layer.w13_input_scale = None
layer.w2_input_scale = None layer.w2_input_scale = None
self.rocm_aiter_moe_enabled = False
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False): if getattr(layer, "_already_called_process_weights_after_loading", False):
return return
# Lazy import to avoid importing triton too early.
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# If checkpoint is fp16, quantize in place. # If checkpoint is fp16, quantize in place.
fp8_dtype = current_platform.fp8_dtype() fp8_dtype = current_platform.fp8_dtype()
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
...@@ -1481,7 +1467,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod): ...@@ -1481,7 +1467,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
replace_parameter(layer, "w2_weight", w2_weight) replace_parameter(layer, "w2_weight", w2_weight)
# Reshuffle weights for AITER if needed. # Reshuffle weights for AITER if needed.
if self.rocm_aiter_moe_enabled: if self.fp8_backend == Fp8MoeBackend.AITER:
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight layer.w13_weight, layer.w2_weight
) )
...@@ -1489,7 +1475,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod): ...@@ -1489,7 +1475,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
replace_parameter(layer, "w2_weight", shuffled_w2) replace_parameter(layer, "w2_weight", shuffled_w2)
# Rushuffle weights for MARLIN if needed. # Rushuffle weights for MARLIN if needed.
if self.fp8_backend == Fp8MoeBackend.MARLIN: elif self.fp8_backend == Fp8MoeBackend.MARLIN:
prepare_moe_fp8_layer_for_marlin( prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype layer, False, input_dtype=self.marlin_input_dtype
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment