Unverified Commit b69bf2f0 authored by Wei Zhao's avatar Wei Zhao Committed by GitHub
Browse files

[Perf] Use torch compile to fuse pack topk in trtllm moe (#37695)


Signed-off-by: default avatarwzhao18 <wzhao18.sz@gmail.com>
Signed-off-by: default avatarWei Zhao <51183510+wzhao18@users.noreply.github.com>
parent 88149b63
......@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.fused_moe.utils import trtllm_moe_pack_topk_ids_weights
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
activation_to_flashinfer_int,
)
......@@ -152,11 +153,8 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
import flashinfer
from flashinfer.fused_moe import Fp8QuantizationType
# Pack topk_ids and topk_weights into single tensor
# Format: (expert_id << 16) | (weight_bf16.view(int16))
packed_topk_ids = (topk_ids << 16) | topk_weights.to(torch.bfloat16).view(
torch.int16
)
# Pack topk ids and weights into format expected by the kernel.
packed_topk_ids = trtllm_moe_pack_topk_ids_weights(topk_ids, topk_weights)
# trtllm_fp8_block_scale_routed_moe does not support autotuning
# so skip this kernel during dummy run for autotuning.
......
......@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.fused_moe.utils import trtllm_moe_pack_topk_ids_weights
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
activation_to_flashinfer_int,
)
......@@ -183,9 +184,7 @@ class TrtLlmNvFp4ExpertsModular(TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsModula
assert self.quant_config.w2_scale is not None
# Pack topk ids and weights into format expected by the kernel.
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
torch.bfloat16
).view(torch.int16)
packed_tensor = trtllm_moe_pack_topk_ids_weights(topk_ids, topk_weights)
# trtllm_fp4_block_scale_routed_moe does not support autotuning
# so skip this kernel during dummy run for autotuning.
......
......@@ -323,3 +323,16 @@ def normalize_batched_scales_shape(
@functools.cache
def disable_inplace() -> bool:
return is_torch_equal_or_newer("2.9")
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def trtllm_moe_pack_topk_ids_weights(
topk_ids: torch.Tensor, topk_weights: torch.Tensor
) -> torch.Tensor:
"""
Pack topk_ids and topk_weights into a single int32 tensor.
Format: (expert_id << 16) | weight_bf16.view(int16)
"""
return (topk_ids.to(torch.int32) << 16) | topk_weights.to(torch.bfloat16).view(
torch.int16
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment