"vscode:/vscode.git/clone" did not exist on "5d7e3d0176e0dbcf144c64b7d14d996c55e36c50"
Commit 6cabbf16 authored by 王敏's avatar 王敏
Browse files

[feat]支持deepep ETP,dp4 tp4 ep16相比dp32 tp1 ep32提升明显

parent ba1999c2
...@@ -90,6 +90,8 @@ class EPSharedExperts(nn.Module): ...@@ -90,6 +90,8 @@ class EPSharedExperts(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj", prefix=f"{prefix}.gate_up_proj",
expect_tp_size=1) expect_tp_size=1)
print("#########self.gate_up_proj quant_method:", self.gate_up_proj.quant_method)
self.down_proj = RowParallelLinear(intermediate_size, self.down_proj = RowParallelLinear(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
...@@ -97,6 +99,19 @@ class EPSharedExperts(nn.Module): ...@@ -97,6 +99,19 @@ class EPSharedExperts(nn.Module):
reduce_results=reduce_results, reduce_results=reduce_results,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
expect_tp_size=1) expect_tp_size=1)
print("#########self.down_proj quant_method:", self.down_proj.quant_method)
# self.gate_up_proj = MergedColumnParallelLinear(
# hidden_size, [intermediate_size] * 2,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.gate_up_proj",
# expect_tp_size=1)
# self.down_proj = ReplicatedLinear(intermediate_size,
# hidden_size,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.down_proj")
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
......
...@@ -783,6 +783,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -783,6 +783,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
return return
if is_gguf_weight: if is_gguf_weight:
print("############is_gguf_weight")
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
...@@ -978,6 +979,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -978,6 +979,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if self.expect_tp_size is not None and self.expect_tp_size == 1:
tp_size = 1
if hasattr(param, "expect_tp_size"):
param.expect_tp_size = self.expect_tp_size
if isinstance(param, BlockQuantScaleParameter): if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import ( from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod) Fp8LinearMethod, Fp8MoEMethod)
...@@ -1519,6 +1525,8 @@ class RowParallelLinear(LinearBase): ...@@ -1519,6 +1525,8 @@ class RowParallelLinear(LinearBase):
assert loaded_weight.numel() == 1 assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
if self.expect_tp_size is not None and hasattr(param, "expect_tp_size"):
param.expect_tp_size = self.expect_tp_size
param.load_row_parallel_weight(loaded_weight=loaded_weight) param.load_row_parallel_weight(loaded_weight=loaded_weight)
def forward( def forward(
......
...@@ -83,6 +83,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -83,6 +83,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
self.dp_size = get_dp_group().world_size self.dp_size = get_dp_group().world_size
self.ep_size = get_ep_group().world_size
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \ self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \ (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
...@@ -241,7 +242,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -241,7 +242,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
#expected_m = max_num_tokens #expected_m = max_num_tokens
ori_bs = x.shape[0] ori_bs = x.shape[0]
expected_m = ori_bs * self.dp_size expected_m = ori_bs * self.ep_size
# expected_m = ( # expected_m = (
# x.shape[0] * self.dp_size * topk_ids.shape[1] # x.shape[0] * self.dp_size * topk_ids.shape[1]
# + global_num_experts # + global_num_experts
......
...@@ -40,7 +40,11 @@ from vllm.compilation.decorators import support_torch_compile ...@@ -40,7 +40,11 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, ModelConfig, VllmConfig, from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config) get_current_vllm_config)
from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group, from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size,
tensor_model_parallel_reduce_scatter,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
get_tensor_model_parallel_rank)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE
...@@ -209,8 +213,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -209,8 +213,7 @@ class DeepseekV2MoE(nn.Module):
if config.n_shared_experts is not None: if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size * intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts) config.n_shared_experts)
shared_expert_cls = DeepseekV2MLP if not self.use_mori_ep else EPSharedExperts self.shared_experts = EPSharedExperts(
self.shared_experts = shared_expert_cls(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
...@@ -710,6 +713,33 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -710,6 +713,33 @@ class DeepseekV2DecoderLayer(nn.Module):
# with the layer's index. # with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1]) layer_idx = int(prefix.split(sep='.')[-1])
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.dp_size = get_dp_group().world_size
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
self.tp_size = get_tensor_model_parallel_world_size()
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekV2MoE(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
)
else:
self.mlp = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
if model_config.use_mla: if model_config.use_mla:
attn_cls = DeepseekV2MLAAttention attn_cls = DeepseekV2MLAAttention
else: else:
...@@ -732,23 +762,6 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -732,23 +762,6 @@ class DeepseekV2DecoderLayer(nn.Module):
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
) )
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekV2MoE(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
)
else:
self.mlp = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, self.post_attention_layernorm = RMSNorm(config.hidden_size,
...@@ -833,8 +846,25 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -833,8 +846,25 @@ class DeepseekV2DecoderLayer(nn.Module):
# Fully Connected # Fully Connected
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual) hidden_states, residual)
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
self.tp_rank = get_tensor_model_parallel_rank()
ori_bs = hidden_states.shape[0]
pad_size = (ori_bs + self.tp_size - 1) // self.tp_size * self.tp_size - ori_bs
if pad_size > 0:
hidden_states = torch.nn.functional.pad(hidden_states.contiguous(), [0, 0, 0, pad_size], value=0).contiguous()
new_bs = (ori_bs+pad_size) // self.tp_size
hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :]
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0).contiguous()
hidden_states = hidden_states[:ori_bs, :].contiguous()
if isinstance(self.mlp, if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16: DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow # Fix FP16 overflow
......
...@@ -96,6 +96,8 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -96,6 +96,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
def __init__(self, output_dim: int, **kwargs): def __init__(self, output_dim: int, **kwargs):
self._output_dim = output_dim self._output_dim = output_dim
super().__init__(**kwargs) super().__init__(**kwargs)
self.expect_tp_size = -1
@property @property
def output_dim(self): def output_dim(self):
...@@ -103,6 +105,8 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -103,6 +105,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
def load_column_parallel_weight(self, loaded_weight: torch.Tensor): def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1:
tp_rank = 0
shard_size = self.data.shape[self.output_dim] shard_size = self.data.shape[self.output_dim]
loaded_weight = loaded_weight.narrow(self.output_dim, loaded_weight = loaded_weight.narrow(self.output_dim,
tp_rank * shard_size, shard_size) tp_rank * shard_size, shard_size)
...@@ -123,6 +127,8 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -123,6 +127,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
param_data = self.data param_data = self.data
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1:
tp_rank = 0
param_data = param_data.narrow(self.output_dim, shard_offset, param_data = param_data.narrow(self.output_dim, shard_offset,
shard_size) shard_size)
loaded_weight = loaded_weight.narrow(self.output_dim, loaded_weight = loaded_weight.narrow(self.output_dim,
...@@ -167,6 +173,7 @@ class RowvLLMParameter(BasevLLMParameter): ...@@ -167,6 +173,7 @@ class RowvLLMParameter(BasevLLMParameter):
def __init__(self, input_dim: int, **kwargs): def __init__(self, input_dim: int, **kwargs):
self._input_dim = input_dim self._input_dim = input_dim
super().__init__(**kwargs) super().__init__(**kwargs)
self.expect_tp_size = -1
@property @property
def input_dim(self): def input_dim(self):
...@@ -174,6 +181,8 @@ class RowvLLMParameter(BasevLLMParameter): ...@@ -174,6 +181,8 @@ class RowvLLMParameter(BasevLLMParameter):
def load_row_parallel_weight(self, loaded_weight: torch.Tensor): def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
if self.expect_tp_size == 1:
tp_rank = 0
shard_size = self.data.shape[self.input_dim] shard_size = self.data.shape[self.input_dim]
loaded_weight = loaded_weight.narrow(self.input_dim, loaded_weight = loaded_weight.narrow(self.input_dim,
tp_rank * shard_size, shard_size) tp_rank * shard_size, shard_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment