Unverified Commit a1448b4b authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernels] Split up fused_moe/layer.py, isolate more modular kernel code (#28064)

parent fa197020
......@@ -6,6 +6,10 @@ import torch
# Fused experts and PrepareFinalize imports
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
......@@ -21,7 +25,6 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts,
NaiveBatchedExperts,
)
from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, TritonExperts
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
......@@ -399,9 +402,7 @@ def make_prepare_finalize(
quant_config: FusedMoEQuantConfig,
) -> mk.FusedMoEPrepareAndFinalize:
if backend != "naive" and backend is not None:
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(
moe, quant_config
)
prepare_finalize = maybe_make_prepare_finalize(moe, quant_config)
assert prepare_finalize is not None
return prepare_finalize
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
......
......@@ -25,7 +25,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
modular_triton_fused_moe,
try_get_optimal_moe_config,
)
from vllm.model_executor.layers.fused_moe.layer import FusedMoEModularMethod
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
FusedMoEModularMethod,
)
class FusedMoEWithLoRA(BaseLayerWithLoRA):
......
......@@ -5,9 +5,11 @@ from contextlib import contextmanager
from typing import Any
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed import (
get_ep_group,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPrepareAndFinalize,
)
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_ep, has_pplx
if current_platform.is_cuda_alike():
if has_pplx():
from .pplx_prepare_finalize import (
PplxPrepareAndFinalize,
pplx_hidden_dim_scale_bytes,
)
if has_deep_ep():
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import (
DEEPEP_QUANT_BLOCK_SHAPE,
DeepEPLLPrepareAndFinalize,
)
def maybe_roundup_layer_hidden_size(
hidden_size: int,
act_dtype: torch.dtype,
moe_parallel_config: FusedMoEParallelConfig,
) -> int:
"""
Given layer hidden size and MoE configurations, round up hidden_size
if necessary.
Args:
hidden_size: Layer hidden-size
act_dtype: Data type of the layer activations.
moe_parallel_config: Fused MoE parallelization strategy configuration.
Return:
Rounded up hidden_size if rounding up is required based on the configs
and all2all backend.
Original hidden size otherwise.
"""
if moe_parallel_config.use_deepep_ht_kernels:
hidden_size = DeepEPHTPrepareAndFinalize.maybe_roundup_layer_hidden_size(
hidden_size, act_dtype
)
if moe_parallel_config.use_deepep_ll_kernels:
hidden_size = DeepEPLLPrepareAndFinalize.maybe_roundup_layer_hidden_size(
hidden_size
)
return hidden_size
def maybe_make_prepare_finalize(
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig | None,
) -> FusedMoEPrepareAndFinalize | None:
if not moe.moe_parallel_config.use_all2all_kernels:
return None
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
prepare_finalize: FusedMoEPrepareAndFinalize | None = None
# TODO: could allow this now
assert not moe.use_flashinfer_cutlass_kernels, "Must be created in modelopt.py"
if moe.use_pplx_kernels:
assert quant_config is not None
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
moe.max_num_tokens,
moe.hidden_dim,
moe.in_dtype,
quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
)
all_to_all_args = dict(
max_num_tokens=moe.max_num_tokens,
num_experts=moe.num_experts,
experts_per_token=moe.experts_per_token, # topk
rank=all2all_manager.rank,
world_size=all2all_manager.world_size,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
hidden_dim=moe.hidden_dim,
hidden_dim_bytes=hidden_dim_bytes,
hidden_dim_scale_bytes=hidden_scale_bytes,
)
num_dispatchers = (
all2all_manager.world_size // all2all_manager.tp_group.world_size
)
# Intranode pplx a2a takes a group name while internode does not.
if not all2all_manager.internode:
all_to_all_args["group_name"] = all2all_manager.cpu_group.group_name
handle = all2all_manager.get_handle(all_to_all_args)
prepare_finalize = PplxPrepareAndFinalize(
handle,
max_num_tokens=moe.max_num_tokens,
num_local_experts=moe.num_local_experts,
num_dispatchers=num_dispatchers,
)
elif moe.use_deepep_ht_kernels:
assert moe.dp_size == all2all_manager.dp_world_size
all_to_all_args = dict()
handle = all2all_manager.get_handle(all_to_all_args)
prepare_finalize = DeepEPHTPrepareAndFinalize(
handle,
num_dispatchers=all2all_manager.world_size,
dp_size=all2all_manager.dp_world_size,
rank_expert_offset=all2all_manager.rank * moe.num_local_experts,
)
elif moe.use_deepep_ll_kernels:
assert quant_config is not None
all_to_all_args = dict(
max_num_tokens_per_dp_rank=moe.max_num_tokens,
token_hidden_size=moe.hidden_dim,
num_ep_ranks=all2all_manager.world_size,
num_global_experts=moe.num_experts,
num_local_experts=moe.num_experts // all2all_manager.world_size,
)
handle = all2all_manager.get_handle(all_to_all_args)
# Note: We may want to use FP8 dispatch just to reduce
# data movement.
use_fp8_dispatch = (
quant_config.quant_dtype == current_platform.fp8_dtype()
and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE
)
prepare_finalize = DeepEPLLPrepareAndFinalize(
handle,
max_tokens_per_rank=moe.max_num_tokens,
num_dispatchers=all2all_manager.world_size,
use_fp8_dispatch=use_fp8_dispatch,
)
return prepare_finalize
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod
from collections.abc import Callable
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase,
)
logger = init_logger(__name__)
class FusedMoEMethodBase(QuantizeMethodBase):
def __init__(self, moe: FusedMoEConfig):
super().__init__()
self.moe: FusedMoEConfig = moe
self.moe_quant_config: FusedMoEQuantConfig | None = None
@abstractmethod
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
raise NotImplementedError
def uses_weight_scale_2_pattern(self) -> bool:
"""
Returns True if this quantization method uses 'weight_scale_2' pattern
for per-tensor weight scales (e.g., FP4 variants), False otherwise.
This method should be overridden by subclasses that use the
'weight_scale_2' pattern instead of the standard 'weight_scale' pattern.
"""
return False
def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
from .all2all_utils import maybe_make_prepare_finalize
return maybe_make_prepare_finalize(self.moe, self.moe_quant_config)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
# based on the all2all implementation, select the appropriate
# gemm implementation
raise NotImplementedError(
f"{self.__class__.__name__} must select appropriate gemm "
"implementation based on the prepare_finalize"
)
@abstractmethod
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
raise NotImplementedError
@property
def topk_indices_dtype(self) -> torch.dtype | None:
return None
@property
def supports_eplb(self) -> bool:
return False
@property
def allow_inplace(self) -> bool:
return False
@abstractmethod
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel,
FusedMoEPrepareAndFinalize,
)
logger = init_logger(__name__)
@CustomOp.register("modular_fused_moe")
class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
def __init__(
self, old_quant_method: FusedMoEMethodBase, experts: FusedMoEModularKernel
):
super().__init__(old_quant_method.moe)
self.moe_quant_config = old_quant_method.moe_quant_config
self.fused_experts = experts
self.disable_expert_map = getattr(
old_quant_method,
"disable_expert_map",
not self.fused_experts.supports_expert_map(),
)
self.old_quant_method = old_quant_method
logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__)
@staticmethod
def make(
moe_layer: torch.nn.Module,
old_quant_method: FusedMoEMethodBase,
prepare_finalize: FusedMoEPrepareAndFinalize,
shared_experts: torch.nn.Module | None,
) -> "FusedMoEModularMethod":
return FusedMoEModularMethod(
old_quant_method,
FusedMoEModularKernel(
prepare_finalize,
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
shared_experts,
),
)
@property
def topk_indices_dtype(self) -> torch.dtype | None:
return self.fused_experts.prepare_finalize.topk_indices_dtype()
@property
def supports_eplb(self) -> bool:
return self.old_quant_method.supports_eplb
@property
def allow_inplace(self) -> bool:
return self.old_quant_method.allow_inplace
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
raise NotImplementedError
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return self.moe_quant_config
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
# Is getattr needed?
zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None)
if enable_eplb:
if self.supports_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
else:
raise NotImplementedError(
"EPLB is not supported for "
f"{self.old_quant_method.__class__.__name__}."
)
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
global_num_experts=global_num_experts,
zero_expert_num=zero_expert_num,
zero_expert_type=zero_expert_type,
)
result = self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=self.allow_inplace,
activation=activation,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=None if self.disable_expert_map else expert_map,
)
if zero_expert_num != 0 and zero_expert_type is not None:
assert not isinstance(result, tuple), (
"Shared + zero experts are mutually exclusive not yet supported"
)
return result, zero_expert_result
else:
return result
......@@ -38,7 +38,7 @@ class SharedFusedMoE(FusedMoE):
and not (
# TODO(wentao): find the root cause and remove this condition
self.enable_eplb
or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1)
)
and self._shared_experts is not None
)
......
......@@ -741,15 +741,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
)
self.w13_weight_triton_tensor = w13_weight
self.w2_weight_triton_tensor = w2_weight
# need to delete the original weights to save memory on single GPU
del layer.w13_weight
del layer.w2_weight
layer.w13_weight = None
layer.w2_weight = None
torch.cuda.empty_cache()
self.w13_weight = w13_weight
self.w2_weight = w2_weight
layer.w13_weight = w13_weight
layer.w2_weight = w2_weight
else:
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
......@@ -824,18 +819,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
"EP batched experts format"
)
else:
layer.w13_weight = (
self.w13_weight_triton_tensor
if layer.w13_weight is None
else layer.w13_weight
)
layer.w2_weight = (
self.w2_weight_triton_tensor
if layer.w2_weight is None
else layer.w2_weight
)
assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]])
assert self.moe_quant_config is not None
if (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
......@@ -1070,8 +1053,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
return triton_kernel_moe_forward(
hidden_states=x,
w1=self.w13_weight_triton_tensor,
w2=self.w2_weight_triton_tensor,
w1=self.w13_weight,
w2=self.w2_weight,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment