Unverified Commit 2cdf9222 authored by Xinan Miao's avatar Xinan Miao Committed by GitHub
Browse files

[Feature]: Remove Chunking From FusedMoE (#34086)


Signed-off-by: default avatarSouthWest7 <am1ao@qq.com>
Signed-off-by: default avatarSouthwest <1403572259@qq.com>
Signed-off-by: default avatarsouthwest <am1ao@qq.com>
Signed-off-by: default avatarXinan Miao <1403572259@qq.com>
Co-authored-by: default avatarSouthWest7 <am1ao@qq.com>
parent c973ecde
...@@ -658,9 +658,6 @@ class MarlinExperts(MarlinExpertsBase): ...@@ -658,9 +658,6 @@ class MarlinExperts(MarlinExpertsBase):
def activation_format() -> mk.FusedMoEActivationFormat: def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard return mk.FusedMoEActivationFormat.Standard
def supports_chunking(self) -> bool:
return True
def workspace_shapes( def workspace_shapes(
self, self,
M: int, M: int,
...@@ -786,9 +783,6 @@ class BatchedMarlinExperts(MarlinExpertsBase): ...@@ -786,9 +783,6 @@ class BatchedMarlinExperts(MarlinExpertsBase):
def activation_format() -> mk.FusedMoEActivationFormat: def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts return mk.FusedMoEActivationFormat.BatchedExperts
def supports_chunking(self) -> bool:
return False
def workspace_shapes( def workspace_shapes(
self, self,
M: int, M: int,
......
...@@ -1693,10 +1693,8 @@ def fused_experts_impl( ...@@ -1693,10 +1693,8 @@ def fused_experts_impl(
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = E global_num_experts = E
top_k_num = topk_ids.size(1) top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938 M = num_tokens
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
config_dtype = _get_config_dtype_str( config_dtype = _get_config_dtype_str(
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
...@@ -1787,139 +1785,114 @@ def fused_experts_impl( ...@@ -1787,139 +1785,114 @@ def fused_experts_impl(
else: else:
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}") raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
for chunk in range((num_tokens // CHUNK_SIZE) + 1): qhidden_states, a1q_scale = moe_kernel_quantize_input(
begin_chunk_idx, end_chunk_idx = ( A=hidden_states,
chunk * CHUNK_SIZE, A_scale=a1_scale,
min((chunk + 1) * CHUNK_SIZE, num_tokens), quant_dtype=quant_dtype,
) per_act_token_quant=per_channel_quant,
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] block_shape=block_shape,
tokens_in_chunk, _ = curr_hidden_states.size() ocp_mx_scheme=ocp_mx_scheme,
)
if tokens_in_chunk == 0:
break
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
# Adjust the intermediate cache size and config for the last
# chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
intermediate_cache2 = intermediate_cache2[
: tokens_in_chunk * topk_ids.size(1)
]
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
config = get_config_func(tokens_in_chunk)
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
A=curr_hidden_states,
A_scale=a1_scale,
quant_dtype=quant_dtype,
per_act_token_quant=per_channel_quant,
block_shape=block_shape,
ocp_mx_scheme=ocp_mx_scheme,
)
# SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k # SPARSITY_FACTOR is a heuristic margin ensuring num_tokens * top_k
# activates only a small fraction of total experts # activates only a small fraction of total experts
SPARSITY_FACTOR = 4 SPARSITY_FACTOR = 4
# block quantized code path is not implemented yet. # block quantized code path is not implemented yet.
naive_block_assignment = ( naive_block_assignment = (
expert_map is None expert_map is None
and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts and num_tokens * top_k_num * SPARSITY_FACTOR <= global_num_experts
and not ( and not (
(use_int8_w8a16 or use_int4_w4a16) (use_int8_w8a16 or use_int4_w4a16)
and block_shape is not None and block_shape is not None
and block_shape[1] > 0 and block_shape[1] > 0
)
) )
)
if not naive_block_assignment: if not naive_block_assignment:
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
curr_topk_ids, topk_ids,
config["BLOCK_SIZE_M"], config["BLOCK_SIZE_M"],
global_num_experts, global_num_experts,
expert_map, expert_map,
ignore_invalid_experts=True, ignore_invalid_experts=True,
)
else:
max_num_tokens_padded = topk_ids.numel() * config["BLOCK_SIZE_M"]
expert_ids = curr_topk_ids.view(-1)
num_tokens_post_padded = torch.empty(
(1), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_padded.fill_(max_num_tokens_padded)
sorted_token_ids = None
dispatch_fused_moe_kernel(
qcurr_hidden_states,
w1,
intermediate_cache1,
a1q_scale,
w1_scale,
w1_zp,
curr_topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
apply_router_weight_on_input,
top_k_num,
config,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
B_bias=w1_bias,
) )
else:
apply_moe_activation( max_num_tokens_padded = topk_ids.numel() * config["BLOCK_SIZE_M"]
activation_enum, intermediate_cache2, intermediate_cache1.view(-1, N) expert_ids = topk_ids.view(-1)
num_tokens_post_padded = torch.empty(
(1), dtype=torch.int32, device=topk_ids.device
) )
num_tokens_post_padded.fill_(max_num_tokens_padded)
sorted_token_ids = None
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( dispatch_fused_moe_kernel(
A=intermediate_cache2, qhidden_states,
A_scale=a2_scale, w1,
quant_dtype=quant_dtype, intermediate_cache1,
per_act_token_quant=per_channel_quant, a1q_scale,
block_shape=block_shape, w1_scale,
ocp_mx_scheme=ocp_mx_scheme, w1_zp,
) topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
apply_router_weight_on_input,
top_k_num,
config,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
B_bias=w1_bias,
)
if expert_map is not None: apply_moe_activation(
intermediate_cache3.zero_() activation_enum, intermediate_cache2, intermediate_cache1.view(-1, N)
)
dispatch_fused_moe_kernel( qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
qintermediate_cache2, A=intermediate_cache2,
w2, A_scale=a2_scale,
intermediate_cache3, quant_dtype=quant_dtype,
a2q_scale, per_act_token_quant=per_channel_quant,
w2_scale, block_shape=block_shape,
w2_zp, ocp_mx_scheme=ocp_mx_scheme,
curr_topk_weights, )
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
not apply_router_weight_on_input,
1,
config,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
B_bias=w2_bias,
)
ops.moe_sum( if expert_map is not None:
intermediate_cache3.view(*intermediate_cache3.size()), intermediate_cache3.zero_()
out_hidden_states[begin_chunk_idx:end_chunk_idx],
) dispatch_fused_moe_kernel(
qintermediate_cache2,
w2,
intermediate_cache3,
a2q_scale,
w2_scale,
w2_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
not apply_router_weight_on_input,
1,
config,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
B_bias=w2_bias,
)
ops.moe_sum(
intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states,
)
return out_hidden_states return out_hidden_states
...@@ -1994,9 +1967,6 @@ class TritonExperts(mk.FusedMoEExpertsModular): ...@@ -1994,9 +1967,6 @@ class TritonExperts(mk.FusedMoEExpertsModular):
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return not moe_parallel_config.use_fi_all2allv_kernels return not moe_parallel_config.use_fi_all2allv_kernels
def supports_chunking(self) -> bool:
return True
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return True return True
......
...@@ -609,9 +609,6 @@ class OAITritonExperts(BaseOAITritonExperts): ...@@ -609,9 +609,6 @@ class OAITritonExperts(BaseOAITritonExperts):
def activation_format() -> mk.FusedMoEActivationFormat: def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard return mk.FusedMoEActivationFormat.Standard
def supports_chunking(self) -> bool:
return True
def workspace_shapes( def workspace_shapes(
self, self,
M: int, M: int,
...@@ -696,9 +693,6 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts): ...@@ -696,9 +693,6 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
def activation_format() -> mk.FusedMoEActivationFormat: def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard return mk.FusedMoEActivationFormat.Standard
def supports_chunking(self) -> bool:
return True
def workspace_shapes( def workspace_shapes(
self, self,
M: int, M: int,
......
...@@ -9,8 +9,6 @@ from typing import final ...@@ -9,8 +9,6 @@ from typing import final
import torch import torch
import vllm.envs as envs
from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import ( from vllm.model_executor.layers.fused_moe.activation import (
MoEActivation, MoEActivation,
...@@ -24,14 +22,12 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -24,14 +22,12 @@ from vllm.model_executor.layers.fused_moe.config import (
) )
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, _resize_cache,
count_expert_num_tokens,
disable_inplace, disable_inplace,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.v1.worker.ubatching import ( from vllm.v1.worker.ubatching import (
dbo_enabled, dbo_enabled,
dbo_maybe_run_recv_hook, dbo_maybe_run_recv_hook,
...@@ -719,15 +715,6 @@ class FusedMoEExperts(ABC): ...@@ -719,15 +715,6 @@ class FusedMoEExperts(ABC):
def g2_alphas(self) -> torch.Tensor | None: def g2_alphas(self) -> torch.Tensor | None:
return self.quant_config.g2_alphas return self.quant_config.g2_alphas
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
@abstractmethod
def supports_chunking(self) -> bool:
"""
A flag indicating whether or not this class supports activation
chunking.
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
""" """
...@@ -742,11 +729,6 @@ class FusedMoEExperts(ABC): ...@@ -742,11 +729,6 @@ class FusedMoEExperts(ABC):
""" """
return False return False
def enable_chunking(self):
return (
envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking()
)
class FusedMoEExpertsModular(FusedMoEExperts): class FusedMoEExpertsModular(FusedMoEExperts):
""" """
...@@ -995,17 +977,6 @@ class FusedMoEExpertsMonolithic(FusedMoEExperts): ...@@ -995,17 +977,6 @@ class FusedMoEExpertsMonolithic(FusedMoEExperts):
raise NotImplementedError raise NotImplementedError
def _slice_scales(
scales: torch.Tensor | None, start: int, end: int
) -> torch.Tensor | None:
if scales is not None:
if scales.numel() == 1:
return scales
else:
return scales[start:end]
return None
################################################################################ ################################################################################
# Kernel # Kernel
################################################################################ ################################################################################
...@@ -1032,26 +1003,6 @@ class FusedMoEKernelModularImpl: ...@@ -1032,26 +1003,6 @@ class FusedMoEKernelModularImpl:
and moe_parallel_config.use_ep and moe_parallel_config.use_ep
) )
def _chunk_info(self, M: int) -> tuple[int, int]:
"""
Compute number of chunks and chunk size for given M.
If chunking is not supported, set the CHUNK_SIZE to M so we
get num_chunks == 1. Take max(M, 1) to avoid divide by zero.
If there are no tokens to process, the number of chunks will be zero.
"""
CHUNK_SIZE = max(
1,
(
M
if not self.fused_experts.enable_chunking()
else min(M, envs.VLLM_FUSED_MOE_CHUNK_SIZE)
),
)
num_chunks = cdiv(M, CHUNK_SIZE)
# If there are no tokens, then there should be no loop iterations.
assert M > 0 or num_chunks == 0
return num_chunks, CHUNK_SIZE
def _allocate_buffers( def _allocate_buffers(
self, self,
out_dtype: torch.dtype, out_dtype: torch.dtype,
...@@ -1076,40 +1027,8 @@ class FusedMoEKernelModularImpl: ...@@ -1076,40 +1027,8 @@ class FusedMoEKernelModularImpl:
""" """
assert M_full > 0 and M_chunk > 0 assert M_full > 0 and M_chunk > 0
num_chunks, _ = self._chunk_info(M_full)
workspace_dtype = self.fused_experts.workspace_dtype(out_dtype) workspace_dtype = self.fused_experts.workspace_dtype(out_dtype)
# Force worst-case allocation in profiling run for
# "mk.FusedMoEKernel.Standard" formats where this is only bounded
# by `VLLM_FUSED_MOE_CHUNK_SIZE` and may not be seen during profiling with
# DP+EP due to the random token routing.
is_profile_run = (
is_forward_context_available()
and get_forward_context().attn_metadata is None
)
if is_profile_run and self.fused_experts.enable_chunking() and self.is_dp_ep:
max_workspace_13, max_workspace_2, max_fused_out_shape = (
self.fused_experts.workspace_shapes(
envs.VLLM_FUSED_MOE_CHUNK_SIZE,
N,
K,
top_k,
global_num_experts,
local_num_experts,
# expert_tokens_meta help in allocating optimal/minimal
# amount of workspace. Mark it None, so we allocate for
# the worst-case scenario.
expert_tokens_meta=None,
activation=activation,
)
)
current_workspace_manager().get_simultaneous(
(max_workspace_13, workspace_dtype),
(max_workspace_2, workspace_dtype),
(max_fused_out_shape, out_dtype),
)
# Get intermediate workspace shapes based off the chunked M size. # Get intermediate workspace shapes based off the chunked M size.
workspace13_shape, workspace2_shape, _ = self.fused_experts.workspace_shapes( workspace13_shape, workspace2_shape, _ = self.fused_experts.workspace_shapes(
M_chunk, M_chunk,
...@@ -1136,79 +1055,16 @@ class FusedMoEKernelModularImpl: ...@@ -1136,79 +1055,16 @@ class FusedMoEKernelModularImpl:
# We can reuse the memory between cache1 and cache3 because by the # We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1. # time we need cache3, we're done with cache1.
# Construct the entire output that can then be processed in chunks. # Reuse workspace13 for the output since there is only one chunk.
# Reuse workspace13 for the output in the non-chunked case. max_shape_size = max(prod(workspace13_shape), prod(fused_out_shape))
# This will not always be the case for standard common_workspace, workspace2 = current_workspace_manager().get_simultaneous(
# format experts and with experts that have empty workspaces. ((max_shape_size,), workspace_dtype),
if num_chunks == 1: (workspace2_shape, workspace_dtype),
max_shape_size = max(prod(workspace13_shape), prod(fused_out_shape))
common_workspace, workspace2 = current_workspace_manager().get_simultaneous(
((max_shape_size,), workspace_dtype),
(workspace2_shape, workspace_dtype),
)
workspace13 = _resize_cache(common_workspace, workspace13_shape)
fused_out = _resize_cache(common_workspace, fused_out_shape)
else:
workspace13, workspace2, fused_out = (
current_workspace_manager().get_simultaneous(
(workspace13_shape, workspace_dtype),
(workspace2_shape, workspace_dtype),
(fused_out_shape, out_dtype),
)
)
return workspace13, workspace2, fused_out
@staticmethod
def _slice_output_tensor(
fused_out: torch.Tensor,
chunk_idx: int,
num_chunks: int,
CHUNK_SIZE: int,
M: int,
) -> torch.Tensor:
if num_chunks == 1:
return fused_out
assert fused_out.size(0) % M == 0, f"fused_out shape {fused_out.shape} vs M {M}"
factor = fused_out.size(0) // M
out_chunk_size = CHUNK_SIZE * factor
s = chunk_idx * out_chunk_size
e = min(s + out_chunk_size, fused_out.size(0))
return fused_out[s:e]
@staticmethod
def _slice_expert_tokens_metadata(
num_chunks: int,
full_expert_tokens_meta: ExpertTokensMetadata | None,
chunk_topk_ids: torch.Tensor,
local_num_experts: int,
expert_map: torch.Tensor | None,
) -> ExpertTokensMetadata | None:
if num_chunks == 1 or full_expert_tokens_meta is None:
return full_expert_tokens_meta
# The existing expert_num_tokens is for the entire a1q
# input. Chunking forces recomputation of the number
# of tokens assigned to each expert.
c_expert_num_tokens = count_expert_num_tokens(
chunk_topk_ids, local_num_experts, expert_map
)
c_expert_num_tokens_cpu = None
need_expert_num_tokens_cpu = (
full_expert_tokens_meta.expert_num_tokens_cpu is not None
) )
if need_expert_num_tokens_cpu: workspace13 = _resize_cache(common_workspace, workspace13_shape)
# This is blocking as some implementations need the count fused_out = _resize_cache(common_workspace, fused_out_shape)
# on the CPU to determine appropriate input/out fused-moe
# buffers
c_expert_num_tokens_cpu = c_expert_num_tokens.to("cpu", non_blocking=False)
return ExpertTokensMetadata( return workspace13, workspace2, fused_out
expert_num_tokens=c_expert_num_tokens,
expert_num_tokens_cpu=c_expert_num_tokens_cpu,
)
def _prepare( def _prepare(
self, self,
...@@ -1318,18 +1174,6 @@ class FusedMoEKernelModularImpl: ...@@ -1318,18 +1174,6 @@ class FusedMoEKernelModularImpl:
a1q, w1, w2, topk_ids a1q, w1, w2, topk_ids
) )
num_chunks, CHUNK_SIZE = self._chunk_info(M_full)
def input_chunk_range(chunk_idx: int) -> tuple[int, int]:
if num_chunks == 1:
# Use a1q.size(0) here since batched format does not
# keep M in the first dimension.
return 0, a1q.size(0)
else:
s = chunk_idx * CHUNK_SIZE
e = min(s + CHUNK_SIZE, M_full)
return s, e
# This happens when none of the tokens from the all2all reach this # This happens when none of the tokens from the all2all reach this
# EP rank. Also, note that this is only relevant for CUDAGraph # EP rank. Also, note that this is only relevant for CUDAGraph
# incompatible all2all kernels like the DeepEP high-throughput # incompatible all2all kernels like the DeepEP high-throughput
...@@ -1337,58 +1181,39 @@ class FusedMoEKernelModularImpl: ...@@ -1337,58 +1181,39 @@ class FusedMoEKernelModularImpl:
# low-latency kernels are always batched and can never run into # low-latency kernels are always batched and can never run into
# the tensor.numel() == 0 case. # the tensor.numel() == 0 case.
if M_full == 0: if M_full == 0:
assert num_chunks == 0 return torch.empty_like(a1q, dtype=in_dtype)
workspace13 = None
workspace2 = None
fused_out = torch.empty_like(a1q, dtype=in_dtype)
else:
assert num_chunks > 0
workspace13, workspace2, fused_out = self._allocate_buffers(
in_dtype,
a1q.device,
CHUNK_SIZE,
M_full,
N,
K,
top_k,
global_num_experts,
local_num_experts,
expert_tokens_meta,
activation,
)
for chunk_idx in range(num_chunks): workspace13, workspace2, fused_out = self._allocate_buffers(
s, e = input_chunk_range(chunk_idx) in_dtype,
a1q.device,
c_expert_tokens_meta = self._slice_expert_tokens_metadata( M_full,
num_chunks, M_full,
expert_tokens_meta, N,
topk_ids[s:e], K,
local_num_experts, top_k,
expert_map, global_num_experts,
) local_num_experts,
expert_tokens_meta,
c_fused_out = self._slice_output_tensor( activation,
fused_out, chunk_idx, num_chunks, CHUNK_SIZE, M_full )
)
self.fused_experts.apply( self.fused_experts.apply(
output=c_fused_out, output=fused_out,
hidden_states=a1q[s:e], hidden_states=a1q,
w1=w1, w1=w1,
w2=w2, w2=w2,
topk_weights=topk_weights[s:e], topk_weights=topk_weights,
topk_ids=topk_ids[s:e], topk_ids=topk_ids,
activation=activation, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
a1q_scale=_slice_scales(a1q_scale, s, e), a1q_scale=a1q_scale,
a2_scale=_slice_scales(self.fused_experts.a2_scale, s, e), a2_scale=self.fused_experts.a2_scale,
workspace13=workspace13, workspace13=workspace13,
workspace2=workspace2, workspace2=workspace2,
expert_tokens_meta=c_expert_tokens_meta, expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
return fused_out return fused_out
......
...@@ -337,9 +337,6 @@ class AiterExperts(mk.FusedMoEExpertsModular): ...@@ -337,9 +337,6 @@ class AiterExperts(mk.FusedMoEExpertsModular):
def supports_expert_map(self): def supports_expert_map(self):
return True return True
def supports_chunking(self):
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP() return TopKWeightAndReduceNoOP()
......
...@@ -83,9 +83,6 @@ class TrtLlmGenExperts(mk.FusedMoEExpertsModular): ...@@ -83,9 +83,6 @@ class TrtLlmGenExperts(mk.FusedMoEExpertsModular):
"This method should not be called." "This method should not be called."
) )
def supports_chunking(self) -> bool:
return True
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return True return True
......
...@@ -79,9 +79,6 @@ class XPUExperts(mk.FusedMoEExpertsModular): ...@@ -79,9 +79,6 @@ class XPUExperts(mk.FusedMoEExpertsModular):
] ]
return (weight_key, activation_key) in SUPPORTED_W_A return (weight_key, activation_key) in SUPPORTED_W_A
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return True return True
......
...@@ -244,8 +244,7 @@ def _get_grouped_gemm_params( ...@@ -244,8 +244,7 @@ def _get_grouped_gemm_params(
device = w1.device device = w1.device
# Assumes all ranks have the same max_num_batched_tokens # Assumes all ranks have the same max_num_batched_tokens
max_tokens_across_dp = get_dp_group().world_size * max_tokens max_tokens = get_dp_group().world_size * max_tokens
max_tokens = min(max_tokens_across_dp, envs.VLLM_FUSED_MOE_CHUNK_SIZE)
# This is the maximum GroupedGemm M size that we expect to run # This is the maximum GroupedGemm M size that we expect to run
# the grouped_gemm with. # the grouped_gemm with.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment