Commit 3f5c2eea authored by zhuwenwen's avatar zhuwenwen
Browse files

add mla tpsp and moe share experts computation communication overlap

parent 8375370f
......@@ -178,6 +178,8 @@ if TYPE_CHECKING:
VLLM_P2P_BUF_TOKENS: int = 30000
VLLM_SCHED_ENABLE_MINIMAL_INJECTION: bool = False
VLLM_USE_PD_SPLIT: bool = False
VLLM_ENABLE_MLA_SP: bool = False
VLLM_ENABLE_MLA_QKV_MERGE: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1094,68 +1096,89 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASH_ATTN_PA":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_PA", "True").lower() in
("true", "1")),
# vLLM will use apex for rmsnorm
"VLLM_USE_APEX_RN":
lambda: (os.environ.get("VLLM_USE_APEX_RN", "False").lower() in
("true", "1")),
# vLLM will use global cache for moe
"VLLM_USE_GLOBAL_CACHE13":
lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in
("true", "1")),
# vLLM will use lightop for deepseek-v3
"VLLM_USE_LIGHTOP":
lambda: (os.environ.get("VLLM_USE_LIGHTOP", "False").lower() in
("true", "1")),
# vLLM will use elenmentwise not triton_
"VLLM_USE_OPT_ZEROS":
lambda: (os.environ.get("VLLM_USE_OPT_ZEROS", "False").lower() in
("true", "1")),
# vLLM will use opt cat for deepseek-v3
"VLLM_USE_OPT_CAT":
lambda: (os.environ.get("VLLM_USE_OPT_CAT", "False").lower() in
("true", "1")),
# vLLM will use triton moe_sum
"VLLM_USE_OPT_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_OPT_MOE_SUM", "False").lower() in
("true", "1")),
# vLLM will use lightop moe_sum_mul_add
"VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD", "False").lower() in
("true", "1")),
# vLLM will use lightop moe_sum
"VLLM_USE_LIGHTOP_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM", "True").lower() in
("true", "1")),
# vLLM will use lightop moe_align_block_size
"VLLM_USE_LIGHTOP_MOE_ALIGN":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_ALIGN", "True").lower() in
("true", "1")),
# vLLM will use opt merge_aatn_states, not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT":
lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in
("true", "1")),
# vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT":
lambda: (os.getenv('USE_FUSED_RMS_QUANT', '0').lower() in
("true", "1")),
# vllm will use silu_mul_quant fused op
"USE_FUSED_SILU_MUL_QUANT":
lambda: (os.getenv('USE_FUSED_SILU_MUL_QUANT', '0').lower() in
("true", "1")),
# vllm pd separation will be used async
"VLLM_P2P_ASYNC":
lambda: bool(int(os.getenv("VLLM_P2P_ASYNC", "0"))),
# pd separation p2p async buf tokens
"VLLM_P2P_BUF_TOKENS":
lambda: int(os.getenv("VLLM_P2P_BUF_TOKENS", "30000")),
# vllm will enable minimal injection for pipeline parallel scheduling
"VLLM_SCHED_ENABLE_MINIMAL_INJECTION":
lambda: (os.getenv("VLLM_SCHED_ENABLE_MINIMAL_INJECTION", "0").lower() in
("true", "1")),
# vLLM will split prefill and decode, not mix up
"VLLM_USE_PD_SPLIT":
lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "True").lower() in
("true", "1")),
"VLLM_ENABLE_MLA_SP":
lambda: bool(int(os.getenv("VLLM_ENABLE_MLA_SP", "0"))),
"VLLM_ENABLE_MLA_QKV_MERGE":
lambda: bool(int(os.getenv("VLLM_ENABLE_MLA_QKV_MERGE", "0"))),
}
# --8<-- [end:env-vars-definition]
......
......@@ -637,6 +637,13 @@ def determine_expert_map(
return (local_num_experts, expert_map)
EventType = Enum(
'EventType',
['Main', 'Attention', 'QCAllgather', 'KVFinish', 'MoeShared', 'MoeChunkingOverlap', 'MoeAllgather', 'MoeReduceScatter'],
start=0,
)
class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models.
......
......@@ -14,7 +14,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
......@@ -454,6 +455,86 @@ class ReplicatedLinear(LinearBase):
return s
class MergedReplicatedLinear(ReplicatedLinear):
"""Merged replicated linear layer
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def __init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
):
self.output_sizes = output_sizes
super().__init__(input_size,
sum(output_sizes),
bias,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix,
return_bias=return_bias)
def weight_loader(self,
param: Union[Parameter, BasevLLMParameter],
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
assert loaded_shard_id is not None
assert loaded_shard_id < len(self.output_sizes)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
assert self.quant_method is not None
assert isinstance(self.quant_method, (Fp8LinearMethod, Fp8MoEMethod))
weight_block_size = self.quant_method.quant_config.weight_block_size
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
)
shard_size = (
(self.output_sizes[loaded_shard_id] + block_n - 1)
// block_n
)
elif isinstance(param, PerTensorScaleParameter) and current_platform.is_rocm():
shard_offset = loaded_shard_id
shard_size = 1
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id]
start_offset = shard_offset
end_offset = start_offset + shard_size
assert loaded_weight.shape == param.data[start_offset:end_offset, ...].shape, (
f"Expected shape {param.data[start_offset:end_offset, ...].shape}, got {loaded_weight.shape}"
)
if envs.VLLM_USE_NN and not is_quantization:
loaded_weight = loaded_weight.t()
param.data[start_offset:end_offset, ...].copy_(loaded_weight)
class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism.
......@@ -1390,6 +1471,7 @@ class RowParallelLinear(LinearBase):
prefix: str = "",
*,
return_bias: bool = True,
sp_parallel: bool = False,
):
# Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank()
......@@ -1397,6 +1479,7 @@ class RowParallelLinear(LinearBase):
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
self.sp_parallel = sp_parallel
super().__init__(input_size,
output_size,
......@@ -1525,6 +1608,9 @@ class RowParallelLinear(LinearBase):
if self.reduce_results and self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
output = self.tbo_all_reduce(output_parallel)
else:
if self.sp_parallel:
output = tensor_model_parallel_reduce_scatter(output_parallel.contiguous(), dim=0)
else:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
......
......@@ -29,10 +29,11 @@ import vllm.envs as envs
import typing
from collections.abc import Callable, Iterable
from typing import Any, Optional, Union
from typing import Any, Optional, Union, Dict
import torch
from torch import nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.attention import Attention
......@@ -40,12 +41,17 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group,
get_tensor_model_parallel_world_size)
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_reduce_scatter,
get_tensor_model_parallel_rank)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.layer import EventType, AuxStreamType
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
MergedReplicatedLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
......@@ -64,6 +70,9 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter,
maybe_prefix)
from vllm import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON
from vllm.logger import init_logger
logger = init_logger(__name__)
class DeepseekV2MLP(nn.Module):
......@@ -75,6 +84,7 @@ class DeepseekV2MLP(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
enable_tpsp: bool = False
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
......@@ -82,12 +92,14 @@ class DeepseekV2MLP(nn.Module):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.enable_tpsp = enable_tpsp
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj")
prefix=f"{prefix}.down_proj",
sp_parallel=self.enable_tpsp)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
......@@ -108,12 +120,57 @@ class DeepseekV2MLP(nn.Module):
return x, new_resi
else:
if self.enable_tpsp:
x = tensor_model_parallel_all_gather(x.contiguous(), dim=0)
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class SharedExpertOverlapSPMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
event_dict: dict = None,
aux_stream: torch.cuda.Stream = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedReplicatedLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = ReplicatedLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
self.event_dict = event_dict
self.aux_stream = aux_stream
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
self.event_dict[EventType.MoeAllgather].wait()
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
self.event_dict[EventType.MoeReduceScatter].wait()
x, _ = self.down_proj(x)
self.event_dict[EventType.MoeShared].record()
return x
class DeepseekV2MoE(nn.Module):
def __init__(
......@@ -121,7 +178,9 @@ class DeepseekV2MoE(nn.Module):
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
trt_aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream] = {},
enable_eplb: bool = False,
enable_tpsp: bool = False,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
......@@ -137,6 +196,7 @@ class DeepseekV2MoE(nn.Module):
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now.")
self.enable_tpsp = enable_tpsp
self.gate = ReplicatedLinear(config.hidden_size,
config.n_routed_experts,
bias=False,
......@@ -182,9 +242,27 @@ class DeepseekV2MoE(nn.Module):
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor)
self.aux_stream = trt_aux_stream_dict[AuxStreamType.MoeShared]
self.event_dict = {
key: torch.cuda.Event()
for key in [EventType.Main, EventType.MoeShared, EventType.MoeAllgather, EventType.MoeReduceScatter]
}
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
if self.enable_tpsp:
self.shared_experts = SharedExpertOverlapSPMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
prefix=f"{prefix}.shared_experts",
event_dict=self.event_dict,
aux_stream=self.aux_stream
)
else:
self.shared_experts = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
......@@ -198,12 +276,59 @@ class DeepseekV2MoE(nn.Module):
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce
def tpsp_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
old_hidden_states = hidden_states
router_logits, _ = self.gate(hidden_states)
self.event_dict[EventType.MoeAllgather].record()
hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
router_logits = tensor_model_parallel_all_gather(router_logits.contiguous(), dim=0)
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
self.event_dict[EventType.MoeReduceScatter].record()
final_hidden_states = tensor_model_parallel_reduce_scatter(
final_hidden_states.contiguous(), dim=0)
shared_output = None
if self.n_shared_experts is not None:
with torch.cuda.stream(self.aux_stream):
shared_output = self.shared_experts(old_hidden_states)
self.event_dict[EventType.MoeShared].wait()
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
return final_hidden_states
def forward(self, hidden_states: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None
) -> torch.Tensor:
if self.enable_tpsp:
return self.tpsp_forward(hidden_states)
is_graph_capturing = True
do_multi_stream = is_graph_capturing
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if do_multi_stream:
self.event_dict[EventType.Main].record()
if self.n_shared_experts is not None:
if envs.USE_FUSED_RMS_QUANT:
......@@ -211,6 +336,24 @@ class DeepseekV2MoE(nn.Module):
else:
shared_output = self.shared_experts(hidden_states)
if do_multi_stream:
with torch.cuda.stream(self.aux_stream):
self.event_dict[EventType.Main].wait()
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
self.event_dict[EventType.MoeShared].record()
self.event_dict[EventType.MoeShared].wait()
else:
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
......@@ -424,12 +567,15 @@ class DeepseekV2MLAAttention(nn.Module):
v_head_dim: int,
q_lora_rank: Optional[int],
kv_lora_rank: int,
layer_idx: int,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
trt_aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream] = {},
enable_tpsp: bool = False,
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -449,6 +595,8 @@ class DeepseekV2MLAAttention(nn.Module):
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.layer_idx = layer_idx
self.enable_tpsp = enable_tpsp
if self.q_lora_rank is not None:
if envs.USE_FUSED_RMS_QUANT:
......@@ -489,12 +637,21 @@ class DeepseekV2MLAAttention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.q_proj")
if not envs.VLLM_ENABLE_MLA_QKV_MERGE:
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa")
else:
self.q_a_and_kv_a_proj = MergedReplicatedLinear(
self.hidden_size,
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_a_and_kv_a_proj"
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
......@@ -507,7 +664,8 @@ class DeepseekV2MLAAttention(nn.Module):
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
prefix=f"{prefix}.o_proj",
sp_parallel=self.enable_tpsp)
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
......@@ -550,6 +708,11 @@ class DeepseekV2MLAAttention(nn.Module):
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
self.aux_stream = trt_aux_stream_dict[AuxStreamType.Attention]
self.event_dict = {
key: torch.cuda.Event()
for key in [EventType.QCAllgather, EventType.KVFinish]
}
def forward(
self,
......@@ -588,6 +751,8 @@ class DeepseekV2MLAAttention(nn.Module):
self.num_local_heads * self.v_head_dim))
return self.o_proj(attn_out)[0], new_residual
else:
if not self.enable_tpsp:
if not envs.VLLM_ENABLE_MLA_QKV_MERGE:
if self.q_lora_rank is not None:
q_c = self.q_a_proj(hidden_states)[0]
q_c = self.q_a_layernorm(q_c)
......@@ -615,7 +780,69 @@ class DeepseekV2MLAAttention(nn.Module):
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim))
return self.o_proj(attn_out)[0]
else:
if self.q_lora_rank is not None:
qkv_lora = self.q_a_and_kv_a_proj(hidden_states)[0]
q_c, kv_lora = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0]
else:
hidden_states_or_q_c = hidden_states
kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
kv_c, k_pe = kv_lora.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c)
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim:], k_pe)
attn_out = self.mla_attn(
q,
kv_c_normed,
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim))
return self.o_proj(attn_out)[0]
q_c = self.q_a_proj(hidden_states)[0]
self.event_dict[EventType.QCAllgather].record()
q_c = self.q_a_layernorm(q_c)
if self.layer_idx > 0:
q_c = tensor_model_parallel_all_gather(q_c.contiguous(), dim=0)
with torch.cuda.stream(self.aux_stream):
self.event_dict[EventType.QCAllgather].wait()
kv_a_out = self.kv_a_proj_with_mqa(hidden_states)[0]
if self.layer_idx > 0:
kv_a_out = tensor_model_parallel_all_gather(kv_a_out.contiguous(), dim=0)
kv_c, k_pe = kv_a_out.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
self.event_dict[EventType.KVFinish].record()
q = self.q_b_proj(q_c)[0]
self.event_dict[EventType.KVFinish].wait()
attn_out = self.mla_attn(
q,
kv_c_normed,
k_pe,
output_shape=(kv_a_out.shape[0],
self.num_local_heads * self.v_head_dim))
return self.o_proj(attn_out)[0]
class DeepseekV2DecoderLayer(nn.Module):
......@@ -627,6 +854,8 @@ class DeepseekV2DecoderLayer(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
enable_eplb: bool = False,
trt_aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream] = {},
mtp_layer: bool = False,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -638,6 +867,9 @@ class DeepseekV2DecoderLayer(nn.Module):
# with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1])
self.layer_idx = layer_idx
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.enable_tpsp = envs.VLLM_ENABLE_MLA_SP and self.tp_size > 1 and not mtp_layer
if model_config.use_mla:
attn_cls = DeepseekV2MLAAttention
else:
......@@ -658,6 +890,8 @@ class DeepseekV2DecoderLayer(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
trt_aux_stream_dict=trt_aux_stream_dict,
enable_tpsp=self.enable_tpsp,
)
if (config.n_routed_experts is not None
......@@ -668,6 +902,8 @@ class DeepseekV2DecoderLayer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
trt_aux_stream_dict=trt_aux_stream_dict,
enable_tpsp=self.enable_tpsp
)
else:
self.mlp = DeepseekV2MLP(
......@@ -676,6 +912,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
enable_tpsp=self.enable_tpsp,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
......@@ -758,6 +995,11 @@ class DeepseekV2DecoderLayer(nn.Module):
# first layer.
residual *= 1. / self.routed_scaling_factor
# split residual into sp piece
if self.layer_idx == 0 and self.enable_tpsp:
residual_per_rank = torch.chunk(residual, chunks=self.tp_size, dim=0)
residual = residual_per_rank[self.tp_rank]
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
......@@ -774,7 +1016,6 @@ class DeepseekV2DecoderLayer(nn.Module):
return hidden_states, residual
@support_torch_compile
class DeepseekV2Model(nn.Module):
......@@ -789,9 +1030,20 @@ class DeepseekV2Model(nn.Module):
quant_config = vllm_config.quant_config
enable_eplb = vllm_config.parallel_config.enable_eplb
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.vocab_size = config.vocab_size
self.aux_stream_dict = {
key: torch.cuda.Stream()
for key in [
AuxStreamType.Attention,
AuxStreamType.MoeShared,
AuxStreamType.MoeChunkingOverlap
]
}
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
......@@ -810,9 +1062,12 @@ class DeepseekV2Model(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
enable_eplb=enable_eplb,
trt_aux_stream_dict=self.aux_stream_dict,
),
prefix=f"{prefix}.layers")
self.enable_tpsp = envs.VLLM_ENABLE_MLA_SP and self.tp_size > 1
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
......@@ -845,6 +1100,27 @@ class DeepseekV2Model(nn.Module):
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(positions, hidden_states, residual)
# padding tpsq bs to tp_size
tpsp_bs_pad = False
bs = input_ids.shape[0]
bs_per_rank = (bs + self.tp_size - 1) // self.tp_size
pad_bs = bs_per_rank * self.tp_size if bs % self.tp_size != 0 else bs
if self.enable_tpsp and pad_bs != bs:
tpsp_bs_pad = True
additional_hidden_state = torch.zeros(pad_bs - bs, hidden_states.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
pad_hidden_state = torch.cat([hidden_states, additional_hidden_state], dim=0).contiguous()
hidden_states = pad_hidden_state
if residual:
additional_residual = torch.zeros(pad_bs - bs, residual.shape[1],
dtype=residual.dtype,
device=residual.device)
pad_residual = torch.cat([residual, additional_residual], dim=0).contiguous()
residual = pad_residual
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
......@@ -990,10 +1266,19 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
if not envs.VLLM_ENABLE_MLA_QKV_MERGE:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
else:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
("q_a_and_kv_a_proj", "q_a_proj", 0),
("q_a_and_kv_a_proj", "kv_a_proj_with_mqa", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment