Unverified Commit 7cf56a59 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Make SharedExperts class for use with DefaultMoERunner (#35153)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent 5e30e9b9
......@@ -3,14 +3,10 @@
import torch
from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
# TODO(bnell): Add shared + fused combo function? e.g. +
# TODO(bnell): Remove this entirely
class SharedFusedMoE(FusedMoE):
"""
A FusedMoE operation that also computes the results of shared experts.
......@@ -23,36 +19,11 @@ class SharedFusedMoE(FusedMoE):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if not self.use_overlapped:
if self._shared_experts is not None:
shared_out = self._shared_experts(hidden_states)
# Reduce shared expert outputs if necessary, since the MLP
# should have been created with reduce_results=False.
if (
self.reduce_results
and get_tensor_model_parallel_world_size() > 1
and self.must_reduce_shared_expert_outputs()
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
else:
shared_out = None
fused_out = super().forward(
result = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
)
if self.shared_experts is None:
return None, result
else:
shared_out, fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
)
# ensure early TP reduction of shared expert outputs when required
if (
shared_out is not None
and self.reduce_results
and get_tensor_model_parallel_world_size() > 1
and self.must_reduce_shared_expert_outputs()
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
return shared_out, fused_out
return result
......@@ -245,7 +245,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
return self.forward(
layer=layer,
x=x,
......@@ -261,7 +261,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
assert self.moe_kernel is not None
return self.moe_kernel.apply(
hidden_states=x,
......@@ -283,7 +283,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
return self.forward_native(
layer, x, topk_weights, topk_ids, shared_experts_input
)
......@@ -293,7 +293,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
assert self.is_monolithic
if self.unquantized_backend == UnquantizedMoeBackend.CPU:
assert self.moe_kernel is None
......
......@@ -811,7 +811,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
return fused_marlin_moe(
x,
layer.w13_qweight,
......
......@@ -483,7 +483,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
# TODO(bnell): Do these need to be called on the hot path?
......
......@@ -355,7 +355,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
......@@ -603,7 +603,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
assert self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
......@@ -628,7 +628,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
......@@ -963,7 +963,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
x,
......@@ -987,7 +987,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
assert not self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply(
......@@ -1127,7 +1127,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
......@@ -1611,7 +1611,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
assert self.kernel_backend == "Flashinfer"
return flashinfer_trtllm_mxint4_moe(
x=x,
......@@ -1638,7 +1638,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
assert self.kernel_backend == "Marlin"
return fused_marlin_moe(
x,
......@@ -1887,7 +1887,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
......@@ -2502,7 +2502,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
if layer.enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet."
......
......@@ -141,7 +141,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
......
......@@ -877,7 +877,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
assert self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
......@@ -902,7 +902,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
assert not self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply(
......
......@@ -650,7 +650,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
if layer.apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for"
......
......@@ -907,7 +907,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
return fused_marlin_moe(
x,
layer.w13_qweight,
......
......@@ -935,7 +935,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
assert self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
......@@ -960,7 +960,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
assert not self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply(
......@@ -1419,7 +1419,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
assert self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
......@@ -1444,7 +1444,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
assert not self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply(
......
......@@ -369,7 +369,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
assert layer.activation == MoEActivation.SILU, (
......
......@@ -377,7 +377,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
assert not self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply(
......@@ -398,7 +398,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
assert self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
......
......@@ -444,7 +444,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
......@@ -634,7 +634,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
)
......@@ -1027,7 +1027,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
if not self.emulate:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
......
......@@ -94,6 +94,8 @@ def transformers_moe_forward(
self = forward_context.no_compile_layers[layer_name]
self._topk_ids = topk_ids
# Clone hidden_states because it will be mutated in-place in FusedMoE
# TODO(bnell): figure out a way to avoid calling runner directly.
# it is a hack that the weight are being passed via logits.
return self.runner.forward(hidden_states.clone(), topk_weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment