Unverified Commit dc2979c5 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernels] Overlap shared experts with combine instead of dispatch (#24254)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent 027d37df
...@@ -240,7 +240,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -240,7 +240,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
quant_config) quant_config)
return receiver() return receiver()
def finalize( def _finalize(
self, self,
output: torch.Tensor, output: torch.Tensor,
fused_expert_output: torch.Tensor, fused_expert_output: torch.Tensor,
...@@ -248,7 +248,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -248,7 +248,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce, weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None: do_async: bool,
) -> Optional[Callable]:
assert self.handle is not None assert self.handle is not None
...@@ -271,7 +272,46 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -271,7 +272,46 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights=None, topk_weights=None,
config=self._get_combine_config(), config=self._get_combine_config(),
previous_event=None, previous_event=None,
async_finish=False, async_finish=do_async,
allocate_on_comm_stream=False) allocate_on_comm_stream=False)
if do_async:
def _receiver():
event.current_stream_wait()
# Respect inplace outputs. # Respect inplace outputs.
output.copy_(combined_x, non_blocking=True) output.copy_(combined_x, non_blocking=True)
return lambda: _receiver()
else:
# Respect inplace outputs.
output.copy_(combined_x, non_blocking=True)
return None
def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> Callable:
receiver = self._finalize(output, fused_expert_output, topk_weights,
topk_ids, apply_router_weight_on_input,
weight_and_reduce_impl, True)
assert receiver is not None
return receiver
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
self._finalize(output, fused_expert_output, topk_weights, topk_ids,
apply_router_weight_on_input, weight_and_reduce_impl,
False)
...@@ -12,8 +12,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( ...@@ -12,8 +12,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input, normalize_batched_scales_shape) moe_kernel_quantize_input, normalize_batched_scales_shape)
from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled, from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
dbo_maybe_run_recv_hook, dbo_maybe_run_recv_hook)
dbo_register_recv_hook, dbo_yield)
# DeepEP kernels quantize dispatch inputs in 128 element chunks. # DeepEP kernels quantize dispatch inputs in 128 element chunks.
DEEPEP_QUANT_BLOCK_SIZE = 128 DEEPEP_QUANT_BLOCK_SIZE = 128
...@@ -198,7 +197,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -198,7 +197,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
hook() hook()
return receiver() return receiver()
def finalize( def _finalize(
self, self,
output: torch.Tensor, output: torch.Tensor,
fused_expert_output: torch.Tensor, fused_expert_output: torch.Tensor,
...@@ -206,13 +205,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -206,13 +205,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce, weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None: do_async: bool,
) -> Optional[Callable]:
assert isinstance( assert isinstance(
weight_and_reduce_impl, TopKWeightAndReduceDelegate weight_and_reduce_impl, TopKWeightAndReduceDelegate
), ("Weight application and reduction happens in the combine kernel.") ), ("Weight application and reduction happens in the combine kernel.")
a2a_idx = dbo_current_ubatch_id() a2a_idx = dbo_current_ubatch_id()
do_recv_hook = dbo_enabled() do_recv_hook = dbo_enabled() or do_async
handle = self.handles[a2a_idx] handle = self.handles[a2a_idx]
assert handle is not None assert handle is not None
...@@ -232,6 +232,45 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -232,6 +232,45 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
zero_copy=False, zero_copy=False,
return_recv_hook=do_recv_hook, return_recv_hook=do_recv_hook,
out=output) out=output)
if recv_hook is not None:
dbo_register_recv_hook(recv_hook) return recv_hook
dbo_yield()
def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> Callable:
recv_hook = self._finalize(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
weight_and_reduce_impl,
do_async=True,
)
assert recv_hook is not None
return recv_hook
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
self._finalize(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
weight_and_reduce_impl,
do_async=False,
)
...@@ -209,7 +209,8 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -209,7 +209,8 @@ class FusedMoEPrepareAndFinalize(ABC):
def supports_async(self) -> bool: def supports_async(self) -> bool:
""" """
Indicates whether or not this class implements prepare_async. Indicates whether or not this class implements prepare_async and
finalize_async.
""" """
return False return False
...@@ -275,6 +276,42 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -275,6 +276,42 @@ class FusedMoEPrepareAndFinalize(ABC):
""" """
raise NotImplementedError raise NotImplementedError
def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: TopKWeightAndReduce,
) -> Callable:
"""
Perform any combine plus apply weights and perform a reduction on the
fused experts output but do not wait for results from other workers.
- output: The output tensor, written in place. Must be (M, K) shape.
- fused_expert_output: The unweighted, unreduced output of the fused
experts, it will have (M, topk, K) shape.
- topk_weights: The weights to be applied to the fused_experts_output.
- topk_ids: The topk_ids.
- apply_router_weight_on_input: When False, apply the weights to
fused_expert_output.
- weight_and_reduce_impl: An optional TopKWeightAndReduce
implementation.
Returns a callback that when invoked waits for results from other
workers and has the same return signature as `finalize`, e.g.
receiver = obj.finalize_async(output, ...)
... output not valid yet ...
receiver()
... output valid here ...
is equivalent to:
obj.finalize(output, ...)
"""
raise NotImplementedError
@property @property
@abstractmethod @abstractmethod
def activation_format(self) -> FusedMoEActivationFormat: def activation_format(self) -> FusedMoEActivationFormat:
...@@ -814,23 +851,20 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -814,23 +851,20 @@ class FusedMoEModularKernel(torch.nn.Module):
""" """
a1 = hidden_states a1 = hidden_states
output = a1 if inplace else torch.zeros_like(a1) if inplace and self.shared_experts is None:
output = a1
else:
output = torch.zeros_like(a1)
local_num_experts = w1.size(0) local_num_experts = w1.size(0)
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = local_num_experts global_num_experts = local_num_experts
shared_output: torch.Tensor
if not self.prepare_finalize.supports_async(): if not self.prepare_finalize.supports_async():
# We shouldn't be running an a2a kernel that doesn't # We shouldn't be running an a2a kernel that doesn't
# support async prepare/finalize # support async prepare/finalize
assert not dbo_enabled() assert not dbo_enabled()
# Run shared experts serially with dispatch.
if self.shared_experts is not None:
shared_output = self.shared_experts(a1)
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare( _expert_topk_weights) = self.prepare_finalize.prepare(
a1, a1,
...@@ -854,9 +888,6 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -854,9 +888,6 @@ class FusedMoEModularKernel(torch.nn.Module):
self.fused_experts.quant_config, self.fused_experts.quant_config,
) )
if self.shared_experts is not None:
shared_output = self.shared_experts(a1)
# If DBO is being used, register the hook with the ubatch context # If DBO is being used, register the hook with the ubatch context
# and call it in dbo_maybe_run_recv_hook instead of passing it to # and call it in dbo_maybe_run_recv_hook instead of passing it to
# the receiver. # the receiver.
...@@ -900,6 +931,11 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -900,6 +931,11 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
shared_output: Optional[torch.Tensor] = None
if not self.prepare_finalize.supports_async():
assert not dbo_enabled()
self.prepare_finalize.finalize( self.prepare_finalize.finalize(
output, output,
fused_out, fused_out,
...@@ -908,8 +944,29 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -908,8 +944,29 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input, apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(), self.fused_experts.finalize_weight_and_reduce_impl(),
) )
if self.shared_experts is not None:
shared_output = self.shared_experts(a1)
else:
recv_hook = self.prepare_finalize.finalize_async(
output,
fused_out,
topk_weights,
topk_ids,
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(),
)
if self.shared_experts is not None:
shared_output = self.shared_experts(a1)
assert recv_hook is not None
dbo_register_recv_hook(recv_hook)
dbo_yield()
if not dbo_enabled():
recv_hook()
if self.shared_experts is None: if self.shared_experts is None:
return output return output
else: else:
assert shared_output is not None
return shared_output, output return shared_output, output
...@@ -272,7 +272,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -272,7 +272,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
hook() hook()
return receiver() return receiver()
def finalize( def finalize_async(
self, self,
output: torch.Tensor, output: torch.Tensor,
fused_expert_output: torch.Tensor, fused_expert_output: torch.Tensor,
...@@ -280,7 +280,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -280,7 +280,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce, weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None: ) -> Callable:
assert isinstance( assert isinstance(
weight_and_reduce_impl, TopKWeightAndReduceDelegate weight_and_reduce_impl, TopKWeightAndReduceDelegate
), ("Weight application and reduction happens in the combine kernel.") ), ("Weight application and reduction happens in the combine kernel.")
...@@ -303,8 +303,39 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -303,8 +303,39 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights) topk_weights = torch.ones_like(topk_weights)
topk_ids_u32 = topk_ids.view(dtype=torch.uint32)
self.a2a.combine(out_tokens=output, self.a2a.combine(out_tokens=output,
indices=topk_ids.view(dtype=torch.uint32), indices=topk_ids_u32,
weights=topk_weights,
expert_y=fused_expert_output,
bound_m=bound_m,
do_send=True,
do_recv=False)
return lambda: self.a2a.combine(out_tokens=output,
indices=topk_ids_u32,
weights=topk_weights, weights=topk_weights,
expert_y=fused_expert_output, expert_y=fused_expert_output,
bound_m=bound_m) bound_m=bound_m,
do_send=False,
do_recv=True)
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
receiver = self.finalize_async(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
weight_and_reduce_impl,
)
receiver()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment