Commit cda54326 authored by 王敏's avatar 王敏
Browse files

[feat]添加dp attention功能

parent e89003dd
......@@ -1883,6 +1883,9 @@ class ParallelConfig:
""" Use data parallelism instead of tensor parallelism for vision encoder.
Only support LLama4 for now"""
enable_dp_attention: bool = False
"""Enable dp attention"""
@property
def world_size_across_dp(self) -> int:
"""world_size_across_dp is TPxPPxDP, it is the size of the world
......@@ -2109,6 +2112,9 @@ class ParallelConfig:
raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.")
if self.enable_dp_attention and self.enable_expert_parallel:
raise ValueError("Dp attention and expert parallel can not enable together.")
return self
......@@ -4805,6 +4811,7 @@ class VllmConfig:
dp_size = self.parallel_config.data_parallel_size
tp_size = self.parallel_config.tensor_parallel_size
ep_sp = self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1
enable_dp_attention = self.parallel_config.enable_dp_attention
# add for spec decode
if self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0:
......@@ -4813,11 +4820,15 @@ class VllmConfig:
batch_size_capture_list = sorted(set(batch_size_capture_list + mtp_batch_size_capture_list))
batch_size_capture_list = [i for i in batch_size_capture_list if i == 1 or i % (1 + self.speculative_config.num_lookahead_slots) == 0]
if ep_sp:
if ep_sp or enable_dp_attention:
batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list]))
if 1 not in batch_size_capture_list:
batch_size_capture_list.insert(0, 1)
else:
if ep_sp:
if ep_sp or enable_dp_attention:
batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list]))
if 1 not in batch_size_capture_list:
batch_size_capture_list.insert(0, 1)
self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list)
......
......@@ -477,6 +477,9 @@ class EngineArgs:
enable_multimodal_encoder_data_parallel: bool = \
ParallelConfig.enable_multimodal_encoder_data_parallel
enable_dp_attention: bool = \
ParallelConfig.enable_dp_attention
def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
......@@ -719,6 +722,10 @@ class EngineArgs:
"--enable-multimodal-encoder-data-parallel",
**parallel_kwargs["enable_multimodal_encoder_data_parallel"])
parallel_group.add_argument(
"--enable-dp-attention",
**parallel_kwargs["enable_dp_attention"])
# KV cache arguments
cache_kwargs = get_kwargs(CacheConfig)
cache_group = parser.add_argument_group(
......@@ -1204,6 +1211,7 @@ class EngineArgs:
worker_extension_cls=self.worker_extension_cls,
enable_multimodal_encoder_data_parallel=self.
enable_multimodal_encoder_data_parallel,
enable_dp_attention=self.enable_dp_attention,
)
speculative_config = self.create_speculative_config(
......
......@@ -211,3 +211,14 @@ def set_profilling(profiling):
def get_profilling() -> bool:
global _profiling
return _profiling
_warming_up = False
@contextmanager
def set_warming_up(warming_up):
global _warming_up
_warming_up = warming_up
def get_warming_up() -> bool:
global _warming_up
return _warming_up
\ No newline at end of file
from typing import TYPE_CHECKING, List, Optional, Tuple
import logging
import torch
import vllm.envs as envs
from vllm.distributed.parallel_state import GroupCoordinator, init_model_parallel_group, get_world_group
from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
get_tensor_model_parallel_rank,
tensor_model_parallel_reduce_scatter,
get_tp_group)
_ENABLE_DP_ATTENTION_FLAG: bool = False
_MOE_TP: Optional[GroupCoordinator] = None
_ATTN_DP_SIZE = 0
_ATTN_TP_SIZE = 0
_ATTN_TP_RANK = 0
_ATTN_DP_RANK = 0
_MOT_TP_SIZE = 0
_MOT_TP_RANK = 0
def initialize_dp_attention(vllm_config, backend: Optional[str] = None):
from vllm.config import VllmConfig
assert isinstance(vllm_config, VllmConfig)
global _ENABLE_DP_ATTENTION_FLAG, _ATTN_DP_SIZE, _ATTN_TP_SIZE, _ATTN_TP_RANK, _ATTN_DP_RANK, _MOT_TP_SIZE, _MOT_TP_RANK
enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
_ENABLE_DP_ATTENTION_FLAG = enable_dp_attention
# Build the moe tensor model-parallel groups.
world_size: int = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
data_parallel_size = vllm_config.parallel_config.data_parallel_size
pipeline_model_parallel_size = vllm_config.parallel_config.pipeline_parallel_size
tensor_model_parallel_size = vllm_config.parallel_config.tensor_parallel_size
moe_tp_size = world_size // pipeline_model_parallel_size
moe_ep_size = moe_tp_size if vllm_config.parallel_config.enable_expert_parallel else 1
_ATTN_DP_SIZE = data_parallel_size
_ATTN_TP_SIZE = tensor_model_parallel_size
_ATTN_TP_RANK = get_tensor_model_parallel_rank()
_ATTN_DP_RANK = vllm_config.parallel_config.data_parallel_rank
_MOT_TP_SIZE = moe_tp_size
_MOT_TP_RANK = rank % _MOT_TP_SIZE
global _MOE_TP
assert _MOE_TP is None, ("moe tensor model parallel group is already initialized")
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)
group_ranks = []
for i in range(pipeline_model_parallel_size):
ranks = list(
range(i * moe_tp_size, (i + 1) * moe_tp_size)
)
group_ranks.append(ranks)
# message queue broadcaster is only used in tensor model parallel group
_MOE_TP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True,
group_name="moe_tp")
def get_attention_tp_size() -> int:
assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
return _ATTN_TP_SIZE
def get_attention_tp_rank() -> int:
assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
return _ATTN_TP_RANK
def get_moe_tp_group() -> GroupCoordinator:
assert _MOE_TP is not None, ("tensor model parallel group is not initialized")
return _MOE_TP
def get_attention_dp_size() -> int:
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
return _ATTN_DP_SIZE
def get_moe_tp_rank() -> int:
assert _MOT_TP_RANK is not None, "dp attention not initialized!"
return _MOT_TP_RANK
def get_moe_tp_size() -> int:
assert _MOT_TP_SIZE is not None, "dp attention not initialized!"
return _MOT_TP_SIZE
def get_attention_tp_group() -> GroupCoordinator:
return get_tp_group()
def moe_tensor_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_moe_tp_group().all_gather(input_, dim)
def moe_tensor_model_parallel_reduce_scatter(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""Reduce-Scatter the input tensor across model parallel group."""
return get_moe_tp_group().reduce_scatter(input_, dim)
def dp_gather(
hidden_states: torch.Tensor,)-> torch.Tensor:
if get_attention_tp_size() == 1:
hidden_states = moe_tensor_model_parallel_all_gather(hidden_states, dim=0)
return hidden_states
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
hidden_states = moe_tensor_model_parallel_all_gather(hidden_states, dim=0)
return hidden_states
def dp_reduce_scatter_tensor(hidden_states: torch.Tensor)-> torch.Tensor:
if get_moe_tp_group().world_size == get_attention_dp_size():
hidden_states = moe_tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
else:
hidden_states = moe_tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
return hidden_states
......@@ -37,6 +37,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk, is_power_of_two)
from vllm.model_executor.layers.dp_attention import get_moe_tp_rank
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
......@@ -900,6 +901,8 @@ class FusedMoE(torch.nn.Module):
self.logical_to_physical_map: Optional[torch.Tensor] = None
self.logical_replica_count: Optional[torch.Tensor] = None
self.enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
# Determine expert maps
if self.use_ep:
if self.enable_eplb:
......@@ -1620,7 +1623,8 @@ class FusedMoE(torch.nn.Module):
The pplx combine kernel reduces across GPU ranks by default.
"""
if (self.use_pplx_kernels or self.use_deepep_ht_kernels
or self.use_deepep_ll_kernels or self.use_deepep_auto_kernels):
or self.use_deepep_ll_kernels or self.use_deepep_auto_kernels
or self.enable_dp_attention):
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
......@@ -1759,6 +1763,7 @@ class FusedMoE(torch.nn.Module):
and not self.moe_parallel_config.use_deepep_ht_kernels
and not self.moe_parallel_config.use_deepep_ll_kernels
and not self.moe_parallel_config.use_deepep_auto_kernels
and not self.enable_dp_attention
)
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch(
......
......@@ -22,6 +22,7 @@ from vllm.utils import round_up
try:
from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from lightop import op
except Exception:
print("INFO: Please install lmslim if you want to use int utils.\n")
from vllm.utils import cdiv
......@@ -622,6 +623,16 @@ def ep_scatter(
num_experts = num_recv_tokens_per_expert.shape[0]
hidden_size = recv_x.shape[1]
scale_hidden_size = recv_x_scale.shape[-1]
if hasattr(op, "ep_scatter"):
op.ep_scatter(
recv_x, recv_x_scale,
recv_topk, expert_map,
num_recv_tokens_per_expert,
output_tensor, output_tensor_scale, m_indices, output_index,
num_experts, BLOCK_E
)
else:
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid = num_experts
......@@ -665,8 +676,8 @@ def ep_scatter(
num_warps=num_warps,
HIDDEN_SIZE=hidden_size,
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
SCALE_HIDDEN_SIZE=scale_hidden_size,#hidden_size // BLOCK_D,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size)#triton.next_power_of_2(hidden_size // BLOCK_D),
SCALE_HIDDEN_SIZE=scale_hidden_size,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size),
)
return
......
......@@ -27,6 +27,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
PackedvLLMParameter,
PerTensorScaleParameter,
RowvLLMParameter)
from vllm.model_executor.layers.dp_attention import get_moe_tp_rank, get_moe_tp_size
# yapf: enable
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
......@@ -625,12 +626,18 @@ class ColumnParallelLinear(LinearBase):
*,
return_bias: bool = True,
expect_tp_size: Optional[int] = None,
enable_dp_attn_moe: bool = False,
):
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
if expect_tp_size is not None:
self.expect_tp_size = expect_tp_size
self.tp_size = self.expect_tp_size
self.enable_dp_attn_moe = enable_dp_attn_moe
if enable_dp_attn_moe:
self.tp_size = get_moe_tp_size()
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
......@@ -878,6 +885,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
*,
return_bias: bool = True,
expect_tp_size: Optional[int] = None,
enable_dp_attn_moe: bool = False,
):
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
......@@ -888,6 +896,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self.expect_tp_size = expect_tp_size
self.enable_dp_attn_moe = enable_dp_attn_moe
if enable_dp_attn_moe:
tp_size = get_moe_tp_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size=input_size,
output_size=sum(output_sizes),
......@@ -898,7 +910,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias,
expect_tp_size=expect_tp_size)
expect_tp_size=expect_tp_size,
enable_dp_attn_moe=enable_dp_attn_moe)
def weight_loader(self,
param: Parameter,
......@@ -1000,6 +1013,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
tp_rank = 0
tp_size = 1
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
tp_size = get_moe_tp_size()
if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
......@@ -1121,6 +1138,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if hasattr(param, "expect_tp_size"):
param.expect_tp_size = self.expect_tp_size
if self.enable_dp_attn_moe and hasattr(param, "enable_dp_attn_moe"):
tp_size = get_moe_tp_size()
param.enable_dp_attn_moe = self.enable_dp_attn_moe
if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
......@@ -1552,6 +1573,7 @@ class RowParallelLinear(LinearBase):
*,
return_bias: bool = True,
expect_tp_size: Optional[int] = None,
enable_dp_attn_moe: bool = False,
):
# Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank()
......@@ -1560,7 +1582,13 @@ class RowParallelLinear(LinearBase):
if expect_tp_size is not None:
self.tp_rank = 0
self.tp_size = 1
self.expect_tp_size = expect_tp_size
self.enable_dp_attn_moe = enable_dp_attn_moe
if enable_dp_attn_moe:
self.tp_rank = get_moe_tp_rank()
self.tp_size = get_moe_tp_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
......@@ -1610,6 +1638,11 @@ class RowParallelLinear(LinearBase):
if self.expect_tp_size is not None:
tp_rank = 0
tp_size = 1
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
tp_size = get_moe_tp_size()
input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
......@@ -1664,6 +1697,9 @@ class RowParallelLinear(LinearBase):
if self.expect_tp_size is not None and hasattr(param, "expect_tp_size"):
param.expect_tp_size = self.expect_tp_size
if self.enable_dp_attn_moe is not None and hasattr(param, "enable_dp_attn_moe"):
param.enable_dp_attn_moe = self.enable_dp_attn_moe
param.load_row_parallel_weight(loaded_weight=loaded_weight)
def forward(
......
......@@ -59,6 +59,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.layers.dp_attention import (dp_gather, dp_reduce_scatter_tensor,
get_moe_tp_size, get_moe_tp_rank, get_attention_tp_size)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -85,17 +87,22 @@ class DeepseekV2MLP(nn.Module):
prefix: str = "",
) -> None:
super().__init__()
vllm_config = get_current_vllm_config()
enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
prefix=f"{prefix}.gate_up_proj",
enable_dp_attn_moe=enable_dp_attention)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj")
reduce_results=reduce_results if not enable_dp_attention else False,
prefix=f"{prefix}.down_proj",
enable_dp_attn_moe=enable_dp_attention)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
......@@ -979,6 +986,8 @@ class DeepseekV2DecoderLayer(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.config = config
self.tp_rank = get_tensor_model_parallel_rank()
self.enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
......@@ -1006,6 +1015,9 @@ class DeepseekV2DecoderLayer(nn.Module):
DeepseekV2MoE) and self.use_deepep and \
self.tp_size > 1 and not self.is_mtp_layer:
reduce_results = False
else:
if self.enable_dp_attention:
reduce_results = False
if model_config.use_mla:
attn_cls = DeepseekV2MLAAttention
......@@ -1176,6 +1188,10 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
if self.enable_dp_attention:
if self.tp_rank == 0:
hidden_states += residual
hidden_states = dp_gather(hidden_states)
if hidden_states.dtype == torch.float16:
# Fix FP16 overflow
......@@ -1188,8 +1204,14 @@ class DeepseekV2DecoderLayer(nn.Module):
residual *= 1. / self.routed_scaling_factor
# Fully Connected
if not self.enable_dp_attention:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
else:
num_tokens = hidden_states.shape[0]
new_bs = num_tokens // get_moe_tp_size() * get_attention_tp_size()
residual = hidden_states[self.dp_rank*new_bs: (self.dp_rank+1)*new_bs, :]
hidden_states = self.post_attention_layernorm(hidden_states)
if self.is_mtp_layer:
if isinstance(self.mlp,
......@@ -1201,9 +1223,11 @@ class DeepseekV2DecoderLayer(nn.Module):
new_bs = (ori_bs+pad_size) // self.tp_size
hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :].contiguous()
hidden_states = self.mlp(hidden_states)
if self.enable_dp_attention:
hidden_states = dp_reduce_scatter_tensor(hidden_states)
if self.is_mtp_layer:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
......
......@@ -10,6 +10,7 @@ from torch.nn import Parameter
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.logger import init_logger
from vllm.model_executor.utils import _make_synced_weight_loader
from vllm.model_executor.layers.dp_attention import get_moe_tp_rank, get_moe_tp_size
__all__ = [
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
......@@ -97,6 +98,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
self._output_dim = output_dim
super().__init__(**kwargs)
self.expect_tp_size = -1
self.enable_dp_attn_moe = False
@property
......@@ -107,6 +109,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1:
tp_rank = 0
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
shard_size = self.data.shape[self.output_dim]
loaded_weight = loaded_weight.narrow(self.output_dim,
tp_rank * shard_size, shard_size)
......@@ -129,6 +135,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1:
tp_rank = 0
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
param_data = param_data.narrow(self.output_dim, shard_offset,
shard_size)
loaded_weight = loaded_weight.narrow(self.output_dim,
......@@ -174,6 +184,7 @@ class RowvLLMParameter(BasevLLMParameter):
self._input_dim = input_dim
super().__init__(**kwargs)
self.expect_tp_size = -1
self.enable_dp_attn_moe = False
@property
def input_dim(self):
......@@ -183,6 +194,9 @@ class RowvLLMParameter(BasevLLMParameter):
tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1:
tp_rank = 0
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
shard_size = self.data.shape[self.input_dim]
loaded_weight = loaded_weight.narrow(self.input_dim,
tp_rank * shard_size, shard_size)
......
......@@ -12,7 +12,7 @@ from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import DPMetadata, set_forward_context
from vllm.forward_context import DPMetadata, set_forward_context, get_warming_up
from vllm.logger import init_logger
import vllm.envs as envs
from vllm.model_executor.model_loader import get_model
......@@ -93,6 +93,8 @@ class EagleProposer:
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.enable_expert_parallel = vllm_config.parallel_config.enable_expert_parallel
self.enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
self.attn_tp_size = vllm_config.parallel_config.tensor_parallel_size
def propose(
self,
......@@ -189,7 +191,8 @@ class EagleProposer:
else:
num_input_tokens = num_tokens
if self.enable_dp_attention:
num_input_tokens = round_up(num_input_tokens, self.attn_tp_size)
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
num_input_tokens += num_pad
# copy inputs to buffer for cudagraph
......@@ -279,6 +282,13 @@ class EagleProposer:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
# dp attention need all dp rank process same number tokens
if self.enable_dp_attention:
input_batch_size = round_up(input_batch_size, self.attn_tp_size)
num_pad, _ = self.get_dp_padding(input_batch_size)
input_batch_size += num_pad
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
......@@ -373,6 +383,7 @@ class EagleProposer:
attn_metadata.num_decode_tokens)
self.attn_metadata_cudagraph.num_prefills = (
attn_metadata.num_prefills)
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.seq_lens)
......@@ -532,9 +543,8 @@ class EagleProposer:
# TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations.
if not self.enable_dp_attention and not envs.VLLM_ALL2ALL_BACKEND == "deepep_auto":
if dp_size == 1 or self.vllm_config.model_config.enforce_eager or envs.VLLM_ALL2ALL_BACKEND != 'naive':
# auto
if not envs.VLLM_ALL2ALL_BACKEND == "deepep_auto":
# Early exit.
return 0, None
......@@ -566,7 +576,7 @@ class EagleProposer:
# Padding for DP
num_input_tokens = num_tokens
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_pad, _ = self.get_dp_padding(num_tokens)
num_input_tokens += num_pad
with set_forward_context(attn_metadata,
......@@ -578,9 +588,49 @@ class EagleProposer:
self.hidden_states[:num_input_tokens],
)
if self.dp_size > 1 and self.enable_expert_parallel and self.num_speculative_tokens > 1:
num_token = 1
for _ in range(self.num_speculative_tokens - 1):
if self.dp_size > 1 and (self.enable_expert_parallel or self.enable_dp_attention) and self.num_speculative_tokens > 1:
num_tokens = 1
# dp attention need all dp rank process same number tokens
if self.enable_dp_attention:
num_tokens = round_up(num_tokens, self.attn_tp_size)
num_pad, _ = self.get_dp_padding(num_tokens)
num_tokens += num_pad
if not get_warming_up():
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=self.runner.query_start_loc[:num_tokens + 1],
seq_lens=self.runner.seq_lens[:num_tokens],
num_reqs=num_tokens,
num_actual_tokens=num_tokens,
max_query_len=num_tokens,
slot_mapping=self.runner.slot_mapping[:num_tokens],
spec_layer_decoding=True
)
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builders[0].build_for_cudagraph_capture(
common_attn_metadata=common_attn_metadata
)
for i in range(self.num_speculative_tokens - 1):
if self.attn_metadata_cudagraph is not None:
if i == 0:
attn_metadata_cudagraph = self.attn_metadata_cudagraph
attn_metadata_cudagraph.num_actual_tokens = num_tokens
attn_metadata_cudagraph.num_decodes = num_tokens
attn_metadata_cudagraph.num_decode_tokens = num_tokens
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
attn_metadata.slot_mapping)
attn_metadata_cudagraph.decode.seq_lens[:num_tokens] = (
attn_metadata.decode.seq_lens)
self.attn_metadata_cudagraph.query_start_loc[:num_tokens + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.decode.block_table[:num_tokens] = (
attn_metadata.decode.block_table)
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_tokens):
......
......@@ -31,7 +31,7 @@ from vllm.distributed.parallel_state import (
prepare_communication_buffer_for_model,
get_tensor_model_parallel_world_size)
from vllm.forward_context import (DPMetadata, get_forward_context,
set_forward_context, set_profilling)
set_forward_context)
from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
......@@ -339,6 +339,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if self.enable_expert_parallel and self.dp_size > 1 and self.tp_size > 1:
self.ep_sp = True
self.enable_dp_attention = self.parallel_config.enable_dp_attention
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
"""
Update the order of requests in the batch based on the attention
......@@ -1275,13 +1277,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
# auto
if not envs.VLLM_ALL2ALL_BACKEND == "deepep_auto":
if not self.enable_dp_attention and not envs.VLLM_ALL2ALL_BACKEND == "deepep_auto":
if dp_size == 1 or self.vllm_config.model_config.enforce_eager or envs.VLLM_ALL2ALL_BACKEND != 'naive':
# Early exit.
return 0, None
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
num_tokens, dp_size, dp_rank)
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
......@@ -1357,7 +1357,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp:
if self.ep_sp or self.enable_dp_attention:
num_input_tokens = round_up(num_scheduled_tokens, self.tp_size)
if (self.use_cuda_graph
and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
......@@ -1638,9 +1638,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
if spec_token_ids is not None:
for i in discard_sampled_tokens_req_indices:
spec_token_ids[i].clear()
# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
......@@ -1681,6 +1678,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
attn_metadata,
)
if spec_token_ids is not None:
for i in discard_sampled_tokens_req_indices:
spec_token_ids[i].clear()
# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
......@@ -2121,13 +2122,12 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
skip_eplb: bool = False,
is_profile: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp:
if self.ep_sp or self.enable_dp_attention:
if num_tokens < self.tp_size:
num_tokens = self.tp_size
# Padding for DP
num_tokens_across_dp = 0
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_tokens += num_pad
......@@ -2148,13 +2148,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
min_tokens_per_req = (1 + self.speculative_config.num_lookahead_slots)
num_reqs = num_tokens // min_tokens_per_req
if self.ep_sp:
if self.ep_sp or self.enable_dp_attention:
num_actual_tokens = round_down(num_tokens, 1 + self.speculative_config.num_lookahead_slots)
num_reqs = num_actual_tokens // min_tokens_per_req
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
if not self.ep_sp:
if not (self.ep_sp or self.enable_dp_attention):
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
else:
if self.speculative_config is not None:
......@@ -3219,7 +3219,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp:
if self.ep_sp or self.enable_dp_attention:
num_input_tokens = round_up(num_scheduled_tokens, self.tp_size)
if (self.use_cuda_graph
and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
......
......@@ -17,6 +17,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.model_executor.layers.dp_attention import initialize_dp_attention
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
......@@ -30,6 +31,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase
from vllm.zero_overhead.utils import zero_overhead_stream
from vllm.zero_overhead.v1.gpu_model_runner import V1ZeroModelRunner
from vllm.forward_context import (set_warming_up, get_warming_up)
logger = init_logger(__name__)
......@@ -260,6 +262,7 @@ class Worker(WorkerBase):
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
set_warming_up(True)
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
if not self.model_config.enforce_eager:
warmup_sizes = [
......@@ -297,6 +300,7 @@ class Worker(WorkerBase):
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
set_warming_up(False)
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
......@@ -399,6 +403,9 @@ def init_worker_distributed_environment(
ensure_kv_transfer_initialized(vllm_config)
if vllm_config.parallel_config.enable_dp_attention:
initialize_dp_attention(vllm_config, backend)
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
......
......@@ -112,6 +112,8 @@ class V1ZeroEagleProposer(EagleProposer):
else:
num_input_tokens = num_tokens
if self.enable_dp_attention:
num_input_tokens = round_up(num_input_tokens, self.attn_tp_size)
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
num_input_tokens += num_pad
......@@ -202,6 +204,13 @@ class V1ZeroEagleProposer(EagleProposer):
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
# dp attention need all dp rank process same number tokens
if self.enable_dp_attention:
input_batch_size = round_up(input_batch_size, self.attn_tp_size)
num_pad, _ = self.get_dp_padding(input_batch_size)
input_batch_size += num_pad
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment