Commit cda54326 authored by 王敏's avatar 王敏
Browse files

[feat]添加dp attention功能

parent e89003dd
...@@ -1883,6 +1883,9 @@ class ParallelConfig: ...@@ -1883,6 +1883,9 @@ class ParallelConfig:
""" Use data parallelism instead of tensor parallelism for vision encoder. """ Use data parallelism instead of tensor parallelism for vision encoder.
Only support LLama4 for now""" Only support LLama4 for now"""
enable_dp_attention: bool = False
"""Enable dp attention"""
@property @property
def world_size_across_dp(self) -> int: def world_size_across_dp(self) -> int:
"""world_size_across_dp is TPxPPxDP, it is the size of the world """world_size_across_dp is TPxPPxDP, it is the size of the world
...@@ -2109,6 +2112,9 @@ class ParallelConfig: ...@@ -2109,6 +2112,9 @@ class ParallelConfig:
raise ValueError("Unable to use nsight profiling unless workers " raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.") "run with Ray.")
if self.enable_dp_attention and self.enable_expert_parallel:
raise ValueError("Dp attention and expert parallel can not enable together.")
return self return self
...@@ -4805,6 +4811,7 @@ class VllmConfig: ...@@ -4805,6 +4811,7 @@ class VllmConfig:
dp_size = self.parallel_config.data_parallel_size dp_size = self.parallel_config.data_parallel_size
tp_size = self.parallel_config.tensor_parallel_size tp_size = self.parallel_config.tensor_parallel_size
ep_sp = self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1 ep_sp = self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1
enable_dp_attention = self.parallel_config.enable_dp_attention
# add for spec decode # add for spec decode
if self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0: if self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0:
...@@ -4813,11 +4820,15 @@ class VllmConfig: ...@@ -4813,11 +4820,15 @@ class VllmConfig:
batch_size_capture_list = sorted(set(batch_size_capture_list + mtp_batch_size_capture_list)) batch_size_capture_list = sorted(set(batch_size_capture_list + mtp_batch_size_capture_list))
batch_size_capture_list = [i for i in batch_size_capture_list if i == 1 or i % (1 + self.speculative_config.num_lookahead_slots) == 0] batch_size_capture_list = [i for i in batch_size_capture_list if i == 1 or i % (1 + self.speculative_config.num_lookahead_slots) == 0]
if ep_sp: if ep_sp or enable_dp_attention:
batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list])) batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list]))
if 1 not in batch_size_capture_list:
batch_size_capture_list.insert(0, 1)
else: else:
if ep_sp: if ep_sp or enable_dp_attention:
batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list])) batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list]))
if 1 not in batch_size_capture_list:
batch_size_capture_list.insert(0, 1)
self.compilation_config.init_with_cudagraph_sizes( self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list) batch_size_capture_list)
......
...@@ -477,6 +477,9 @@ class EngineArgs: ...@@ -477,6 +477,9 @@ class EngineArgs:
enable_multimodal_encoder_data_parallel: bool = \ enable_multimodal_encoder_data_parallel: bool = \
ParallelConfig.enable_multimodal_encoder_data_parallel ParallelConfig.enable_multimodal_encoder_data_parallel
enable_dp_attention: bool = \
ParallelConfig.enable_dp_attention
def __post_init__(self): def __post_init__(self):
# support `EngineArgs(compilation_config={...})` # support `EngineArgs(compilation_config={...})`
# without having to manually construct a # without having to manually construct a
...@@ -719,6 +722,10 @@ class EngineArgs: ...@@ -719,6 +722,10 @@ class EngineArgs:
"--enable-multimodal-encoder-data-parallel", "--enable-multimodal-encoder-data-parallel",
**parallel_kwargs["enable_multimodal_encoder_data_parallel"]) **parallel_kwargs["enable_multimodal_encoder_data_parallel"])
parallel_group.add_argument(
"--enable-dp-attention",
**parallel_kwargs["enable_dp_attention"])
# KV cache arguments # KV cache arguments
cache_kwargs = get_kwargs(CacheConfig) cache_kwargs = get_kwargs(CacheConfig)
cache_group = parser.add_argument_group( cache_group = parser.add_argument_group(
...@@ -1204,6 +1211,7 @@ class EngineArgs: ...@@ -1204,6 +1211,7 @@ class EngineArgs:
worker_extension_cls=self.worker_extension_cls, worker_extension_cls=self.worker_extension_cls,
enable_multimodal_encoder_data_parallel=self. enable_multimodal_encoder_data_parallel=self.
enable_multimodal_encoder_data_parallel, enable_multimodal_encoder_data_parallel,
enable_dp_attention=self.enable_dp_attention,
) )
speculative_config = self.create_speculative_config( speculative_config = self.create_speculative_config(
......
...@@ -211,3 +211,14 @@ def set_profilling(profiling): ...@@ -211,3 +211,14 @@ def set_profilling(profiling):
def get_profilling() -> bool: def get_profilling() -> bool:
global _profiling global _profiling
return _profiling return _profiling
_warming_up = False
@contextmanager
def set_warming_up(warming_up):
global _warming_up
_warming_up = warming_up
def get_warming_up() -> bool:
global _warming_up
return _warming_up
\ No newline at end of file
from typing import TYPE_CHECKING, List, Optional, Tuple
import logging
import torch
import vllm.envs as envs
from vllm.distributed.parallel_state import GroupCoordinator, init_model_parallel_group, get_world_group
from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
get_tensor_model_parallel_rank,
tensor_model_parallel_reduce_scatter,
get_tp_group)
_ENABLE_DP_ATTENTION_FLAG: bool = False
_MOE_TP: Optional[GroupCoordinator] = None
_ATTN_DP_SIZE = 0
_ATTN_TP_SIZE = 0
_ATTN_TP_RANK = 0
_ATTN_DP_RANK = 0
_MOT_TP_SIZE = 0
_MOT_TP_RANK = 0
def initialize_dp_attention(vllm_config, backend: Optional[str] = None):
from vllm.config import VllmConfig
assert isinstance(vllm_config, VllmConfig)
global _ENABLE_DP_ATTENTION_FLAG, _ATTN_DP_SIZE, _ATTN_TP_SIZE, _ATTN_TP_RANK, _ATTN_DP_RANK, _MOT_TP_SIZE, _MOT_TP_RANK
enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
_ENABLE_DP_ATTENTION_FLAG = enable_dp_attention
# Build the moe tensor model-parallel groups.
world_size: int = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
data_parallel_size = vllm_config.parallel_config.data_parallel_size
pipeline_model_parallel_size = vllm_config.parallel_config.pipeline_parallel_size
tensor_model_parallel_size = vllm_config.parallel_config.tensor_parallel_size
moe_tp_size = world_size // pipeline_model_parallel_size
moe_ep_size = moe_tp_size if vllm_config.parallel_config.enable_expert_parallel else 1
_ATTN_DP_SIZE = data_parallel_size
_ATTN_TP_SIZE = tensor_model_parallel_size
_ATTN_TP_RANK = get_tensor_model_parallel_rank()
_ATTN_DP_RANK = vllm_config.parallel_config.data_parallel_rank
_MOT_TP_SIZE = moe_tp_size
_MOT_TP_RANK = rank % _MOT_TP_SIZE
global _MOE_TP
assert _MOE_TP is None, ("moe tensor model parallel group is already initialized")
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)
group_ranks = []
for i in range(pipeline_model_parallel_size):
ranks = list(
range(i * moe_tp_size, (i + 1) * moe_tp_size)
)
group_ranks.append(ranks)
# message queue broadcaster is only used in tensor model parallel group
_MOE_TP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True,
group_name="moe_tp")
def get_attention_tp_size() -> int:
assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
return _ATTN_TP_SIZE
def get_attention_tp_rank() -> int:
assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
return _ATTN_TP_RANK
def get_moe_tp_group() -> GroupCoordinator:
assert _MOE_TP is not None, ("tensor model parallel group is not initialized")
return _MOE_TP
def get_attention_dp_size() -> int:
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
return _ATTN_DP_SIZE
def get_moe_tp_rank() -> int:
assert _MOT_TP_RANK is not None, "dp attention not initialized!"
return _MOT_TP_RANK
def get_moe_tp_size() -> int:
assert _MOT_TP_SIZE is not None, "dp attention not initialized!"
return _MOT_TP_SIZE
def get_attention_tp_group() -> GroupCoordinator:
return get_tp_group()
def moe_tensor_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_moe_tp_group().all_gather(input_, dim)
def moe_tensor_model_parallel_reduce_scatter(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""Reduce-Scatter the input tensor across model parallel group."""
return get_moe_tp_group().reduce_scatter(input_, dim)
def dp_gather(
hidden_states: torch.Tensor,)-> torch.Tensor:
if get_attention_tp_size() == 1:
hidden_states = moe_tensor_model_parallel_all_gather(hidden_states, dim=0)
return hidden_states
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
hidden_states = moe_tensor_model_parallel_all_gather(hidden_states, dim=0)
return hidden_states
def dp_reduce_scatter_tensor(hidden_states: torch.Tensor)-> torch.Tensor:
if get_moe_tp_group().world_size == get_attention_dp_size():
hidden_states = moe_tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
else:
hidden_states = moe_tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
return hidden_states
...@@ -37,6 +37,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -37,6 +37,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk, is_power_of_two) fused_topk, grouped_topk, is_power_of_two)
from vllm.model_executor.layers.dp_attention import get_moe_tp_rank
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
...@@ -900,6 +901,8 @@ class FusedMoE(torch.nn.Module): ...@@ -900,6 +901,8 @@ class FusedMoE(torch.nn.Module):
self.logical_to_physical_map: Optional[torch.Tensor] = None self.logical_to_physical_map: Optional[torch.Tensor] = None
self.logical_replica_count: Optional[torch.Tensor] = None self.logical_replica_count: Optional[torch.Tensor] = None
self.enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
# Determine expert maps # Determine expert maps
if self.use_ep: if self.use_ep:
if self.enable_eplb: if self.enable_eplb:
...@@ -1620,7 +1623,8 @@ class FusedMoE(torch.nn.Module): ...@@ -1620,7 +1623,8 @@ class FusedMoE(torch.nn.Module):
The pplx combine kernel reduces across GPU ranks by default. The pplx combine kernel reduces across GPU ranks by default.
""" """
if (self.use_pplx_kernels or self.use_deepep_ht_kernels if (self.use_pplx_kernels or self.use_deepep_ht_kernels
or self.use_deepep_ll_kernels or self.use_deepep_auto_kernels): or self.use_deepep_ll_kernels or self.use_deepep_auto_kernels
or self.enable_dp_attention):
return final_hidden_states return final_hidden_states
else: else:
return tensor_model_parallel_all_reduce(final_hidden_states) return tensor_model_parallel_all_reduce(final_hidden_states)
...@@ -1759,6 +1763,7 @@ class FusedMoE(torch.nn.Module): ...@@ -1759,6 +1763,7 @@ class FusedMoE(torch.nn.Module):
and not self.moe_parallel_config.use_deepep_ht_kernels and not self.moe_parallel_config.use_deepep_ht_kernels
and not self.moe_parallel_config.use_deepep_ll_kernels and not self.moe_parallel_config.use_deepep_ll_kernels
and not self.moe_parallel_config.use_deepep_auto_kernels and not self.moe_parallel_config.use_deepep_auto_kernels
and not self.enable_dp_attention
) )
if do_naive_dispatch_combine: if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits = get_ep_group().dispatch(
......
...@@ -22,6 +22,7 @@ from vllm.utils import round_up ...@@ -22,6 +22,7 @@ from vllm.utils import round_up
try: try:
from lmslim.layers.gemm.int8_utils import ( from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8) per_token_group_quant_int8, per_token_quant_int8)
from lightop import op
except Exception: except Exception:
print("INFO: Please install lmslim if you want to use int utils.\n") print("INFO: Please install lmslim if you want to use int utils.\n")
from vllm.utils import cdiv from vllm.utils import cdiv
...@@ -622,6 +623,16 @@ def ep_scatter( ...@@ -622,6 +623,16 @@ def ep_scatter(
num_experts = num_recv_tokens_per_expert.shape[0] num_experts = num_recv_tokens_per_expert.shape[0]
hidden_size = recv_x.shape[1] hidden_size = recv_x.shape[1]
scale_hidden_size = recv_x_scale.shape[-1] scale_hidden_size = recv_x_scale.shape[-1]
if hasattr(op, "ep_scatter"):
op.ep_scatter(
recv_x, recv_x_scale,
recv_topk, expert_map,
num_recv_tokens_per_expert,
output_tensor, output_tensor_scale, m_indices, output_index,
num_experts, BLOCK_E
)
else:
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts) # grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid = num_experts grid = num_experts
...@@ -665,8 +676,8 @@ def ep_scatter( ...@@ -665,8 +676,8 @@ def ep_scatter(
num_warps=num_warps, num_warps=num_warps,
HIDDEN_SIZE=hidden_size, HIDDEN_SIZE=hidden_size,
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size), HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
SCALE_HIDDEN_SIZE=scale_hidden_size,#hidden_size // BLOCK_D, SCALE_HIDDEN_SIZE=scale_hidden_size,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size)#triton.next_power_of_2(hidden_size // BLOCK_D), SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size),
) )
return return
......
...@@ -27,6 +27,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter, ...@@ -27,6 +27,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
PackedvLLMParameter, PackedvLLMParameter,
PerTensorScaleParameter, PerTensorScaleParameter,
RowvLLMParameter) RowvLLMParameter)
from vllm.model_executor.layers.dp_attention import get_moe_tp_rank, get_moe_tp_size
# yapf: enable # yapf: enable
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -625,12 +626,18 @@ class ColumnParallelLinear(LinearBase): ...@@ -625,12 +626,18 @@ class ColumnParallelLinear(LinearBase):
*, *,
return_bias: bool = True, return_bias: bool = True,
expect_tp_size: Optional[int] = None, expect_tp_size: Optional[int] = None,
enable_dp_attn_moe: bool = False,
): ):
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
if expect_tp_size is not None: if expect_tp_size is not None:
self.expect_tp_size = expect_tp_size self.expect_tp_size = expect_tp_size
self.tp_size = self.expect_tp_size self.tp_size = self.expect_tp_size
self.enable_dp_attn_moe = enable_dp_attn_moe
if enable_dp_attn_moe:
self.tp_size = get_moe_tp_size()
self.input_size_per_partition = input_size self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size) self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition] self.output_partition_sizes = [self.output_size_per_partition]
...@@ -878,6 +885,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -878,6 +885,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
*, *,
return_bias: bool = True, return_bias: bool = True,
expect_tp_size: Optional[int] = None, expect_tp_size: Optional[int] = None,
enable_dp_attn_moe: bool = False,
): ):
self.output_sizes = output_sizes self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
...@@ -888,6 +896,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -888,6 +896,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self.expect_tp_size = expect_tp_size self.expect_tp_size = expect_tp_size
self.enable_dp_attn_moe = enable_dp_attn_moe
if enable_dp_attn_moe:
tp_size = get_moe_tp_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes) assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size=input_size, super().__init__(input_size=input_size,
output_size=sum(output_sizes), output_size=sum(output_sizes),
...@@ -898,7 +910,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -898,7 +910,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
return_bias=return_bias, return_bias=return_bias,
expect_tp_size=expect_tp_size) expect_tp_size=expect_tp_size,
enable_dp_attn_moe=enable_dp_attn_moe)
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
...@@ -1000,6 +1013,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -1000,6 +1013,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
tp_rank = 0 tp_rank = 0
tp_size = 1 tp_size = 1
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
tp_size = get_moe_tp_size()
if output_dim is not None: if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size shard_size = self.output_sizes[loaded_shard_id] // tp_size
...@@ -1121,6 +1138,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -1121,6 +1138,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if hasattr(param, "expect_tp_size"): if hasattr(param, "expect_tp_size"):
param.expect_tp_size = self.expect_tp_size param.expect_tp_size = self.expect_tp_size
if self.enable_dp_attn_moe and hasattr(param, "enable_dp_attn_moe"):
tp_size = get_moe_tp_size()
param.enable_dp_attn_moe = self.enable_dp_attn_moe
if isinstance(param, BlockQuantScaleParameter): if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import ( from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod) Fp8LinearMethod, Fp8MoEMethod)
...@@ -1552,6 +1573,7 @@ class RowParallelLinear(LinearBase): ...@@ -1552,6 +1573,7 @@ class RowParallelLinear(LinearBase):
*, *,
return_bias: bool = True, return_bias: bool = True,
expect_tp_size: Optional[int] = None, expect_tp_size: Optional[int] = None,
enable_dp_attn_moe: bool = False,
): ):
# Divide the weight matrix along the first dimension. # Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
...@@ -1560,7 +1582,13 @@ class RowParallelLinear(LinearBase): ...@@ -1560,7 +1582,13 @@ class RowParallelLinear(LinearBase):
if expect_tp_size is not None: if expect_tp_size is not None:
self.tp_rank = 0 self.tp_rank = 0
self.tp_size = 1 self.tp_size = 1
self.expect_tp_size = expect_tp_size self.expect_tp_size = expect_tp_size
self.enable_dp_attn_moe = enable_dp_attn_moe
if enable_dp_attn_moe:
self.tp_rank = get_moe_tp_rank()
self.tp_size = get_moe_tp_size()
self.input_size_per_partition = divide(input_size, self.tp_size) self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size] self.output_partition_sizes = [output_size]
...@@ -1610,6 +1638,11 @@ class RowParallelLinear(LinearBase): ...@@ -1610,6 +1638,11 @@ class RowParallelLinear(LinearBase):
if self.expect_tp_size is not None: if self.expect_tp_size is not None:
tp_rank = 0 tp_rank = 0
tp_size = 1 tp_size = 1
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
tp_size = get_moe_tp_size()
input_dim = getattr(param, "input_dim", None) input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False) is_sharded_weight = getattr(param, "is_sharded_weight", False)
...@@ -1664,6 +1697,9 @@ class RowParallelLinear(LinearBase): ...@@ -1664,6 +1697,9 @@ class RowParallelLinear(LinearBase):
if self.expect_tp_size is not None and hasattr(param, "expect_tp_size"): if self.expect_tp_size is not None and hasattr(param, "expect_tp_size"):
param.expect_tp_size = self.expect_tp_size param.expect_tp_size = self.expect_tp_size
if self.enable_dp_attn_moe is not None and hasattr(param, "enable_dp_attn_moe"):
param.enable_dp_attn_moe = self.enable_dp_attn_moe
param.load_row_parallel_weight(loaded_weight=loaded_weight) param.load_row_parallel_weight(loaded_weight=loaded_weight)
def forward( def forward(
......
...@@ -59,6 +59,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -59,6 +59,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.layers.dp_attention import (dp_gather, dp_reduce_scatter_tensor,
get_moe_tp_size, get_moe_tp_rank, get_attention_tp_size)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -85,17 +87,22 @@ class DeepseekV2MLP(nn.Module): ...@@ -85,17 +87,22 @@ class DeepseekV2MLP(nn.Module):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
vllm_config = get_current_vllm_config()
enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj") prefix=f"{prefix}.gate_up_proj",
enable_dp_attn_moe=enable_dp_attention)
self.down_proj = RowParallelLinear(intermediate_size, self.down_proj = RowParallelLinear(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results, reduce_results=reduce_results if not enable_dp_attention else False,
prefix=f"{prefix}.down_proj") prefix=f"{prefix}.down_proj",
enable_dp_attn_moe=enable_dp_attention)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -979,6 +986,8 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -979,6 +986,8 @@ class DeepseekV2DecoderLayer(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.config = config self.config = config
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
if (config.n_routed_experts is not None if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace and layer_idx >= config.first_k_dense_replace
...@@ -1006,6 +1015,9 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1006,6 +1015,9 @@ class DeepseekV2DecoderLayer(nn.Module):
DeepseekV2MoE) and self.use_deepep and \ DeepseekV2MoE) and self.use_deepep and \
self.tp_size > 1 and not self.is_mtp_layer: self.tp_size > 1 and not self.is_mtp_layer:
reduce_results = False reduce_results = False
else:
if self.enable_dp_attention:
reduce_results = False
if model_config.use_mla: if model_config.use_mla:
attn_cls = DeepseekV2MLAAttention attn_cls = DeepseekV2MLAAttention
...@@ -1176,6 +1188,10 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1176,6 +1188,10 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0) hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
if self.enable_dp_attention:
if self.tp_rank == 0:
hidden_states += residual
hidden_states = dp_gather(hidden_states)
if hidden_states.dtype == torch.float16: if hidden_states.dtype == torch.float16:
# Fix FP16 overflow # Fix FP16 overflow
...@@ -1188,8 +1204,14 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1188,8 +1204,14 @@ class DeepseekV2DecoderLayer(nn.Module):
residual *= 1. / self.routed_scaling_factor residual *= 1. / self.routed_scaling_factor
# Fully Connected # Fully Connected
if not self.enable_dp_attention:
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual) hidden_states, residual)
else:
num_tokens = hidden_states.shape[0]
new_bs = num_tokens // get_moe_tp_size() * get_attention_tp_size()
residual = hidden_states[self.dp_rank*new_bs: (self.dp_rank+1)*new_bs, :]
hidden_states = self.post_attention_layernorm(hidden_states)
if self.is_mtp_layer: if self.is_mtp_layer:
if isinstance(self.mlp, if isinstance(self.mlp,
...@@ -1201,9 +1223,11 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1201,9 +1223,11 @@ class DeepseekV2DecoderLayer(nn.Module):
new_bs = (ori_bs+pad_size) // self.tp_size new_bs = (ori_bs+pad_size) // self.tp_size
hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :].contiguous() hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :].contiguous()
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
if self.enable_dp_attention:
hidden_states = dp_reduce_scatter_tensor(hidden_states)
if self.is_mtp_layer: if self.is_mtp_layer:
if isinstance(self.mlp, if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1: DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
......
...@@ -10,6 +10,7 @@ from torch.nn import Parameter ...@@ -10,6 +10,7 @@ from torch.nn import Parameter
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.utils import _make_synced_weight_loader from vllm.model_executor.utils import _make_synced_weight_loader
from vllm.model_executor.layers.dp_attention import get_moe_tp_rank, get_moe_tp_size
__all__ = [ __all__ = [
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter", "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
...@@ -97,6 +98,7 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -97,6 +98,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
self._output_dim = output_dim self._output_dim = output_dim
super().__init__(**kwargs) super().__init__(**kwargs)
self.expect_tp_size = -1 self.expect_tp_size = -1
self.enable_dp_attn_moe = False
@property @property
...@@ -107,6 +109,10 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -107,6 +109,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1: if self.expect_tp_size == 1:
tp_rank = 0 tp_rank = 0
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
shard_size = self.data.shape[self.output_dim] shard_size = self.data.shape[self.output_dim]
loaded_weight = loaded_weight.narrow(self.output_dim, loaded_weight = loaded_weight.narrow(self.output_dim,
tp_rank * shard_size, shard_size) tp_rank * shard_size, shard_size)
...@@ -129,6 +135,10 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -129,6 +135,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1: if self.expect_tp_size == 1:
tp_rank = 0 tp_rank = 0
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
param_data = param_data.narrow(self.output_dim, shard_offset, param_data = param_data.narrow(self.output_dim, shard_offset,
shard_size) shard_size)
loaded_weight = loaded_weight.narrow(self.output_dim, loaded_weight = loaded_weight.narrow(self.output_dim,
...@@ -174,6 +184,7 @@ class RowvLLMParameter(BasevLLMParameter): ...@@ -174,6 +184,7 @@ class RowvLLMParameter(BasevLLMParameter):
self._input_dim = input_dim self._input_dim = input_dim
super().__init__(**kwargs) super().__init__(**kwargs)
self.expect_tp_size = -1 self.expect_tp_size = -1
self.enable_dp_attn_moe = False
@property @property
def input_dim(self): def input_dim(self):
...@@ -183,6 +194,9 @@ class RowvLLMParameter(BasevLLMParameter): ...@@ -183,6 +194,9 @@ class RowvLLMParameter(BasevLLMParameter):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1: if self.expect_tp_size == 1:
tp_rank = 0 tp_rank = 0
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
shard_size = self.data.shape[self.input_dim] shard_size = self.data.shape[self.input_dim]
loaded_weight = loaded_weight.narrow(self.input_dim, loaded_weight = loaded_weight.narrow(self.input_dim,
tp_rank * shard_size, shard_size) tp_rank * shard_size, shard_size)
......
...@@ -12,7 +12,7 @@ from vllm.attention.layer import Attention ...@@ -12,7 +12,7 @@ from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig, from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config) get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import DPMetadata, set_forward_context from vllm.forward_context import DPMetadata, set_forward_context, get_warming_up
from vllm.logger import init_logger from vllm.logger import init_logger
import vllm.envs as envs import vllm.envs as envs
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
...@@ -93,6 +93,8 @@ class EagleProposer: ...@@ -93,6 +93,8 @@ class EagleProposer:
self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_size = vllm_config.parallel_config.data_parallel_size
self.enable_expert_parallel = vllm_config.parallel_config.enable_expert_parallel self.enable_expert_parallel = vllm_config.parallel_config.enable_expert_parallel
self.enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
self.attn_tp_size = vllm_config.parallel_config.tensor_parallel_size
def propose( def propose(
self, self,
...@@ -189,7 +191,8 @@ class EagleProposer: ...@@ -189,7 +191,8 @@ class EagleProposer:
else: else:
num_input_tokens = num_tokens num_input_tokens = num_tokens
if self.enable_dp_attention:
num_input_tokens = round_up(num_input_tokens, self.attn_tp_size)
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
num_input_tokens += num_pad num_input_tokens += num_pad
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
...@@ -279,6 +282,13 @@ class EagleProposer: ...@@ -279,6 +282,13 @@ class EagleProposer:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else: else:
input_batch_size = batch_size input_batch_size = batch_size
# dp attention need all dp rank process same number tokens
if self.enable_dp_attention:
input_batch_size = round_up(input_batch_size, self.attn_tp_size)
num_pad, _ = self.get_dp_padding(input_batch_size)
input_batch_size += num_pad
attn_metadata.num_actual_tokens = batch_size attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1 attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1] attn_metadata.query_start_loc = self.arange[:batch_size + 1]
...@@ -373,6 +383,7 @@ class EagleProposer: ...@@ -373,6 +383,7 @@ class EagleProposer:
attn_metadata.num_decode_tokens) attn_metadata.num_decode_tokens)
self.attn_metadata_cudagraph.num_prefills = ( self.attn_metadata_cudagraph.num_prefills = (
attn_metadata.num_prefills) attn_metadata.num_prefills)
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = ( self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.seq_lens) attn_metadata.decode.seq_lens)
...@@ -532,9 +543,8 @@ class EagleProposer: ...@@ -532,9 +543,8 @@ class EagleProposer:
# TODO(tms) : There are many cases where padding is enabled for # TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations. # prefills, causing unnecessary and excessive padding of activations.
if not self.enable_dp_attention and not envs.VLLM_ALL2ALL_BACKEND == "deepep_auto":
if dp_size == 1 or self.vllm_config.model_config.enforce_eager or envs.VLLM_ALL2ALL_BACKEND != 'naive': if dp_size == 1 or self.vllm_config.model_config.enforce_eager or envs.VLLM_ALL2ALL_BACKEND != 'naive':
# auto
if not envs.VLLM_ALL2ALL_BACKEND == "deepep_auto":
# Early exit. # Early exit.
return 0, None return 0, None
...@@ -566,7 +576,7 @@ class EagleProposer: ...@@ -566,7 +576,7 @@ class EagleProposer:
# Padding for DP # Padding for DP
num_input_tokens = num_tokens num_input_tokens = num_tokens
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) num_pad, _ = self.get_dp_padding(num_tokens)
num_input_tokens += num_pad num_input_tokens += num_pad
with set_forward_context(attn_metadata, with set_forward_context(attn_metadata,
...@@ -578,9 +588,49 @@ class EagleProposer: ...@@ -578,9 +588,49 @@ class EagleProposer:
self.hidden_states[:num_input_tokens], self.hidden_states[:num_input_tokens],
) )
if self.dp_size > 1 and self.enable_expert_parallel and self.num_speculative_tokens > 1: if self.dp_size > 1 and (self.enable_expert_parallel or self.enable_dp_attention) and self.num_speculative_tokens > 1:
num_token = 1 num_tokens = 1
for _ in range(self.num_speculative_tokens - 1): # dp attention need all dp rank process same number tokens
if self.enable_dp_attention:
num_tokens = round_up(num_tokens, self.attn_tp_size)
num_pad, _ = self.get_dp_padding(num_tokens)
num_tokens += num_pad
if not get_warming_up():
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=self.runner.query_start_loc[:num_tokens + 1],
seq_lens=self.runner.seq_lens[:num_tokens],
num_reqs=num_tokens,
num_actual_tokens=num_tokens,
max_query_len=num_tokens,
slot_mapping=self.runner.slot_mapping[:num_tokens],
spec_layer_decoding=True
)
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builders[0].build_for_cudagraph_capture(
common_attn_metadata=common_attn_metadata
)
for i in range(self.num_speculative_tokens - 1):
if self.attn_metadata_cudagraph is not None:
if i == 0:
attn_metadata_cudagraph = self.attn_metadata_cudagraph
attn_metadata_cudagraph.num_actual_tokens = num_tokens
attn_metadata_cudagraph.num_decodes = num_tokens
attn_metadata_cudagraph.num_decode_tokens = num_tokens
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
attn_metadata.slot_mapping)
attn_metadata_cudagraph.decode.seq_lens[:num_tokens] = (
attn_metadata.decode.seq_lens)
self.attn_metadata_cudagraph.query_start_loc[:num_tokens + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.decode.block_table[:num_tokens] = (
attn_metadata.decode.block_table)
with set_forward_context(attn_metadata, with set_forward_context(attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_tokens): num_tokens=num_tokens):
......
...@@ -31,7 +31,7 @@ from vllm.distributed.parallel_state import ( ...@@ -31,7 +31,7 @@ from vllm.distributed.parallel_state import (
prepare_communication_buffer_for_model, prepare_communication_buffer_for_model,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.forward_context import (DPMetadata, get_forward_context, from vllm.forward_context import (DPMetadata, get_forward_context,
set_forward_context, set_profilling) set_forward_context)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
...@@ -339,6 +339,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -339,6 +339,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if self.enable_expert_parallel and self.dp_size > 1 and self.tp_size > 1: if self.enable_expert_parallel and self.dp_size > 1 and self.tp_size > 1:
self.ep_sp = True self.ep_sp = True
self.enable_dp_attention = self.parallel_config.enable_dp_attention
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
""" """
Update the order of requests in the batch based on the attention Update the order of requests in the batch based on the attention
...@@ -1275,13 +1277,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1275,13 +1277,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# TODO(tms) : There are many cases where padding is enabled for # TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations. # prefills, causing unnecessary and excessive padding of activations.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager: if not self.enable_dp_attention and not envs.VLLM_ALL2ALL_BACKEND == "deepep_auto":
# auto if dp_size == 1 or self.vllm_config.model_config.enforce_eager or envs.VLLM_ALL2ALL_BACKEND != 'naive':
if not envs.VLLM_ALL2ALL_BACKEND == "deepep_auto":
# Early exit. # Early exit.
return 0, None return 0, None
num_tokens_across_dp = DPMetadata.num_tokens_across_dp( num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
num_tokens, dp_size, dp_rank) num_tokens, dp_size, dp_rank)
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item() max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
...@@ -1357,7 +1357,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1357,7 +1357,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim. # make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp: if self.ep_sp or self.enable_dp_attention:
num_input_tokens = round_up(num_scheduled_tokens, self.tp_size) num_input_tokens = round_up(num_scheduled_tokens, self.tp_size)
if (self.use_cuda_graph if (self.use_cuda_graph
and num_input_tokens <= self.cudagraph_batch_sizes[-1]): and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
...@@ -1638,9 +1638,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1638,9 +1638,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Mask out the sampled tokens that should not be sampled. # Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices: for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear() valid_sampled_token_ids[i].clear()
if spec_token_ids is not None:
for i in discard_sampled_tokens_req_indices:
spec_token_ids[i].clear()
# Cache the sampled tokens in the model runner, so that the scheduler # Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back. # doesn't need to send them back.
...@@ -1681,6 +1678,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1681,6 +1678,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
attn_metadata, attn_metadata,
) )
if spec_token_ids is not None:
for i in discard_sampled_tokens_req_indices:
spec_token_ids[i].clear()
# Clear KVConnector state after all KVs are generated. # Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group(): if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata() get_kv_transfer_group().clear_connector_metadata()
...@@ -2121,13 +2122,12 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -2121,13 +2122,12 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
skip_eplb: bool = False, skip_eplb: bool = False,
is_profile: bool = False, is_profile: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim. # make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp: if self.ep_sp or self.enable_dp_attention:
if num_tokens < self.tp_size: if num_tokens < self.tp_size:
num_tokens = self.tp_size num_tokens = self.tp_size
# Padding for DP num_tokens_across_dp = 0
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_tokens += num_pad num_tokens += num_pad
...@@ -2148,13 +2148,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -2148,13 +2148,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
min_tokens_per_req = (1 + self.speculative_config.num_lookahead_slots) min_tokens_per_req = (1 + self.speculative_config.num_lookahead_slots)
num_reqs = num_tokens // min_tokens_per_req num_reqs = num_tokens // min_tokens_per_req
if self.ep_sp: if self.ep_sp or self.enable_dp_attention:
num_actual_tokens = round_down(num_tokens, 1 + self.speculative_config.num_lookahead_slots) num_actual_tokens = round_down(num_tokens, 1 + self.speculative_config.num_lookahead_slots)
num_reqs = num_actual_tokens // min_tokens_per_req num_reqs = num_actual_tokens // min_tokens_per_req
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
if not self.ep_sp: if not (self.ep_sp or self.enable_dp_attention):
num_scheduled_tokens_list[-1] += num_tokens % num_reqs num_scheduled_tokens_list[-1] += num_tokens % num_reqs
else: else:
if self.speculative_config is not None: if self.speculative_config is not None:
...@@ -3219,7 +3219,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3219,7 +3219,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim. # make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp: if self.ep_sp or self.enable_dp_attention:
num_input_tokens = round_up(num_scheduled_tokens, self.tp_size) num_input_tokens = round_up(num_scheduled_tokens, self.tp_size)
if (self.use_cuda_graph if (self.use_cuda_graph
and num_input_tokens <= self.cudagraph_batch_sizes[-1]): and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
......
...@@ -17,6 +17,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, ...@@ -17,6 +17,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
set_custom_all_reduce) set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.model_executor.layers.dp_attention import initialize_dp_attention
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
...@@ -30,6 +31,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner ...@@ -30,6 +31,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.worker_base import WorkerBase
from vllm.zero_overhead.utils import zero_overhead_stream from vllm.zero_overhead.utils import zero_overhead_stream
from vllm.zero_overhead.v1.gpu_model_runner import V1ZeroModelRunner from vllm.zero_overhead.v1.gpu_model_runner import V1ZeroModelRunner
from vllm.forward_context import (set_warming_up, get_warming_up)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -260,6 +262,7 @@ class Worker(WorkerBase): ...@@ -260,6 +262,7 @@ class Worker(WorkerBase):
# warm up sizes that are not in cudagraph capture sizes, # warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance, # but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill. # e.g. for the max-num-batched token size in chunked prefill.
set_warming_up(True)
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
if not self.model_config.enforce_eager: if not self.model_config.enforce_eager:
warmup_sizes = [ warmup_sizes = [
...@@ -297,6 +300,7 @@ class Worker(WorkerBase): ...@@ -297,6 +300,7 @@ class Worker(WorkerBase):
# Reset the seed to ensure that the random state is not affected by # Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling. # the model initialization and profiling.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
set_warming_up(False)
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model_runner.get_model() return self.model_runner.get_model()
...@@ -399,6 +403,9 @@ def init_worker_distributed_environment( ...@@ -399,6 +403,9 @@ def init_worker_distributed_environment(
ensure_kv_transfer_initialized(vllm_config) ensure_kv_transfer_initialized(vllm_config)
if vllm_config.parallel_config.enable_dp_attention:
initialize_dp_attention(vllm_config, backend)
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype. # Check if the GPU supports the dtype.
......
...@@ -112,6 +112,8 @@ class V1ZeroEagleProposer(EagleProposer): ...@@ -112,6 +112,8 @@ class V1ZeroEagleProposer(EagleProposer):
else: else:
num_input_tokens = num_tokens num_input_tokens = num_tokens
if self.enable_dp_attention:
num_input_tokens = round_up(num_input_tokens, self.attn_tp_size)
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
num_input_tokens += num_pad num_input_tokens += num_pad
...@@ -202,6 +204,13 @@ class V1ZeroEagleProposer(EagleProposer): ...@@ -202,6 +204,13 @@ class V1ZeroEagleProposer(EagleProposer):
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else: else:
input_batch_size = batch_size input_batch_size = batch_size
# dp attention need all dp rank process same number tokens
if self.enable_dp_attention:
input_batch_size = round_up(input_batch_size, self.attn_tp_size)
num_pad, _ = self.get_dp_padding(input_batch_size)
input_batch_size += num_pad
attn_metadata.num_actual_tokens = batch_size attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1 attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1] attn_metadata.query_start_loc = self.arange[:batch_size + 1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment