Commit 8ae59a9c authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-ds-wm-1208' into 'v0.9.2-dev-ds'

[feat]支持deepep ETP,dp4 tp4 ep16相比dp32 tp1 ep32提升明显

See merge request dcutoolkit/deeplearing/vllm!289
parents ba1999c2 1a315a58
......@@ -90,6 +90,7 @@ class EPSharedExperts(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
expect_tp_size=1)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
......
......@@ -978,6 +978,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
tp_size = get_tensor_model_parallel_world_size()
if self.expect_tp_size is not None and self.expect_tp_size == 1:
tp_size = 1
if hasattr(param, "expect_tp_size"):
param.expect_tp_size = self.expect_tp_size
if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
......@@ -1519,6 +1524,8 @@ class RowParallelLinear(LinearBase):
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
if self.expect_tp_size is not None and hasattr(param, "expect_tp_size"):
param.expect_tp_size = self.expect_tp_size
param.load_row_parallel_weight(loaded_weight=loaded_weight)
def forward(
......
......@@ -83,6 +83,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.dp_size = get_dp_group().world_size
self.ep_size = get_ep_group().world_size
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
......@@ -241,7 +242,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
#expected_m = max_num_tokens
ori_bs = x.shape[0]
expected_m = ori_bs * self.dp_size
expected_m = ori_bs * self.ep_size
# expected_m = (
# x.shape[0] * self.dp_size * topk_ids.shape[1]
# + global_num_experts
......
......@@ -40,7 +40,11 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group,
get_tensor_model_parallel_world_size)
get_tensor_model_parallel_world_size,
tensor_model_parallel_reduce_scatter,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
get_tensor_model_parallel_rank)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE
......@@ -209,8 +213,7 @@ class DeepseekV2MoE(nn.Module):
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
shared_expert_cls = DeepseekV2MLP if not self.use_mori_ep else EPSharedExperts
self.shared_experts = shared_expert_cls(
self.shared_experts = EPSharedExperts(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
......@@ -710,6 +713,33 @@ class DeepseekV2DecoderLayer(nn.Module):
# with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1])
self.layer_idx = layer_idx
self.dp_size = get_dp_group().world_size
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
self.tp_size = get_tensor_model_parallel_world_size()
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekV2MoE(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
)
else:
self.mlp = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
if model_config.use_mla:
attn_cls = DeepseekV2MLAAttention
else:
......@@ -732,23 +762,6 @@ class DeepseekV2DecoderLayer(nn.Module):
prefix=f"{prefix}.self_attn",
)
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekV2MoE(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
)
else:
self.mlp = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
......@@ -833,8 +846,25 @@ class DeepseekV2DecoderLayer(nn.Module):
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
self.tp_rank = get_tensor_model_parallel_rank()
ori_bs = hidden_states.shape[0]
pad_size = (ori_bs + self.tp_size - 1) // self.tp_size * self.tp_size - ori_bs
if pad_size > 0:
hidden_states = torch.nn.functional.pad(hidden_states.contiguous(), [0, 0, 0, pad_size], value=0).contiguous()
new_bs = (ori_bs+pad_size) // self.tp_size
hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :]
hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0).contiguous()
hidden_states = hidden_states[:ori_bs, :].contiguous()
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow
......
......@@ -96,6 +96,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
def __init__(self, output_dim: int, **kwargs):
self._output_dim = output_dim
super().__init__(**kwargs)
self.expect_tp_size = -1
@property
def output_dim(self):
......@@ -103,6 +105,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1:
tp_rank = 0
shard_size = self.data.shape[self.output_dim]
loaded_weight = loaded_weight.narrow(self.output_dim,
tp_rank * shard_size, shard_size)
......@@ -123,6 +127,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
param_data = self.data
tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1:
tp_rank = 0
param_data = param_data.narrow(self.output_dim, shard_offset,
shard_size)
loaded_weight = loaded_weight.narrow(self.output_dim,
......@@ -167,6 +173,7 @@ class RowvLLMParameter(BasevLLMParameter):
def __init__(self, input_dim: int, **kwargs):
self._input_dim = input_dim
super().__init__(**kwargs)
self.expect_tp_size = -1
@property
def input_dim(self):
......@@ -174,6 +181,8 @@ class RowvLLMParameter(BasevLLMParameter):
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1:
tp_rank = 0
shard_size = self.data.shape[self.input_dim]
loaded_weight = loaded_weight.narrow(self.input_dim,
tp_rank * shard_size, shard_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment