Commit 4e51cae7 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-wm-0112' into 'v0.9.2-dev'

[feat]添加dp attention功能

See merge request dcutoolkit/deeplearing/vllm!383
parents fe2e2705 cc4d1002
......@@ -1883,6 +1883,9 @@ class ParallelConfig:
""" Use data parallelism instead of tensor parallelism for vision encoder.
Only support LLama4 for now"""
enable_dp_attention: bool = False
"""Enable dp attention"""
@property
def world_size_across_dp(self) -> int:
"""world_size_across_dp is TPxPPxDP, it is the size of the world
......@@ -2108,6 +2111,9 @@ class ParallelConfig:
if self.ray_workers_use_nsight and not self.use_ray:
raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.")
if self.enable_dp_attention and self.enable_expert_parallel:
raise ValueError("Dp attention and expert parallel can not enable together.")
return self
......@@ -4805,6 +4811,7 @@ class VllmConfig:
dp_size = self.parallel_config.data_parallel_size
tp_size = self.parallel_config.tensor_parallel_size
ep_sp = self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1
enable_dp_attention = self.parallel_config.enable_dp_attention
# add for spec decode
if self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0:
......@@ -4813,10 +4820,10 @@ class VllmConfig:
batch_size_capture_list = sorted(set(batch_size_capture_list + mtp_batch_size_capture_list))
batch_size_capture_list = [i for i in batch_size_capture_list if i == 1 or i % (1 + self.speculative_config.num_lookahead_slots) == 0]
if ep_sp:
if ep_sp or enable_dp_attention:
batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list]))
else:
if ep_sp:
if ep_sp or enable_dp_attention:
batch_size_capture_list = sorted(set([round_up(i, tp_size) for i in batch_size_capture_list]))
self.compilation_config.init_with_cudagraph_sizes(
......
......@@ -103,7 +103,7 @@ class DeviceCommunicatorBase:
# as long as we use data parallel (coupled data parallel
# where all data parallel ranks execute forward together),
# we initialize the all2all manager used in expert parallel.
use_ep = config.parallel_config.data_parallel_size > 1
use_ep = config.parallel_config.data_parallel_size > 1 and not config.parallel_config.enable_dp_attention
self.use_all2all = "ep" in unique_name and use_ep
self.all2all_manager: Optional[All2AllManagerBase] = None
......
......@@ -476,6 +476,9 @@ class EngineArgs:
enable_multimodal_encoder_data_parallel: bool = \
ParallelConfig.enable_multimodal_encoder_data_parallel
enable_dp_attention: bool = \
ParallelConfig.enable_dp_attention
def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
......@@ -718,6 +721,10 @@ class EngineArgs:
parallel_group.add_argument(
"--enable-multimodal-encoder-data-parallel",
**parallel_kwargs["enable_multimodal_encoder_data_parallel"])
parallel_group.add_argument(
"--enable-dp-attention",
**parallel_kwargs["enable_dp_attention"])
# KV cache arguments
cache_kwargs = get_kwargs(CacheConfig)
......@@ -1204,6 +1211,7 @@ class EngineArgs:
worker_extension_cls=self.worker_extension_cls,
enable_multimodal_encoder_data_parallel=self.
enable_multimodal_encoder_data_parallel,
enable_dp_attention=self.enable_dp_attention,
)
speculative_config = self.create_speculative_config(
......
......@@ -203,6 +203,7 @@ if TYPE_CHECKING:
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
VLLM_USE_FUSED_FILL_RMS_CAT:bool = False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True
VLLM_ENABLE_DEEPEP_INT8_DISPATCH: bool = True
VLLM_ZERO_OVERHEAD_ENHANCE: bool = False
VLLM_USE_FUSED_QA_KVA_GEMM: bool = False
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
......@@ -1325,6 +1326,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM":
lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in
("true", "1")),
# vLLM will use deepep int8 dispatch
"VLLM_ENABLE_DEEPEP_INT8_DISPATCH":
lambda: (os.getenv('VLLM_ENABLE_DEEPEP_INT8_DISPATCH', '1').lower() in
("true", "1")),
# Only quantized DeepSeek models supported.
# Unquantized versions are not supported.
......
......@@ -136,8 +136,8 @@ def set_forward_context(
forward_start_time = time.perf_counter()
dp_metadata: Optional[DPMetadata] = None
dp_size = vllm_config.parallel_config.data_parallel_size
use_navie_ep = envs.VLLM_ALL2ALL_BACKEND == 'naive' and dp_size > 1 and vllm_config.parallel_config.enable_expert_parallel
if use_navie_ep and dp_size > 1 and (
use_navie_all2all = envs.VLLM_ALL2ALL_BACKEND == 'naive' and dp_size > 1
if use_navie_all2all and dp_size > 1 and (
attn_metadata is not None or num_tokens is not None):
dp_metadata = DPMetadata.make(vllm_config.parallel_config,
attn_metadata, num_tokens or 0,
......@@ -210,4 +210,15 @@ def set_profilling(profiling):
def get_profilling() -> bool:
global _profiling
return _profiling
\ No newline at end of file
return _profiling
_warming_up = False
@contextmanager
def set_warming_up(warming_up):
global _warming_up
_warming_up = warming_up
def get_warming_up() -> bool:
global _warming_up
return _warming_up
\ No newline at end of file
from typing import TYPE_CHECKING, List, Optional, Tuple
import logging
import torch
import vllm.envs as envs
from vllm.distributed.parallel_state import GroupCoordinator, init_model_parallel_group, get_world_group
from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
get_tensor_model_parallel_rank,
tensor_model_parallel_reduce_scatter,
get_tp_group)
_ENABLE_DP_ATTENTION_FLAG: bool = False
_MOE_TP: Optional[GroupCoordinator] = None
_ATTN_DP_SIZE = 0
_ATTN_TP_SIZE = 0
_ATTN_TP_RANK = 0
_ATTN_DP_RANK = 0
_MOT_TP_SIZE = 0
_MOT_TP_RANK = 0
def initialize_dp_attention(vllm_config, backend: Optional[str] = None):
from vllm.config import VllmConfig
assert isinstance(vllm_config, VllmConfig)
global _ENABLE_DP_ATTENTION_FLAG, _ATTN_DP_SIZE, _ATTN_TP_SIZE, _ATTN_TP_RANK, _ATTN_DP_RANK, _MOT_TP_SIZE, _MOT_TP_RANK
enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
_ENABLE_DP_ATTENTION_FLAG = enable_dp_attention
# Build the moe tensor model-parallel groups.
world_size: int = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
data_parallel_size = vllm_config.parallel_config.data_parallel_size
pipeline_model_parallel_size = vllm_config.parallel_config.pipeline_parallel_size
tensor_model_parallel_size = vllm_config.parallel_config.tensor_parallel_size
moe_tp_size = world_size // pipeline_model_parallel_size
moe_ep_size = moe_tp_size if vllm_config.parallel_config.enable_expert_parallel else 1
_ATTN_DP_SIZE = data_parallel_size
_ATTN_TP_SIZE = tensor_model_parallel_size
_ATTN_TP_RANK = get_tensor_model_parallel_rank()
_ATTN_DP_RANK = vllm_config.parallel_config.data_parallel_rank
_MOT_TP_SIZE = moe_tp_size
_MOT_TP_RANK = rank % _MOT_TP_SIZE
global _MOE_TP
assert _MOE_TP is None, ("moe tensor model parallel group is already initialized")
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)
group_ranks = []
for i in range(pipeline_model_parallel_size):
ranks = list(
range(i * moe_tp_size, (i + 1) * moe_tp_size)
)
group_ranks.append(ranks)
# message queue broadcaster is only used in tensor model parallel group
_MOE_TP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True,
group_name="moe_tp")
def get_attention_tp_size() -> int:
assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
return _ATTN_TP_SIZE
def get_attention_tp_rank() -> int:
assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
return _ATTN_TP_RANK
def get_moe_tp_group() -> GroupCoordinator:
assert _MOE_TP is not None, ("tensor model parallel group is not initialized")
return _MOE_TP
def get_attention_dp_size() -> int:
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
return _ATTN_DP_SIZE
def get_moe_tp_rank() -> int:
assert _MOT_TP_RANK is not None, "dp attention not initialized!"
return _MOT_TP_RANK
def get_moe_tp_size() -> int:
assert _MOT_TP_SIZE is not None, "dp attention not initialized!"
return _MOT_TP_SIZE
def get_attention_tp_group() -> GroupCoordinator:
return get_tp_group()
def moe_tensor_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_moe_tp_group().all_gather(input_, dim)
def moe_tensor_model_parallel_reduce_scatter(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""Reduce-Scatter the input tensor across model parallel group."""
return get_moe_tp_group().reduce_scatter(input_, dim)
def dp_gather(
hidden_states: torch.Tensor,)-> torch.Tensor:
if get_attention_tp_size() == 1:
hidden_states = moe_tensor_model_parallel_all_gather(hidden_states, dim=0)
return hidden_states
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
hidden_states = moe_tensor_model_parallel_all_gather(hidden_states, dim=0)
return hidden_states
def dp_reduce_scatter_tensor(hidden_states: torch.Tensor)-> torch.Tensor:
if get_moe_tp_group().world_size == get_attention_dp_size():
hidden_states = moe_tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
else:
hidden_states = moe_tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
return hidden_states
......@@ -38,6 +38,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk, is_power_of_two)
from vllm.model_executor.layers.dp_attention import get_moe_tp_rank
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
......@@ -237,7 +238,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
all2all_manager.world_size,
)
ll_handle = all2all_manager.get_handle(ll_all_to_all_args)
# HT prepare/finalize built on the same LL handle per request
ht_prepare_finalize = DeepEPHTPrepareAndFinalize(
ll_handle,
......@@ -253,7 +254,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
and moe.quant_config.block_shape
== DEEPEP_QUANT_BLOCK_SHAPE)
use_int8_dispatch = False
use_int8_dispatch = moe.quant_config.quant_dtype == torch.int8 and envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
ll_prepare_finalize = DeepEPLLPrepareAndFinalize(
ll_handle,
......@@ -265,10 +266,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
prepare_finalize = DeepEPAutoPrepareAndFinalize(
ht_prepare_finalize, ll_prepare_finalize)
experts_ht = self.select_gemm_impl(ht_prepare_finalize, moe)
experts_ll = self.select_gemm_impl(ll_prepare_finalize, moe)
experts_ll = self.select_gemm_impl(ll_prepare_finalize, moe)
self.topk_indices_dtype = ll_prepare_finalize.topk_indices_dtype()
self.fused_experts = DeepGemmDisabledFusedMoEModularKernel(
prepare_finalize,
......@@ -276,9 +277,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
experts_ht=experts_ht,
experts_ll=experts_ll,
shared_experts=layer.shared_experts if hasattr(layer, "shared_experts") else None,
)
)
return
elif moe.use_deepep_ht_kernels:
assert moe.dp_size == all2all_manager.dp_world_size
......@@ -310,7 +311,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
and moe.quant_config.block_shape
== DEEPEP_QUANT_BLOCK_SHAPE)
use_int8_dispatch = moe.quant_config.quant_dtype == torch.int8
use_int8_dispatch = moe.quant_config.quant_dtype == torch.int8 and envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
# Note (varun): Whether to use FP8 dispatch or not needs some
# profiling. Turning it off for now.
......@@ -957,6 +958,8 @@ class FusedMoE(torch.nn.Module):
self.logical_to_physical_map: Optional[torch.Tensor] = None
self.logical_replica_count: Optional[torch.Tensor] = None
self.enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
# Determine expert maps
if self.use_ep:
if self.enable_eplb:
......@@ -1693,7 +1696,8 @@ class FusedMoE(torch.nn.Module):
The pplx combine kernel reduces across GPU ranks by default.
"""
if (self.use_pplx_kernels or self.use_deepep_ht_kernels
or self.use_deepep_ll_kernels or self.use_deepep_auto_kernels):
or self.use_deepep_ll_kernels or self.use_deepep_auto_kernels
or self.enable_dp_attention):
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
......@@ -1832,6 +1836,7 @@ class FusedMoE(torch.nn.Module):
and not self.moe_parallel_config.use_deepep_ht_kernels
and not self.moe_parallel_config.use_deepep_ll_kernels
and not self.moe_parallel_config.use_deepep_auto_kernels
and not self.enable_dp_attention
)
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch(
......
......@@ -22,6 +22,7 @@ from vllm.utils import round_up
try:
from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from lightop import op
except Exception:
print("INFO: Please install lmslim if you want to use int utils.\n")
from vllm.utils import cdiv
......@@ -622,52 +623,62 @@ def ep_scatter(
num_experts = num_recv_tokens_per_expert.shape[0]
hidden_size = recv_x.shape[1]
scale_hidden_size = recv_x_scale.shape[-1]
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid = num_experts
assert m_indices.shape[0] % BLOCK_E == 0
if hasattr(op, "ep_scatter"):
op.ep_scatter(
recv_x, recv_x_scale,
recv_topk, expert_map,
num_recv_tokens_per_expert,
output_tensor, output_tensor_scale, m_indices, output_index,
num_experts, BLOCK_E
)
else:
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid = num_experts
_fwd_kernel_ep_scatter_1[(grid,)](
num_recv_tokens_per_expert,
expert_start_loc,
m_indices,
num_experts=num_experts,
num_warps=num_warps,
BLOCK_E=BLOCK_E,
BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
)
assert m_indices.shape[0] % BLOCK_E == 0
grid = min(recv_topk.shape[0], 1024 * 8)
_fwd_kernel_ep_scatter_2[(grid,)](
recv_topk.shape[0],
expert_start_loc,
recv_x,
recv_x.stride(0),
recv_x.stride(1),
recv_x_scale,
recv_x_scale.stride(0),
recv_x_scale.stride(1),
recv_topk,
recv_topk.stride(0),
recv_topk.stride(1),
output_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
output_tensor_scale,
output_tensor_scale.stride(0),
output_tensor_scale.stride(1),
output_index,
output_index.stride(0),
output_index.stride(1),
topk_num=recv_topk.shape[1],
expert_map=expert_map,
HAS_EXPERT_MAP=expert_map is not None,
num_warps=num_warps,
HIDDEN_SIZE=hidden_size,
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
SCALE_HIDDEN_SIZE=scale_hidden_size,#hidden_size // BLOCK_D,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size)#triton.next_power_of_2(hidden_size // BLOCK_D),
)
_fwd_kernel_ep_scatter_1[(grid,)](
num_recv_tokens_per_expert,
expert_start_loc,
m_indices,
num_experts=num_experts,
num_warps=num_warps,
BLOCK_E=BLOCK_E,
BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
)
grid = min(recv_topk.shape[0], 1024 * 8)
_fwd_kernel_ep_scatter_2[(grid,)](
recv_topk.shape[0],
expert_start_loc,
recv_x,
recv_x.stride(0),
recv_x.stride(1),
recv_x_scale,
recv_x_scale.stride(0),
recv_x_scale.stride(1),
recv_topk,
recv_topk.stride(0),
recv_topk.stride(1),
output_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
output_tensor_scale,
output_tensor_scale.stride(0),
output_tensor_scale.stride(1),
output_index,
output_index.stride(0),
output_index.stride(1),
topk_num=recv_topk.shape[1],
expert_map=expert_map,
HAS_EXPERT_MAP=expert_map is not None,
num_warps=num_warps,
HIDDEN_SIZE=hidden_size,
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
SCALE_HIDDEN_SIZE=scale_hidden_size,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size),
)
return
......
......@@ -27,6 +27,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
PackedvLLMParameter,
PerTensorScaleParameter,
RowvLLMParameter)
from vllm.model_executor.layers.dp_attention import get_moe_tp_rank, get_moe_tp_size
# yapf: enable
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
......@@ -625,12 +626,18 @@ class ColumnParallelLinear(LinearBase):
*,
return_bias: bool = True,
expect_tp_size: Optional[int] = None,
enable_dp_attn_moe: bool = False,
):
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
if expect_tp_size is not None:
self.expect_tp_size = expect_tp_size
self.tp_size = self.expect_tp_size
self.enable_dp_attn_moe = enable_dp_attn_moe
if enable_dp_attn_moe:
self.tp_size = get_moe_tp_size()
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
......@@ -878,6 +885,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
*,
return_bias: bool = True,
expect_tp_size: Optional[int] = None,
enable_dp_attn_moe: bool = False,
):
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
......@@ -888,6 +896,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self.expect_tp_size = expect_tp_size
self.enable_dp_attn_moe = enable_dp_attn_moe
if enable_dp_attn_moe:
tp_size = get_moe_tp_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size=input_size,
output_size=sum(output_sizes),
......@@ -898,7 +910,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias,
expect_tp_size=expect_tp_size)
expect_tp_size=expect_tp_size,
enable_dp_attn_moe=enable_dp_attn_moe)
def weight_loader(self,
param: Parameter,
......@@ -999,6 +1012,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if self.expect_tp_size is not None and self.expect_tp_size == 1:
tp_rank = 0
tp_size = 1
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
tp_size = get_moe_tp_size()
if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
......@@ -1121,6 +1138,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if hasattr(param, "expect_tp_size"):
param.expect_tp_size = self.expect_tp_size
if self.enable_dp_attn_moe and hasattr(param, "enable_dp_attn_moe"):
tp_size = get_moe_tp_size()
param.enable_dp_attn_moe = self.enable_dp_attn_moe
if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
......@@ -1552,6 +1573,7 @@ class RowParallelLinear(LinearBase):
*,
return_bias: bool = True,
expect_tp_size: Optional[int] = None,
enable_dp_attn_moe: bool = False,
):
# Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank()
......@@ -1560,7 +1582,13 @@ class RowParallelLinear(LinearBase):
if expect_tp_size is not None:
self.tp_rank = 0
self.tp_size = 1
self.expect_tp_size = expect_tp_size
self.enable_dp_attn_moe = enable_dp_attn_moe
if enable_dp_attn_moe:
self.tp_rank = get_moe_tp_rank()
self.tp_size = get_moe_tp_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
......@@ -1610,6 +1638,11 @@ class RowParallelLinear(LinearBase):
if self.expect_tp_size is not None:
tp_rank = 0
tp_size = 1
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
tp_size = get_moe_tp_size()
input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
......@@ -1664,6 +1697,9 @@ class RowParallelLinear(LinearBase):
if self.expect_tp_size is not None and hasattr(param, "expect_tp_size"):
param.expect_tp_size = self.expect_tp_size
if self.enable_dp_attn_moe is not None and hasattr(param, "enable_dp_attn_moe"):
param.enable_dp_attn_moe = self.enable_dp_attn_moe
param.load_row_parallel_weight(loaded_weight=loaded_weight)
def forward(
......
......@@ -59,6 +59,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.layers.dp_attention import (dp_gather, dp_reduce_scatter_tensor,
get_moe_tp_size, get_moe_tp_rank, get_attention_tp_size)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -85,17 +87,22 @@ class DeepseekV2MLP(nn.Module):
prefix: str = "",
) -> None:
super().__init__()
vllm_config = get_current_vllm_config()
enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
prefix=f"{prefix}.gate_up_proj",
enable_dp_attn_moe=enable_dp_attention)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj")
reduce_results=reduce_results if not enable_dp_attention else False,
prefix=f"{prefix}.down_proj",
enable_dp_attn_moe=enable_dp_attention)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
......@@ -991,6 +998,8 @@ class DeepseekV2DecoderLayer(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.config = config
self.tp_rank = get_tensor_model_parallel_rank()
self.enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
......@@ -1010,6 +1019,8 @@ class DeepseekV2DecoderLayer(nn.Module):
prefix=f"{prefix}.mlp",
)
self.enable_ep_sp = isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1
self.is_mtp_layer = False
if self.layer_idx == config.num_hidden_layers:
self.is_mtp_layer = True
......@@ -1018,6 +1029,9 @@ class DeepseekV2DecoderLayer(nn.Module):
DeepseekV2MoE) and self.use_deepep and \
self.tp_size > 1 and not self.is_mtp_layer:
reduce_results = False
else:
if self.enable_dp_attention:
reduce_results = False
if model_config.use_mla:
attn_cls = DeepseekV2MLAAttention
......@@ -1169,25 +1183,25 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
if not self.is_mtp_layer:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1 and \
self.layer_idx > self.config.first_k_dense_replace:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
if not self.is_mtp_layer and self.enable_ep_sp and \
self.layer_idx > self.config.first_k_dense_replace:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
if not self.is_mtp_layer:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
if self.layer_idx == self.config.first_k_dense_replace:
residual = residual.tensor_split(self.tp_size)[self.tp_rank]
if not self.is_mtp_layer and self.enable_ep_sp:
if self.layer_idx == self.config.first_k_dense_replace:
residual = residual.tensor_split(self.tp_size)[self.tp_rank]
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
if self.enable_dp_attention:
if self.tp_rank == 0:
hidden_states += residual
hidden_states = dp_gather(hidden_states)
if hidden_states.dtype == torch.float16:
# Fix FP16 overflow
......@@ -1200,27 +1214,31 @@ class DeepseekV2DecoderLayer(nn.Module):
residual *= 1. / self.routed_scaling_factor
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if self.is_mtp_layer:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
ori_bs = hidden_states.shape[0]
pad_size = (ori_bs + self.tp_size - 1) // self.tp_size * self.tp_size - ori_bs
if pad_size > 0:
hidden_states = torch.nn.functional.pad(hidden_states.contiguous(), [0, 0, 0, pad_size], value=0).contiguous()
new_bs = (ori_bs+pad_size) // self.tp_size
hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :].contiguous()
if not self.enable_dp_attention:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
else:
num_tokens = hidden_states.shape[0]
new_bs = num_tokens // get_moe_tp_size() * get_attention_tp_size()
residual = hidden_states[self.dp_rank*new_bs: (self.dp_rank+1)*new_bs, :]
hidden_states = self.post_attention_layernorm(hidden_states)
if self.is_mtp_layer and self.enable_ep_sp:
ori_bs = hidden_states.shape[0]
pad_size = (ori_bs + self.tp_size - 1) // self.tp_size * self.tp_size - ori_bs
if pad_size > 0:
hidden_states = torch.nn.functional.pad(hidden_states, [0, 0, 0, pad_size], value=0)
new_bs = (ori_bs+pad_size) // self.tp_size
hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :]
hidden_states = self.mlp(hidden_states)
if self.is_mtp_layer:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
hidden_states = hidden_states[:ori_bs, :]
if self.enable_dp_attention:
hidden_states = dp_reduce_scatter_tensor(hidden_states)
if self.is_mtp_layer and self.enable_ep_sp:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
hidden_states = hidden_states[:ori_bs, :]
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
......
......@@ -10,6 +10,7 @@ from torch.nn import Parameter
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.logger import init_logger
from vllm.model_executor.utils import _make_synced_weight_loader
from vllm.model_executor.layers.dp_attention import get_moe_tp_rank, get_moe_tp_size
__all__ = [
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
......@@ -97,6 +98,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
self._output_dim = output_dim
super().__init__(**kwargs)
self.expect_tp_size = -1
self.enable_dp_attn_moe = False
@property
......@@ -107,6 +109,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1:
tp_rank = 0
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
shard_size = self.data.shape[self.output_dim]
loaded_weight = loaded_weight.narrow(self.output_dim,
tp_rank * shard_size, shard_size)
......@@ -129,6 +135,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1:
tp_rank = 0
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
param_data = param_data.narrow(self.output_dim, shard_offset,
shard_size)
loaded_weight = loaded_weight.narrow(self.output_dim,
......@@ -174,6 +184,7 @@ class RowvLLMParameter(BasevLLMParameter):
self._input_dim = input_dim
super().__init__(**kwargs)
self.expect_tp_size = -1
self.enable_dp_attn_moe = False
@property
def input_dim(self):
......@@ -183,6 +194,9 @@ class RowvLLMParameter(BasevLLMParameter):
tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1:
tp_rank = 0
if self.enable_dp_attn_moe:
tp_rank = get_moe_tp_rank()
shard_size = self.data.shape[self.input_dim]
loaded_weight = loaded_weight.narrow(self.input_dim,
tp_rank * shard_size, shard_size)
......
......@@ -12,7 +12,7 @@ from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import DPMetadata, set_forward_context
from vllm.forward_context import DPMetadata, set_forward_context, get_warming_up
from vllm.logger import init_logger
import vllm.envs as envs
from vllm.model_executor.model_loader import get_model
......@@ -93,6 +93,12 @@ class EagleProposer:
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.enable_expert_parallel = vllm_config.parallel_config.enable_expert_parallel
self.enable_dp_attention = vllm_config.parallel_config.enable_dp_attention
self.attn_tp_size = vllm_config.parallel_config.tensor_parallel_size
self.ep_sp = False
if self.enable_expert_parallel and self.dp_size > 1 and self.attn_tp_size > 1:
self.ep_sp = True
def propose(
self,
......@@ -189,9 +195,12 @@ class EagleProposer:
else:
num_input_tokens = num_tokens
if self.enable_dp_attention:
num_input_tokens = round_up(num_input_tokens, self.attn_tp_size)
# num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
# num_input_tokens += num_pad
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
num_input_tokens += num_pad
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
......@@ -279,6 +288,13 @@ class EagleProposer:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
# dp attention need all dp rank process same number tokens
if self.enable_dp_attention:
input_batch_size = round_up(input_batch_size, self.attn_tp_size)
num_pad, _ = self.get_dp_padding(input_batch_size)
input_batch_size += num_pad
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
......@@ -373,6 +389,7 @@ class EagleProposer:
attn_metadata.num_decode_tokens)
self.attn_metadata_cudagraph.num_prefills = (
attn_metadata.num_prefills)
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.seq_lens)
......@@ -532,11 +549,9 @@ class EagleProposer:
# TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager or envs.VLLM_ALL2ALL_BACKEND != 'naive':
# auto
if not envs.VLLM_ALL2ALL_BACKEND == "deepep_auto":
# Early exit.
return 0, None
if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
# Early exit.
return 0, None
try:
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
......@@ -559,6 +574,7 @@ class EagleProposer:
self,
num_tokens: int,
attn_metadata: Optional[dict[str, Any]] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None,
) -> None:
if attn_metadata is not None and self.attn_metadata_cudagraph is None:
self.attn_metadata_cudagraph = attn_metadata[
......@@ -566,29 +582,73 @@ class EagleProposer:
# Padding for DP
num_input_tokens = num_tokens
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_input_tokens += num_pad
# num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
# num_input_tokens += num_pad
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_tokens):
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp):
self.model(
self.input_ids[:num_input_tokens],
self.positions[:num_input_tokens],
self.hidden_states[:num_input_tokens],
)
if self.dp_size > 1 and self.enable_expert_parallel and self.num_speculative_tokens > 1:
num_token = 1
for _ in range(self.num_speculative_tokens - 1):
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_tokens):
self.model(
self.input_ids[:num_tokens],
self.positions[:num_tokens],
self.hidden_states[:num_tokens],
)
if self.dp_size > 1 and (self.enable_expert_parallel or self.enable_dp_attention) and self.num_speculative_tokens > 1:
num_tokens = 1
if self.enable_dp_attention or self.ep_sp:
num_tokens = round_up(num_tokens, self.attn_tp_size)
# dp attention need all dp rank process same number tokens
if self.enable_dp_attention:
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_tokens += num_pad
if not get_warming_up():
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=self.runner.query_start_loc[:num_tokens + 1],
seq_lens=self.runner.seq_lens[:num_tokens],
num_reqs=num_tokens,
num_actual_tokens=num_tokens,
max_query_len=num_tokens,
slot_mapping=self.runner.slot_mapping[:num_tokens],
spec_layer_decoding=True
)
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builders[0].build_for_cudagraph_capture(
common_attn_metadata=common_attn_metadata
)
for i in range(self.num_speculative_tokens - 1):
if self.attn_metadata_cudagraph is not None:
if i == 0:
attn_metadata_cudagraph = self.attn_metadata_cudagraph
attn_metadata_cudagraph.num_actual_tokens = num_tokens
attn_metadata_cudagraph.num_decodes = num_tokens
attn_metadata_cudagraph.num_decode_tokens = num_tokens
attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
attn_metadata.slot_mapping)
attn_metadata_cudagraph.decode.seq_lens[:num_tokens] = (
attn_metadata.decode.seq_lens)
attn_metadata_cudagraph.query_start_loc[:num_tokens + 1] = (
attn_metadata.query_start_loc)
attn_metadata_cudagraph.decode.block_table[:num_tokens] = (
attn_metadata.decode.block_table)
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp):
self.model(
self.input_ids[:num_tokens],
self.positions[:num_tokens],
self.hidden_states[:num_tokens],
)
def validate_same_kv_cache_group(self,
kv_cache_config: KVCacheConfig) -> None:
......
......@@ -339,6 +339,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if self.enable_expert_parallel and self.dp_size > 1 and self.tp_size > 1:
self.ep_sp = True
self.enable_dp_attention = self.parallel_config.enable_dp_attention
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
"""
Update the order of requests in the batch based on the attention
......@@ -1278,13 +1280,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
#
# TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
# auto
if not envs.VLLM_ALL2ALL_BACKEND == "deepep_auto":
# Early exit.
return 0, None
# Early exit.
return 0, None
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
num_tokens, dp_size, dp_rank)
......@@ -1361,7 +1359,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp:
if self.ep_sp or self.enable_dp_attention:
num_input_tokens = round_up(num_scheduled_tokens, self.tp_size)
if (self.use_cuda_graph
and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
......@@ -2129,13 +2127,12 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
skip_eplb: bool = False,
is_profile: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp:
if self.ep_sp or self.enable_dp_attention:
if num_tokens < self.tp_size:
num_tokens = self.tp_size
# Padding for DP
num_tokens_across_dp = 0
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_tokens += num_pad
......@@ -2156,13 +2153,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
min_tokens_per_req = (1 + self.speculative_config.num_lookahead_slots)
num_reqs = num_tokens // min_tokens_per_req
if self.ep_sp:
if self.ep_sp or self.enable_dp_attention:
num_actual_tokens = round_down(num_tokens, 1 + self.speculative_config.num_lookahead_slots)
num_reqs = num_actual_tokens // min_tokens_per_req
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
if not self.ep_sp:
if not (self.ep_sp or self.enable_dp_attention):
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
else:
if self.speculative_config is not None:
......@@ -2254,7 +2251,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if self.speculative_config and self.speculative_config.use_eagle() and not is_profile:
#assert isinstance(self.drafter, EagleProposer)
if hasattr(self, 'drafter') and isinstance(self.drafter, EagleProposer):
self.drafter.dummy_run(num_tokens, attn_metadata)
self.drafter.dummy_run(num_tokens, attn_metadata,
num_tokens_across_dp=num_tokens_across_dp)
# This is necessary to avoid blocking DP.
# For dummy runs, we typically skip EPLB since we don't have any real
......@@ -3227,7 +3225,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp:
if self.ep_sp or self.enable_dp_attention:
num_input_tokens = round_up(num_scheduled_tokens, self.tp_size)
if (self.use_cuda_graph
and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
......
......@@ -17,6 +17,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.model_executor.layers.dp_attention import initialize_dp_attention
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
......@@ -30,6 +31,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase
from vllm.zero_overhead.utils import zero_overhead_stream
from vllm.zero_overhead.v1.gpu_model_runner import V1ZeroModelRunner
from vllm.forward_context import (set_warming_up, get_warming_up)
logger = init_logger(__name__)
......@@ -260,6 +262,7 @@ class Worker(WorkerBase):
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
set_warming_up(True)
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
if not self.model_config.enforce_eager:
warmup_sizes = [
......@@ -297,6 +300,7 @@ class Worker(WorkerBase):
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
set_warming_up(False)
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
......@@ -399,6 +403,9 @@ def init_worker_distributed_environment(
ensure_kv_transfer_initialized(vllm_config)
if vllm_config.parallel_config.enable_dp_attention:
initialize_dp_attention(vllm_config, backend)
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
......
......@@ -112,8 +112,11 @@ class V1ZeroEagleProposer(EagleProposer):
else:
num_input_tokens = num_tokens
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
num_input_tokens += num_pad
if self.enable_dp_attention:
num_input_tokens = round_up(num_input_tokens, self.attn_tp_size)
# num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
# num_input_tokens += num_pad
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
......@@ -202,6 +205,13 @@ class V1ZeroEagleProposer(EagleProposer):
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
# dp attention need all dp rank process same number tokens
if self.enable_dp_attention:
input_batch_size = round_up(input_batch_size, self.attn_tp_size)
num_pad, _ = self.get_dp_padding(input_batch_size)
input_batch_size += num_pad
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
......
......@@ -465,7 +465,7 @@ class V1ZeroModelRunner(GPUModelRunner):
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp:
if self.ep_sp or self.enable_dp_attention:
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
if (self.use_cuda_graph
and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment