Unverified Commit 447c372a authored by Jackmin801's avatar Jackmin801 Committed by GitHub
Browse files

[MoE] Move remaining PrepareAndFinalize to prepare finalize folder (#39009)


Signed-off-by: default avatarRobert Shaw <robertgshaw2@gmail.com>
Signed-off-by: default avatarJackmin801 <ongjackm@gmail.com>
Co-authored-by: default avatarRobert Shaw <robertgshaw2@gmail.com>
parent ff2c2bd8
...@@ -223,7 +223,7 @@ if has_deep_ep() and not current_platform.has_device_capability(100): ...@@ -223,7 +223,7 @@ if has_deep_ep() and not current_platform.has_device_capability(100):
) )
if has_mori(): if has_mori():
from vllm.model_executor.layers.fused_moe.mori_prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize.mori import (
MoriPrepareAndFinalize, MoriPrepareAndFinalize,
) )
......
...@@ -10,10 +10,12 @@ from vllm.model_executor.layers.fused_moe.experts.batched_deep_gemm_moe import ( ...@@ -10,10 +10,12 @@ from vllm.model_executor.layers.fused_moe.experts.batched_deep_gemm_moe import (
BatchedDeepGemmExperts, BatchedDeepGemmExperts,
) )
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize,
BatchedTritonExperts, BatchedTritonExperts,
) )
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
from vllm.model_executor.layers.fused_moe.prepare_finalize.batched import (
BatchedPrepareAndFinalize,
)
from vllm.utils.deep_gemm import calc_diff, is_deep_gemm_supported from vllm.utils.deep_gemm import calc_diff, is_deep_gemm_supported
from .test_deepgemm import make_block_quant_fp8_weights from .test_deepgemm import make_block_quant_fp8_weights
......
...@@ -18,7 +18,6 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -18,7 +18,6 @@ from vllm.model_executor.layers.fused_moe.config import (
RoutingMethodType, RoutingMethodType,
) )
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize,
BatchedTritonExperts, BatchedTritonExperts,
NaiveBatchedExperts, NaiveBatchedExperts,
) )
...@@ -27,6 +26,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( ...@@ -27,6 +26,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_experts,
) )
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
from vllm.model_executor.layers.fused_moe.prepare_finalize.batched import (
BatchedPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.deep_gemm import per_block_cast_to_fp8 from vllm.utils.deep_gemm import per_block_cast_to_fp8
......
...@@ -41,9 +41,9 @@ if current_platform.is_cuda_alike(): ...@@ -41,9 +41,9 @@ if current_platform.is_cuda_alike():
DeepEPLLPrepareAndFinalize, DeepEPLLPrepareAndFinalize,
) )
if has_mori(): if has_mori():
from .mori_prepare_finalize import MoriPrepareAndFinalize from .prepare_finalize.mori import MoriPrepareAndFinalize
if has_nixl_ep(): if has_nixl_ep():
from .nixl_ep_prepare_finalize import ( from .prepare_finalize.nixl_ep import (
NIXL_EP_QUANT_BLOCK_SHAPE, NIXL_EP_QUANT_BLOCK_SHAPE,
NixlEPPrepareAndFinalize, NixlEPPrepareAndFinalize,
) )
......
...@@ -14,13 +14,11 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -14,13 +14,11 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_moe import try_get_optimal_moe_config from vllm.model_executor.layers.fused_moe.fused_moe import try_get_optimal_moe_config
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate, TopKWeightAndReduceDelegate,
TopKWeightAndReduceNaiveBatched,
) )
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, _resize_cache,
moe_kernel_quantize_input, moe_kernel_quantize_input,
normalize_batched_scales_shape, normalize_batched_scales_shape,
normalize_scales_shape,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
...@@ -489,162 +487,6 @@ def invoke_moe_batched_triton_kernel( ...@@ -489,162 +487,6 @@ def invoke_moe_batched_triton_kernel(
) )
class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
"""
A reference prepare/finalize class that reorganizes the tokens into
expert batched format, i.e. E x max_num_tokens x K. This is the format
that the batched dispatch/combine kernels use.
"""
def __init__(
self,
max_num_tokens: int,
num_local_experts: int,
num_dispatchers: int,
rank: int,
):
super().__init__()
self.max_num_tokens = max_num_tokens
self.num_local_experts = num_local_experts
self.rank = rank
self.num_dispatchers_ = num_dispatchers
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts
def max_num_tokens_per_rank(self) -> int | None:
return self.max_num_tokens
def topk_indices_dtype(self) -> torch.dtype | None:
return None
def num_dispatchers(self) -> int:
return self.num_dispatchers_
def output_is_reduced(self) -> bool:
return False
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
assert a1.dim() == 2
assert topk_ids.dim() == 2
assert topk_ids.size(0) == a1.size(0)
if apply_router_weight_on_input:
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
a1.mul_(topk_weights.to(a1.dtype))
num_tokens, hidden_dim = a1.size()
topk = topk_ids.size(1)
tokens_per_expert = torch.zeros(num_experts, dtype=torch.int, device=a1.device)
num_local_experts = self.num_local_experts
if quant_config.quant_dtype is None:
b_type = a1.dtype
else:
b_type = quant_config.quant_dtype
b_a1 = torch.zeros(
(num_local_experts, self.max_num_tokens, hidden_dim),
dtype=b_type,
device=a1.device,
)
if quant_config.is_quantized:
scale_shape = quant_config.batched_scale_shape(
num_local_experts, self.max_num_tokens, hidden_dim
)
b_a1_scale = torch.empty(scale_shape, dtype=torch.float32, device=a1.device)
else:
assert quant_config.a1_scale is None
b_a1_scale = None
first_expert = num_local_experts * self.rank
last_expert = first_expert + num_local_experts
a1_scale = normalize_scales_shape(quant_config.a1_scale)
for expert_id in range(first_expert, last_expert):
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
rows = torch.count_nonzero(topks.flatten())
if rows == 0:
continue
idx = expert_id - first_expert
tokens_per_expert[idx] = rows
rhs = a1[: topks.numel()][topks]
if quant_config.quant_dtype is not None:
if a1_scale is not None:
if quant_config.is_per_act_token:
rhs_a1_scale = a1_scale[: topks.numel()][topks]
else:
rhs_a1_scale = a1_scale
else:
rhs_a1_scale = None
b_a1[idx, :rows, :], b_s = moe_kernel_quantize_input(
rhs,
rhs_a1_scale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
)
assert b_s is not None
if quant_config.is_per_act_token:
b_a1_scale[idx, :rows] = b_s[:rows]
else:
b_a1_scale[idx, : b_s.shape[0]] = b_s
else:
b_a1[idx, :rows, :] = rhs
assert b_a1_scale is None or b_a1_scale.ndim == 3
expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=tokens_per_expert, expert_num_tokens_cpu=None
)
return b_a1, b_a1_scale, expert_tokens_meta, None, None
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank)
weight_and_reduce_impl.apply(
output=output,
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input,
)
class NaiveBatchedExperts(mk.FusedMoEExpertsModular): class NaiveBatchedExperts(mk.FusedMoEExpertsModular):
""" """
A reference MoE expert class that operates on expert batched format, A reference MoE expert class that operates on expert batched format,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.model_executor.layers.fused_moe.prepare_finalize.batched import (
BatchedPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize.naive_dp_ep import ( from vllm.model_executor.layers.fused_moe.prepare_finalize.naive_dp_ep import (
MoEPrepareAndFinalizeNaiveDPEPModular, MoEPrepareAndFinalizeNaiveDPEPModular,
MoEPrepareAndFinalizeNaiveDPEPMonolithic, MoEPrepareAndFinalizeNaiveDPEPMonolithic,
...@@ -13,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize.no_dp_ep import ( ...@@ -13,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize.no_dp_ep import (
) )
__all__ = [ __all__ = [
"BatchedPrepareAndFinalize",
"MoEPrepareAndFinalizeNaiveDPEPMonolithic", "MoEPrepareAndFinalizeNaiveDPEPMonolithic",
"MoEPrepareAndFinalizeNaiveDPEPModular", "MoEPrepareAndFinalizeNaiveDPEPModular",
"make_moe_prepare_and_finalize_naive_dp_ep", "make_moe_prepare_and_finalize_naive_dp_ep",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
TopKWeightAndReduceNaiveBatched,
)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input,
normalize_scales_shape,
)
class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
"""
A reference prepare/finalize class that reorganizes the tokens into
expert batched format, i.e. E x max_num_tokens x K. This is the format
that the batched dispatch/combine kernels use.
"""
def __init__(
self,
max_num_tokens: int,
num_local_experts: int,
num_dispatchers: int,
rank: int,
):
super().__init__()
self.max_num_tokens = max_num_tokens
self.num_local_experts = num_local_experts
self.rank = rank
self.num_dispatchers_ = num_dispatchers
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts
def max_num_tokens_per_rank(self) -> int | None:
return self.max_num_tokens
def topk_indices_dtype(self) -> torch.dtype | None:
return None
def num_dispatchers(self) -> int:
return self.num_dispatchers_
def output_is_reduced(self) -> bool:
return False
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
assert a1.dim() == 2
assert topk_ids.dim() == 2
assert topk_ids.size(0) == a1.size(0)
if apply_router_weight_on_input:
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
a1.mul_(topk_weights.to(a1.dtype))
num_tokens, hidden_dim = a1.size()
topk = topk_ids.size(1)
tokens_per_expert = torch.zeros(num_experts, dtype=torch.int, device=a1.device)
num_local_experts = self.num_local_experts
if quant_config.quant_dtype is None:
b_type = a1.dtype
else:
b_type = quant_config.quant_dtype
b_a1 = torch.zeros(
(num_local_experts, self.max_num_tokens, hidden_dim),
dtype=b_type,
device=a1.device,
)
if quant_config.is_quantized:
scale_shape = quant_config.batched_scale_shape(
num_local_experts, self.max_num_tokens, hidden_dim
)
b_a1_scale = torch.empty(scale_shape, dtype=torch.float32, device=a1.device)
else:
assert quant_config.a1_scale is None
b_a1_scale = None
first_expert = num_local_experts * self.rank
last_expert = first_expert + num_local_experts
a1_scale = normalize_scales_shape(quant_config.a1_scale)
for expert_id in range(first_expert, last_expert):
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
rows = torch.count_nonzero(topks.flatten())
if rows == 0:
continue
idx = expert_id - first_expert
tokens_per_expert[idx] = rows
rhs = a1[: topks.numel()][topks]
if quant_config.quant_dtype is not None:
if a1_scale is not None:
if quant_config.is_per_act_token:
rhs_a1_scale = a1_scale[: topks.numel()][topks]
else:
rhs_a1_scale = a1_scale
else:
rhs_a1_scale = None
b_a1[idx, :rows, :], b_s = moe_kernel_quantize_input(
rhs,
rhs_a1_scale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
)
assert b_s is not None
if quant_config.is_per_act_token:
b_a1_scale[idx, :rows] = b_s[:rows]
else:
b_a1_scale[idx, : b_s.shape[0]] = b_s
else:
b_a1[idx, :rows, :] = rhs
assert b_a1_scale is None or b_a1_scale.ndim == 3
expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=tokens_per_expert, expert_num_tokens_cpu=None
)
return b_a1, b_a1_scale, expert_tokens_meta, None, None
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank)
weight_and_reduce_impl.apply(
output=output,
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment