Commit 0d3ae2fc authored by yangql's avatar yangql
Browse files

up auto deepep

parent 94c4ca4d
...@@ -273,3 +273,56 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): ...@@ -273,3 +273,56 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
# in get_or_create must be updated. # in get_or_create must be updated.
handle.set_num_sms(self.num_sms) handle.set_num_sms(self.num_sms)
return handle return handle
class DeepEPAutoAll2AllManager(All2AllManagerBase):
"""
Simplified auto manager that always builds handles through the
low-latency DeepEP manager. This avoids creating multiple buffer
instances and mirrors the sglang behavior of relying on LL buffers.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
self.ll_manager = DeepEPLLAll2AllManager(cpu_group)
self.ht_manager = DeepEPHTAll2AllManager(cpu_group)
def get_handle(self, kwargs):
"""
Build a DeepEP Buffer using LL args but sized to the larger of HT/LL
requirements (max of num_nvl_bytes/num_rdma_bytes).
"""
import deep_ep
kwargs = dict(kwargs)
# Build canonical kwargs for each path.
ll_kwargs = self.ll_manager._make_all2all_kwargs(**kwargs)
ht_kwargs = self.ht_manager._make_all2all_kwargs()
# Take the max for buffer sizes to be compatible with both modes.
merged_kwargs = dict(ll_kwargs)
merged_kwargs["num_nvl_bytes"] = max(ll_kwargs["num_nvl_bytes"],
ht_kwargs["num_nvl_bytes"])
merged_kwargs["num_rdma_bytes"] = max(ll_kwargs["num_rdma_bytes"],
ht_kwargs["num_rdma_bytes"])
logger.debug("DeepEP auto merged args %s", merged_kwargs)
handle: deep_ep.Buffer = self.ll_manager.handle_cache.get_or_create(
merged_kwargs, deep_ep.Buffer)
handle.set_num_sms(self.ll_manager.num_sms)
return handle
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError(
"DeepEPAutoAll2AllManager does not support dispatch directly; "
"use the underlying HT/LL managers.")
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError(
"DeepEPAutoAll2AllManager does not support combine directly; "
"use the underlying HT/LL managers.")
def destroy(self):
self.ll_manager.destroy()
...@@ -87,6 +87,10 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -87,6 +87,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
from .all2all import DeepEPLLAll2AllManager from .all2all import DeepEPLLAll2AllManager
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group) self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
logger.info("Using DeepEP Low-Latency all2all manager.") logger.info("Using DeepEP Low-Latency all2all manager.")
elif all2all_backend == "deepep_auto":
from .all2all import DeepEPAutoAll2AllManager
self.all2all_manager = DeepEPAutoAll2AllManager(self.cpu_group)
logger.info("Using DeepEP Auto all2all manager.")
elif all2all_backend == "mori": elif all2all_backend == "mori":
pass pass
else: else:
......
...@@ -187,6 +187,11 @@ class FusedMoEParallelConfig: ...@@ -187,6 +187,11 @@ class FusedMoEParallelConfig:
return (self.use_all2all_kernels return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
@property
def use_deepep_auto_kernels(self):
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
@staticmethod @staticmethod
def make(tp_size_: int, dp_size_: int, def make(tp_size_: int, dp_size_: int,
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
...@@ -385,6 +390,10 @@ class FusedMoEConfig: ...@@ -385,6 +390,10 @@ class FusedMoEConfig:
def use_deepep_ll_kernels(self): def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels return self.moe_parallel_config.use_deepep_ll_kernels
@property
def use_deepep_auto_kernels(self):
return self.moe_parallel_config.use_deepep_auto_kernels
@staticmethod @staticmethod
def make( def make(
num_experts: int, num_experts: int,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.forward_context import get_forward_context
class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"""
Auto Prepare/Finalize that wraps both DeepEP High-Throughput and
Low-Latency implementations and selects one based on prefill/decode phase.
"""
def __init__(self,
ht_prepare_finalize: mk.FusedMoEPrepareAndFinalize,
ll_prepare_finalize: mk.FusedMoEPrepareAndFinalize):
super().__init__()
self.ht_prepare_finalize = ht_prepare_finalize
self.ll_prepare_finalize = ll_prepare_finalize
self._current_phase = "decode" # default to prefill (HT)
def _get_current_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize:
"""Get the appropriate prepare_finalize based on current phase."""
# Try to infer phase from forward_context if available
try:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
# Handle both v0 (single AttentionMetadata) and v1 (dict) formats
if isinstance(attn_metadata, dict):
if attn_metadata:
attn_metadata = next(iter(attn_metadata.values()))
else:
attn_metadata = None
if attn_metadata is not None and hasattr(attn_metadata, 'num_prefill_tokens') and hasattr(attn_metadata, 'num_decode_tokens'):
# Only use prefill mode when BOTH conditions are met:
# 1. There are prefill tokens and no decode tokens
# 2. skip_cuda_graphs is True
is_prefill_tokens = attn_metadata.num_prefill_tokens > 0 and attn_metadata.num_decode_tokens == 0
skip_cuda_graphs = forward_context.skip_cuda_graphs
# Only use prefill (HT) when both conditions are satisfied
self._current_phase = "prefill" if (is_prefill_tokens and skip_cuda_graphs) else "decode"
except Exception:
# If forward_context is not available, use stored phase
pass
# Prefill uses HT, decode uses LL
# print("self._current_phase",self._current_phase)
# if self._current_phase == "prefill":
# return self.ht_prepare_finalize
# else:
return self.ll_prepare_finalize
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
# Use the current prepare_finalize's activation format
# Note: HT uses Standard, LL uses BatchedExperts
# Dynamically return based on current phase
prepare_finalize = self._get_current_prepare_finalize()
return prepare_finalize.activation_format
def topk_indices_dtype(self) -> Optional[torch.dtype]:
# Both HT and LL return int64
return torch.int64
def max_num_tokens_per_rank(self) -> Optional[int]:
# LL has a limit, HT returns None
return self.ll_prepare_finalize.max_num_tokens_per_rank()
def num_dispatchers(self) -> int:
# Both should return the same value
return self.ht_prepare_finalize.num_dispatchers()
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Route prepare call to the appropriate implementation."""
prepare_finalize = self._get_current_prepare_finalize()
return prepare_finalize.prepare(
a1, a1_scale, a2_scale, topk_weights, topk_ids,
num_experts, expert_map, apply_router_weight_on_input, quant_config)
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True
) -> None:
"""Route finalize call to the appropriate implementation."""
prepare_finalize = self._get_current_prepare_finalize()
return prepare_finalize.finalize(
output, fused_expert_output, topk_weights, topk_ids,
apply_router_weight_on_input, apply_weights_and_reduce)
def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True
):
"""Route finalize_async call to the appropriate implementation if available."""
prepare_finalize = self._get_current_prepare_finalize()
if hasattr(prepare_finalize, 'finalize_async'):
return prepare_finalize.finalize_async(
output, fused_expert_output, topk_weights, topk_ids,
apply_router_weight_on_input, apply_weights_and_reduce)
else:
# Fallback to synchronous finalize
return prepare_finalize.finalize(
output, fused_expert_output, topk_weights, topk_ids,
apply_router_weight_on_input, apply_weights_and_reduce)
...@@ -55,6 +55,7 @@ if current_platform.is_cuda_alike(): ...@@ -55,6 +55,7 @@ if current_platform.is_cuda_alike():
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE, from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE,
DeepEPLLPrepareAndFinalize) DeepEPLLPrepareAndFinalize)
from .deepep_auto_prepare_finalize import DeepEPAutoPrepareAndFinalize
else: else:
fused_experts = None # type: ignore fused_experts = None # type: ignore
FusedMoEPermuteExpertsUnpermute = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore
...@@ -140,6 +141,48 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -140,6 +141,48 @@ class FusedMoEMethodBase(QuantizeMethodBase):
num_local_experts=moe.num_local_experts, num_local_experts=moe.num_local_experts,
num_dispatchers=num_dispatchers, num_dispatchers=num_dispatchers,
) )
elif moe.use_deepep_auto_kernels:
# Initialize both HT and LL prepare_finalize but reuse the single
# LL handle for both (sglang-style single handle)
assert moe.dp_size == all2all_manager.dp_world_size
ll_all_to_all_args = dict(
max_num_tokens_per_dp_rank=moe.max_num_tokens,
token_hidden_size=moe.hidden_dim,
num_ep_ranks=all2all_manager.world_size,
num_global_experts=moe.num_experts,
num_local_experts=moe.num_experts //
all2all_manager.world_size,
)
ll_handle = all2all_manager.get_handle(ll_all_to_all_args)
# HT prepare/finalize built on the same LL handle per request
ht_prepare_finalize = DeepEPHTPrepareAndFinalize(
ll_handle,
num_dispatchers=all2all_manager.world_size,
dp_size=all2all_manager.dp_world_size,
rank_expert_offset=all2all_manager.rank *
moe.num_local_experts,
)
use_fp8_dispatch = (moe.quant_config is not None
and moe.quant_config.quant_dtype
== current_platform.fp8_dtype()
and moe.quant_config.block_shape
== DEEPEP_QUANT_BLOCK_SHAPE)
use_int8_dispatch = False
ll_prepare_finalize = DeepEPLLPrepareAndFinalize(
ll_handle,
max_tokens_per_rank=moe.max_num_tokens,
num_dispatchers=all2all_manager.world_size,
use_fp8_dispatch=use_fp8_dispatch,
use_int8_dispatch=use_int8_dispatch,
)
prepare_finalize = DeepEPAutoPrepareAndFinalize(
ht_prepare_finalize, ll_prepare_finalize)
elif moe.use_deepep_ht_kernels: elif moe.use_deepep_ht_kernels:
assert moe.dp_size == all2all_manager.dp_world_size assert moe.dp_size == all2all_manager.dp_world_size
......
...@@ -84,11 +84,14 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -84,11 +84,14 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
self.dp_size = get_dp_group().world_size self.dp_size = get_dp_group().world_size
self.ep_size = get_ep_group().world_size self.ep_size = get_ep_group().world_size
backend = envs.VLLM_ALL2ALL_BACKEND
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \ self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \ (backend == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") backend == "deepep_low_latency" or \
backend == "deepep_auto")
self.use_deepep_ll = self.use_deepep and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" self.use_deepep_ll = self.use_deepep and (backend == "deepep_low_latency" or \
(backend == "deepep_auto"))
if self.use_deepep: if self.use_deepep:
all2all_manager = get_ep_group().device_communicator.all2all_manager all2all_manager = get_ep_group().device_communicator.all2all_manager
......
...@@ -174,8 +174,12 @@ class DeepseekV2MoE(nn.Module): ...@@ -174,8 +174,12 @@ class DeepseekV2MoE(nn.Module):
dp_size = get_dp_group().world_size dp_size = get_dp_group().world_size
self.use_mori_ep = parallel_config.enable_expert_parallel and dp_size > 1 and envs.VLLM_ALL2ALL_BACKEND == 'mori' self.use_mori_ep = parallel_config.enable_expert_parallel and dp_size > 1 and envs.VLLM_ALL2ALL_BACKEND == 'mori'
self.enable_expert_parallel = parallel_config.enable_expert_parallel self.enable_expert_parallel = parallel_config.enable_expert_parallel
self.use_deepep_ll = dp_size > 1 and parallel_config.enable_expert_parallel and \ backend = envs.VLLM_ALL2ALL_BACKEND
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" self.use_deepep_ll = (
dp_size > 1
and parallel_config.enable_expert_parallel
and (backend == "deepep_low_latency" or backend == "deepep_auto")
)
if not self.use_deepep_ll: if not self.use_deepep_ll:
moe_cls = FusedMoE if not self.use_mori_ep else MoriMoE moe_cls = FusedMoE if not self.use_mori_ep else MoriMoE
...@@ -717,8 +721,12 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -717,8 +721,12 @@ class DeepseekV2DecoderLayer(nn.Module):
self.dp_size = get_dp_group().world_size self.dp_size = get_dp_group().world_size
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
self.use_deepep_ll = self.dp_size > 1 and parallel_config.enable_expert_parallel and \ backend = envs.VLLM_ALL2ALL_BACKEND
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" self.use_deepep_ll = (
self.dp_size > 1
and parallel_config.enable_expert_parallel
and (backend == "deepep_low_latency" or backend == "deepep_auto")
)
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
if (config.n_routed_experts is not None if (config.n_routed_experts is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment