Commit aae1cc39 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-ds_auto_12.29' into 'v0.9.2-dev-ds'

合并deepep的auto模式分支"

See merge request dcutoolkit/deeplearing/vllm!335
parents ace32edb c773cc66
...@@ -194,6 +194,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): ...@@ -194,6 +194,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes=num_rdma_bytes, num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False, low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank, num_qps_per_rank=num_qps_per_rank,
allow_mnnvl=envs.VLLM_ALLOW_MNNVL,
explicitly_destroy=False) explicitly_destroy=False)
def get_handle(self, kwargs): def get_handle(self, kwargs):
...@@ -275,3 +276,56 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): ...@@ -275,3 +276,56 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
# in get_or_create must be updated. # in get_or_create must be updated.
handle.set_num_sms(self.num_sms) handle.set_num_sms(self.num_sms)
return handle return handle
class DeepEPAutoAll2AllManager(All2AllManagerBase):
"""
Simplified auto manager that always builds handles through the
low-latency DeepEP manager. This avoids creating multiple buffer
instances and mirrors the sglang behavior of relying on LL buffers.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
self.ll_manager = DeepEPLLAll2AllManager(cpu_group)
self.ht_manager = DeepEPHTAll2AllManager(cpu_group)
def get_handle(self, kwargs):
"""
Build a DeepEP Buffer using LL args but sized to the larger of HT/LL
requirements (max of num_nvl_bytes/num_rdma_bytes).
"""
import deep_ep
kwargs = dict(kwargs)
# Build canonical kwargs for each path.
ll_kwargs = self.ll_manager._make_all2all_kwargs(**kwargs)
ht_kwargs = self.ht_manager._make_all2all_kwargs()
# Take the max for buffer sizes to be compatible with both modes.
merged_kwargs = dict(ll_kwargs)
merged_kwargs["num_nvl_bytes"] = max(ll_kwargs["num_nvl_bytes"],
ht_kwargs["num_nvl_bytes"])
merged_kwargs["num_rdma_bytes"] = max(ll_kwargs["num_rdma_bytes"],
ht_kwargs["num_rdma_bytes"])
logger.debug("DeepEP auto merged args %s", merged_kwargs)
handle: deep_ep.Buffer = self.ll_manager.handle_cache.get_or_create(
merged_kwargs, deep_ep.Buffer)
handle.set_num_sms(self.ll_manager.num_sms)
return handle
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError(
"DeepEPAutoAll2AllManager does not support dispatch directly; "
"use the underlying HT/LL managers.")
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError(
"DeepEPAutoAll2AllManager does not support combine directly; "
"use the underlying HT/LL managers.")
def destroy(self):
self.ll_manager.destroy()
...@@ -87,6 +87,10 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -87,6 +87,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
from .all2all import DeepEPLLAll2AllManager from .all2all import DeepEPLLAll2AllManager
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group) self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
logger.info("Using DeepEP Low-Latency all2all manager.") logger.info("Using DeepEP Low-Latency all2all manager.")
elif all2all_backend == "deepep_auto":
from .all2all import DeepEPAutoAll2AllManager
self.all2all_manager = DeepEPAutoAll2AllManager(self.cpu_group)
logger.info("Using DeepEP Auto all2all manager.")
elif all2all_backend == "mori": elif all2all_backend == "mori":
pass pass
else: else:
......
...@@ -128,6 +128,7 @@ if TYPE_CHECKING: ...@@ -128,6 +128,7 @@ if TYPE_CHECKING:
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557 VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
VLLM_ALL2ALL_BACKEND: str = "naive" VLLM_ALL2ALL_BACKEND: str = "naive"
VLLM_MOE_HT_THRESHOLD: int = 128
VLLM_ALLOW_MNNVL: bool = False VLLM_ALLOW_MNNVL: bool = False
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
...@@ -954,6 +955,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -954,6 +955,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ALL2ALL_BACKEND": "VLLM_ALL2ALL_BACKEND":
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"), lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
# VLLM_MOE_HT_THRESHOLD
"VLLM_MOE_HT_THRESHOLD":
lambda: int(os.getenv("VLLM_MOE_HT_THRESHOLD", "128")),
# use ALLOW_MNNVL # use ALLOW_MNNVL
"VLLM_ALLOW_MNNVL": "VLLM_ALLOW_MNNVL":
lambda: (os.environ.get("VLLM_ALLOW_MNNVL", "False").lower() in lambda: (os.environ.get("VLLM_ALLOW_MNNVL", "False").lower() in
......
...@@ -187,6 +187,11 @@ class FusedMoEParallelConfig: ...@@ -187,6 +187,11 @@ class FusedMoEParallelConfig:
return (self.use_all2all_kernels return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
@property
def use_deepep_auto_kernels(self):
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
@staticmethod @staticmethod
def make(tp_size_: int, dp_size_: int, def make(tp_size_: int, dp_size_: int,
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
...@@ -385,6 +390,10 @@ class FusedMoEConfig: ...@@ -385,6 +390,10 @@ class FusedMoEConfig:
def use_deepep_ll_kernels(self): def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels return self.moe_parallel_config.use_deepep_ll_kernels
@property
def use_deepep_auto_kernels(self):
return self.moe_parallel_config.use_deepep_auto_kernels
@staticmethod @staticmethod
def make( def make(
num_experts: int, num_experts: int,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.forward_context import get_forward_context
class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"""
Auto Prepare/Finalize that wraps both DeepEP High-Throughput and
Low-Latency implementations and selects one based on prefill/decode phase.
"""
def __init__(self,
ht_prepare_finalize: mk.FusedMoEPrepareAndFinalize,
ll_prepare_finalize: mk.FusedMoEPrepareAndFinalize):
super().__init__()
self.ht_prepare_finalize = ht_prepare_finalize
self.ll_prepare_finalize = ll_prepare_finalize
self._current_phase = "decode" # default to decode (LL)
def _get_current_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize:
"""Get the appropriate prepare_finalize based on current phase."""
# Try to infer phase from forward_context if available:
# try:
# forward_context = get_forward_context()
# attn_metadata = forward_context.attn_metadata
# # Handle both v0 (single AttentionMetadata) and v1 (dict) formats
# if isinstance(attn_metadata, dict):
# if attn_metadata:
# attn_metadata = next(iter(attn_metadata.values()))
# else:
# attn_metadata = None
# if attn_metadata is not None and hasattr(attn_metadata,
# "num_decode_tokens"):
# # 只根据 decode tokens 判定:有 decode -> decode,否则 prefill
# self._current_phase = ("decode"
# if attn_metadata.num_decode_tokens > 0
# else "prefill")
# except Exception:
# # If forward_context is not available, use stored phase
# pass
# Prefill uses HT, decode uses LL
if self._current_phase == "prefill":
#rint("************prefill***********")
return self.ll_prepare_finalize
else:
# print("attn_metadata.num_decode_tokens",attn_metadata.num_decode_tokens)
return self.ht_prepare_finalize
#return self.ht_prepare_finalize
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
pf = self._get_current_prepare_finalize()
try:
return pf.activation_format
except NotImplementedError:
# Fallback to standard format if underlying impl does not provide it.
return mk.FusedMoEActivationFormat.Standard
def topk_indices_dtype(self) -> Optional[torch.dtype]:
pf = self._get_current_prepare_finalize()
return pf.topk_indices_dtype()
def max_num_tokens_per_rank(self) -> Optional[int]:
pf = self._get_current_prepare_finalize()
return pf.max_num_tokens_per_rank()
def num_dispatchers(self) -> int:
pf = self._get_current_prepare_finalize()
return pf.num_dispatchers()
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
):
pf = self._get_current_prepare_finalize()
return pf.prepare_async(
a1, a1_scale, a2_scale, topk_weights, topk_ids,
num_experts, expert_map, apply_router_weight_on_input, quant_config)
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
pf = self._get_current_prepare_finalize()
return pf.prepare(
a1, a1_scale, a2_scale, topk_weights, topk_ids,
num_experts, expert_map, apply_router_weight_on_input, quant_config)
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True
) -> None:
pf = self._get_current_prepare_finalize()
return pf.finalize(
output, fused_expert_output, topk_weights, topk_ids,
apply_router_weight_on_input, apply_weights_and_reduce)
def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True
):
pf = self._get_current_prepare_finalize()
if hasattr(pf, "finalize_async"):
return pf.finalize_async(
output, fused_expert_output, topk_weights, topk_ids,
apply_router_weight_on_input, apply_weights_and_reduce)
return pf.finalize(
output, fused_expert_output, topk_weights, topk_ids,
apply_router_weight_on_input, apply_weights_and_reduce)
...@@ -55,6 +55,7 @@ if current_platform.is_cuda_alike(): ...@@ -55,6 +55,7 @@ if current_platform.is_cuda_alike():
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE, from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE,
DeepEPLLPrepareAndFinalize) DeepEPLLPrepareAndFinalize)
from .deepep_auto_prepare_finalize import DeepEPAutoPrepareAndFinalize
else: else:
fused_experts = None # type: ignore fused_experts = None # type: ignore
FusedMoEPermuteExpertsUnpermute = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore
...@@ -140,6 +141,62 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -140,6 +141,62 @@ class FusedMoEMethodBase(QuantizeMethodBase):
num_local_experts=moe.num_local_experts, num_local_experts=moe.num_local_experts,
num_dispatchers=num_dispatchers, num_dispatchers=num_dispatchers,
) )
elif moe.use_deepep_auto_kernels:
# Initialize both HT and LL prepare_finalize but reuse the single
# LL handle for both (sglang-style single handle)
assert moe.dp_size == all2all_manager.dp_world_size
ll_all_to_all_args = dict(
max_num_tokens_per_dp_rank=moe.max_num_tokens,
token_hidden_size=moe.hidden_dim,
num_ep_ranks=all2all_manager.world_size,
num_global_experts=moe.num_experts,
num_local_experts=moe.num_experts //
all2all_manager.world_size,
)
ll_handle = all2all_manager.get_handle(ll_all_to_all_args)
# HT prepare/finalize built on the same LL handle per request
ht_prepare_finalize = DeepEPHTPrepareAndFinalize(
ll_handle,
num_dispatchers=all2all_manager.world_size,
dp_size=all2all_manager.dp_world_size,
rank_expert_offset=all2all_manager.rank *
moe.num_local_experts,
)
use_fp8_dispatch = (moe.quant_config is not None
and moe.quant_config.quant_dtype
== current_platform.fp8_dtype()
and moe.quant_config.block_shape
== DEEPEP_QUANT_BLOCK_SHAPE)
use_int8_dispatch = False
ll_prepare_finalize = DeepEPLLPrepareAndFinalize(
ll_handle,
max_tokens_per_rank=moe.max_num_tokens,
num_dispatchers=all2all_manager.world_size,
use_fp8_dispatch=use_fp8_dispatch,
use_int8_dispatch=use_int8_dispatch,
)
prepare_finalize = DeepEPAutoPrepareAndFinalize(
ht_prepare_finalize, ll_prepare_finalize)
experts_ht = self.select_gemm_impl(ht_prepare_finalize, moe)
experts_ll = self.select_gemm_impl(ll_prepare_finalize, moe)
self.topk_indices_dtype = ll_prepare_finalize.topk_indices_dtype()
self.fused_experts = DeepGemmDisabledFusedMoEModularKernel(
prepare_finalize,
experts_ll,
experts_ht=experts_ht,
experts_ll=experts_ll,
shared_experts=layer.shared_experts if hasattr(layer, "shared_experts") else None,
)
return
elif moe.use_deepep_ht_kernels: elif moe.use_deepep_ht_kernels:
assert moe.dp_size == all2all_manager.dp_world_size assert moe.dp_size == all2all_manager.dp_world_size
...@@ -854,7 +911,7 @@ class FusedMoE(torch.nn.Module): ...@@ -854,7 +911,7 @@ class FusedMoE(torch.nn.Module):
self.batched_hidden_states: Optional[torch.Tensor] = None self.batched_hidden_states: Optional[torch.Tensor] = None
self.batched_router_logits: Optional[torch.Tensor] = None self.batched_router_logits: Optional[torch.Tensor] = None
if (self.moe_parallel_config.use_pplx_kernels if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels): or self.moe_parallel_config.use_deepep_ll_kernels or self.moe_parallel_config.use_deepep_auto_kernels):
self.batched_hidden_states = torch.zeros( self.batched_hidden_states = torch.zeros(
(moe.max_num_tokens, self.hidden_size), (moe.max_num_tokens, self.hidden_size),
dtype=moe.in_dtype, dtype=moe.in_dtype,
...@@ -916,7 +973,10 @@ class FusedMoE(torch.nn.Module): ...@@ -916,7 +973,10 @@ class FusedMoE(torch.nn.Module):
@property @property
def use_deepep_ll_kernels(self): def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels return self.moe_parallel_config.use_deepep_ll_kernels
@property
def use_deepep_auto_kernels(self):
return self.moe_parallel_config.use_deepep_auto_kernels
@property @property
def shared_experts(self) -> Optional[torch.nn.Module]: def shared_experts(self) -> Optional[torch.nn.Module]:
return None return None
...@@ -1443,7 +1503,7 @@ class FusedMoE(torch.nn.Module): ...@@ -1443,7 +1503,7 @@ class FusedMoE(torch.nn.Module):
early. early.
""" """
return (self.use_pplx_kernels or self.use_deepep_ht_kernels return (self.use_pplx_kernels or self.use_deepep_ht_kernels
or self.use_deepep_ll_kernels) or self.use_deepep_ll_kernels or self.use_deepep_auto_kernels)
def maybe_all_reduce_tensor_model_parallel( def maybe_all_reduce_tensor_model_parallel(
self, final_hidden_states: torch.Tensor): self, final_hidden_states: torch.Tensor):
...@@ -1451,7 +1511,7 @@ class FusedMoE(torch.nn.Module): ...@@ -1451,7 +1511,7 @@ class FusedMoE(torch.nn.Module):
The pplx combine kernel reduces across GPU ranks by default. The pplx combine kernel reduces across GPU ranks by default.
""" """
if (self.use_pplx_kernels or self.use_deepep_ht_kernels if (self.use_pplx_kernels or self.use_deepep_ht_kernels
or self.use_deepep_ll_kernels): or self.use_deepep_ll_kernels or self.use_deepep_auto_kernels):
return final_hidden_states return final_hidden_states
else: else:
return tensor_model_parallel_all_reduce(final_hidden_states) return tensor_model_parallel_all_reduce(final_hidden_states)
......
...@@ -6,7 +6,9 @@ from math import prod ...@@ -6,7 +6,9 @@ from math import prod
from typing import Optional, final from typing import Optional, final
from dataclasses import dataclass from dataclasses import dataclass
from collections.abc import Callable from collections.abc import Callable
from vllm.logger import init_logger
logger = init_logger(__name__)
import torch import torch
import vllm.envs as envs import vllm.envs as envs
...@@ -843,11 +845,16 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -843,11 +845,16 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
self, self,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: CustomizedFusedMoEPermuteExpertsUnpermute, fused_experts: CustomizedFusedMoEPermuteExpertsUnpermute,
experts_ht: CustomizedFusedMoEPermuteExpertsUnpermute = None,
experts_ll: CustomizedFusedMoEPermuteExpertsUnpermute = None,
shared_experts: Optional[torch.nn.Module] = None, shared_experts: Optional[torch.nn.Module] = None,
): ):
super().__init__() super().__init__()
self.prepare_finalize = prepare_finalize self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts self.fused_experts = fused_experts
self.fused_experts_ht = experts_ht
self.fused_experts_ll = experts_ll
self.shared_experts = shared_experts self.shared_experts = shared_experts
if self.shared_experts is not None: if self.shared_experts is not None:
...@@ -919,7 +926,29 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -919,7 +926,29 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
""" """
prepare_finalize = self.prepare_finalize
fused_experts = self.fused_experts
# from vllm.config import get_current_vllm_config
# vllm_cfg = get_current_vllm_config()
# max_tokens_for_cudagraph = vllm_cfg.compilation_config.max_capture_size
# num_ht_ll_tokens = max_tokens_for_cudagraph
if envs.VLLM_ALL2ALL_BACKEND == "deepep_auto":
num_ht_ll_tokens = envs.VLLM_MOE_HT_THRESHOLD
num_tokens = hidden_states.size(0)
# logger.info("num_tokens=%d", num_tokens)
if num_tokens > num_ht_ll_tokens:
prepare_finalize = self.prepare_finalize.ht_prepare_finalize
fused_experts = self.fused_experts_ht
else:
prepare_finalize = self.prepare_finalize.ll_prepare_finalize
fused_experts = self.fused_experts_ll
a1 = hidden_states a1 = hidden_states
if inplace and self.shared_experts is None: if inplace and self.shared_experts is None:
...@@ -931,7 +960,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -931,7 +960,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = local_num_experts global_num_experts = local_num_experts
prepare_ret = self.prepare_finalize.prepare_async( prepare_ret = prepare_finalize.prepare_async(
a1, a1,
a1_scale, a1_scale,
a2_scale, a2_scale,
...@@ -940,7 +969,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -940,7 +969,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
global_num_experts, global_num_experts,
expert_map, expert_map,
apply_router_weight_on_input, apply_router_weight_on_input,
self.fused_experts.quant_config, fused_experts.quant_config,
) )
hook, receiver = ( hook, receiver = (
prepare_ret if isinstance(prepare_ret, tuple) else (None, prepare_ret) prepare_ret if isinstance(prepare_ret, tuple) else (None, prepare_ret)
...@@ -971,7 +1000,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -971,7 +1000,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
# and can never run into the tensor.numel() == 0 case. # and can never run into the tensor.numel() == 0 case.
fused_out = torch.empty_like(a1q).to(dtype=a1.dtype) fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
else: else:
fused_out = self.fused_experts.apply( fused_out = fused_experts.apply(
None, None,
a1, a1,
a1q, a1q,
...@@ -1008,12 +1037,12 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module): ...@@ -1008,12 +1037,12 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
self.alt_stream.wait_event(self.alt_event) self.alt_stream.wait_event(self.alt_event)
hook = None hook = None
if self.prepare_finalize.activation_format == \ if prepare_finalize.activation_format == \
FusedMoEActivationFormat.BatchedExperts: FusedMoEActivationFormat.BatchedExperts:
self.prepare_finalize.finalize(output, fused_out, topk_weights, prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True) topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True)
else: else:
hook = self.prepare_finalize.finalize_async(output, fused_out, topk_weights, hook = prepare_finalize.finalize_async(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True) topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True)
if hook is not None: if hook is not None:
hook() hook()
......
...@@ -87,9 +87,11 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -87,9 +87,11 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
self.dp_size = get_dp_group().world_size self.dp_size = get_dp_group().world_size
self.ep_size = get_ep_group().world_size self.ep_size = get_ep_group().world_size
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \ self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \ (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
self.use_deepgemm = False self.use_deepgemm = False
if self.use_deepep: if self.use_deepep:
...@@ -97,7 +99,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -97,7 +99,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
assert all2all_manager is not None assert all2all_manager is not None
self.num_dispatchers = all2all_manager.world_size self.num_dispatchers = all2all_manager.world_size
self.block_shape = [256, 256] self.block_shape = [256, 256]
self.use_deepgemm = envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM self.use_deepgemm = envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM or envs.VLLM_ALL2ALL_BACKEND == "deepep_auto"
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
......
...@@ -176,7 +176,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -176,7 +176,8 @@ class DeepseekV2MoE(nn.Module):
self.enable_expert_parallel = parallel_config.enable_expert_parallel self.enable_expert_parallel = parallel_config.enable_expert_parallel
self.use_deepep = dp_size > 1 and parallel_config.enable_expert_parallel and \ self.use_deepep = dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \ (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
if not self.use_deepep: if not self.use_deepep:
moe_cls = FusedMoE if not self.use_mori_ep else MoriMoE moe_cls = FusedMoE if not self.use_mori_ep else MoriMoE
...@@ -724,7 +725,8 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -724,7 +725,8 @@ class DeepseekV2DecoderLayer(nn.Module):
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \ self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \ (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.config = config self.config = config
...@@ -961,7 +963,8 @@ class DeepseekV2Model(nn.Module): ...@@ -961,7 +963,8 @@ class DeepseekV2Model(nn.Module):
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \ self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \ (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
......
...@@ -10,8 +10,9 @@ from vllm.attention.layer import Attention ...@@ -10,8 +10,9 @@ from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig, from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config) get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context from vllm.forward_context import DPMetadata, set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
import vllm.envs as envs
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
...@@ -184,7 +185,10 @@ class EagleProposer: ...@@ -184,7 +185,10 @@ class EagleProposer:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else: else:
num_input_tokens = num_tokens num_input_tokens = num_tokens
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
num_input_tokens += num_pad
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states self.hidden_states[:num_tokens] = target_hidden_states
...@@ -223,7 +227,8 @@ class EagleProposer: ...@@ -223,7 +227,8 @@ class EagleProposer:
with set_forward_context(per_layer_attn_metadata, with set_forward_context(per_layer_attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_input_tokens,): num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp):
#skip_cuda_graphs=not decoding): #skip_cuda_graphs=not decoding):
ret_hidden_states = self.model( ret_hidden_states = self.model(
self.input_ids[:num_input_tokens], self.input_ids[:num_input_tokens],
...@@ -369,7 +374,8 @@ class EagleProposer: ...@@ -369,7 +374,8 @@ class EagleProposer:
# Run the model. # Run the model.
with set_forward_context(per_layer_attn_metadata, with set_forward_context(per_layer_attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=input_batch_size): num_tokens=input_batch_size,
num_tokens_across_dp=num_tokens_across_dp):
ret_hidden_states = self.model( ret_hidden_states = self.model(
self.input_ids[:input_batch_size], self.input_ids[:input_batch_size],
self.positions[:input_batch_size], self.positions[:input_batch_size],
...@@ -496,6 +502,40 @@ class EagleProposer: ...@@ -496,6 +502,40 @@ class EagleProposer:
logger.info("Loading EAGLE LM head weights from the target model.") logger.info("Loading EAGLE LM head weights from the target model.")
self.model.lm_head = target_language_model.lm_head self.model.lm_head = target_language_model.lm_head
def get_dp_padding(self,
num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
dp_size = self.vllm_config.parallel_config.data_parallel_size
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
# For DP: Don't pad when setting enforce_eager.
# This lets us set enforce_eager on the prefiller in a P/D setup and
# still use CUDA graphs (enabled by this padding) on the decoder.
#
# TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager or envs.VLLM_ALL2ALL_BACKEND != 'naive':
# auto
if not envs.VLLM_ALL2ALL_BACKEND == "deepep_auto":
# Early exit.
return 0, None
try:
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
num_tokens, dp_size, dp_rank)
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] *
dp_size,
device="cpu",
dtype=torch.int32)
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
except (RuntimeError, AttributeError) as e:
# DP group may not be initialized yet during dummy run
# Skip padding in this case
logger.debug(
"Skipping DP padding in eagle get_dp_padding due to: %s", e)
return 0, None
@torch.inference_mode() @torch.inference_mode()
def dummy_run( def dummy_run(
self, self,
...@@ -505,20 +545,28 @@ class EagleProposer: ...@@ -505,20 +545,28 @@ class EagleProposer:
if attn_metadata is not None and self.attn_metadata_cudagraph is None: if attn_metadata is not None and self.attn_metadata_cudagraph is None:
self.attn_metadata_cudagraph = attn_metadata[ self.attn_metadata_cudagraph = attn_metadata[
self.attn_layer_names[0]] self.attn_layer_names[0]]
# Padding for DP
num_input_tokens = num_tokens
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_input_tokens += num_pad
with set_forward_context(attn_metadata, with set_forward_context(attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_tokens): num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp):
self.model( self.model(
self.input_ids[:num_tokens], self.input_ids[:num_input_tokens],
self.positions[:num_tokens], self.positions[:num_input_tokens],
self.hidden_states[:num_tokens], self.hidden_states[:num_input_tokens],
) )
if self.dp_size > 1 and self.enable_expert_parallel and self.num_speculative_tokens > 1: if self.dp_size > 1 and self.enable_expert_parallel and self.num_speculative_tokens > 1:
num_tokens = 1
for _ in range(self.num_speculative_tokens - 1): for _ in range(self.num_speculative_tokens - 1):
with set_forward_context(attn_metadata, with set_forward_context(attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_tokens): num_tokens=num_tokens,):
self.model( self.model(
self.input_ids[:num_tokens], self.input_ids[:num_tokens],
self.positions[:num_tokens], self.positions[:num_tokens],
......
...@@ -1246,8 +1246,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1246,8 +1246,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# prefills, causing unnecessary and excessive padding of activations. # prefills, causing unnecessary and excessive padding of activations.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager or envs.VLLM_ALL2ALL_BACKEND != 'naive': if dp_size == 1 or self.vllm_config.model_config.enforce_eager or envs.VLLM_ALL2ALL_BACKEND != 'naive':
# Early exit. # auto
return 0, None if not envs.VLLM_ALL2ALL_BACKEND == "deepep_auto":
# Early exit.
return 0, None
num_tokens_across_dp = DPMetadata.num_tokens_across_dp( num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
num_tokens, dp_size, dp_rank) num_tokens, dp_size, dp_rank)
......
...@@ -108,6 +108,9 @@ class V1ZeroEagleProposer(EagleProposer): ...@@ -108,6 +108,9 @@ class V1ZeroEagleProposer(EagleProposer):
else: else:
num_input_tokens = num_tokens num_input_tokens = num_tokens
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
num_input_tokens += num_pad
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states self.hidden_states[:num_tokens] = target_hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment