Unverified Commit a57c8228 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Moe Refactor] Make Inplace Flag for FusedMoEModularKernel part of the constructor (#33375)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent 1ee95841
...@@ -46,6 +46,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -46,6 +46,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
old_quant_method: FusedMoEMethodBase, old_quant_method: FusedMoEMethodBase,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalize,
shared_experts: torch.nn.Module | None, shared_experts: torch.nn.Module | None,
inplace: bool = False,
) -> "FusedMoEModularMethod": ) -> "FusedMoEModularMethod":
return FusedMoEModularMethod( return FusedMoEModularMethod(
old_quant_method, old_quant_method,
...@@ -54,6 +55,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -54,6 +55,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer), old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
shared_experts, shared_experts,
moe_parallel_config=moe_layer.moe_parallel_config, moe_parallel_config=moe_layer.moe_parallel_config,
inplace=inplace,
), ),
) )
...@@ -61,10 +63,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -61,10 +63,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
def supports_eplb(self) -> bool: def supports_eplb(self) -> bool:
return self.old_quant_method.supports_eplb return self.old_quant_method.supports_eplb
@property
def allow_inplace(self) -> bool:
return self.old_quant_method.allow_inplace
@property @property
def method_name(self) -> str: def method_name(self) -> str:
return self.old_quant_method.method_name return self.old_quant_method.method_name
...@@ -99,7 +97,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -99,7 +97,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
w2=layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=self.allow_inplace,
activation=layer.activation, activation=layer.activation,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
......
...@@ -50,6 +50,9 @@ from vllm.model_executor.layers.fused_moe.router.router_factory import ( ...@@ -50,6 +50,9 @@ from vllm.model_executor.layers.fused_moe.router.router_factory import (
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod, UnquantizedFusedMoEMethod,
) )
from vllm.model_executor.layers.fused_moe.utils import (
disable_inplace,
)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
) )
...@@ -560,6 +563,8 @@ class FusedMoE(CustomOp): ...@@ -560,6 +563,8 @@ class FusedMoE(CustomOp):
activation=activation, activation=activation,
device=vllm_config.device_config.device, device=vllm_config.device_config.device,
routing_method=self.routing_method_type, routing_method=self.routing_method_type,
# TODO: in_dtype == out_dtype?
disable_inplace=disable_inplace() or self.shared_experts is not None,
) )
if self.use_mori_kernels: if self.use_mori_kernels:
assert self.rocm_aiter_fmoe_enabled, ( assert self.rocm_aiter_fmoe_enabled, (
...@@ -650,7 +655,11 @@ class FusedMoE(CustomOp): ...@@ -650,7 +655,11 @@ class FusedMoE(CustomOp):
"%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self) "%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self)
) )
self.quant_method = FusedMoEModularMethod.make( self.quant_method = FusedMoEModularMethod.make(
self, self.quant_method, prepare_finalize, self.shared_experts self,
self.quant_method,
prepare_finalize,
self.shared_experts,
inplace=not self.moe_config.disable_inplace,
) )
@property @property
......
...@@ -811,11 +811,13 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -811,11 +811,13 @@ class FusedMoEModularKernel(torch.nn.Module):
fused_experts: FusedMoEPermuteExpertsUnpermute, fused_experts: FusedMoEPermuteExpertsUnpermute,
shared_experts: torch.nn.Module | None = None, shared_experts: torch.nn.Module | None = None,
moe_parallel_config: FusedMoEParallelConfig | None = None, moe_parallel_config: FusedMoEParallelConfig | None = None,
inplace: bool = False,
): ):
super().__init__() super().__init__()
self.prepare_finalize = prepare_finalize self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts self.fused_experts = fused_experts
self.shared_experts = shared_experts self.shared_experts = shared_experts
self.inplace = inplace
# prefer an explicit FusedMoEParallelConfig when available (from # prefer an explicit FusedMoEParallelConfig when available (from
# FusedMoE layers / tests). # FusedMoE layers / tests).
...@@ -1292,7 +1294,6 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1292,7 +1294,6 @@ class FusedMoEModularKernel(torch.nn.Module):
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu", activation: str = "silu",
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
...@@ -1309,8 +1310,6 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1309,8 +1310,6 @@ class FusedMoEModularKernel(torch.nn.Module):
- topk_weights (torch.Tensor): The topk weights applied at the end of - topk_weights (torch.Tensor): The topk weights applied at the end of
the layer. the layer.
- topk_ids (torch.Tensor): A map of row to expert id. - topk_ids (torch.Tensor): A map of row to expert id.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first - activation (str): The activation function to apply after the first
MoE layer. MoE layer.
- global_num_experts (int): The total number of experts in the global - global_num_experts (int): The total number of experts in the global
...@@ -1326,7 +1325,9 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1326,7 +1325,9 @@ class FusedMoEModularKernel(torch.nn.Module):
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
""" """
if inplace and self.shared_experts is None and not disable_inplace(): if self.inplace:
assert self.shared_experts is None
assert not disable_inplace()
output = hidden_states output = hidden_states
else: else:
output = torch.zeros_like(hidden_states) output = torch.zeros_like(hidden_states)
......
...@@ -472,7 +472,7 @@ def make_fp8_moe_kernel( ...@@ -472,7 +472,7 @@ def make_fp8_moe_kernel(
fp8_backend: Fp8MoeBackend, fp8_backend: Fp8MoeBackend,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None, shared_experts: torch.nn.Module | None = None,
) -> tuple[mk.FusedMoEModularKernel, bool]: ) -> mk.FusedMoEModularKernel:
# Create Prepare/Finalize. # Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize( prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config, moe=moe_config,
...@@ -512,8 +512,10 @@ def make_fp8_moe_kernel( ...@@ -512,8 +512,10 @@ def make_fp8_moe_kernel(
else None else None
), ),
moe_parallel_config=moe_config.moe_parallel_config, moe_parallel_config=moe_config.moe_parallel_config,
inplace=(
not moe_config.disable_inplace
and fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS
),
) )
# TODO(rob): update inplace logic to be part of the kernel. return kernel
inplace = fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS
return kernel, inplace
...@@ -437,6 +437,7 @@ def make_nvfp4_moe_kernel( ...@@ -437,6 +437,7 @@ def make_nvfp4_moe_kernel(
else None else None
), ),
moe_parallel_config=moe_config.moe_parallel_config, moe_parallel_config=moe_config.moe_parallel_config,
inplace=False,
) )
# TODO(rob): update inplace logic to be part of the kernel. # TODO(rob): update inplace logic to be part of the kernel.
......
...@@ -154,11 +154,9 @@ def make_unquantized_moe_kernel( ...@@ -154,11 +154,9 @@ def make_unquantized_moe_kernel(
backend: UnquantizedMoeBackend, backend: UnquantizedMoeBackend,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
) -> tuple[mk.FusedMoEModularKernel | None, bool]: ) -> mk.FusedMoEModularKernel | None:
use_inplace = True
if backend in UNSUPPORTED_BACKEND: if backend in UNSUPPORTED_BACKEND:
return None, use_inplace return None
if backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS: if backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
...@@ -171,8 +169,9 @@ def make_unquantized_moe_kernel( ...@@ -171,8 +169,9 @@ def make_unquantized_moe_kernel(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=quant_config,
), ),
inplace=False,
) )
use_inplace = False
elif backend == UnquantizedMoeBackend.AITER: elif backend == UnquantizedMoeBackend.AITER:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts, AiterExperts,
...@@ -184,6 +183,7 @@ def make_unquantized_moe_kernel( ...@@ -184,6 +183,7 @@ def make_unquantized_moe_kernel(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=quant_config,
), ),
inplace=not moe_config.disable_inplace,
) )
elif backend == UnquantizedMoeBackend.TRITON: elif backend == UnquantizedMoeBackend.TRITON:
from vllm.model_executor.layers.fused_moe import TritonExperts from vllm.model_executor.layers.fused_moe import TritonExperts
...@@ -194,6 +194,7 @@ def make_unquantized_moe_kernel( ...@@ -194,6 +194,7 @@ def make_unquantized_moe_kernel(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=quant_config,
), ),
inplace=not moe_config.disable_inplace,
) )
elif backend == UnquantizedMoeBackend.XPU: elif backend == UnquantizedMoeBackend.XPU:
from vllm.model_executor.layers.fused_moe import XPUExperts from vllm.model_executor.layers.fused_moe import XPUExperts
...@@ -204,5 +205,6 @@ def make_unquantized_moe_kernel( ...@@ -204,5 +205,6 @@ def make_unquantized_moe_kernel(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=quant_config,
), ),
inplace=not moe_config.disable_inplace,
) )
return kernel, use_inplace return kernel
...@@ -101,10 +101,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -101,10 +101,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def supports_eplb(self) -> bool: def supports_eplb(self) -> bool:
return True return True
@property
def allow_inplace(self) -> bool:
return True
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
...@@ -225,7 +221,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -225,7 +221,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
assert self.moe_quant_config is not None assert self.moe_quant_config is not None
self.kernel, self.use_inplace = make_unquantized_moe_kernel( self.kernel = make_unquantized_moe_kernel(
backend=self.unquantized_backend, backend=self.unquantized_backend,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
...@@ -329,7 +325,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -329,7 +325,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
w2=layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=self.use_inplace,
activation=layer.activation, activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
......
...@@ -785,4 +785,5 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -785,4 +785,5 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
w2_zeros=layer.w2_qzeros, w2_zeros=layer.w2_qzeros,
workspace=layer.workspace, workspace=layer.workspace,
input_dtype=self.input_dtype, input_dtype=self.input_dtype,
inplace=not self.moe.disable_inplace,
) )
...@@ -515,7 +515,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): ...@@ -515,7 +515,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
w2=w2, w2=w2,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=not self.moe.disable_inplace,
activation=layer.activation, activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
......
...@@ -357,7 +357,6 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -357,7 +357,6 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight, layer.w2_weight,
topk_weights, topk_weights,
topk_ids, topk_ids,
inplace=False,
activation=layer.activation, activation=layer.activation,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map, expert_map=layer.expert_map,
...@@ -669,7 +668,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -669,7 +668,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight, layer.w2_weight,
topk_weights, topk_weights,
topk_ids, topk_ids,
inplace=False,
activation=layer.activation, activation=layer.activation,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map, expert_map=layer.expert_map,
...@@ -960,7 +958,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -960,7 +958,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config: if self.moe_quant_config:
assert self.experts_cls is not None assert self.experts_cls is not None
self.moe_mk, self.use_inplace = make_fp8_moe_kernel( self.moe_mk = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
fp8_backend=self.fp8_backend, fp8_backend=self.fp8_backend,
...@@ -1073,7 +1071,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1073,7 +1071,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight, layer.w2_weight,
topk_weights, topk_weights,
topk_ids, topk_ids,
inplace=self.use_inplace,
activation=layer.activation, activation=layer.activation,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
# TODO(rob): investigate the disable_expert_map introduced by: # TODO(rob): investigate the disable_expert_map introduced by:
...@@ -1212,7 +1209,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1212,7 +1209,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
w2=layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=not self.moe.disable_inplace,
activation=layer.activation, activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
...@@ -1739,6 +1736,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1739,6 +1736,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
workspace=layer.workspace, workspace=layer.workspace,
input_dtype=self.marlin_input_dtype, input_dtype=self.marlin_input_dtype,
is_k_full=self.is_k_full, is_k_full=self.is_k_full,
inplace=not self.moe.disable_inplace,
) )
...@@ -1969,7 +1967,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -1969,7 +1967,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight_packed, layer.w2_weight_packed,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=not self.moe.disable_inplace,
activation=layer.activation, activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
...@@ -2605,6 +2603,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -2605,6 +2603,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
s_strides1=self.s_strides1, s_strides1=self.s_strides1,
s_strides2=self.s_strides2, s_strides2=self.s_strides2,
group_size=self.group_size, group_size=self.group_size,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
@property @property
......
...@@ -149,7 +149,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): ...@@ -149,7 +149,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=not self.moe.disable_inplace,
activation=layer.activation, activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
......
...@@ -854,7 +854,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -854,7 +854,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config: if self.moe_quant_config:
assert self.experts_cls is not None assert self.experts_cls is not None
self.moe_mk, self.use_inplace = make_fp8_moe_kernel( self.moe_mk = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
fp8_backend=self.fp8_backend, fp8_backend=self.fp8_backend,
...@@ -958,10 +958,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -958,10 +958,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def supports_eplb(self) -> bool: def supports_eplb(self) -> bool:
return True return True
@property
def allow_inplace(self) -> bool:
return True
@property @property
def is_monolithic(self) -> bool: def is_monolithic(self) -> bool:
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
...@@ -1032,7 +1028,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1032,7 +1028,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight, layer.w2_weight,
topk_weights, topk_weights,
topk_ids, topk_ids,
inplace=self.use_inplace,
activation=layer.activation, activation=layer.activation,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map, expert_map=layer.expert_map,
......
...@@ -924,4 +924,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -924,4 +924,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
workspace=layer.workspace, workspace=layer.workspace,
is_k_full=self.is_k_full, is_k_full=self.is_k_full,
input_dtype=self.input_dtype, input_dtype=self.input_dtype,
inplace=not self.moe.disable_inplace,
) )
...@@ -853,7 +853,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -853,7 +853,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config: if self.moe_quant_config:
assert self.experts_cls is not None assert self.experts_cls is not None
self.moe_mk, self.use_inplace = make_fp8_moe_kernel( self.moe_mk = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
fp8_backend=self.fp8_backend, fp8_backend=self.fp8_backend,
...@@ -967,7 +967,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -967,7 +967,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight, layer.w2_weight,
topk_weights, topk_weights,
topk_ids, topk_ids,
inplace=self.use_inplace,
activation=layer.activation, activation=layer.activation,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map, expert_map=layer.expert_map,
...@@ -1538,7 +1537,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1538,7 +1537,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.w2_weight, layer.w2_weight,
topk_weights, topk_weights,
topk_ids, topk_ids,
inplace=False,
activation=layer.activation, activation=layer.activation,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map, expert_map=layer.expert_map,
......
...@@ -378,7 +378,7 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -378,7 +378,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer.w2_qweight, layer.w2_qweight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=not self.moe.disable_inplace,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map, expert_map=layer.expert_map,
......
...@@ -881,10 +881,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -881,10 +881,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP" f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP"
) )
@property
def allow_inplace(self) -> bool:
return True
@property @property
def is_monolithic(self) -> bool: def is_monolithic(self) -> bool:
return ( return (
...@@ -923,6 +919,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -923,6 +919,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
activation=layer.activation, activation=layer.activation,
expert_map=layer.expert_map, expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype, input_dtype=self.marlin_input_dtype,
inplace=not self.moe.disable_inplace,
) )
assert _can_support_mxfp4( assert _can_support_mxfp4(
......
...@@ -388,6 +388,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -388,6 +388,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map, expert_map=layer.expert_map,
inplace=not self.moe.disable_inplace,
) )
else: else:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -398,7 +399,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -398,7 +399,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
w2=layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=not self.moe.disable_inplace,
activation=layer.activation, activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
...@@ -734,10 +735,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -734,10 +735,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
block_shape=None, block_shape=None,
) )
@property
def allow_inplace(self) -> bool:
return True
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
...@@ -769,7 +766,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -769,7 +766,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=not self.moe.disable_inplace,
activation=layer.activation, activation=layer.activation,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment