Unverified Commit a57c8228 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Moe Refactor] Make Inplace Flag for FusedMoEModularKernel part of the constructor (#33375)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent 1ee95841
......@@ -46,6 +46,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
old_quant_method: FusedMoEMethodBase,
prepare_finalize: FusedMoEPrepareAndFinalize,
shared_experts: torch.nn.Module | None,
inplace: bool = False,
) -> "FusedMoEModularMethod":
return FusedMoEModularMethod(
old_quant_method,
......@@ -54,6 +55,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
shared_experts,
moe_parallel_config=moe_layer.moe_parallel_config,
inplace=inplace,
),
)
......@@ -61,10 +63,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
def supports_eplb(self) -> bool:
return self.old_quant_method.supports_eplb
@property
def allow_inplace(self) -> bool:
return self.old_quant_method.allow_inplace
@property
def method_name(self) -> str:
return self.old_quant_method.method_name
......@@ -99,7 +97,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=self.allow_inplace,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
......
......@@ -50,6 +50,9 @@ from vllm.model_executor.layers.fused_moe.router.router_factory import (
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
from vllm.model_executor.layers.fused_moe.utils import (
disable_inplace,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
)
......@@ -560,6 +563,8 @@ class FusedMoE(CustomOp):
activation=activation,
device=vllm_config.device_config.device,
routing_method=self.routing_method_type,
# TODO: in_dtype == out_dtype?
disable_inplace=disable_inplace() or self.shared_experts is not None,
)
if self.use_mori_kernels:
assert self.rocm_aiter_fmoe_enabled, (
......@@ -650,7 +655,11 @@ class FusedMoE(CustomOp):
"%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self)
)
self.quant_method = FusedMoEModularMethod.make(
self, self.quant_method, prepare_finalize, self.shared_experts
self,
self.quant_method,
prepare_finalize,
self.shared_experts,
inplace=not self.moe_config.disable_inplace,
)
@property
......
......@@ -811,11 +811,13 @@ class FusedMoEModularKernel(torch.nn.Module):
fused_experts: FusedMoEPermuteExpertsUnpermute,
shared_experts: torch.nn.Module | None = None,
moe_parallel_config: FusedMoEParallelConfig | None = None,
inplace: bool = False,
):
super().__init__()
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
self.shared_experts = shared_experts
self.inplace = inplace
# prefer an explicit FusedMoEParallelConfig when available (from
# FusedMoE layers / tests).
......@@ -1292,7 +1294,6 @@ class FusedMoEModularKernel(torch.nn.Module):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
......@@ -1309,8 +1310,6 @@ class FusedMoEModularKernel(torch.nn.Module):
- topk_weights (torch.Tensor): The topk weights applied at the end of
the layer.
- topk_ids (torch.Tensor): A map of row to expert id.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
......@@ -1326,7 +1325,9 @@ class FusedMoEModularKernel(torch.nn.Module):
- torch.Tensor: The output tensor after applying the MoE layer.
"""
if inplace and self.shared_experts is None and not disable_inplace():
if self.inplace:
assert self.shared_experts is None
assert not disable_inplace()
output = hidden_states
else:
output = torch.zeros_like(hidden_states)
......
......@@ -472,7 +472,7 @@ def make_fp8_moe_kernel(
fp8_backend: Fp8MoeBackend,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None,
) -> tuple[mk.FusedMoEModularKernel, bool]:
) -> mk.FusedMoEModularKernel:
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
......@@ -512,8 +512,10 @@ def make_fp8_moe_kernel(
else None
),
moe_parallel_config=moe_config.moe_parallel_config,
inplace=(
not moe_config.disable_inplace
and fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS
),
)
# TODO(rob): update inplace logic to be part of the kernel.
inplace = fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS
return kernel, inplace
return kernel
......@@ -437,6 +437,7 @@ def make_nvfp4_moe_kernel(
else None
),
moe_parallel_config=moe_config.moe_parallel_config,
inplace=False,
)
# TODO(rob): update inplace logic to be part of the kernel.
......
......@@ -154,11 +154,9 @@ def make_unquantized_moe_kernel(
backend: UnquantizedMoeBackend,
quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
) -> tuple[mk.FusedMoEModularKernel | None, bool]:
use_inplace = True
) -> mk.FusedMoEModularKernel | None:
if backend in UNSUPPORTED_BACKEND:
return None, use_inplace
return None
if backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
......@@ -171,8 +169,9 @@ def make_unquantized_moe_kernel(
moe_config=moe_config,
quant_config=quant_config,
),
inplace=False,
)
use_inplace = False
elif backend == UnquantizedMoeBackend.AITER:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
......@@ -184,6 +183,7 @@ def make_unquantized_moe_kernel(
moe_config=moe_config,
quant_config=quant_config,
),
inplace=not moe_config.disable_inplace,
)
elif backend == UnquantizedMoeBackend.TRITON:
from vllm.model_executor.layers.fused_moe import TritonExperts
......@@ -194,6 +194,7 @@ def make_unquantized_moe_kernel(
moe_config=moe_config,
quant_config=quant_config,
),
inplace=not moe_config.disable_inplace,
)
elif backend == UnquantizedMoeBackend.XPU:
from vllm.model_executor.layers.fused_moe import XPUExperts
......@@ -204,5 +205,6 @@ def make_unquantized_moe_kernel(
moe_config=moe_config,
quant_config=quant_config,
),
inplace=not moe_config.disable_inplace,
)
return kernel, use_inplace
return kernel
......@@ -101,10 +101,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def supports_eplb(self) -> bool:
return True
@property
def allow_inplace(self) -> bool:
return True
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
......@@ -225,7 +221,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
assert self.moe_quant_config is not None
self.kernel, self.use_inplace = make_unquantized_moe_kernel(
self.kernel = make_unquantized_moe_kernel(
backend=self.unquantized_backend,
quant_config=self.moe_quant_config,
moe_config=self.moe,
......@@ -329,7 +325,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=self.use_inplace,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
......
......@@ -785,4 +785,5 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
w2_zeros=layer.w2_qzeros,
workspace=layer.workspace,
input_dtype=self.input_dtype,
inplace=not self.moe.disable_inplace,
)
......@@ -515,7 +515,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
inplace=not self.moe.disable_inplace,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
......
......@@ -357,7 +357,6 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight,
topk_weights,
topk_ids,
inplace=False,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
......@@ -669,7 +668,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight,
topk_weights,
topk_ids,
inplace=False,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
......@@ -960,7 +958,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
self.moe_mk = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
......@@ -1073,7 +1071,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight,
topk_weights,
topk_ids,
inplace=self.use_inplace,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
# TODO(rob): investigate the disable_expert_map introduced by:
......@@ -1212,7 +1209,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
inplace=not self.moe.disable_inplace,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
......@@ -1739,6 +1736,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
workspace=layer.workspace,
input_dtype=self.marlin_input_dtype,
is_k_full=self.is_k_full,
inplace=not self.moe.disable_inplace,
)
......@@ -1969,7 +1967,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight_packed,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
inplace=not self.moe.disable_inplace,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
......@@ -2605,6 +2603,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
s_strides1=self.s_strides1,
s_strides2=self.s_strides2,
group_size=self.group_size,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
@property
......
......@@ -149,7 +149,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
inplace=not self.moe.disable_inplace,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
......
......@@ -854,7 +854,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
self.moe_mk = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
......@@ -958,10 +958,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def supports_eplb(self) -> bool:
return True
@property
def allow_inplace(self) -> bool:
return True
@property
def is_monolithic(self) -> bool:
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
......@@ -1032,7 +1028,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight,
topk_weights,
topk_ids,
inplace=self.use_inplace,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
......
......@@ -924,4 +924,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
workspace=layer.workspace,
is_k_full=self.is_k_full,
input_dtype=self.input_dtype,
inplace=not self.moe.disable_inplace,
)
......@@ -853,7 +853,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
self.moe_mk = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
......@@ -967,7 +967,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight,
topk_weights,
topk_ids,
inplace=self.use_inplace,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
......@@ -1538,7 +1537,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.w2_weight,
topk_weights,
topk_ids,
inplace=False,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
......
......@@ -378,7 +378,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer.w2_qweight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
inplace=not self.moe.disable_inplace,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
......
......@@ -881,10 +881,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP"
)
@property
def allow_inplace(self) -> bool:
return True
@property
def is_monolithic(self) -> bool:
return (
......@@ -923,6 +919,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
activation=layer.activation,
expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype,
inplace=not self.moe.disable_inplace,
)
assert _can_support_mxfp4(
......
......@@ -388,6 +388,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
inplace=not self.moe.disable_inplace,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
......@@ -398,7 +399,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
inplace=not self.moe.disable_inplace,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
......@@ -734,10 +735,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
block_shape=None,
)
@property
def allow_inplace(self) -> bool:
return True
def apply(
self,
layer: FusedMoE,
......@@ -769,7 +766,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
inplace=not self.moe.disable_inplace,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment