Unverified Commit 7cf56a59 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Make SharedExperts class for use with DefaultMoERunner (#35153)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent 5e30e9b9
...@@ -3,14 +3,10 @@ ...@@ -3,14 +3,10 @@
import torch import torch
from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.fused_moe.layer import FusedMoE
# TODO(bnell): Add shared + fused combo function? e.g. + # TODO(bnell): Remove this entirely
class SharedFusedMoE(FusedMoE): class SharedFusedMoE(FusedMoE):
""" """
A FusedMoE operation that also computes the results of shared experts. A FusedMoE operation that also computes the results of shared experts.
...@@ -23,36 +19,11 @@ class SharedFusedMoE(FusedMoE): ...@@ -23,36 +19,11 @@ class SharedFusedMoE(FusedMoE):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if not self.use_overlapped: result = super().forward(
if self._shared_experts is not None:
shared_out = self._shared_experts(hidden_states)
# Reduce shared expert outputs if necessary, since the MLP
# should have been created with reduce_results=False.
if (
self.reduce_results
and get_tensor_model_parallel_world_size() > 1
and self.must_reduce_shared_expert_outputs()
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
else:
shared_out = None
fused_out = super().forward(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
) )
if self.shared_experts is None:
return None, result
else: else:
shared_out, fused_out = super().forward( return result
hidden_states=hidden_states,
router_logits=router_logits,
)
# ensure early TP reduction of shared expert outputs when required
if (
shared_out is not None
and self.reduce_results
and get_tensor_model_parallel_world_size() > 1
and self.must_reduce_shared_expert_outputs()
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
return shared_out, fused_out
...@@ -245,7 +245,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -245,7 +245,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
return self.forward( return self.forward(
layer=layer, layer=layer,
x=x, x=x,
...@@ -261,7 +261,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -261,7 +261,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
assert self.moe_kernel is not None assert self.moe_kernel is not None
return self.moe_kernel.apply( return self.moe_kernel.apply(
hidden_states=x, hidden_states=x,
...@@ -283,7 +283,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -283,7 +283,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
return self.forward_native( return self.forward_native(
layer, x, topk_weights, topk_ids, shared_experts_input layer, x, topk_weights, topk_ids, shared_experts_input
) )
...@@ -293,7 +293,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -293,7 +293,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
assert self.is_monolithic assert self.is_monolithic
if self.unquantized_backend == UnquantizedMoeBackend.CPU: if self.unquantized_backend == UnquantizedMoeBackend.CPU:
assert self.moe_kernel is None assert self.moe_kernel is None
......
...@@ -811,7 +811,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -811,7 +811,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
return fused_marlin_moe( return fused_marlin_moe(
x, x,
layer.w13_qweight, layer.w13_qweight,
......
...@@ -483,7 +483,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): ...@@ -483,7 +483,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
# TODO(bnell): Do these need to be called on the hot path? # TODO(bnell): Do these need to be called on the hot path?
......
...@@ -355,7 +355,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -355,7 +355,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
assert self.moe_kernel is not None assert self.moe_kernel is not None
return self.moe_kernel.apply( return self.moe_kernel.apply(
x, x,
...@@ -603,7 +603,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -603,7 +603,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
assert self.is_monolithic assert self.is_monolithic
assert self.moe_kernel is not None assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic( return self.moe_kernel.apply_monolithic(
...@@ -628,7 +628,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -628,7 +628,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
assert self.moe_kernel is not None assert self.moe_kernel is not None
return self.moe_kernel.apply( return self.moe_kernel.apply(
x, x,
...@@ -963,7 +963,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -963,7 +963,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
assert self.moe_kernel is not None assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic( return self.moe_kernel.apply_monolithic(
x, x,
...@@ -987,7 +987,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -987,7 +987,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
assert not self.is_monolithic assert not self.is_monolithic
assert self.moe_kernel is not None assert self.moe_kernel is not None
return self.moe_kernel.apply( return self.moe_kernel.apply(
...@@ -1127,7 +1127,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1127,7 +1127,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts( return fused_experts(
...@@ -1611,7 +1611,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1611,7 +1611,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
assert self.kernel_backend == "Flashinfer" assert self.kernel_backend == "Flashinfer"
return flashinfer_trtllm_mxint4_moe( return flashinfer_trtllm_mxint4_moe(
x=x, x=x,
...@@ -1638,7 +1638,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1638,7 +1638,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
assert self.kernel_backend == "Marlin" assert self.kernel_backend == "Marlin"
return fused_marlin_moe( return fused_marlin_moe(
x, x,
...@@ -1887,7 +1887,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -1887,7 +1887,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts( return fused_experts(
...@@ -2502,7 +2502,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -2502,7 +2502,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
if layer.enable_eplb: if layer.enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet." "EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet."
......
...@@ -141,7 +141,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): ...@@ -141,7 +141,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts( return fused_experts(
......
...@@ -877,7 +877,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -877,7 +877,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
assert self.is_monolithic assert self.is_monolithic
assert self.moe_kernel is not None assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic( return self.moe_kernel.apply_monolithic(
...@@ -902,7 +902,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -902,7 +902,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
assert not self.is_monolithic assert not self.is_monolithic
assert self.moe_kernel is not None assert self.moe_kernel is not None
return self.moe_kernel.apply( return self.moe_kernel.apply(
......
...@@ -650,7 +650,7 @@ class GGUFMoEMethod(FusedMoEMethodBase): ...@@ -650,7 +650,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
if layer.apply_router_weight_on_input: if layer.apply_router_weight_on_input:
raise NotImplementedError( raise NotImplementedError(
"Apply router weight on input is not supported for" "Apply router weight on input is not supported for"
......
...@@ -907,7 +907,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -907,7 +907,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
return fused_marlin_moe( return fused_marlin_moe(
x, x,
layer.w13_qweight, layer.w13_qweight,
......
...@@ -935,7 +935,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -935,7 +935,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
assert self.is_monolithic assert self.is_monolithic
assert self.moe_kernel is not None assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic( return self.moe_kernel.apply_monolithic(
...@@ -960,7 +960,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -960,7 +960,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
assert not self.is_monolithic assert not self.is_monolithic
assert self.moe_kernel is not None assert self.moe_kernel is not None
return self.moe_kernel.apply( return self.moe_kernel.apply(
...@@ -1419,7 +1419,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1419,7 +1419,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
assert self.is_monolithic assert self.is_monolithic
assert self.moe_kernel is not None assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic( return self.moe_kernel.apply_monolithic(
...@@ -1444,7 +1444,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1444,7 +1444,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
assert not self.is_monolithic assert not self.is_monolithic
assert self.moe_kernel is not None assert self.moe_kernel is not None
return self.moe_kernel.apply( return self.moe_kernel.apply(
......
...@@ -369,7 +369,7 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -369,7 +369,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
assert layer.activation == MoEActivation.SILU, ( assert layer.activation == MoEActivation.SILU, (
......
...@@ -377,7 +377,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -377,7 +377,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
assert not self.is_monolithic assert not self.is_monolithic
assert self.moe_kernel is not None assert self.moe_kernel is not None
return self.moe_kernel.apply( return self.moe_kernel.apply(
...@@ -398,7 +398,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -398,7 +398,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
assert self.is_monolithic assert self.is_monolithic
assert self.moe_kernel is not None assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic( return self.moe_kernel.apply_monolithic(
......
...@@ -444,7 +444,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -444,7 +444,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts, rocm_aiter_fused_experts,
...@@ -634,7 +634,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -634,7 +634,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts, rocm_aiter_fused_experts,
) )
...@@ -1027,7 +1027,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -1027,7 +1027,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
if not self.emulate: if not self.emulate:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts, rocm_aiter_fused_experts,
......
...@@ -94,6 +94,8 @@ def transformers_moe_forward( ...@@ -94,6 +94,8 @@ def transformers_moe_forward(
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
self._topk_ids = topk_ids self._topk_ids = topk_ids
# Clone hidden_states because it will be mutated in-place in FusedMoE # Clone hidden_states because it will be mutated in-place in FusedMoE
# TODO(bnell): figure out a way to avoid calling runner directly.
# it is a hack that the weight are being passed via logits.
return self.runner.forward(hidden_states.clone(), topk_weights) return self.runner.forward(hidden_states.clone(), topk_weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment