Unverified Commit 92edf358 authored by Sijia(Jackson) Chen's avatar Sijia(Jackson) Chen Committed by GitHub
Browse files

[ROCM] enable aiter fused moe kernel for llama4 bf16 checkpoints (#16674)

parent eb5819b2
......@@ -26,6 +26,7 @@ def rocm_aiter_fused_experts(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
apply_router_weight_on_input: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
......@@ -39,6 +40,18 @@ def rocm_aiter_fused_experts(
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
if apply_router_weight_on_input:
assert (topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert (
topk == 1
), "Only support topk=1 when `apply_router_weight_on_input` is True"
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
topk_ids = topk_ids.to(torch.int32)
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)
if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8:
assert w1_scale is not None
assert w2_scale is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment