"tests/vscode:/vscode.git/clone" did not exist on "fcb73f306ccedb07ff33e3e3696018f66ccd40ea"
Commit 8813afd8 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-ds-wm-1218' into 'v0.9.2-dev-ds'

[feat]优化dp attention,减少1次allgather耗时,高吞吐提升明显

See merge request dcutoolkit/deeplearing/vllm!307
parents 428f3245 3e386c3b
......@@ -354,6 +354,7 @@ class DeepseekV2Attention(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
reduce_results: bool = True,
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -412,7 +413,8 @@ class DeepseekV2Attention(nn.Module):
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
prefix=f"{prefix}.o_proj",
reduce_results=reduce_results)
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
......@@ -506,6 +508,7 @@ class DeepseekV2MLAAttention(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
reduce_results: bool = True,
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -583,7 +586,8 @@ class DeepseekV2MLAAttention(nn.Module):
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
prefix=f"{prefix}.o_proj",
reduce_results=reduce_results)
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
......@@ -722,6 +726,7 @@ class DeepseekV2DecoderLayer(nn.Module):
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
self.tp_size = get_tensor_model_parallel_world_size()
self.config = config
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
......@@ -745,6 +750,15 @@ class DeepseekV2DecoderLayer(nn.Module):
attn_cls = DeepseekV2MLAAttention
else:
attn_cls = DeepseekV2Attention
self.is_mtp_layer = False
if self.layer_idx == config.num_hidden_layers:
self.is_mtp_layer = True
reduce_results = True
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and \
self.tp_size > 1 and not self.is_mtp_layer:
reduce_results = False
self.self_attn = attn_cls(
config=config,
hidden_size=self.hidden_size,
......@@ -761,6 +775,7 @@ class DeepseekV2DecoderLayer(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
reduce_results=reduce_results
)
self.input_layernorm = RMSNorm(config.hidden_size,
......@@ -769,6 +784,8 @@ class DeepseekV2DecoderLayer(nn.Module):
eps=config.rms_norm_eps)
self.routed_scaling_factor = config.routed_scaling_factor
self.tp_rank = get_tensor_model_parallel_rank()
def forward(
self,
positions: torch.Tensor,
......@@ -829,6 +846,13 @@ class DeepseekV2DecoderLayer(nn.Module):
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
if not self.is_mtp_layer:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1 and \
self.layer_idx > self.config.first_k_dense_replace:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
......@@ -844,27 +868,29 @@ class DeepseekV2DecoderLayer(nn.Module):
# first layer.
residual *= 1. / self.routed_scaling_factor
if not self.is_mtp_layer:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
if self.layer_idx == self.config.first_k_dense_replace:
residual = residual.tensor_split(self.tp_size)[self.tp_rank]
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
self.tp_rank = get_tensor_model_parallel_rank()
ori_bs = hidden_states.shape[0]
pad_size = (ori_bs + self.tp_size - 1) // self.tp_size * self.tp_size - ori_bs
if pad_size > 0:
hidden_states = torch.nn.functional.pad(hidden_states.contiguous(), [0, 0, 0, pad_size], value=0).contiguous()
new_bs = (ori_bs+pad_size) // self.tp_size
hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :]
if self.is_mtp_layer:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
hidden_states = hidden_states.tensor_split(self.tp_size)[self.tp_rank]
hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0).contiguous()
hidden_states = hidden_states[:ori_bs, :].contiguous()
if self.is_mtp_layer:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
......@@ -923,6 +949,14 @@ class DeepseekV2Model(nn.Module):
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
self.dp_size = get_dp_group().world_size
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
self.tp_size = get_tensor_model_parallel_world_size()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
......@@ -955,6 +989,9 @@ class DeepseekV2Model(nn.Module):
})
hidden_states, _ = self.norm(hidden_states, residual)
if self.use_deepep and self.tp_size > 1:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
return hidden_states
......
......@@ -23,6 +23,7 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata, MLACommonDe
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel
from vllm.utils import round_up
logger = init_logger(__name__)
......@@ -184,6 +185,13 @@ class EagleProposer:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
num_input_tokens = num_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
dp_size = self.vllm_config.parallel_config.data_parallel_size
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.vllm_config.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1:
num_input_tokens = round_up(num_input_tokens, tp_size)
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
......
......@@ -28,7 +28,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import (
get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
prepare_communication_buffer_for_model)
prepare_communication_buffer_for_model,
get_tensor_model_parallel_world_size)
from vllm.forward_context import (DPMetadata, get_forward_context,
set_forward_context, set_profilling)
from vllm.logger import init_logger
......@@ -1330,6 +1331,12 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
else:
num_input_tokens = num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
dp_size = self.vllm_config.parallel_config.data_parallel_size
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1:
num_input_tokens = round_up(num_input_tokens, tp_size)
# Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
num_input_tokens += num_pad
......@@ -2028,6 +2035,12 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
is_profile: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
dp_size = self.vllm_config.parallel_config.data_parallel_size
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1:
num_tokens = round_up(num_tokens, tp_size)
# Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_tokens += num_pad
......@@ -3101,6 +3114,12 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
else:
num_input_tokens = num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
dp_size = self.vllm_config.parallel_config.data_parallel_size
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1:
num_input_tokens = round_up(num_input_tokens, tp_size)
# Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
num_input_tokens += num_pad
......
......@@ -7,6 +7,7 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, EagleProposer
from vllm.utils import round_up
class V1ZeroEagleProposer(EagleProposer):
......@@ -107,6 +108,13 @@ class V1ZeroEagleProposer(EagleProposer):
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
num_input_tokens = num_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
dp_size = self.vllm_config.parallel_config.data_parallel_size
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.vllm_config.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1:
num_input_tokens = round_up(num_input_tokens, tp_size)
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
......
......@@ -422,6 +422,12 @@ class V1ZeroModelRunner(GPUModelRunner):
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else:
num_input_tokens = num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
dp_size = self.vllm_config.parallel_config.data_parallel_size
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.parallel_config.enable_expert_parallel and dp_size > 1 and tp_size > 1:
num_input_tokens = round_up(num_input_tokens, tp_size)
# Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment