Commit 4e51cae7 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-wm-0112' into 'v0.9.2-dev'

[feat]添加dp attention功能

See merge request dcutoolkit/deeplearing/vllm!383
parents fe2e2705 cc4d1002
...@@ -1883,6 +1883,9 @@ class ParallelConfig: ...@@ -1883,6 +1883,9 @@ class ParallelConfig:
""" Use data parallelism instead of tensor parallelism for vision encoder. """ Use data parallelism instead of tensor parallelism for vision encoder.
Only support LLama4 for now""" Only support LLama4 for now"""
enable_dp_attention: bool = False
"""Enable dp attention"""
@property @property
def world_size_across_dp(self) -> int: def world_size_across_dp(self) -> int:
"""world_size_across_dp is TPxPPxDP, it is the size of the world """world_size_across_dp is TPxPPxDP, it is the size of the world
...@@ -2108,6 +2111,9 @@ class ParallelConfig: ...@@ -2108,6 +2111,9 @@ class ParallelConfig:
if self.ray_workers_use_nsight and not self.use_ray: if self.ray_workers_use_nsight and not self.use_ray:
raise ValueError("Unable to use nsight profiling unless workers " raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.") "run with Ray.")
if self.enable_dp_attention and self.enable_expert_parallel:
raise ValueError("Dp attention and expert parallel can not enable together.")
return self return self
...@@ -4805,6 +4811,7 @@ class VllmConfig: ...@@ -4805,6 +4811,7 @@ class VllmConfig:
dp_size = self.parallel_config.data_parallel_size dp_size = self.parallel_config.data_parallel_size
tp_size = self.parallel_config.tensor_parallel_size tp_size = self.parallel_config.tensor_parallel_size
ep_sp = self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1 ep_sp = self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1
enable_dp_attention = self.parallel_config.enable_dp_attention
# add for spec decode # add for spec decode
if self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0: if self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0:
...@@ -4813,10 +4820,10 @@ class VllmConfig: ...@@ -4813,10 +4820,10 @@ class VllmConfig:
batch_size_capture_list = sorted(set(batch_size_capture_list + mtp_batch_size_capture_list)) batch_size_capture_list = sorted(set(batch_size_capture_list + mtp_batch_size_capture_list))
batch_size_capture_list = [i for i in batch_size_capture_list if i == 1 or i % (1 + self.speculative_config.num_lookahead_slots) == 0] batch_size_capture_list = [i for i in batch_size_capture_list if i == 1 or i % (1 + self.speculative_config.num_lookahead_slots) == 0]
if ep_sp: if ep_sp or enable_dp_attention:
batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list])) batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list]))
else: else:
if ep_sp: if ep_sp or enable_dp_attention:
batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list])) batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list]))
self.compilation_config.init_with_cudagraph_sizes( self.compilation_config.init_with_cudagraph_sizes(
......
...@@ -103,7 +103,7 @@ class DeviceCommunicatorBase: ...@@ -103,7 +103,7 @@ class DeviceCommunicatorBase:
# as long as we use data parallel (coupled data parallel # as long as we use data parallel (coupled data parallel
# where all data parallel ranks execute forward together), # where all data parallel ranks execute forward together),
# we initialize the all2all manager used in expert parallel. # we initialize the all2all manager used in expert parallel.
use_ep = config.parallel_config.data_parallel_size > 1 use_ep = config.parallel_config.data_parallel_size > 1 and not config.parallel_config.enable_dp_attention
self.use_all2all = "ep" in unique_name and use_ep self.use_all2all = "ep" in unique_name and use_ep
self.all2all_manager: Optional[All2AllManagerBase] = None self.all2all_manager: Optional[All2AllManagerBase] = None
......
...@@ -476,6 +476,9 @@ class EngineArgs: ...@@ -476,6 +476,9 @@ class EngineArgs:
enable_multimodal_encoder_data_parallel: bool = \ enable_multimodal_encoder_data_parallel: bool = \
ParallelConfig.enable_multimodal_encoder_data_parallel ParallelConfig.enable_multimodal_encoder_data_parallel
enable_dp_attention: bool = \
ParallelConfig.enable_dp_attention
def __post_init__(self): def __post_init__(self):
# support `EngineArgs(compilation_config={...})` # support `EngineArgs(compilation_config={...})`
...@@ -718,6 +721,10 @@ class EngineArgs: ...@@ -718,6 +721,10 @@ class EngineArgs:
parallel_group.add_argument( parallel_group.add_argument(
"--enable-multimodal-encoder-data-parallel", "--enable-multimodal-encoder-data-parallel",
**parallel_kwargs["enable_multimodal_encoder_data_parallel"]) **parallel_kwargs["enable_multimodal_encoder_data_parallel"])
parallel_group.add_argument(
"--enable-dp-attention",
**parallel_kwargs["enable_dp_attention"])
# KV cache arguments # KV cache arguments
cache_kwargs = get_kwargs(CacheConfig) cache_kwargs = get_kwargs(CacheConfig)
...@@ -1204,6 +1211,7 @@ class EngineArgs: ...@@ -1204,6 +1211,7 @@ class EngineArgs:
worker_extension_cls=self.worker_extension_cls, worker_extension_cls=self.worker_extension_cls,
enable_multimodal_encoder_data_parallel=self. enable_multimodal_encoder_data_parallel=self.
enable_multimodal_encoder_data_parallel, enable_multimodal_encoder_data_parallel,
enable_dp_attention=self.enable_dp_attention,
) )
speculative_config = self.create_speculative_config( speculative_config = self.create_speculative_config(
......
...@@ -203,6 +203,7 @@ if TYPE_CHECKING: ...@@ -203,6 +203,7 @@ if TYPE_CHECKING:
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
VLLM_USE_FUSED_FILL_RMS_CAT:bool = False VLLM_USE_FUSED_FILL_RMS_CAT:bool = False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True
VLLM_ENABLE_DEEPEP_INT8_DISPATCH: bool = True
VLLM_ZERO_OVERHEAD_ENHANCE: bool = False VLLM_ZERO_OVERHEAD_ENHANCE: bool = False
VLLM_USE_FUSED_QA_KVA_GEMM: bool = False VLLM_USE_FUSED_QA_KVA_GEMM: bool = False
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
...@@ -1325,6 +1326,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1325,6 +1326,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM": "VLLM_ENABLE_DEEPEP_HT_DEEPGEMM":
lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in
("true", "1")), ("true", "1")),
# vLLM will use deepep int8 dispatch
"VLLM_ENABLE_DEEPEP_INT8_DISPATCH":
lambda: (os.getenv('VLLM_ENABLE_DEEPEP_INT8_DISPATCH', '1').lower() in
("true", "1")),
# Only quantized DeepSeek models supported. # Only quantized DeepSeek models supported.
# Unquantized versions are not supported. # Unquantized versions are not supported.
......
...@@ -136,8 +136,8 @@ def set_forward_context( ...@@ -136,8 +136,8 @@ def set_forward_context(
forward_start_time = time.perf_counter() forward_start_time = time.perf_counter()
dp_metadata: Optional[DPMetadata] = None dp_metadata: Optional[DPMetadata] = None
dp_size = vllm_config.parallel_config.data_parallel_size dp_size = vllm_config.parallel_config.data_parallel_size
use_navie_ep = envs.VLLM_ALL2ALL_BACKEND == 'naive' and dp_size > 1 and vllm_config.parallel_config.enable_expert_parallel use_navie_all2all = envs.VLLM_ALL2ALL_BACKEND == 'naive' and dp_size > 1
if use_navie_ep and dp_size > 1 and ( if use_navie_all2all and dp_size > 1 and (
attn_metadata is not None or num_tokens is not None): attn_metadata is not None or num_tokens is not None):
dp_metadata = DPMetadata.make(vllm_config.parallel_config, dp_metadata = DPMetadata.make(vllm_config.parallel_config,
attn_metadata, num_tokens or 0, attn_metadata, num_tokens or 0,
...@@ -210,4 +210,15 @@ def set_profilling(profiling): ...@@ -210,4 +210,15 @@ def set_profilling(profiling):
def get_profilling() -> bool: def get_profilling() -> bool:
global _profiling global _profiling
return _profiling return _profiling
\ No newline at end of file
_warming_up = False
@contextmanager
def set_warming_up(warming_up):
global _warming_up
_warming_up = warming_up
def get_warming_up() -> bool:
global _warming_up
return _warming_up
\ No newline at end of file
from typing import TYPE_CHECKING, List, Optional, Tuple
import logging
import torch
import vllm.envs as envs
from vllm.distributed.parallel_state import GroupCoordinator, init_model_parallel_group, get_world_group
from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
get_tensor_model_parallel_rank,
tensor_model_parallel_reduce_scatter,
get_tp_group)
_ENABLE_DP_ATTENTION_FLAG: bool = False
_MOE_TP: Optional[GroupCoordinator] = None
_ATTN_DP_SIZE = 0
_ATTN_TP_SIZE = 0
_ATTN_TP_RANK = 0
_ATTN_DP_RANK = 0
_MOT_TP_SIZE = 0
_MOT_TP_RANK = 0
def initialize_dp_attention(vllm_config, backend: Optional[str] = None):
from vllm.config import VllmConfig
assert isinstance(vllm_config, VllmConfig)
global _ENABLE_DP_ATTENTION_FLAG, _ATTN_DP_SIZE, _ATTN_TP_SIZE, _ATTN_TP_RANK, _ATTN_DP_RANK, _MOT_TP_SIZE, _MOT_TP_RANK
enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
_ENABLE_DP_ATTENTION_FLAG = enable_dp_attention
# Build the moe tensor model-parallel groups.
world_size: int = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
data_parallel_size = vllm_config.parallel_config.data_parallel_size
pipeline_model_parallel_size = vllm_config.parallel_config.pipeline_parallel_size
tensor_model_parallel_size = vllm_config.parallel_config.tensor_parallel_size
moe_tp_size = world_size // pipeline_model_parallel_size
moe_ep_size = moe_tp_size if vllm_config.parallel_config.enable_expert_parallel else 1
_ATTN_DP_SIZE = data_parallel_size
_ATTN_TP_SIZE = tensor_model_parallel_size
_ATTN_TP_RANK = get_tensor_model_parallel_rank()
_ATTN_DP_RANK = vllm_config.parallel_config.data_parallel_rank
_MOT_TP_SIZE = moe_tp_size
_MOT_TP_RANK = rank % _MOT_TP_SIZE
global _MOE_TP
assert _MOE_TP is None, ("moe tensor model parallel group is already initialized")
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)
group_ranks = []
for i in range(pipeline_model_parallel_size):
ranks = list(
range(i * moe_tp_size, (i + 1) * moe_tp_size)
)
group_ranks.append(ranks)
# message queue broadcaster is only used in tensor model parallel group
_MOE_TP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True,
group_name="moe_tp")
def get_attention_tp_size() -> int:
assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
return _ATTN_TP_SIZE
def get_attention_tp_rank() -> int:
assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
return _ATTN_TP_RANK
def get_moe_tp_group() -> GroupCoordinator:
assert _MOE_TP is not None, ("tensor model parallel group is not initialized")
return _MOE_TP
def get_attention_dp_size() -> int:
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
return _ATTN_DP_SIZE
def get_moe_tp_rank() -> int:
assert _MOT_TP_RANK is not None, "dp attention not initialized!"
return _MOT_TP_RANK
def get_moe_tp_size() -> int:
assert _MOT_TP_SIZE is not None, "dp attention not initialized!"
return _MOT_TP_SIZE
def get_attention_tp_group() -> GroupCoordinator:
return get_tp_group()
def moe_tensor_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_moe_tp_group().all_gather(input_, dim)
def moe_tensor_model_parallel_reduce_scatter(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""Reduce-Scatter the input tensor across model parallel group."""
return get_moe_tp_group().reduce_scatter(input_, dim)
def dp_gather(
hidden_states: torch.Tensor,)-> torch.Tensor:
if get_attention_tp_size() == 1:
hidden_states = moe_tensor_model_parallel_all_gather(hidden_states, dim=0)
return hidden_states
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
hidden_states = moe_tensor_model_parallel_all_gather(hidden_states, dim=0)
return hidden_states
def dp_reduce_scatter_tensor(hidden_states: torch.Tensor)-> torch.Tensor:
if get_moe_tp_group().world_size == get_attention_dp_size():
hidden_states = moe_tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
else:
hidden_states = moe_tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
return hidden_states
...@@ -38,6 +38,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -38,6 +38,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk, is_power_of_two) fused_topk, grouped_topk, is_power_of_two)
from vllm.model_executor.layers.dp_attention import get_moe_tp_rank
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
...@@ -237,7 +238,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -237,7 +238,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
all2all_manager.world_size, all2all_manager.world_size,
) )
ll_handle = all2all_manager.get_handle(ll_all_to_all_args) ll_handle = all2all_manager.get_handle(ll_all_to_all_args)
# HT prepare/finalize built on the same LL handle per request # HT prepare/finalize built on the same LL handle per request
ht_prepare_finalize = DeepEPHTPrepareAndFinalize( ht_prepare_finalize = DeepEPHTPrepareAndFinalize(
ll_handle, ll_handle,
...@@ -253,7 +254,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -253,7 +254,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
and moe.quant_config.block_shape and moe.quant_config.block_shape
== DEEPEP_QUANT_BLOCK_SHAPE) == DEEPEP_QUANT_BLOCK_SHAPE)
use_int8_dispatch = False use_int8_dispatch = moe.quant_config.quant_dtype == torch.int8 and envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
ll_prepare_finalize = DeepEPLLPrepareAndFinalize( ll_prepare_finalize = DeepEPLLPrepareAndFinalize(
ll_handle, ll_handle,
...@@ -265,10 +266,10 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -265,10 +266,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
prepare_finalize = DeepEPAutoPrepareAndFinalize( prepare_finalize = DeepEPAutoPrepareAndFinalize(
ht_prepare_finalize, ll_prepare_finalize) ht_prepare_finalize, ll_prepare_finalize)
experts_ht = self.select_gemm_impl(ht_prepare_finalize, moe) experts_ht = self.select_gemm_impl(ht_prepare_finalize, moe)
experts_ll = self.select_gemm_impl(ll_prepare_finalize, moe) experts_ll = self.select_gemm_impl(ll_prepare_finalize, moe)
self.topk_indices_dtype = ll_prepare_finalize.topk_indices_dtype() self.topk_indices_dtype = ll_prepare_finalize.topk_indices_dtype()
self.fused_experts = DeepGemmDisabledFusedMoEModularKernel( self.fused_experts = DeepGemmDisabledFusedMoEModularKernel(
prepare_finalize, prepare_finalize,
...@@ -276,9 +277,9 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -276,9 +277,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
experts_ht=experts_ht, experts_ht=experts_ht,
experts_ll=experts_ll, experts_ll=experts_ll,
shared_experts=layer.shared_experts if hasattr(layer, "shared_experts") else None, shared_experts=layer.shared_experts if hasattr(layer, "shared_experts") else None,
) )
return return
elif moe.use_deepep_ht_kernels: elif moe.use_deepep_ht_kernels:
assert moe.dp_size == all2all_manager.dp_world_size assert moe.dp_size == all2all_manager.dp_world_size
...@@ -310,7 +311,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -310,7 +311,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
and moe.quant_config.block_shape and moe.quant_config.block_shape
== DEEPEP_QUANT_BLOCK_SHAPE) == DEEPEP_QUANT_BLOCK_SHAPE)
use_int8_dispatch = moe.quant_config.quant_dtype == torch.int8 use_int8_dispatch = moe.quant_config.quant_dtype == torch.int8 and envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
# Note (varun): Whether to use FP8 dispatch or not needs some # Note (varun): Whether to use FP8 dispatch or not needs some
# profiling. Turning it off for now. # profiling. Turning it off for now.
...@@ -957,6 +958,8 @@ class FusedMoE(torch.nn.Module): ...@@ -957,6 +958,8 @@ class FusedMoE(torch.nn.Module):
self.logical_to_physical_map: Optional[torch.Tensor] = None self.logical_to_physical_map: Optional[torch.Tensor] = None
self.logical_replica_count: Optional[torch.Tensor] = None self.logical_replica_count: Optional[torch.Tensor] = None
self.enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
# Determine expert maps # Determine expert maps
if self.use_ep: if self.use_ep:
if self.enable_eplb: if self.enable_eplb:
...@@ -1693,7 +1696,8 @@ class FusedMoE(torch.nn.Module): ...@@ -1693,7 +1696,8 @@ class FusedMoE(torch.nn.Module):
The pplx combine kernel reduces across GPU ranks by default. The pplx combine kernel reduces across GPU ranks by default.
""" """
if (self.use_pplx_kernels or self.use_deepep_ht_kernels if (self.use_pplx_kernels or self.use_deepep_ht_kernels
or self.use_deepep_ll_kernels or self.use_deepep_auto_kernels): or self.use_deepep_ll_kernels or self.use_deepep_auto_kernels
or self.enable_dp_attention):
return final_hidden_states return final_hidden_states
else: else:
return tensor_model_parallel_all_reduce(final_hidden_states) return tensor_model_parallel_all_reduce(final_hidden_states)
...@@ -1832,6 +1836,7 @@ class FusedMoE(torch.nn.Module): ...@@ -1832,6 +1836,7 @@ class FusedMoE(torch.nn.Module):
and not self.moe_parallel_config.use_deepep_ht_kernels and not self.moe_parallel_config.use_deepep_ht_kernels
and not self.moe_parallel_config.use_deepep_ll_kernels and not self.moe_parallel_config.use_deepep_ll_kernels
and not self.moe_parallel_config.use_deepep_auto_kernels and not self.moe_parallel_config.use_deepep_auto_kernels
and not self.enable_dp_attention
) )
if do_naive_dispatch_combine: if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits = get_ep_group().dispatch(
......
...@@ -22,6 +22,7 @@ from vllm.utils import round_up ...@@ -22,6 +22,7 @@ from vllm.utils import round_up
try: try:
from lmslim.layers.gemm.int8_utils import ( from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8) per_token_group_quant_int8, per_token_quant_int8)
from lightop import op
except Exception: except Exception:
print("INFO: Please install lmslim if you want to use int utils.\n") print("INFO: Please install lmslim if you want to use int utils.\n")
from vllm.utils import cdiv from vllm.utils import cdiv
...@@ -622,52 +623,62 @@ def ep_scatter( ...@@ -622,52 +623,62 @@ def ep_scatter(
num_experts = num_recv_tokens_per_expert.shape[0] num_experts = num_recv_tokens_per_expert.shape[0]
hidden_size = recv_x.shape[1] hidden_size = recv_x.shape[1]
scale_hidden_size = recv_x_scale.shape[-1] scale_hidden_size = recv_x_scale.shape[-1]
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid = num_experts
assert m_indices.shape[0] % BLOCK_E == 0 if hasattr(op, "ep_scatter"):
op.ep_scatter(
recv_x, recv_x_scale,
recv_topk, expert_map,
num_recv_tokens_per_expert,
output_tensor, output_tensor_scale, m_indices, output_index,
num_experts, BLOCK_E
)
else:
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid = num_experts
_fwd_kernel_ep_scatter_1[(grid,)]( assert m_indices.shape[0] % BLOCK_E == 0
num_recv_tokens_per_expert,
expert_start_loc,
m_indices,
num_experts=num_experts,
num_warps=num_warps,
BLOCK_E=BLOCK_E,
BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
)
grid = min(recv_topk.shape[0], 1024 * 8) _fwd_kernel_ep_scatter_1[(grid,)](
_fwd_kernel_ep_scatter_2[(grid,)]( num_recv_tokens_per_expert,
recv_topk.shape[0], expert_start_loc,
expert_start_loc, m_indices,
recv_x, num_experts=num_experts,
recv_x.stride(0), num_warps=num_warps,
recv_x.stride(1), BLOCK_E=BLOCK_E,
recv_x_scale, BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
recv_x_scale.stride(0), )
recv_x_scale.stride(1),
recv_topk, grid = min(recv_topk.shape[0], 1024 * 8)
recv_topk.stride(0), _fwd_kernel_ep_scatter_2[(grid,)](
recv_topk.stride(1), recv_topk.shape[0],
output_tensor, expert_start_loc,
output_tensor.stride(0), recv_x,
output_tensor.stride(1), recv_x.stride(0),
output_tensor_scale, recv_x.stride(1),
output_tensor_scale.stride(0), recv_x_scale,
output_tensor_scale.stride(1), recv_x_scale.stride(0),
output_index, recv_x_scale.stride(1),
output_index.stride(0), recv_topk,
output_index.stride(1), recv_topk.stride(0),
topk_num=recv_topk.shape[1], recv_topk.stride(1),
expert_map=expert_map, output_tensor,
HAS_EXPERT_MAP=expert_map is not None, output_tensor.stride(0),
num_warps=num_warps, output_tensor.stride(1),
HIDDEN_SIZE=hidden_size, output_tensor_scale,
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size), output_tensor_scale.stride(0),
SCALE_HIDDEN_SIZE=scale_hidden_size,#hidden_size // BLOCK_D, output_tensor_scale.stride(1),
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size)#triton.next_power_of_2(hidden_size // BLOCK_D), output_index,
) output_index.stride(0),
output_index.stride(1),
topk_num=recv_topk.shape[1],
expert_map=expert_map,
HAS_EXPERT_MAP=expert_map is not None,
num_warps=num_warps,
HIDDEN_SIZE=hidden_size,
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
SCALE_HIDDEN_SIZE=scale_hidden_size,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size),
)
return return
......
...@@ -27,6 +27,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter, ...@@ -27,6 +27,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
PackedvLLMParameter, PackedvLLMParameter,
PerTensorScaleParameter, PerTensorScaleParameter,
RowvLLMParameter) RowvLLMParameter)
from vllm.model_executor.layers.dp_attention import get_moe_tp_rank, get_moe_tp_size
# yapf: enable # yapf: enable
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -625,12 +626,18 @@ class ColumnParallelLinear(LinearBase): ...@@ -625,12 +626,18 @@ class ColumnParallelLinear(LinearBase):
*, *,
return_bias: bool = True, return_bias: bool = True,
expect_tp_size: Optional[int] = None, expect_tp_size: Optional[int] = None,
enable_dp_attn_moe: bool = False,
): ):
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
if expect_tp_size is not None: if expect_tp_size is not None:
self.expect_tp_size = expect_tp_size self.expect_tp_size = expect_tp_size
self.tp_size = self.expect_tp_size self.tp_size = self.expect_tp_size
self.enable_dp_attn_moe = enable_dp_attn_moe
if enable_dp_attn_moe:
self.tp_size = get_moe_tp_size()
self.input_size_per_partition = input_size self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size) self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition] self.output_partition_sizes = [self.output_size_per_partition]
...@@ -878,6 +885,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -878,6 +885,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
*, *,
return_bias: bool = True, return_bias: bool = True,
expect_tp_size: Optional[int] = None, expect_tp_size: Optional[int] = None,
enable_dp_attn_moe: bool = False,
): ):
self.output_sizes = output_sizes self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
...@@ -888,6 +896,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -888,6 +896,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self.expect_tp_size = expect_tp_size self.expect_tp_size = expect_tp_size
self.enable_dp_attn_moe = enable_dp_attn_moe
if enable_dp_attn_moe:
tp_size = get_moe_tp_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes) assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size=input_size, super().__init__(input_size=input_size,
output_size=sum(output_sizes), output_size=sum(output_sizes),
...@@ -898,7 +910,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -898,7 +910,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
return_bias=return_bias, return_bias=return_bias,
expect_tp_size=expect_tp_size) expect_tp_size=expect_tp_size,
enable_dp_attn_moe=enable_dp_attn_moe)
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
...@@ -999,6 +1012,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -999,6 +1012,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if self.expect_tp_size is not None and self.expect_tp_size == 1: if self.expect_tp_size is not None and self.expect_tp_size == 1:
tp_rank = 0 tp_rank = 0
tp_size = 1 tp_size = 1
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
tp_size = get_moe_tp_size()
if output_dim is not None: if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
...@@ -1121,6 +1138,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -1121,6 +1138,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if hasattr(param, "expect_tp_size"): if hasattr(param, "expect_tp_size"):
param.expect_tp_size = self.expect_tp_size param.expect_tp_size = self.expect_tp_size
if self.enable_dp_attn_moe and hasattr(param, "enable_dp_attn_moe"):
tp_size = get_moe_tp_size()
param.enable_dp_attn_moe = self.enable_dp_attn_moe
if isinstance(param, BlockQuantScaleParameter): if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import ( from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod) Fp8LinearMethod, Fp8MoEMethod)
...@@ -1552,6 +1573,7 @@ class RowParallelLinear(LinearBase): ...@@ -1552,6 +1573,7 @@ class RowParallelLinear(LinearBase):
*, *,
return_bias: bool = True, return_bias: bool = True,
expect_tp_size: Optional[int] = None, expect_tp_size: Optional[int] = None,
enable_dp_attn_moe: bool = False,
): ):
# Divide the weight matrix along the first dimension. # Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
...@@ -1560,7 +1582,13 @@ class RowParallelLinear(LinearBase): ...@@ -1560,7 +1582,13 @@ class RowParallelLinear(LinearBase):
if expect_tp_size is not None: if expect_tp_size is not None:
self.tp_rank = 0 self.tp_rank = 0
self.tp_size = 1 self.tp_size = 1
self.expect_tp_size = expect_tp_size self.expect_tp_size = expect_tp_size
self.enable_dp_attn_moe = enable_dp_attn_moe
if enable_dp_attn_moe:
self.tp_rank = get_moe_tp_rank()
self.tp_size = get_moe_tp_size()
self.input_size_per_partition = divide(input_size, self.tp_size) self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size] self.output_partition_sizes = [output_size]
...@@ -1610,6 +1638,11 @@ class RowParallelLinear(LinearBase): ...@@ -1610,6 +1638,11 @@ class RowParallelLinear(LinearBase):
if self.expect_tp_size is not None: if self.expect_tp_size is not None:
tp_rank = 0 tp_rank = 0
tp_size = 1 tp_size = 1
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
tp_size = get_moe_tp_size()
input_dim = getattr(param, "input_dim", None) input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False) is_sharded_weight = getattr(param, "is_sharded_weight", False)
...@@ -1664,6 +1697,9 @@ class RowParallelLinear(LinearBase): ...@@ -1664,6 +1697,9 @@ class RowParallelLinear(LinearBase):
if self.expect_tp_size is not None and hasattr(param, "expect_tp_size"): if self.expect_tp_size is not None and hasattr(param, "expect_tp_size"):
param.expect_tp_size = self.expect_tp_size param.expect_tp_size = self.expect_tp_size
if self.enable_dp_attn_moe is not None and hasattr(param, "enable_dp_attn_moe"):
param.enable_dp_attn_moe = self.enable_dp_attn_moe
param.load_row_parallel_weight(loaded_weight=loaded_weight) param.load_row_parallel_weight(loaded_weight=loaded_weight)
def forward( def forward(
......
...@@ -59,6 +59,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -59,6 +59,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.layers.dp_attention import (dp_gather, dp_reduce_scatter_tensor,
get_moe_tp_size, get_moe_tp_rank, get_attention_tp_size)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -85,17 +87,22 @@ class DeepseekV2MLP(nn.Module): ...@@ -85,17 +87,22 @@ class DeepseekV2MLP(nn.Module):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
vllm_config = get_current_vllm_config()
enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj") prefix=f"{prefix}.gate_up_proj",
enable_dp_attn_moe=enable_dp_attention)
self.down_proj = RowParallelLinear(intermediate_size, self.down_proj = RowParallelLinear(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results, reduce_results=reduce_results if not enable_dp_attention else False,
prefix=f"{prefix}.down_proj") prefix=f"{prefix}.down_proj",
enable_dp_attn_moe=enable_dp_attention)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -991,6 +998,8 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -991,6 +998,8 @@ class DeepseekV2DecoderLayer(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.config = config self.config = config
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
if (config.n_routed_experts is not None if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace and layer_idx >= config.first_k_dense_replace
...@@ -1010,6 +1019,8 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1010,6 +1019,8 @@ class DeepseekV2DecoderLayer(nn.Module):
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
self.enable_ep_sp = isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1
self.is_mtp_layer = False self.is_mtp_layer = False
if self.layer_idx == config.num_hidden_layers: if self.layer_idx == config.num_hidden_layers:
self.is_mtp_layer = True self.is_mtp_layer = True
...@@ -1018,6 +1029,9 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1018,6 +1029,9 @@ class DeepseekV2DecoderLayer(nn.Module):
DeepseekV2MoE) and self.use_deepep and \ DeepseekV2MoE) and self.use_deepep and \
self.tp_size > 1 and not self.is_mtp_layer: self.tp_size > 1 and not self.is_mtp_layer:
reduce_results = False reduce_results = False
else:
if self.enable_dp_attention:
reduce_results = False
if model_config.use_mla: if model_config.use_mla:
attn_cls = DeepseekV2MLAAttention attn_cls = DeepseekV2MLAAttention
...@@ -1169,25 +1183,25 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1169,25 +1183,25 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states, residual = self.input_layernorm( hidden_states, residual = self.input_layernorm(
hidden_states, residual) hidden_states, residual)
if not self.is_mtp_layer: if not self.is_mtp_layer and self.enable_ep_sp and \
if isinstance(self.mlp, self.layer_idx > self.config.first_k_dense_replace:
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1 and \ hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
self.layer_idx > self.config.first_k_dense_replace:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
) )
if not self.is_mtp_layer: if not self.is_mtp_layer and self.enable_ep_sp:
if isinstance(self.mlp, if self.layer_idx == self.config.first_k_dense_replace:
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1: residual = residual.tensor_split(self.tp_size)[self.tp_rank]
if self.layer_idx == self.config.first_k_dense_replace:
residual = residual.tensor_split(self.tp_size)[self.tp_rank]
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0) hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
if self.enable_dp_attention:
if self.tp_rank == 0:
hidden_states += residual
hidden_states = dp_gather(hidden_states)
if hidden_states.dtype == torch.float16: if hidden_states.dtype == torch.float16:
# Fix FP16 overflow # Fix FP16 overflow
...@@ -1200,27 +1214,31 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1200,27 +1214,31 @@ class DeepseekV2DecoderLayer(nn.Module):
residual *= 1. / self.routed_scaling_factor residual *= 1. / self.routed_scaling_factor
# Fully Connected # Fully Connected
hidden_states, residual = self.post_attention_layernorm( if not self.enable_dp_attention:
hidden_states, residual) hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if self.is_mtp_layer: else:
if isinstance(self.mlp, num_tokens = hidden_states.shape[0]
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1: new_bs = num_tokens // get_moe_tp_size() * get_attention_tp_size()
ori_bs = hidden_states.shape[0] residual = hidden_states[self.dp_rank*new_bs: (self.dp_rank+1)*new_bs, :]
pad_size = (ori_bs + self.tp_size - 1) // self.tp_size * self.tp_size - ori_bs hidden_states = self.post_attention_layernorm(hidden_states)
if pad_size > 0:
hidden_states = torch.nn.functional.pad(hidden_states.contiguous(), [0, 0, 0, pad_size], value=0).contiguous() if self.is_mtp_layer and self.enable_ep_sp:
new_bs = (ori_bs+pad_size) // self.tp_size ori_bs = hidden_states.shape[0]
hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :].contiguous() pad_size = (ori_bs + self.tp_size - 1) // self.tp_size * self.tp_size - ori_bs
if pad_size > 0:
hidden_states = torch.nn.functional.pad(hidden_states, [0, 0, 0, pad_size], value=0)
new_bs = (ori_bs+pad_size) // self.tp_size
hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :]
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
if self.is_mtp_layer: if self.enable_dp_attention:
if isinstance(self.mlp, hidden_states = dp_reduce_scatter_tensor(hidden_states)
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0) if self.is_mtp_layer and self.enable_ep_sp:
hidden_states = hidden_states[:ori_bs, :] hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
hidden_states = hidden_states[:ori_bs, :]
if isinstance(self.mlp, if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16: DeepseekV2MLP) and hidden_states.dtype == torch.float16:
......
...@@ -10,6 +10,7 @@ from torch.nn import Parameter ...@@ -10,6 +10,7 @@ from torch.nn import Parameter
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.utils import _make_synced_weight_loader from vllm.model_executor.utils import _make_synced_weight_loader
from vllm.model_executor.layers.dp_attention import get_moe_tp_rank, get_moe_tp_size
__all__ = [ __all__ = [
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter", "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
...@@ -97,6 +98,7 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -97,6 +98,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
self._output_dim = output_dim self._output_dim = output_dim
super().__init__(**kwargs) super().__init__(**kwargs)
self.expect_tp_size = -1 self.expect_tp_size = -1
self.enable_dp_attn_moe = False
@property @property
...@@ -107,6 +109,10 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -107,6 +109,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1: if self.expect_tp_size == 1:
tp_rank = 0 tp_rank = 0
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
shard_size = self.data.shape[self.output_dim] shard_size = self.data.shape[self.output_dim]
loaded_weight = loaded_weight.narrow(self.output_dim, loaded_weight = loaded_weight.narrow(self.output_dim,
tp_rank * shard_size, shard_size) tp_rank * shard_size, shard_size)
...@@ -129,6 +135,10 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -129,6 +135,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1: if self.expect_tp_size == 1:
tp_rank = 0 tp_rank = 0
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
param_data = param_data.narrow(self.output_dim, shard_offset, param_data = param_data.narrow(self.output_dim, shard_offset,
shard_size) shard_size)
loaded_weight = loaded_weight.narrow(self.output_dim, loaded_weight = loaded_weight.narrow(self.output_dim,
...@@ -174,6 +184,7 @@ class RowvLLMParameter(BasevLLMParameter): ...@@ -174,6 +184,7 @@ class RowvLLMParameter(BasevLLMParameter):
self._input_dim = input_dim self._input_dim = input_dim
super().__init__(**kwargs) super().__init__(**kwargs)
self.expect_tp_size = -1 self.expect_tp_size = -1
self.enable_dp_attn_moe = False
@property @property
def input_dim(self): def input_dim(self):
...@@ -183,6 +194,9 @@ class RowvLLMParameter(BasevLLMParameter): ...@@ -183,6 +194,9 @@ class RowvLLMParameter(BasevLLMParameter):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1: if self.expect_tp_size == 1:
tp_rank = 0 tp_rank = 0
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
shard_size = self.data.shape[self.input_dim] shard_size = self.data.shape[self.input_dim]
loaded_weight = loaded_weight.narrow(self.input_dim, loaded_weight = loaded_weight.narrow(self.input_dim,
tp_rank * shard_size, shard_size) tp_rank * shard_size, shard_size)
......
...@@ -12,7 +12,7 @@ from vllm.attention.layer import Attention ...@@ -12,7 +12,7 @@ from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig, from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config) get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import DPMetadata, set_forward_context from vllm.forward_context import DPMetadata, set_forward_context, get_warming_up
from vllm.logger import init_logger from vllm.logger import init_logger
import vllm.envs as envs import vllm.envs as envs
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
...@@ -93,6 +93,12 @@ class EagleProposer: ...@@ -93,6 +93,12 @@ class EagleProposer:
self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_size = vllm_config.parallel_config.data_parallel_size
self.enable_expert_parallel = vllm_config.parallel_config.enable_expert_parallel self.enable_expert_parallel = vllm_config.parallel_config.enable_expert_parallel
self.enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
self.attn_tp_size = vllm_config.parallel_config.tensor_parallel_size
self.ep_sp = False
if self.enable_expert_parallel and self.dp_size > 1 and self.attn_tp_size > 1:
self.ep_sp = True
def propose( def propose(
self, self,
...@@ -189,9 +195,12 @@ class EagleProposer: ...@@ -189,9 +195,12 @@ class EagleProposer:
else: else:
num_input_tokens = num_tokens num_input_tokens = num_tokens
if self.enable_dp_attention:
num_input_tokens = round_up(num_input_tokens, self.attn_tp_size)
# num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
# num_input_tokens += num_pad
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
num_input_tokens += num_pad
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states self.hidden_states[:num_tokens] = target_hidden_states
...@@ -279,6 +288,13 @@ class EagleProposer: ...@@ -279,6 +288,13 @@ class EagleProposer:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else: else:
input_batch_size = batch_size input_batch_size = batch_size
# dp attention need all dp rank process same number tokens
if self.enable_dp_attention:
input_batch_size = round_up(input_batch_size, self.attn_tp_size)
num_pad, _ = self.get_dp_padding(input_batch_size)
input_batch_size += num_pad
attn_metadata.num_actual_tokens = batch_size attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1 attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1] attn_metadata.query_start_loc = self.arange[:batch_size + 1]
...@@ -373,6 +389,7 @@ class EagleProposer: ...@@ -373,6 +389,7 @@ class EagleProposer:
attn_metadata.num_decode_tokens) attn_metadata.num_decode_tokens)
self.attn_metadata_cudagraph.num_prefills = ( self.attn_metadata_cudagraph.num_prefills = (
attn_metadata.num_prefills) attn_metadata.num_prefills)
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = ( self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.seq_lens) attn_metadata.decode.seq_lens)
...@@ -532,11 +549,9 @@ class EagleProposer: ...@@ -532,11 +549,9 @@ class EagleProposer:
# TODO(tms) : There are many cases where padding is enabled for # TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations. # prefills, causing unnecessary and excessive padding of activations.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager or envs.VLLM_ALL2ALL_BACKEND != 'naive': if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
# auto # Early exit.
if not envs.VLLM_ALL2ALL_BACKEND == "deepep_auto": return 0, None
# Early exit.
return 0, None
try: try:
num_tokens_across_dp = DPMetadata.num_tokens_across_dp( num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
...@@ -559,6 +574,7 @@ class EagleProposer: ...@@ -559,6 +574,7 @@ class EagleProposer:
self, self,
num_tokens: int, num_tokens: int,
attn_metadata: Optional[dict[str, Any]] = None, attn_metadata: Optional[dict[str, Any]] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None,
) -> None: ) -> None:
if attn_metadata is not None and self.attn_metadata_cudagraph is None: if attn_metadata is not None and self.attn_metadata_cudagraph is None:
self.attn_metadata_cudagraph = attn_metadata[ self.attn_metadata_cudagraph = attn_metadata[
...@@ -566,29 +582,73 @@ class EagleProposer: ...@@ -566,29 +582,73 @@ class EagleProposer:
# Padding for DP # Padding for DP
num_input_tokens = num_tokens num_input_tokens = num_tokens
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) # num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_input_tokens += num_pad # num_input_tokens += num_pad
with set_forward_context(attn_metadata, with set_forward_context(attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_tokens): num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp):
self.model( self.model(
self.input_ids[:num_input_tokens], self.input_ids[:num_input_tokens],
self.positions[:num_input_tokens], self.positions[:num_input_tokens],
self.hidden_states[:num_input_tokens], self.hidden_states[:num_input_tokens],
) )
if self.dp_size > 1 and self.enable_expert_parallel and self.num_speculative_tokens > 1: if self.dp_size > 1 and (self.enable_expert_parallel or self.enable_dp_attention) and self.num_speculative_tokens > 1:
num_token = 1 num_tokens = 1
for _ in range(self.num_speculative_tokens - 1):
with set_forward_context(attn_metadata, if self.enable_dp_attention or self.ep_sp:
self.vllm_config, num_tokens = round_up(num_tokens, self.attn_tp_size)
num_tokens=num_tokens): # dp attention need all dp rank process same number tokens
self.model( if self.enable_dp_attention:
self.input_ids[:num_tokens], num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
self.positions[:num_tokens], num_tokens += num_pad
self.hidden_states[:num_tokens],
) if not get_warming_up():
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=self.runner.query_start_loc[:num_tokens + 1],
seq_lens=self.runner.seq_lens[:num_tokens],
num_reqs=num_tokens,
num_actual_tokens=num_tokens,
max_query_len=num_tokens,
slot_mapping=self.runner.slot_mapping[:num_tokens],
spec_layer_decoding=True
)
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builders[0].build_for_cudagraph_capture(
common_attn_metadata=common_attn_metadata
)
for i in range(self.num_speculative_tokens - 1):
if self.attn_metadata_cudagraph is not None:
if i == 0:
attn_metadata_cudagraph = self.attn_metadata_cudagraph
attn_metadata_cudagraph.num_actual_tokens = num_tokens
attn_metadata_cudagraph.num_decodes = num_tokens
attn_metadata_cudagraph.num_decode_tokens = num_tokens
attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
attn_metadata.slot_mapping)
attn_metadata_cudagraph.decode.seq_lens[:num_tokens] = (
attn_metadata.decode.seq_lens)
attn_metadata_cudagraph.query_start_loc[:num_tokens + 1] = (
attn_metadata.query_start_loc)
attn_metadata_cudagraph.decode.block_table[:num_tokens] = (
attn_metadata.decode.block_table)
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp):
self.model(
self.input_ids[:num_tokens],
self.positions[:num_tokens],
self.hidden_states[:num_tokens],
)
def validate_same_kv_cache_group(self, def validate_same_kv_cache_group(self,
kv_cache_config: KVCacheConfig) -> None: kv_cache_config: KVCacheConfig) -> None:
......
...@@ -339,6 +339,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -339,6 +339,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if self.enable_expert_parallel and self.dp_size > 1 and self.tp_size > 1: if self.enable_expert_parallel and self.dp_size > 1 and self.tp_size > 1:
self.ep_sp = True self.ep_sp = True
self.enable_dp_attention = self.parallel_config.enable_dp_attention
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
""" """
Update the order of requests in the batch based on the attention Update the order of requests in the batch based on the attention
...@@ -1278,13 +1280,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1278,13 +1280,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# #
# TODO(tms) : There are many cases where padding is enabled for # TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations. # prefills, causing unnecessary and excessive padding of activations.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager: if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
# auto # Early exit.
if not envs.VLLM_ALL2ALL_BACKEND == "deepep_auto": return 0, None
# Early exit.
return 0, None
num_tokens_across_dp = DPMetadata.num_tokens_across_dp( num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
num_tokens, dp_size, dp_rank) num_tokens, dp_size, dp_rank)
...@@ -1361,7 +1359,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1361,7 +1359,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim. # make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp: if self.ep_sp or self.enable_dp_attention:
num_input_tokens = round_up(num_scheduled_tokens, self.tp_size) num_input_tokens = round_up(num_scheduled_tokens, self.tp_size)
if (self.use_cuda_graph if (self.use_cuda_graph
and num_input_tokens <= self.cudagraph_batch_sizes[-1]): and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
...@@ -2129,13 +2127,12 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -2129,13 +2127,12 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
skip_eplb: bool = False, skip_eplb: bool = False,
is_profile: bool = False, is_profile: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim. # make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp: if self.ep_sp or self.enable_dp_attention:
if num_tokens < self.tp_size: if num_tokens < self.tp_size:
num_tokens = self.tp_size num_tokens = self.tp_size
# Padding for DP num_tokens_across_dp = 0
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_tokens += num_pad num_tokens += num_pad
...@@ -2156,13 +2153,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -2156,13 +2153,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
min_tokens_per_req = (1 + self.speculative_config.num_lookahead_slots) min_tokens_per_req = (1 + self.speculative_config.num_lookahead_slots)
num_reqs = num_tokens // min_tokens_per_req num_reqs = num_tokens // min_tokens_per_req
if self.ep_sp: if self.ep_sp or self.enable_dp_attention:
num_actual_tokens = round_down(num_tokens, 1 + self.speculative_config.num_lookahead_slots) num_actual_tokens = round_down(num_tokens, 1 + self.speculative_config.num_lookahead_slots)
num_reqs = num_actual_tokens // min_tokens_per_req num_reqs = num_actual_tokens // min_tokens_per_req
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
if not self.ep_sp: if not (self.ep_sp or self.enable_dp_attention):
num_scheduled_tokens_list[-1] += num_tokens % num_reqs num_scheduled_tokens_list[-1] += num_tokens % num_reqs
else: else:
if self.speculative_config is not None: if self.speculative_config is not None:
...@@ -2254,7 +2251,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -2254,7 +2251,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if self.speculative_config and self.speculative_config.use_eagle() and not is_profile: if self.speculative_config and self.speculative_config.use_eagle() and not is_profile:
#assert isinstance(self.drafter, EagleProposer) #assert isinstance(self.drafter, EagleProposer)
if hasattr(self, 'drafter') and isinstance(self.drafter, EagleProposer): if hasattr(self, 'drafter') and isinstance(self.drafter, EagleProposer):
self.drafter.dummy_run(num_tokens, attn_metadata) self.drafter.dummy_run(num_tokens, attn_metadata,
num_tokens_across_dp=num_tokens_across_dp)
# This is necessary to avoid blocking DP. # This is necessary to avoid blocking DP.
# For dummy runs, we typically skip EPLB since we don't have any real # For dummy runs, we typically skip EPLB since we don't have any real
...@@ -3227,7 +3225,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3227,7 +3225,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim. # make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp: if self.ep_sp or self.enable_dp_attention:
num_input_tokens = round_up(num_scheduled_tokens, self.tp_size) num_input_tokens = round_up(num_scheduled_tokens, self.tp_size)
if (self.use_cuda_graph if (self.use_cuda_graph
and num_input_tokens <= self.cudagraph_batch_sizes[-1]): and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
......
...@@ -17,6 +17,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, ...@@ -17,6 +17,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
set_custom_all_reduce) set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.model_executor.layers.dp_attention import initialize_dp_attention
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
...@@ -30,6 +31,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner ...@@ -30,6 +31,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.worker_base import WorkerBase
from vllm.zero_overhead.utils import zero_overhead_stream from vllm.zero_overhead.utils import zero_overhead_stream
from vllm.zero_overhead.v1.gpu_model_runner import V1ZeroModelRunner from vllm.zero_overhead.v1.gpu_model_runner import V1ZeroModelRunner
from vllm.forward_context import (set_warming_up, get_warming_up)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -260,6 +262,7 @@ class Worker(WorkerBase): ...@@ -260,6 +262,7 @@ class Worker(WorkerBase):
# warm up sizes that are not in cudagraph capture sizes, # warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance, # but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill. # e.g. for the max-num-batched token size in chunked prefill.
set_warming_up(True)
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
if not self.model_config.enforce_eager: if not self.model_config.enforce_eager:
warmup_sizes = [ warmup_sizes = [
...@@ -297,6 +300,7 @@ class Worker(WorkerBase): ...@@ -297,6 +300,7 @@ class Worker(WorkerBase):
# Reset the seed to ensure that the random state is not affected by # Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling. # the model initialization and profiling.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
set_warming_up(False)
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model_runner.get_model() return self.model_runner.get_model()
...@@ -399,6 +403,9 @@ def init_worker_distributed_environment( ...@@ -399,6 +403,9 @@ def init_worker_distributed_environment(
ensure_kv_transfer_initialized(vllm_config) ensure_kv_transfer_initialized(vllm_config)
if vllm_config.parallel_config.enable_dp_attention:
initialize_dp_attention(vllm_config, backend)
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype. # Check if the GPU supports the dtype.
......
...@@ -112,8 +112,11 @@ class V1ZeroEagleProposer(EagleProposer): ...@@ -112,8 +112,11 @@ class V1ZeroEagleProposer(EagleProposer):
else: else:
num_input_tokens = num_tokens num_input_tokens = num_tokens
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) if self.enable_dp_attention:
num_input_tokens += num_pad num_input_tokens = round_up(num_input_tokens, self.attn_tp_size)
# num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
# num_input_tokens += num_pad
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions self.positions[:num_tokens] = target_positions
...@@ -202,6 +205,13 @@ class V1ZeroEagleProposer(EagleProposer): ...@@ -202,6 +205,13 @@ class V1ZeroEagleProposer(EagleProposer):
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else: else:
input_batch_size = batch_size input_batch_size = batch_size
# dp attention need all dp rank process same number tokens
if self.enable_dp_attention:
input_batch_size = round_up(input_batch_size, self.attn_tp_size)
num_pad, _ = self.get_dp_padding(input_batch_size)
input_batch_size += num_pad
attn_metadata.num_actual_tokens = batch_size attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1 attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1] attn_metadata.query_start_loc = self.arange[:batch_size + 1]
......
...@@ -465,7 +465,7 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -465,7 +465,7 @@ class V1ZeroModelRunner(GPUModelRunner):
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim. # make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp: if self.ep_sp or self.enable_dp_attention:
num_input_tokens = round_up(num_scheduled_tokens, tp_size) num_input_tokens = round_up(num_scheduled_tokens, tp_size)
if (self.use_cuda_graph if (self.use_cuda_graph
and num_input_tokens <= self.cudagraph_batch_sizes[-1]): and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment