Unverified Commit a3f2f409 authored by Yongye Zhu's avatar Yongye Zhu Committed by GitHub
Browse files

[MoE Refactor] Explicit construct mk for flashinfer bf16 kernel (#31504)


Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent 5a468ff7
...@@ -16,6 +16,9 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -16,6 +16,9 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
biased_moe_quant_config, biased_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase, FusedMoEMethodBase,
) )
...@@ -27,7 +30,10 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( ...@@ -27,7 +30,10 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP, MoEPrepareAndFinalizeNoEP,
) )
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
swap_w13_to_w31,
)
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
...@@ -73,18 +79,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -73,18 +79,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logger.info_once( logger.info_once(
"Enabling FlashInfer CUTLASS MoE for UnquantizedFusedMoEMethod" "Enabling FlashInfer CUTLASS MoE for UnquantizedFusedMoEMethod"
) )
from functools import partial
from .flashinfer_cutlass_moe import flashinfer_cutlass_moe
self.flashinfer_cutlass_moe = partial(
flashinfer_cutlass_moe,
quant_config=FUSED_MOE_UNQUANTIZED_CONFIG,
tp_rank=self.moe.moe_parallel_config.tp_rank,
tp_size=self.moe.moe_parallel_config.tp_size,
ep_rank=self.moe.moe_parallel_config.ep_rank,
ep_size=self.moe.moe_parallel_config.ep_size,
)
else: else:
if ( if (
self.moe.moe_parallel_config.use_ep self.moe.moe_parallel_config.use_ep
...@@ -101,7 +95,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -101,7 +95,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"FlashInfer CUTLASS MoE is currently not available for DP.", "FlashInfer CUTLASS MoE is currently not available for DP.",
scope="local", scope="local",
) )
self.flashinfer_cutlass_moe = None # type: ignore
@property @property
def supports_eplb(self) -> bool: def supports_eplb(self) -> bool:
...@@ -222,12 +215,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -222,12 +215,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.w13_weight.data = shuffled_w13 layer.w13_weight.data = shuffled_w13
layer.w2_weight.data = shuffled_w2 layer.w2_weight.data = shuffled_w2
if self.flashinfer_cutlass_moe_enabled:
# Swap halves to arrange as [w3; w1] (kernel expectation)
w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1)
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
layer.w13_weight.data = w13_weight_swapped.contiguous()
if current_platform.is_xpu(): if current_platform.is_xpu():
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
...@@ -271,11 +258,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -271,11 +258,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
elif current_platform.is_cuda_alike(): elif current_platform.is_cuda_alike():
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
self.kernel = mk.FusedMoEModularKernel( if self.flashinfer_cutlass_moe_enabled:
MoEPrepareAndFinalizeNoEP(), self.use_inplace = False
TritonExperts(self.moe_quant_config), # Swap halves to arrange as [w3; w1] (kernel expectation)
shared_experts=None, w13_weight = swap_w13_to_w31(layer.w13_weight.data)
) replace_parameter(layer, "w13_weight", w13_weight)
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
FlashInferExperts(
out_dtype=layer.params_dtype,
quant_config=self.moe_quant_config,
tp_rank=self.moe.moe_parallel_config.tp_rank,
tp_size=self.moe.moe_parallel_config.tp_size,
ep_rank=self.moe.moe_parallel_config.ep_rank,
ep_size=self.moe.moe_parallel_config.ep_size,
),
)
else:
self.use_inplace = True
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonExperts(self.moe_quant_config),
shared_experts=None,
)
def apply( def apply(
self, self,
...@@ -320,16 +326,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -320,16 +326,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation=layer.activation, activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
elif self.flashinfer_cutlass_moe_enabled:
return self.flashinfer_cutlass_moe(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
else: else:
result = self.kernel( result = self.kernel(
hidden_states=x, hidden_states=x,
...@@ -337,7 +333,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -337,7 +333,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
w2=layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=self.use_inplace,
activation=layer.activation, activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment