Unverified Commit f0c98cae authored by Varun Sundar Rabindranath's avatar Varun Sundar Rabindranath Committed by GitHub
Browse files

[Misc] MoE ModularKernel : Introduce TopKWeightAndReduce (#20648)


Signed-off-by: default avatarVarun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: default avatarVarun Sundar Rabindranath <vsundarr@redhat.com>
parent 574ad60d
...@@ -32,6 +32,8 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( ...@@ -32,6 +32,8 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel) FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import round_up from vllm.utils import round_up
...@@ -371,6 +373,7 @@ def pplx_prepare_finalize( ...@@ -371,6 +373,7 @@ def pplx_prepare_finalize(
chunk_topk_weight, chunk_topk_weight,
chunk_topk_ids, chunk_topk_ids,
False, False,
weight_and_reduce_impl=TopKWeightAndReduceDelegate(),
) )
torch.cuda.synchronize() torch.cuda.synchronize()
......
...@@ -7,6 +7,8 @@ import torch ...@@ -7,6 +7,8 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
...@@ -217,6 +219,10 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -217,6 +219,10 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return False return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
......
...@@ -88,6 +88,25 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -88,6 +88,25 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
return ((bdge is None or bdge.supports_expert_map()) return ((bdge is None or bdge.supports_expert_map())
and (bte is None or bte.supports_expert_map())) and (bte is None or bte.supports_expert_map()))
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
bdge = self.batched_deep_gemm_experts
bte = self.batched_triton_experts
bdge_war = bdge.finalize_weight_and_reduce_impl() if bdge else None
bte_war = bte.finalize_weight_and_reduce_impl() if bte else None
is_bdge_war = bdge_war is not None
is_bte_war = bte_war is not None
if is_bdge_war and is_bte_war:
assert bdge_war == bte_war, (
"Both implementations should agree on WeightAndReduce impls. "
f"Got bdge_war: {bdge_war}, and bte_war: {bte_war}")
if bdge_war is not None:
return bdge_war
assert bte_war is not None
return bte_war
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
......
...@@ -11,6 +11,8 @@ from vllm.logger import init_logger ...@@ -11,6 +11,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP) MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
_fp8_quantize, _fp8_quantize,
_resize_cache) _resize_cache)
...@@ -255,6 +257,10 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -255,6 +257,10 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return not self.use_batched_format return not self.use_batched_format
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
......
...@@ -12,6 +12,8 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( ...@@ -12,6 +12,8 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_permute) _moe_permute)
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP) MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, per_token_group_quant_fp8) _resize_cache, per_token_group_quant_fp8)
from vllm.utils import has_deep_gemm, round_up from vllm.utils import has_deep_gemm, round_up
...@@ -85,6 +87,10 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -85,6 +87,10 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return True return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def workspace_shapes( def workspace_shapes(
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
topk: int, global_num_experts: int, local_num_experts: int topk: int, global_num_experts: int, local_num_experts: int
......
...@@ -6,8 +6,9 @@ import deep_ep ...@@ -6,8 +6,9 @@ import deep_ep
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input) moe_kernel_quantize_input)
...@@ -187,45 +188,25 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -187,45 +188,25 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights) expert_topk_weights)
def _apply_weights_and_reduce(self, num_tokens: int,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
apply_router_weight_on_input: bool,
output_dtype: torch.dtype):
hidden_dim = fused_expert_output.size(-1)
if fused_expert_output.ndim == 2:
fused_expert_output = fused_expert_output.view(
num_tokens, -1, hidden_dim)
if not apply_router_weight_on_input:
# The DeepEP combine kernels don't do the topk weight
# multiplication. We multiply the weights locally.
m_x_topk = fused_expert_output.size(0)
fused_expert_output.mul_(topk_weights.view(m_x_topk, -1, 1))
out = torch.empty((num_tokens, hidden_dim),
device=fused_expert_output.device,
dtype=output_dtype)
ops.moe_sum(fused_expert_output, out)
return out
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> None: apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None:
assert self.handle is not None assert self.handle is not None
# fused_expert_output can have 0 tokens - This happens when none of the # fused_expert_output can have 0 tokens - This happens when none of the
# tokens from the all2all reach this EP rank. # tokens from the all2all reach this EP rank.
if fused_expert_output.numel() != 0: if fused_expert_output.numel() != 0:
fused_expert_output = self._apply_weights_and_reduce( if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
num_tokens=topk_ids.size(0), weight_and_reduce_impl = TopKWeightAndReduceContiguous()
fused_expert_output = weight_and_reduce_impl.apply(
output=None,
fused_expert_output=fused_expert_output, fused_expert_output=fused_expert_output,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
output_dtype=output.dtype) )
combined_x, _, event = self.buffer.combine( combined_x, _, event = self.buffer.combine(
x=fused_expert_output, x=fused_expert_output,
......
...@@ -7,6 +7,8 @@ import torch ...@@ -7,6 +7,8 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input, normalize_batched_scales_shape) moe_kernel_quantize_input, normalize_batched_scales_shape)
...@@ -166,8 +168,11 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -166,8 +168,11 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> None: apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None:
assert isinstance(
weight_and_reduce_impl, TopKWeightAndReduceDelegate
), ("Weight application and reduction happens in the combine kernel.")
assert self.handle is not None assert self.handle is not None
combine_topk_weights = topk_weights combine_topk_weights = topk_weights
......
...@@ -11,6 +11,8 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk ...@@ -11,6 +11,8 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, try_get_optimal_moe_config) get_config_dtype_str, try_get_optimal_moe_config)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate, TopKWeightAndReduceNaiveBatched)
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, moe_kernel_quantize_input, normalize_batched_scales_shape, _resize_cache, moe_kernel_quantize_input, normalize_batched_scales_shape,
normalize_scales_shape) normalize_scales_shape)
...@@ -600,25 +602,17 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -600,25 +602,17 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None: ) -> None:
num_tokens = topk_ids.size(0) if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
num_local_experts = fused_expert_output.size(0) weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank)
K = fused_expert_output.size(-1) weight_and_reduce_impl.apply(
assert output.size(0) == num_tokens and output.size(1) == K output=output,
fused_expert_output=fused_expert_output,
output.fill_(0) topk_weights=topk_weights,
topk_ids=topk_ids,
first_expert = num_local_experts * self.rank apply_router_weight_on_input=apply_router_weight_on_input,
last_expert = first_expert + num_local_experts )
for expert_id in range(first_expert, last_expert):
matching_tokens = topk_ids == expert_id
topks = torch.any(matching_tokens, dim=1).flatten()
rows = torch.count_nonzero(topks)
rhs = fused_expert_output[expert_id - first_expert, :rows, :]
if not apply_router_weight_on_input:
rhs.mul_(topk_weights[matching_tokens].view(rhs.size(0), 1))
output[topks] = output[topks] + rhs
class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
...@@ -670,6 +664,10 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -670,6 +664,10 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return False return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
...@@ -877,6 +875,10 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -877,6 +875,10 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return False return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
......
...@@ -25,6 +25,8 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( ...@@ -25,6 +25,8 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size) moe_align_block_size)
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP) MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, moe_kernel_quantize_input) _resize_cache, moe_kernel_quantize_input)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
...@@ -1596,6 +1598,10 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1596,6 +1598,10 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return True return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
......
...@@ -23,7 +23,7 @@ from vllm.utils import cdiv ...@@ -23,7 +23,7 @@ from vllm.utils import cdiv
# #
# [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine] # [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine]
# #
# Each component will be independent of the others except for # Each component will be independent of (but may inform) the others except for
# [Quantize-Dispatch] and `[Combine] (see below). The components can then be # [Quantize-Dispatch] and `[Combine] (see below). The components can then be
# mixed and matched with so that DP+EP can be supported easily for multiple # mixed and matched with so that DP+EP can be supported easily for multiple
# MoE kernel implementations. # MoE kernel implementations.
...@@ -32,13 +32,19 @@ from vllm.utils import cdiv ...@@ -32,13 +32,19 @@ from vllm.utils import cdiv
# * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE # * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE
# inputs (e.g. quantization, distribution) and finalization of Moe outputs. # inputs (e.g. quantization, distribution) and finalization of Moe outputs.
# The prepare method must take care of any needed quantization and the # The prepare method must take care of any needed quantization and the
# finalize method must apply weights and do the final reduction of the output. # finalize method, informed by the FusedMoEPermuteExpertsUnpermute method,
# may apply weights and/or do the final reduction of the output.
# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused # * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused
# MoE operation. One important feature to note is that this class does not # MoE operation, i.e matmul + act_mul + optionally quant + matmul.
# apply topk weights or reduce the final output. # Some FusedMoEPermuteExpertsUnpermute implementations may choose to do
# the weight application and/or reduction. The class communicates this
# to [Finalize] via a TopKWeightAndReduce object.
# * FusedMoEModularKernel - an interface class that combines a # * FusedMoEModularKernel - an interface class that combines a
# FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to # FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to
# provide the standard fused MoE kernel interface. # provide the standard fused MoE kernel interface.
# * TopKWeightAndReduce - A TopKWeightAndReduce implementation chosen
# by the FusedMoEPermuteExpertsUnpermute implementation that is passed
# on to [Finalize].
# #
# [Quantize-Prepare] and [Finalize] functionality are bundled into a single # [Quantize-Prepare] and [Finalize] functionality are bundled into a single
# class `FusedMoEPrepareAndFinalize` since they could use collective # class `FusedMoEPrepareAndFinalize` since they could use collective
...@@ -117,6 +123,24 @@ class ExpertTokensMetadata: ...@@ -117,6 +123,24 @@ class ExpertTokensMetadata:
expert_num_tokens_cpu=expert_num_tokens_cpu) expert_num_tokens_cpu=expert_num_tokens_cpu)
class TopKWeightAndReduce(ABC):
"""
An abstract base class for weight application and reduction implementations.
"""
@abstractmethod
def apply(self, output: Optional[torch.Tensor],
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> torch.Tensor:
"""
Apply topk_weights to the fused_experts_outputs and/or reduce.
If an output tensor is not passed, it will be created in the
function.
"""
raise NotImplementedError
# TODO: pass FusedMoEParallelConfig in as ctor parameter? # TODO: pass FusedMoEParallelConfig in as ctor parameter?
class FusedMoEPrepareAndFinalize(ABC): class FusedMoEPrepareAndFinalize(ABC):
""" """
...@@ -173,6 +197,7 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -173,6 +197,7 @@ class FusedMoEPrepareAndFinalize(ABC):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
weight_and_reduce_impl: TopKWeightAndReduce,
) -> None: ) -> None:
""" """
Perform any combine plus apply weights and perform a reduction on the Perform any combine plus apply weights and perform a reduction on the
...@@ -184,6 +209,8 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -184,6 +209,8 @@ class FusedMoEPrepareAndFinalize(ABC):
- topk_ids: The topk_ids. - topk_ids: The topk_ids.
- apply_router_weight_on_input: When False, apply the weights to - apply_router_weight_on_input: When False, apply the weights to
fused_expert_output. fused_expert_output.
- weight_and_reduce_impl: An optional TopKWeightAndReduce
implementation.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -323,6 +350,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -323,6 +350,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
return envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and \ return envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and \
self.supports_chunking() self.supports_chunking()
def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce:
raise NotImplementedError
@abstractmethod @abstractmethod
def apply( def apply(
self, self,
...@@ -702,7 +732,9 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -702,7 +732,9 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=a2_scale, a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta) expert_tokens_meta=expert_tokens_meta)
self.prepare_finalize.finalize(output, fused_out, topk_weights, self.prepare_finalize.finalize(
topk_ids, apply_router_weight_on_input) output, fused_out, topk_weights, topk_ids,
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl())
return output return output
...@@ -8,6 +8,8 @@ import torch ...@@ -8,6 +8,8 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
_validate_scale_shape, moe_kernel_quantize_input) _validate_scale_shape, moe_kernel_quantize_input)
from vllm.utils import cdiv, round_up from vllm.utils import cdiv, round_up
...@@ -222,7 +224,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -222,7 +224,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None: ) -> None:
assert isinstance(
weight_and_reduce_impl, TopKWeightAndReduceDelegate
), ("Weight application and reduction happens in the combine kernel.")
# This argument is optional # This argument is optional
# There's not much point setting this unless it is != topk_ids.size(0) # There's not much point setting this unless it is != topk_ids.size(0)
bound_m: Optional[torch.Tensor] = None bound_m: Optional[torch.Tensor] = None
......
...@@ -6,8 +6,8 @@ import torch ...@@ -6,8 +6,8 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
_moe_unpermute_and_reduce) TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input) moe_kernel_quantize_input)
...@@ -62,6 +62,13 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -62,6 +62,13 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None: ) -> None:
_moe_unpermute_and_reduce(output, fused_expert_output, None, if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
topk_weights, apply_router_weight_on_input) weight_and_reduce_impl = TopKWeightAndReduceContiguous()
weight_and_reduce_impl.apply(
output=output,
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce):
"""
Useful in the case when some FusedMoEPermuteExpertsUnpermute
implementation does not perform weight application and reduction
but cannot address the needs of all the compatible PrepareAndFinalize
implementations.
For example, BatchedTritonExperts is compatible with both
PplxPrepareAndFinalize and BatchedPrepareAndFinalize. PplxPrepareAndFinalize
does the weight-application + reduction as part of the pplx combine kernel.
But the BatchedPrepareAndFinalize needs an implementation. To facilitate
this case, the BatchedTritonExperts could use TopKWeightAndReduceDelegate
so the PrepareAndFinalize implementations could choose how to
weight + reduce.
"""
def __eq__(self, other):
return isinstance(other, TopKWeightAndReduceDelegate)
def apply(self, output: Optional[torch.Tensor],
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> torch.Tensor:
raise RuntimeError("The caller is expected to choose an appropriate "
"TopKWeightAndReduce implementation.")
class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce):
"""
The fused_experts outputs have already been weight applied and reduced.
This implementation is a no-op.
"""
def __eq__(self, other):
return isinstance(other, TopKWeightAndReduceNoOP)
def apply(self, output: Optional[torch.Tensor],
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> torch.Tensor:
# Relax this if an explicit copy is necessary. Note that,
# if a copy is employed we have to make sure that the
# tensors don't overlap
assert output is None
return fused_expert_output
class TopKWeightAndReduceContiguous(mk.TopKWeightAndReduce):
"""
TopKWeightAndReduce implementation for a fused_experts output
of shape (m, topk, K)
"""
def __eq__(self, other):
return isinstance(other, TopKWeightAndReduceContiguous)
def apply(self, output: Optional[torch.Tensor],
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> torch.Tensor:
m, num_topk = topk_ids.size()
k = fused_expert_output.size(-1)
if fused_expert_output.ndim == 2:
fused_expert_output = fused_expert_output.view(m, num_topk, k)
assert fused_expert_output.size() == (m, num_topk, k), (
f"Expected fused_expert_output size {(m, num_topk, k)}. But got "
f"{fused_expert_output.size()}")
if not apply_router_weight_on_input:
fused_expert_output.mul_(topk_weights.view(m, -1, 1))
if output is None:
output = torch.empty((m, k),
device=fused_expert_output.device,
dtype=fused_expert_output.dtype)
assert output.size() == (m, k), (
f"Expected output size {(m, k)}. But got {output.size()}")
ops.moe_sum(fused_expert_output, output)
return output
class TopKWeightAndReduceNaiveBatched(mk.TopKWeightAndReduce):
"""
TopKWeightAndReduce implementation for a fused_experts output
of shape (num_experts, batch_size, K)
"""
def __init__(self, rank: int):
self.rank = rank
def __eq__(self, other):
return (isinstance(other, TopKWeightAndReduceNaiveBatched)
and (other.rank == self.rank))
def apply(self, output: Optional[torch.Tensor],
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> torch.Tensor:
assert fused_expert_output.ndim == 3
num_tokens = topk_ids.size(0)
num_local_experts = fused_expert_output.size(0)
K = fused_expert_output.size(-1)
if output is None:
output = torch.zeros((num_tokens, K),
device=fused_expert_output.device,
dtype=fused_expert_output.dtype)
else:
output.fill_(0)
assert output.size() == (num_tokens, K), (
f"Expected output size {(num_tokens, K)}, but got {output.size()}")
first_expert = num_local_experts * self.rank
last_expert = first_expert + num_local_experts
for expert_id in range(first_expert, last_expert):
matching_tokens = topk_ids == expert_id
topks = torch.any(matching_tokens, dim=1).flatten()
rows = torch.count_nonzero(topks)
rhs = fused_expert_output[expert_id - first_expert, :rows, :]
if not apply_router_weight_on_input:
rhs.mul_(topk_weights[matching_tokens].view(rhs.size(0), 1))
output[topks] = output[topks] + rhs
return output
...@@ -69,6 +69,25 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -69,6 +69,25 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
return ((dge is None or dge.supports_expert_map()) return ((dge is None or dge.supports_expert_map())
and (te is None or te.supports_expert_map())) and (te is None or te.supports_expert_map()))
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
dge = self.deep_gemm_expert
te = self.triton_expert
dge_war = dge.finalize_weight_and_reduce_impl() if dge else None
te_war = te.finalize_weight_and_reduce_impl() if te else None
is_dge_war = dge_war is not None
is_te_war = te_war is not None
if is_dge_war and is_te_war:
assert dge_war == te_war, (
"Both implementations should agree on WeightAndReduce impls. "
f"Got dge_war: {dge_war}, and te_war: {te_war}")
if dge_war is not None:
return dge_war
assert te_war is not None
return te_war
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment