Commit 1e2ac05c authored by chenhw5's avatar chenhw5
Browse files

add DeepGEMM SBO for DeepEP LL

parent 18459e7a
......@@ -271,6 +271,8 @@ if TYPE_CHECKING:
VLLM_HAS_CONTEXT_DEFAULT: bool = False
VLLM_USE_NN: bool = False
VLLM_ENABLE_TBO: bool = False
# Whether to use single batch overlapping (SBO) for MoE with DeepEP low-latency.
VLLM_EP_USE_SBO: bool = False
VLLM_ENABLE_MOE_FUSED_GATE: bool = False
VLLM_USE_FLASH_ATTN_PA: bool = False
VLLM_USE_APEX_RN: bool = False
......@@ -1229,6 +1231,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DEEPEPLL_NVFP4_DISPATCH": lambda: bool(
int(os.getenv("VLLM_DEEPEPLL_NVFP4_DISPATCH", "0"))
),
# Whether to use single batch overlapping optimization
"VLLM_EP_USE_SBO": lambda: bool(int(os.getenv("VLLM_EP_USE_SBO", "0"))),
# Whether to turn on the outlines cache for V0
# This cache is unbounded and on disk, so it's not safe to use in
# an environment with potentially malicious users.
"VLLM_V0_USE_OUTLINES_CACHE": lambda: os.environ.get(
"VLLM_V0_USE_OUTLINES_CACHE", "0"
)
== "1",
# Whether to turn on the outlines cache for V1
# This cache is unbounded and on disk, so it's not safe to use in
# an environment with potentially malicious users.
......
......@@ -34,13 +34,13 @@ from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.import_utils import has_deep_gemm
from lightop import fuse_silu_mul_quant_ep
from lightop import fuse_silu_mul_quant_ep, fuse_silu_mul_fp8_quant_ep
if has_deep_gemm():
from deep_gemm import m_grouped_w8a8_gemm_nt_masked
from deep_gemm import m_grouped_w8a8_gemm_nt_masked, m_grouped_fp8_gemm_nt_masked
else:
from lightop import m_grouped_w8a8_gemm_nt_masked
from lightop import m_grouped_w8a8_gemm_nt_masked, m_grouped_fp8_gemm_nt_masked
from typing import Any
logger = init_logger(__name__)
......@@ -415,6 +415,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
use_nn_moe: bool | None = False,
w2_gemm_overlap_args: Any = None,
meta_overlap_args: dict[str, Any] | None = None,
):
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens
......@@ -443,7 +445,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
)
if self.quant_config.use_fp8_w8a16 or self.quant_config.use_fp8_w8a8:
fp8_m_grouped_gemm_nt_masked(
m_grouped_fp8_gemm_nt_masked(
(a1q, a1q_scale),
(w1, self.w1_scale),
workspace1,
......@@ -451,20 +453,40 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expected_m,
)
quant_scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
a2q, a2q_scale = persistent_masked_m_silu_mul_quant(
workspace1,
expert_num_tokens,
quant_scale_fmt=quant_scale_fmt,
# ---- SiLU + quant (对应 SGLang 的 fuse_silu_mul_fp8_quant_ep) ----
# workspace1: [E, max_num_tokens, N],在每个 expert 内做 silu*up 并量化成 fp8
q_a2_all, q_a2_scale = fuse_silu_mul_fp8_quant_ep(
input=workspace1,
fp8type=0, # 和你们 deepgemm 约定一致
tokens_per_expert=expert_num_tokens,
)
fp8_m_grouped_gemm_nt_masked(
(a2q, a2q_scale),
# When using SBO, we record event here to indicate
# that the signal tensor and input to deepep ll combine
# are ready
enable_overlap = w2_gemm_overlap_args is not None
signal = w2_gemm_overlap_args.signal if enable_overlap else None
if enable_overlap:
w2_gemm_overlap_args.start_event.record()
block_m, threshold = m_grouped_fp8_gemm_nt_masked(
(q_a2_all, q_a2_scale),
(w2, self.w2_scale),
output,
expert_num_tokens,
expected_m,
enable_overlap,
signal,
)
# return meta_overlap_args to DeepEP combine.
if meta_overlap_args is not None:
if block_m is not None:
meta_overlap_args["block_m"] = block_m
if threshold is not None:
meta_overlap_args["threshold"] = threshold
elif self.quant_config.use_int8_w8a8:
m_grouped_w8a8_gemm_nt_masked((a1q, a1q_scale),
(w1, self.w1_scale),
......
......@@ -254,6 +254,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
local_num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
......@@ -299,6 +300,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
local_num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
......@@ -308,6 +310,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights,
topk_ids,
num_experts,
local_num_experts,
expert_map,
apply_router_weight_on_input,
quant_config,
......
......@@ -23,6 +23,11 @@ from vllm.v1.worker.ubatching import (
dbo_maybe_run_recv_hook,
)
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Any
alt_stream = torch.cuda.Stream()
logger = init_logger(__name__)
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
......@@ -31,6 +36,42 @@ DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE]
logger = init_logger(__name__)
@dataclass
class CombineOverlapArgs:
# Whether to use overlap for w2 gemm and deepep ll combine
overlap: bool
# we launch deepep ll combine on this stream, which
# is different from the default compute stream
stream: torch.cuda.Stream
# We record this wait even in the compute stream between
# silu_mul_fp4_quantize and w2 gemm.
# And we wait for this even before deepep ll combine on the
# combine stream to ensure signal tensors have been allocated
wait_event: torch.cuda.Event
# Number of CU used for combine kernel, currently hardcoded to be 32
num_sms: int
# The signal tensor is shared by the w2 gemm and deepep ll combine.
# w2 gemm atomic_add to the tensor to signal deepep combine can start
# send data
signal: torch.Tensor | None = None
#
block_m: int = 64
# Set to the number of CU used by W2 gemm, which is a persistent kernel
# So when all CU has completed the computation for an expert,
# combine kernel can start to send data for this expert
threshold: int = 32
@dataclass
class W2GemmOverlapArgs:
# Number of CU used by W2 gemm
num_sms: int
# Same signal tensor mentioned above
signal: torch.Tensor
# Same as the wait_even in CombineOverlapArgs
start_event: torch.cuda.Event
def dequant_fp8(
expert_x_fp8: torch.Tensor, expert_x_scales: torch.Tensor
......@@ -122,6 +163,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# time. This setting is handled by post_init_setup.
self.use_ue8m0_dispatch = False
# SBO: DeepEP LL overlap 配置
self.combine_overlap_args: CombineOverlapArgs | None = None
self.meta_overlap_args: dict[str, Any] | None = None
self.packed_recv_count: torch.Tensor | None = None
def post_init_setup(self, fused_experts: mk.FusedMoEPermuteExpertsUnpermute):
if not fused_experts.supports_packed_ue8m0_act_scales():
# Early exit.
......@@ -247,6 +294,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
local_num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
......@@ -317,6 +365,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
)
self.handles[a2a_idx] = handle
# We need to pass w2_gemm_overlap_args to moe implementation,
# so return it as an output paramter
w2_gemm_overlap_args = self._create_sbo_args(local_num_experts, a1.device)
return (
hook,
lambda: self._receiver(
......@@ -326,6 +378,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1.dtype,
quant_config,
),
w2_gemm_overlap_args,
)
def _receiver(
......@@ -341,7 +394,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None
)
# SBO: save last expert_num_tokens for packed_recv_count,needed by deepep ll combine when use SBO.
self.packed_recv_count = expert_num_tokens
return expert_x, expert_x_scale, expert_tokens_meta, None, None
def prepare(
......@@ -350,15 +404,17 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
local_num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
hook, receiver = self.prepare_async(
hook, receiver, _ = self.prepare_async(
a1,
topk_weights,
topk_ids,
num_experts,
local_num_experts,
expert_map,
apply_router_weight_on_input,
quant_config,
......@@ -393,6 +449,28 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
combine_topk_ids = self._map_global_to_physical_ids(topk_ids)
# TODO (varun) : Enable zero copy mode
dbo_maybe_run_recv_hook()
ctx = nullcontext()
if self.combine_overlap_args is not None:
# For SBO, we need to wait for compute stream
# to have completed signal tensor allocation
self.combine_overlap_args.stream.wait_event(
self.combine_overlap_args.wait_event
)
# And we launch ll combine phase 1 in a separate stream
# for overlaping
ctx = torch.cuda.stream(self.combine_overlap_args.stream)
overlap_args_dict = dict(
overlap=self.combine_overlap_args.overlap,
packed_recv_count=self.packed_recv_count,
comp_signal=self.combine_overlap_args.signal,
block_m=self.meta_overlap_args["block_m"],
threshold=self.meta_overlap_args["threshold"],
num_sms=self.combine_overlap_args.num_sms,
)
else:
overlap_args_dict = {}
with ctx:
_, _, recv_hook = self.buffer.low_latency_combine(
fused_expert_output,
combine_topk_ids,
......@@ -402,8 +480,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
zero_copy=False,
return_recv_hook=do_recv_hook,
out=output,
**overlap_args_dict,
)
if self.combine_overlap_args is not None:
return recv_hook, lambda: self._sbo_wait_stream()
else:
return recv_hook, lambda: None
def finalize_async(
......@@ -443,3 +525,47 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
weight_and_reduce_impl,
do_async=False,
)
def _create_sbo_args(
self, local_num_experts: int, device: torch.device
) -> W2GemmOverlapArgs | None: # None when SBO is not enabled
w2_gemm_overlap_args = None
self.combine_overlap_args = None
if not envs.VLLM_EP_USE_SBO:
self.meta_overlap_args = None
return None
else:# SBO enabled
self.meta_overlap_args = {} # empty every time, avoid use history args.
total_num_sms = torch.cuda.get_device_properties(
device=device
).multi_processor_count
communicate_num_sms = 32
compute_num_sms = total_num_sms - communicate_num_sms
combine_wait_event = torch.cuda.Event()
combine_overlap_args = CombineOverlapArgs(
num_sms=communicate_num_sms,
stream=alt_stream,
wait_event=combine_wait_event,
)
combine_signal = torch.zeros(
local_num_experts, dtype=torch.uint32, device=device
)
w2_gemm_overlap_args = W2GemmOverlapArgs(
signal=combine_signal,
start_event=combine_wait_event,
num_sms=compute_num_sms,
)
combine_overlap_args.overlap = True
combine_overlap_args.signal = combine_signal
self.combine_overlap_args = combine_overlap_args
return w2_gemm_overlap_args
def _sbo_wait_stream(self) -> None:
# When SBO enabled, ll combine phase 2 is still launched
# on the main compute stream, but we need to wait for
# ll combine 1 to complete
if self.combine_overlap_args is not None:
torch.cuda.current_stream().wait_stream(self.combine_overlap_args.stream)
\ No newline at end of file
......@@ -96,6 +96,7 @@ class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFina
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
local_num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
......@@ -177,6 +178,7 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
local_num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
......
......@@ -530,6 +530,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
local_num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
......
......@@ -34,6 +34,9 @@ from vllm.v1.worker.ubatching import (
dbo_yield,
)
from vllm.v1.worker.workspace import current_workspace_manager
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
logger = init_logger(__name__)
......@@ -177,6 +180,7 @@ class FusedMoEPrepareAndFinalize(ABC):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
local_num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
......@@ -217,6 +221,7 @@ class FusedMoEPrepareAndFinalize(ABC):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
local_num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
......@@ -1059,6 +1064,7 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
global_num_experts: int,
local_num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
) -> tuple[
......@@ -1072,6 +1078,7 @@ class FusedMoEModularKernel(torch.nn.Module):
The _prepare method is a wrapper around self.prepare_finalize.prepare
that handles DBO and async.
"""
w2_gemm_overlap_args = None
if not self.prepare_finalize.supports_async():
# We shouldn't be running an a2a kernel that doesn't
# support async prepare/finalize
......@@ -1089,6 +1096,7 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights,
topk_ids,
global_num_experts,
local_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
......@@ -1101,6 +1109,7 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights,
topk_ids,
global_num_experts,
local_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
......@@ -1109,7 +1118,7 @@ class FusedMoEModularKernel(torch.nn.Module):
# TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just
# receiver (see finalize_async docstring)
hook, receiver = (
hook, receiver, w2_gemm_overlap_args = (
prepare_ret if isinstance(prepare_ret, tuple) else (None, prepare_ret)
)
......@@ -1137,7 +1146,7 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights if _expert_topk_weights is None else _expert_topk_weights
)
return a1q, a1q_scale, expert_tokens_meta, topk_ids, topk_weights
return a1q, a1q_scale, expert_tokens_meta, topk_ids, topk_weights, w2_gemm_overlap_args
def _fused_experts(
self,
......@@ -1155,6 +1164,7 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input: bool,
expert_tokens_meta: ExpertTokensMetadata | None,
use_nn_moe: bool | None = False,
w2_gemm_overlap_args = None,
) -> torch.Tensor:
_, M_full, N, K, top_k = self.fused_experts.moe_problem_size(
a1q, w1, w2, topk_ids
......@@ -1217,6 +1227,28 @@ class FusedMoEModularKernel(torch.nn.Module):
fused_out, chunk_idx, num_chunks, CHUNK_SIZE, M_full
)
if isinstance(self.fused_experts, BatchedDeepGemmExperts):
self.fused_experts.apply(
output=c_fused_out,
hidden_states=a1q[s:e],
w1=w1,
w2=w2,
topk_weights=topk_weights[s:e],
topk_ids=topk_ids[s:e],
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
a1q_scale=_slice_scales(a1q_scale, s, e),
a2_scale=_slice_scales(self.fused_experts.a2_scale, s, e),
workspace13=workspace13,
workspace2=workspace2,
expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
use_nn_moe=use_nn_moe,
w2_gemm_overlap_args=w2_gemm_overlap_args,
meta_overlap_args=self.prepare_finalize.meta_overlap_args,
)
else:
self.fused_experts.apply(
output=c_fused_out,
hidden_states=a1q[s:e],
......@@ -1365,11 +1397,12 @@ class FusedMoEModularKernel(torch.nn.Module):
if global_num_experts == -1:
global_num_experts = local_num_experts
a1q, a1q_scale, expert_tokens_meta, topk_ids, topk_weights = self._prepare(
a1q, a1q_scale, expert_tokens_meta, topk_ids, topk_weights, w2_gemm_overlap_args = self._prepare(
hidden_states,
topk_weights,
topk_ids,
global_num_experts,
local_num_experts,
expert_map,
apply_router_weight_on_input,
)
......@@ -1389,6 +1422,7 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input=apply_router_weight_on_input,
expert_tokens_meta=expert_tokens_meta,
use_nn_moe=use_nn_moe,
w2_gemm_overlap_args=w2_gemm_overlap_args,
)
return self._finalize(
......
......@@ -55,6 +55,7 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
local_num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
......
......@@ -103,6 +103,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
local_num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
......@@ -271,6 +272,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
local_num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
......@@ -280,6 +282,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights,
topk_ids,
num_experts,
local_num_experts,
expert_map,
apply_router_weight_on_input,
quant_config,
......
......@@ -39,6 +39,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
local_num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment