Unverified Commit df44df01 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Feature] Shared Experts Overlap with FI deepgemm swap kernel, 2.2% throughput...


[Feature] Shared Experts Overlap with FI deepgemm swap kernel, 2.2% throughput improvement and 3.6% TTFT improvement (#28879)
Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 87cbbdff
...@@ -50,6 +50,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -50,6 +50,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
prepare_finalize, prepare_finalize,
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer), old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
shared_experts, shared_experts,
getattr(moe_layer, "shared_experts_stream", None),
), ),
) )
......
...@@ -850,6 +850,45 @@ class FusedMoE(CustomOp): ...@@ -850,6 +850,45 @@ class FusedMoE(CustomOp):
dp_size=get_dp_group().world_size, dp_size=get_dp_group().world_size,
) )
def _maybe_setup_shared_experts_stream(
self,
hidden_states: torch.Tensor,
has_separate_shared_experts: bool,
use_chunked_impl: bool,
) -> tuple[bool, torch.Tensor | None]:
use_shared_experts_stream = (
has_separate_shared_experts
and not use_chunked_impl
and self.shared_experts_stream is not None
and (
hidden_states.shape[0]
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
)
)
hidden_states_clone: torch.Tensor | None = None
if use_shared_experts_stream:
assert self.shared_experts_stream is not None
# Clone BEFORE switching streams to avoid race condition
# where routed_expert kernel may mutate hidden_states.
hidden_states_clone = hidden_states.clone()
# Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# NOTE: We dont need shared_output.record_stream(current_stream())
# because we synch the streams before using shared_output.
hidden_states_clone.record_stream(self.shared_experts_stream)
# Mark sync start point for the separate shared experts
# stream here since we want to run in parallel with the
# router/gate (next op below)
assert self.shared_experts_stream is not None
self.shared_experts_stream.wait_stream(current_stream())
return use_shared_experts_stream, hidden_states_clone
def _load_per_tensor_weight_scale( def _load_per_tensor_weight_scale(
self, self,
shard_id: str, shard_id: str,
...@@ -1819,36 +1858,12 @@ class FusedMoE(CustomOp): ...@@ -1819,36 +1858,12 @@ class FusedMoE(CustomOp):
use_chunked_impl = self.use_dp_chunking use_chunked_impl = self.use_dp_chunking
use_shared_experts_stream = ( use_shared_experts_stream, hidden_states_clone = (
has_separate_shared_experts self._maybe_setup_shared_experts_stream(
and not use_chunked_impl hidden_states, has_separate_shared_experts, use_chunked_impl
and self.shared_experts_stream is not None
and (
hidden_states.shape[0]
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
) )
) )
if use_shared_experts_stream:
assert self.shared_experts_stream is not None
# Clone BEFORE switching streams to avoid race condition
# where routed_expert kernel may mutate hidden_states.
hidden_states_clone = hidden_states.clone()
# Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# NOTE: We dont need shared_output.record_stream(current_stream())
# because we synch the streams before using shared_output.
hidden_states_clone.record_stream(self.shared_experts_stream)
# Mark sync start point for the separate shared experts
# stream here since we want to run in parallel with the
# router/gate (next op below)
assert self.shared_experts_stream is not None
self.shared_experts_stream.wait_stream(current_stream())
# If router/gate provided, then apply it here. # If router/gate provided, then apply it here.
# (Note: This code runs only when "overlapped mode" is on to allow # (Note: This code runs only when "overlapped mode" is on to allow
# parallel execution of shared experts with the FusedMoE via # parallel execution of shared experts with the FusedMoE via
......
...@@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.utils import ( ...@@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.utils import (
count_expert_num_tokens, count_expert_num_tokens,
disable_inplace, disable_inplace,
) )
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.worker.ubatching import ( from vllm.v1.worker.ubatching import (
dbo_current_ubatch_id, dbo_current_ubatch_id,
...@@ -709,11 +710,13 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -709,11 +710,13 @@ class FusedMoEModularKernel(torch.nn.Module):
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: FusedMoEPermuteExpertsUnpermute, fused_experts: FusedMoEPermuteExpertsUnpermute,
shared_experts: torch.nn.Module | None = None, shared_experts: torch.nn.Module | None = None,
shared_experts_stream: torch.cuda.Stream | None = None,
): ):
super().__init__() super().__init__()
self.prepare_finalize = prepare_finalize self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts self.fused_experts = fused_experts
self.shared_experts = shared_experts self.shared_experts = shared_experts
self.shared_experts_stream = shared_experts_stream
self._post_init_setup() self._post_init_setup()
assert ( assert (
...@@ -890,6 +893,34 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -890,6 +893,34 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_num_tokens_cpu=c_expert_num_tokens_cpu, expert_num_tokens_cpu=c_expert_num_tokens_cpu,
) )
def _maybe_setup_shared_experts_stream(
self, hidden_states: torch.Tensor
) -> tuple[bool, torch.Tensor | None]:
# decide whether to run shared experts on a separate CUDA stream to
# overlap with the main fused MoE kernel.
use_shared_experts_stream = (
self.shared_experts is not None
and self.shared_experts_stream is not None
and hidden_states.is_cuda
and (
hidden_states.shape[0]
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
)
)
hidden_states_clone: torch.Tensor | None = None
if use_shared_experts_stream and self.shared_experts_stream is not None:
# TODO: Optimize this (complicated)
# Note: this clone adds overhead but is required
# for correctness with multiple CUDA streams and CUDA graph capture.
hidden_states_clone = hidden_states.clone()
# record that the clone will be used by the separate stream so its
# lifetime is correctly tracked.
hidden_states_clone.record_stream(self.shared_experts_stream)
self.shared_experts_stream.wait_stream(torch.cuda.current_stream())
return use_shared_experts_stream, hidden_states_clone
def _prepare( def _prepare(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -1077,12 +1108,30 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1077,12 +1108,30 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
hidden_states_clone: torch.Tensor | None = None,
use_shared_experts_stream: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
""" """
The _finalize method is a wrapper around self.prepare_finalize.finalize The _finalize method is a wrapper around self.prepare_finalize.finalize
that handles DBO, async and shared expert overlap. that handles DBO, async and shared expert overlap.
""" """
shared_output: torch.Tensor | None = None
def maybe_run_shared_experts() -> torch.Tensor | None:
if self.shared_experts is None:
return None
if (
not use_shared_experts_stream
or self.shared_experts_stream is not None
and (not hidden_states.is_cuda or not torch.cuda.is_available())
):
# fall back to running on the current stream
return self.shared_experts(hidden_states)
assert hidden_states_clone is not None
# launch shared experts on the dedicated stream.
with torch.cuda.stream(self.shared_experts_stream):
return self.shared_experts(hidden_states_clone)
if not self.prepare_finalize.supports_async(): if not self.prepare_finalize.supports_async():
assert not dbo_enabled() assert not dbo_enabled()
...@@ -1095,8 +1144,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1095,8 +1144,7 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input, apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(), self.fused_experts.finalize_weight_and_reduce_impl(),
) )
if self.shared_experts is not None: shared_output = maybe_run_shared_experts()
shared_output = self.shared_experts(hidden_states)
else: else:
finalize_ret = self.prepare_finalize.finalize_async( finalize_ret = self.prepare_finalize.finalize_async(
output, output,
...@@ -1107,8 +1155,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1107,8 +1155,7 @@ class FusedMoEModularKernel(torch.nn.Module):
self.fused_experts.finalize_weight_and_reduce_impl(), self.fused_experts.finalize_weight_and_reduce_impl(),
) )
if self.shared_experts is not None: shared_output = maybe_run_shared_experts()
shared_output = self.shared_experts(hidden_states)
# TODO(lucas): refactor this in the alternative schedules followup # TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just # currently unpack if we have hook + receiver pair or just
...@@ -1131,12 +1178,28 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1131,12 +1178,28 @@ class FusedMoEModularKernel(torch.nn.Module):
receiver() receiver()
self._wait_for_shared_experts_stream(hidden_states, use_shared_experts_stream)
if self.shared_experts is None: if self.shared_experts is None:
return output return output
else: else:
assert shared_output is not None assert shared_output is not None
return shared_output, output return shared_output, output
def _wait_for_shared_experts_stream(
self, hidden_states: torch.Tensor, use_shared_experts_stream: bool
) -> None:
# ensure that any work enqueued on the shared_experts_stream is
# completed before the shared_output tensor is consumed
if (
self.shared_experts is not None
and use_shared_experts_stream
and self.shared_experts_stream is not None
and hidden_states.is_cuda
and current_platform.is_cuda()
):
torch.cuda.current_stream().wait_stream(self.shared_experts_stream)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -1183,6 +1246,10 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1183,6 +1246,10 @@ class FusedMoEModularKernel(torch.nn.Module):
else: else:
output = torch.zeros_like(hidden_states) output = torch.zeros_like(hidden_states)
use_shared_experts_stream, hidden_states_clone = (
self._maybe_setup_shared_experts_stream(hidden_states)
)
local_num_experts = w1.size(0) local_num_experts = w1.size(0)
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = local_num_experts global_num_experts = local_num_experts
...@@ -1219,4 +1286,6 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1219,4 +1286,6 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights, topk_weights,
topk_ids, topk_ids,
apply_router_weight_on_input, apply_router_weight_on_input,
hidden_states_clone=hidden_states_clone,
use_shared_experts_stream=use_shared_experts_stream,
) )
...@@ -45,7 +45,8 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -45,7 +45,8 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
assert topk == 1, ( assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1" "apply_router_weight_on_input is only implemented for topk=1"
) )
a1.mul_(topk_weights.to(a1.dtype)) # Note: do not use inplace for shared experts overlap
a1 = a1 * topk_weights.to(a1.dtype)
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment