Commit 3f5c2eea authored by zhuwenwen's avatar zhuwenwen
Browse files

add mla tpsp and moe share experts computation communication overlap

parent 8375370f
...@@ -178,6 +178,8 @@ if TYPE_CHECKING: ...@@ -178,6 +178,8 @@ if TYPE_CHECKING:
VLLM_P2P_BUF_TOKENS: int = 30000 VLLM_P2P_BUF_TOKENS: int = 30000
VLLM_SCHED_ENABLE_MINIMAL_INJECTION: bool = False VLLM_SCHED_ENABLE_MINIMAL_INJECTION: bool = False
VLLM_USE_PD_SPLIT: bool = False VLLM_USE_PD_SPLIT: bool = False
VLLM_ENABLE_MLA_SP: bool = False
VLLM_ENABLE_MLA_QKV_MERGE: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1094,68 +1096,89 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1094,68 +1096,89 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASH_ATTN_PA": "VLLM_USE_FLASH_ATTN_PA":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_PA", "True").lower() in lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_PA", "True").lower() in
("true", "1")), ("true", "1")),
# vLLM will use apex for rmsnorm # vLLM will use apex for rmsnorm
"VLLM_USE_APEX_RN": "VLLM_USE_APEX_RN":
lambda: (os.environ.get("VLLM_USE_APEX_RN", "False").lower() in lambda: (os.environ.get("VLLM_USE_APEX_RN", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use global cache for moe # vLLM will use global cache for moe
"VLLM_USE_GLOBAL_CACHE13": "VLLM_USE_GLOBAL_CACHE13":
lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use lightop for deepseek-v3 # vLLM will use lightop for deepseek-v3
"VLLM_USE_LIGHTOP": "VLLM_USE_LIGHTOP":
lambda: (os.environ.get("VLLM_USE_LIGHTOP", "False").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use elenmentwise not triton_ # vLLM will use elenmentwise not triton_
"VLLM_USE_OPT_ZEROS": "VLLM_USE_OPT_ZEROS":
lambda: (os.environ.get("VLLM_USE_OPT_ZEROS", "False").lower() in lambda: (os.environ.get("VLLM_USE_OPT_ZEROS", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use opt cat for deepseek-v3 # vLLM will use opt cat for deepseek-v3
"VLLM_USE_OPT_CAT": "VLLM_USE_OPT_CAT":
lambda: (os.environ.get("VLLM_USE_OPT_CAT", "False").lower() in lambda: (os.environ.get("VLLM_USE_OPT_CAT", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use triton moe_sum # vLLM will use triton moe_sum
"VLLM_USE_OPT_MOE_SUM": "VLLM_USE_OPT_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_OPT_MOE_SUM", "False").lower() in lambda: (os.environ.get("VLLM_USE_OPT_MOE_SUM", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use lightop moe_sum_mul_add # vLLM will use lightop moe_sum_mul_add
"VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD": "VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD", "False").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use lightop moe_sum # vLLM will use lightop moe_sum
"VLLM_USE_LIGHTOP_MOE_SUM": "VLLM_USE_LIGHTOP_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM", "True").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM", "True").lower() in
("true", "1")), ("true", "1")),
# vLLM will use lightop moe_align_block_size # vLLM will use lightop moe_align_block_size
"VLLM_USE_LIGHTOP_MOE_ALIGN": "VLLM_USE_LIGHTOP_MOE_ALIGN":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_ALIGN", "True").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_ALIGN", "True").lower() in
("true", "1")), ("true", "1")),
# vLLM will use opt merge_aatn_states, not triton # vLLM will use opt merge_aatn_states, not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT": "VLLM_USE_MERGE_ATTN_STATES_OPT":
lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in
("true", "1")), ("true", "1")),
# vllm will use rmsquant fused op # vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT": "USE_FUSED_RMS_QUANT":
lambda: (os.getenv('USE_FUSED_RMS_QUANT', '0').lower() in lambda: (os.getenv('USE_FUSED_RMS_QUANT', '0').lower() in
("true", "1")), ("true", "1")),
# vllm will use silu_mul_quant fused op # vllm will use silu_mul_quant fused op
"USE_FUSED_SILU_MUL_QUANT": "USE_FUSED_SILU_MUL_QUANT":
lambda: (os.getenv('USE_FUSED_SILU_MUL_QUANT', '0').lower() in lambda: (os.getenv('USE_FUSED_SILU_MUL_QUANT', '0').lower() in
("true", "1")), ("true", "1")),
# vllm pd separation will be used async # vllm pd separation will be used async
"VLLM_P2P_ASYNC": "VLLM_P2P_ASYNC":
lambda: bool(int(os.getenv("VLLM_P2P_ASYNC", "0"))), lambda: bool(int(os.getenv("VLLM_P2P_ASYNC", "0"))),
# pd separation p2p async buf tokens # pd separation p2p async buf tokens
"VLLM_P2P_BUF_TOKENS": "VLLM_P2P_BUF_TOKENS":
lambda: int(os.getenv("VLLM_P2P_BUF_TOKENS", "30000")), lambda: int(os.getenv("VLLM_P2P_BUF_TOKENS", "30000")),
# vllm will enable minimal injection for pipeline parallel scheduling # vllm will enable minimal injection for pipeline parallel scheduling
"VLLM_SCHED_ENABLE_MINIMAL_INJECTION": "VLLM_SCHED_ENABLE_MINIMAL_INJECTION":
lambda: (os.getenv("VLLM_SCHED_ENABLE_MINIMAL_INJECTION", "0").lower() in lambda: (os.getenv("VLLM_SCHED_ENABLE_MINIMAL_INJECTION", "0").lower() in
("true", "1")), ("true", "1")),
# vLLM will split prefill and decode, not mix up # vLLM will split prefill and decode, not mix up
"VLLM_USE_PD_SPLIT": "VLLM_USE_PD_SPLIT":
lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "True").lower() in lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "True").lower() in
("true", "1")), ("true", "1")),
"VLLM_ENABLE_MLA_SP":
lambda: bool(int(os.getenv("VLLM_ENABLE_MLA_SP", "0"))),
"VLLM_ENABLE_MLA_QKV_MERGE":
lambda: bool(int(os.getenv("VLLM_ENABLE_MLA_QKV_MERGE", "0"))),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -637,6 +637,13 @@ def determine_expert_map( ...@@ -637,6 +637,13 @@ def determine_expert_map(
return (local_num_experts, expert_map) return (local_num_experts, expert_map)
EventType = Enum(
'EventType',
['Main', 'Attention', 'QCAllgather', 'KVFinish', 'MoeShared', 'MoeChunkingOverlap', 'MoeAllgather', 'MoeReduceScatter'],
start=0,
)
class FusedMoE(torch.nn.Module): class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models. """FusedMoE layer for MoE models.
......
...@@ -14,7 +14,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, ...@@ -14,7 +14,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
split_tensor_along_last_dim, split_tensor_along_last_dim,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
...@@ -454,6 +455,86 @@ class ReplicatedLinear(LinearBase): ...@@ -454,6 +455,86 @@ class ReplicatedLinear(LinearBase):
return s return s
class MergedReplicatedLinear(ReplicatedLinear):
"""Merged replicated linear layer
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def __init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
):
self.output_sizes = output_sizes
super().__init__(input_size,
sum(output_sizes),
bias,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix,
return_bias=return_bias)
def weight_loader(self,
param: Union[Parameter, BasevLLMParameter],
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
assert loaded_shard_id is not None
assert loaded_shard_id < len(self.output_sizes)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
assert self.quant_method is not None
assert isinstance(self.quant_method, (Fp8LinearMethod, Fp8MoEMethod))
weight_block_size = self.quant_method.quant_config.weight_block_size
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
)
shard_size = (
(self.output_sizes[loaded_shard_id] + block_n - 1)
// block_n
)
elif isinstance(param, PerTensorScaleParameter) and current_platform.is_rocm():
shard_offset = loaded_shard_id
shard_size = 1
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id]
start_offset = shard_offset
end_offset = start_offset + shard_size
assert loaded_weight.shape == param.data[start_offset:end_offset, ...].shape, (
f"Expected shape {param.data[start_offset:end_offset, ...].shape}, got {loaded_weight.shape}"
)
if envs.VLLM_USE_NN and not is_quantization:
loaded_weight = loaded_weight.t()
param.data[start_offset:end_offset, ...].copy_(loaded_weight)
class ColumnParallelLinear(LinearBase): class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism. """Linear layer with column parallelism.
...@@ -1390,6 +1471,7 @@ class RowParallelLinear(LinearBase): ...@@ -1390,6 +1471,7 @@ class RowParallelLinear(LinearBase):
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
sp_parallel: bool = False,
): ):
# Divide the weight matrix along the first dimension. # Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
...@@ -1397,6 +1479,7 @@ class RowParallelLinear(LinearBase): ...@@ -1397,6 +1479,7 @@ class RowParallelLinear(LinearBase):
self.input_size_per_partition = divide(input_size, self.tp_size) self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size] self.output_partition_sizes = [output_size]
self.sp_parallel = sp_parallel
super().__init__(input_size, super().__init__(input_size,
output_size, output_size,
...@@ -1526,7 +1609,10 @@ class RowParallelLinear(LinearBase): ...@@ -1526,7 +1609,10 @@ class RowParallelLinear(LinearBase):
if envs.VLLM_ENABLE_TBO: if envs.VLLM_ENABLE_TBO:
output = self.tbo_all_reduce(output_parallel) output = self.tbo_all_reduce(output_parallel)
else: else:
output = tensor_model_parallel_all_reduce(output_parallel) if self.sp_parallel:
output = tensor_model_parallel_reduce_scatter(output_parallel.contiguous(), dim=0)
else:
output = tensor_model_parallel_all_reduce(output_parallel)
else: else:
output = output_parallel output = output_parallel
......
...@@ -29,10 +29,11 @@ import vllm.envs as envs ...@@ -29,10 +29,11 @@ import vllm.envs as envs
import typing import typing
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from typing import Any, Optional, Union from typing import Any, Optional, Union, Dict
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention from vllm.attention import Attention
...@@ -40,12 +41,17 @@ from vllm.compilation.decorators import support_torch_compile ...@@ -40,12 +41,17 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, ModelConfig, VllmConfig, from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config) get_current_vllm_config)
from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group, from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_reduce_scatter,
get_tensor_model_parallel_rank)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.layer import EventType, AuxStreamType
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
MergedReplicatedLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
...@@ -64,6 +70,9 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter, ...@@ -64,6 +70,9 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter,
maybe_prefix) maybe_prefix)
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
from vllm.logger import init_logger
logger = init_logger(__name__)
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
...@@ -75,6 +84,7 @@ class DeepseekV2MLP(nn.Module): ...@@ -75,6 +84,7 @@ class DeepseekV2MLP(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True, reduce_results: bool = True,
prefix: str = "", prefix: str = "",
enable_tpsp: bool = False
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
...@@ -82,12 +92,14 @@ class DeepseekV2MLP(nn.Module): ...@@ -82,12 +92,14 @@ class DeepseekV2MLP(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj") prefix=f"{prefix}.gate_up_proj")
self.enable_tpsp = enable_tpsp
self.down_proj = RowParallelLinear(intermediate_size, self.down_proj = RowParallelLinear(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results, reduce_results=reduce_results,
prefix=f"{prefix}.down_proj") prefix=f"{prefix}.down_proj",
sp_parallel=self.enable_tpsp)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -108,12 +120,57 @@ class DeepseekV2MLP(nn.Module): ...@@ -108,12 +120,57 @@ class DeepseekV2MLP(nn.Module):
return x, new_resi return x, new_resi
else: else:
if self.enable_tpsp:
x = tensor_model_parallel_all_gather(x.contiguous(), dim=0)
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj(x) x, _ = self.down_proj(x)
return x return x
class SharedExpertOverlapSPMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
event_dict: dict = None,
aux_stream: torch.cuda.Stream = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedReplicatedLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = ReplicatedLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
self.event_dict = event_dict
self.aux_stream = aux_stream
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
self.event_dict[EventType.MoeAllgather].wait()
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
self.event_dict[EventType.MoeReduceScatter].wait()
x, _ = self.down_proj(x)
self.event_dict[EventType.MoeShared].record()
return x
class DeepseekV2MoE(nn.Module): class DeepseekV2MoE(nn.Module):
def __init__( def __init__(
...@@ -121,7 +178,9 @@ class DeepseekV2MoE(nn.Module): ...@@ -121,7 +178,9 @@ class DeepseekV2MoE(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
trt_aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream] = {},
enable_eplb: bool = False, enable_eplb: bool = False,
enable_tpsp: bool = False,
): ):
super().__init__() super().__init__()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
...@@ -137,6 +196,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -137,6 +196,7 @@ class DeepseekV2MoE(nn.Module):
raise ValueError(f"Unsupported activation: {config.hidden_act}. " raise ValueError(f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
self.enable_tpsp = enable_tpsp
self.gate = ReplicatedLinear(config.hidden_size, self.gate = ReplicatedLinear(config.hidden_size,
config.n_routed_experts, config.n_routed_experts,
bias=False, bias=False,
...@@ -182,28 +242,93 @@ class DeepseekV2MoE(nn.Module): ...@@ -182,28 +242,93 @@ class DeepseekV2MoE(nn.Module):
num_redundant_experts=self.n_redundant_experts, num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor) routed_scaling_factor=self.routed_scaling_factor)
self.aux_stream = trt_aux_stream_dict[AuxStreamType.MoeShared]
self.event_dict = {
key: torch.cuda.Event()
for key in [EventType.Main, EventType.MoeShared, EventType.MoeAllgather, EventType.MoeReduceScatter]
}
if config.n_shared_experts is not None: if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size * intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts) config.n_shared_experts)
self.shared_experts = DeepseekV2MLP( if self.enable_tpsp:
hidden_size=config.hidden_size, self.shared_experts = SharedExpertOverlapSPMLP(
intermediate_size=intermediate_size, hidden_size=config.hidden_size,
hidden_act=config.hidden_act, intermediate_size=intermediate_size,
quant_config=quant_config, hidden_act=config.hidden_act,
reduce_results=self.experts.must_reduce_shared_expert_outputs( quant_config=quant_config,
), reduce_results=False,
prefix=f"{prefix}.shared_experts", prefix=f"{prefix}.shared_experts",
) event_dict=self.event_dict,
aux_stream=self.aux_stream
)
else:
self.shared_experts = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=self.experts.must_reduce_shared_expert_outputs(
),
prefix=f"{prefix}.shared_experts",
)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce self.tbo_all_reduce = tbo_all_reduce
def tpsp_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
old_hidden_states = hidden_states
router_logits, _ = self.gate(hidden_states)
self.event_dict[EventType.MoeAllgather].record()
hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
router_logits = tensor_model_parallel_all_gather(router_logits.contiguous(), dim=0)
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
self.event_dict[EventType.MoeReduceScatter].record()
final_hidden_states = tensor_model_parallel_reduce_scatter(
final_hidden_states.contiguous(), dim=0)
shared_output = None
if self.n_shared_experts is not None:
with torch.cuda.stream(self.aux_stream):
shared_output = self.shared_experts(old_hidden_states)
self.event_dict[EventType.MoeShared].wait()
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
return final_hidden_states
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None, rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None residual: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
if self.enable_tpsp:
return self.tpsp_forward(hidden_states)
is_graph_capturing = True
do_multi_stream = is_graph_capturing
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
if do_multi_stream:
self.event_dict[EventType.Main].record()
if self.n_shared_experts is not None: if self.n_shared_experts is not None:
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
...@@ -211,33 +336,51 @@ class DeepseekV2MoE(nn.Module): ...@@ -211,33 +336,51 @@ class DeepseekV2MoE(nn.Module):
else: else:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts) if do_multi_stream:
router_logits, _ = self.gate(hidden_states) with torch.cuda.stream(self.aux_stream):
self.event_dict[EventType.Main].wait()
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD: # router_logits: (num_tokens, n_experts)
final_hidden_states = self.experts( router_logits, _ = self.gate(hidden_states)
hidden_states=hidden_states, if hidden_states.dtype != torch.float16:
router_logits=router_logits, final_hidden_states = self.experts(
shared_output=shared_output) hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
self.event_dict[EventType.MoeShared].record()
self.event_dict[EventType.MoeShared].wait()
else: else:
if hidden_states.dtype != torch.float16: # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor router_logits=router_logits,
shared_output=shared_output)
else: else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else: else:
# Fix FP16 overflow # Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \ final_hidden_states = self.experts(hidden_states=hidden_states,
* (1. / self.routed_scaling_factor) router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if self.tp_size > 1: if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO: if envs.VLLM_ENABLE_TBO:
...@@ -424,12 +567,15 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -424,12 +567,15 @@ class DeepseekV2MLAAttention(nn.Module):
v_head_dim: int, v_head_dim: int,
q_lora_rank: Optional[int], q_lora_rank: Optional[int],
kv_lora_rank: int, kv_lora_rank: int,
layer_idx: int,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None, rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
trt_aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream] = {},
enable_tpsp: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -449,6 +595,8 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -449,6 +595,8 @@ class DeepseekV2MLAAttention(nn.Module):
self.scaling = self.qk_head_dim**-0.5 self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.layer_idx = layer_idx
self.enable_tpsp = enable_tpsp
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
...@@ -489,12 +637,21 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -489,12 +637,21 @@ class DeepseekV2MLAAttention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.q_proj") prefix=f"{prefix}.q_proj")
self.kv_a_proj_with_mqa = ReplicatedLinear( if not envs.VLLM_ENABLE_MLA_QKV_MERGE:
self.hidden_size, self.kv_a_proj_with_mqa = ReplicatedLinear(
self.kv_lora_rank + self.qk_rope_head_dim, self.hidden_size,
bias=False, self.kv_lora_rank + self.qk_rope_head_dim,
quant_config=quant_config, bias=False,
prefix=f"{prefix}.kv_a_proj_with_mqa") quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa")
else:
self.q_a_and_kv_a_proj = MergedReplicatedLinear(
self.hidden_size,
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_a_and_kv_a_proj"
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear( self.kv_b_proj = ColumnParallelLinear(
...@@ -507,7 +664,8 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -507,7 +664,8 @@ class DeepseekV2MLAAttention(nn.Module):
self.hidden_size, self.hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj") prefix=f"{prefix}.o_proj",
sp_parallel=self.enable_tpsp)
if rope_scaling: if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn' rope_scaling["rope_type"] = 'deepseek_yarn'
...@@ -550,6 +708,11 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -550,6 +708,11 @@ class DeepseekV2MLAAttention(nn.Module):
self.prefix = prefix self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2]) self.debug_layer_idx = int(self.prefix.split(".")[-2])
self.aux_stream = trt_aux_stream_dict[AuxStreamType.Attention]
self.event_dict = {
key: torch.cuda.Event()
for key in [EventType.QCAllgather, EventType.KVFinish]
}
def forward( def forward(
self, self,
...@@ -588,34 +751,98 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -588,34 +751,98 @@ class DeepseekV2MLAAttention(nn.Module):
self.num_local_heads * self.v_head_dim)) self.num_local_heads * self.v_head_dim))
return self.o_proj(attn_out)[0], new_residual return self.o_proj(attn_out)[0], new_residual
else: else:
if self.q_lora_rank is not None: if not self.enable_tpsp:
q_c = self.q_a_proj(hidden_states)[0] if not envs.VLLM_ENABLE_MLA_QKV_MERGE:
q_c = self.q_a_layernorm(q_c) if self.q_lora_rank is not None:
q = self.q_b_proj(q_c)[0] q_c = self.q_a_proj(hidden_states)[0]
else: q_c = self.q_a_layernorm(q_c)
q = self.q_proj(hidden_states)[0] q = self.q_b_proj(q_c)[0]
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( else:
q = self.q_proj(hidden_states)[0]
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if envs.VLLM_USE_LIGHTOP:
kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c)
else:
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim:], k_pe)
attn_out = self.mla_attn(
q,
kv_c_normed,
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim))
return self.o_proj(attn_out)[0]
else:
if self.q_lora_rank is not None:
qkv_lora = self.q_a_and_kv_a_proj(hidden_states)[0]
q_c, kv_lora = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0]
else:
hidden_states_or_q_c = hidden_states
kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
kv_c, k_pe = kv_lora.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c)
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim:], k_pe)
attn_out = self.mla_attn(
q,
kv_c_normed,
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim))
return self.o_proj(attn_out)[0]
q_c = self.q_a_proj(hidden_states)[0]
self.event_dict[EventType.QCAllgather].record()
q_c = self.q_a_layernorm(q_c)
if self.layer_idx > 0:
q_c = tensor_model_parallel_all_gather(q_c.contiguous(), dim=0)
with torch.cuda.stream(self.aux_stream):
self.event_dict[EventType.QCAllgather].wait()
kv_a_out = self.kv_a_proj_with_mqa(hidden_states)[0]
if self.layer_idx > 0:
kv_a_out = tensor_model_parallel_all_gather(kv_a_out.contiguous(), dim=0)
kv_c, k_pe = kv_a_out.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if envs.VLLM_USE_LIGHTOP:
kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c)
else:
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
q = q.view(-1, self.num_local_heads, self.qk_head_dim) kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( self.event_dict[EventType.KVFinish].record()
positions, q[..., self.qk_nope_head_dim:], k_pe)
attn_out = self.mla_attn( q = self.q_b_proj(q_c)[0]
q,
kv_c_normed,
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim))
return self.o_proj(attn_out)[0]
self.event_dict[EventType.KVFinish].wait()
attn_out = self.mla_attn(
q,
kv_c_normed,
k_pe,
output_shape=(kv_a_out.shape[0],
self.num_local_heads * self.v_head_dim))
return self.o_proj(attn_out)[0]
class DeepseekV2DecoderLayer(nn.Module): class DeepseekV2DecoderLayer(nn.Module):
...@@ -627,6 +854,8 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -627,6 +854,8 @@ class DeepseekV2DecoderLayer(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
enable_eplb: bool = False, enable_eplb: bool = False,
trt_aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream] = {},
mtp_layer: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -638,6 +867,9 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -638,6 +867,9 @@ class DeepseekV2DecoderLayer(nn.Module):
# with the layer's index. # with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1]) layer_idx = int(prefix.split(sep='.')[-1])
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.enable_tpsp = envs.VLLM_ENABLE_MLA_SP and self.tp_size > 1 and not mtp_layer
if model_config.use_mla: if model_config.use_mla:
attn_cls = DeepseekV2MLAAttention attn_cls = DeepseekV2MLAAttention
else: else:
...@@ -658,6 +890,8 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -658,6 +890,8 @@ class DeepseekV2DecoderLayer(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
trt_aux_stream_dict=trt_aux_stream_dict,
enable_tpsp=self.enable_tpsp,
) )
if (config.n_routed_experts is not None if (config.n_routed_experts is not None
...@@ -668,6 +902,8 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -668,6 +902,8 @@ class DeepseekV2DecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb, enable_eplb=enable_eplb,
trt_aux_stream_dict=trt_aux_stream_dict,
enable_tpsp=self.enable_tpsp
) )
else: else:
self.mlp = DeepseekV2MLP( self.mlp = DeepseekV2MLP(
...@@ -676,6 +912,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -676,6 +912,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
enable_tpsp=self.enable_tpsp,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -758,6 +995,11 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -758,6 +995,11 @@ class DeepseekV2DecoderLayer(nn.Module):
# first layer. # first layer.
residual *= 1. / self.routed_scaling_factor residual *= 1. / self.routed_scaling_factor
# split residual into sp piece
if self.layer_idx == 0 and self.enable_tpsp:
residual_per_rank = torch.chunk(residual, chunks=self.tp_size, dim=0)
residual = residual_per_rank[self.tp_rank]
# Fully Connected # Fully Connected
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual) hidden_states, residual)
...@@ -774,7 +1016,6 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -774,7 +1016,6 @@ class DeepseekV2DecoderLayer(nn.Module):
return hidden_states, residual return hidden_states, residual
@support_torch_compile @support_torch_compile
class DeepseekV2Model(nn.Module): class DeepseekV2Model(nn.Module):
...@@ -789,8 +1030,19 @@ class DeepseekV2Model(nn.Module): ...@@ -789,8 +1030,19 @@ class DeepseekV2Model(nn.Module):
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
enable_eplb = vllm_config.parallel_config.enable_eplb enable_eplb = vllm_config.parallel_config.enable_eplb
self.config = config self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.aux_stream_dict = {
key: torch.cuda.Stream()
for key in [
AuxStreamType.Attention,
AuxStreamType.MoeShared,
AuxStreamType.MoeChunkingOverlap
]
}
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
...@@ -810,9 +1062,12 @@ class DeepseekV2Model(nn.Module): ...@@ -810,9 +1062,12 @@ class DeepseekV2Model(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
enable_eplb=enable_eplb, enable_eplb=enable_eplb,
trt_aux_stream_dict=self.aux_stream_dict,
), ),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.enable_tpsp = envs.VLLM_ENABLE_MLA_SP and self.tp_size > 1
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else: else:
...@@ -823,7 +1078,7 @@ class DeepseekV2Model(nn.Module): ...@@ -823,7 +1078,7 @@ class DeepseekV2Model(nn.Module):
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -845,6 +1100,27 @@ class DeepseekV2Model(nn.Module): ...@@ -845,6 +1100,27 @@ class DeepseekV2Model(nn.Module):
for layer in self.layers[self.start_layer:self.end_layer]: for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
# padding tpsq bs to tp_size
tpsp_bs_pad = False
bs = input_ids.shape[0]
bs_per_rank = (bs + self.tp_size - 1) // self.tp_size
pad_bs = bs_per_rank * self.tp_size if bs % self.tp_size != 0 else bs
if self.enable_tpsp and pad_bs != bs:
tpsp_bs_pad = True
additional_hidden_state = torch.zeros(pad_bs - bs, hidden_states.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
pad_hidden_state = torch.cat([hidden_states, additional_hidden_state], dim=0).contiguous()
hidden_states = pad_hidden_state
if residual:
additional_residual = torch.zeros(pad_bs - bs, residual.shape[1],
dtype=residual.dtype,
device=residual.device)
pad_residual = torch.cat([residual, additional_residual], dim=0).contiguous()
residual = pad_residual
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
"hidden_states": hidden_states, "hidden_states": hidden_states,
...@@ -990,11 +1266,20 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -990,11 +1266,20 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ if not envs.VLLM_ENABLE_MLA_QKV_MERGE:
# (param_name, shard_name, shard_id) stacked_params_mapping = [
("gate_up_proj", "gate_proj", 0), # (param_name, shard_name, shard_id)
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "gate_proj", 0),
] ("gate_up_proj", "up_proj", 1),
]
else:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
("q_a_and_kv_a_proj", "q_a_proj", 0),
("q_a_and_kv_a_proj", "kv_a_proj_with_mqa", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment