Commit d997afc4 authored by 王敏's avatar 王敏
Browse files

1.优化大EP,合入grouped gemm

2.解决mtp >1 大EP推理all gather卡住问题
parent 6f5d76dc
...@@ -164,6 +164,7 @@ if TYPE_CHECKING: ...@@ -164,6 +164,7 @@ if TYPE_CHECKING:
VLLM_USE_FLASH_ATTN_PA: bool = False VLLM_USE_FLASH_ATTN_PA: bool = False
VLLM_USE_APEX_RN: bool = False VLLM_USE_APEX_RN: bool = False
VLLM_USE_GLOBAL_CACHE13: bool = False VLLM_USE_GLOBAL_CACHE13: bool = False
VLLM_USE_ALLTOALL_EP: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1089,7 +1090,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1089,7 +1090,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_GLOBAL_CACHE13": "VLLM_USE_GLOBAL_CACHE13":
lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use all_to_all ep mode
"VLLM_USE_ALLTOALL_EP":
lambda: (os.environ.get("VLLM_USE_ALLTOALL_EP", "True").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -135,8 +135,10 @@ def set_forward_context( ...@@ -135,8 +135,10 @@ def set_forward_context(
if need_to_track_batchsize: if need_to_track_batchsize:
forward_start_time = time.perf_counter() forward_start_time = time.perf_counter()
dp_metadata: Optional[DPMetadata] = None dp_metadata: Optional[DPMetadata] = None
if vllm_config.parallel_config.data_parallel_size > 1 and ( dp_size = vllm_config.parallel_config.data_parallel_size
attn_metadata is not None or num_tokens is not None): use_all2all_ep = envs.VLLM_USE_ALLTOALL_EP and dp_size > 1 and vllm_config.parallel_config.enable_expert_parallel
if not use_all2all_ep and dp_size > 1 and (
attn_metadata is not None or num_tokens is not None) :
dp_metadata = DPMetadata.make(vllm_config.parallel_config, dp_metadata = DPMetadata.make(vllm_config.parallel_config,
attn_metadata, num_tokens or 0, attn_metadata, num_tokens or 0,
num_tokens_across_dp) num_tokens_across_dp)
......
...@@ -327,7 +327,8 @@ def all_to_all(group, input, output_split_sizes, input_split_sizes): ...@@ -327,7 +327,8 @@ def all_to_all(group, input, output_split_sizes, input_split_sizes):
output = input.new_empty( output = input.new_empty(
size=[sum(output_split_sizes)] + list(input.size()[1:]), size=[sum(output_split_sizes)] + list(input.size()[1:]),
dtype=input.dtype, dtype=input.dtype,
device=torch.cuda.current_device(), #device=torch.cuda.current_device(),
device=input.device,
) )
torch.distributed.all_to_all_single( torch.distributed.all_to_all_single(
...@@ -336,6 +337,7 @@ def all_to_all(group, input, output_split_sizes, input_split_sizes): ...@@ -336,6 +337,7 @@ def all_to_all(group, input, output_split_sizes, input_split_sizes):
output_split_sizes=output_split_sizes, output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes, input_split_sizes=input_split_sizes,
group=group, group=group,
async_op=True
) )
return output return output
...@@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, Unqua ...@@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, Unqua
from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAlltoAllTokenDispatcher from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EpMoeConfig from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EpMoeConfig
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from lightop import groupgemm
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -33,6 +34,7 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -33,6 +34,7 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self.moe = moe self.moe = moe
self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled() self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
self.zero_token_count = None
def apply( def apply(
self, self,
...@@ -55,8 +57,8 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -55,8 +57,8 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
# process MoE # process MoE
def custom_forward(layer, hidden_states, tokens_per_expert): def custom_forward(layer, hidden_states, tokens_per_expert):
if False:
tokens_per_expert = tokens_per_expert.cpu().numpy() tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = [] outputs = []
start_idx = 0 start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert): for i, num_tokens in enumerate(tokens_per_expert):
...@@ -67,7 +69,7 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -67,7 +69,7 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
w2 = layer.w2_weight[i] w2 = layer.w2_weight[i]
tokens_for_this_expert = hidden_states[start_idx:end_idx] tokens_for_this_expert = hidden_states[start_idx:end_idx]
gateup_output = torch.matmul(tokens_for_this_expert, w1.T) gateup_output = torch.matmul(tokens_for_this_expert, w1)
# Act # Act
down_input = torch.zeros( down_input = torch.zeros(
...@@ -77,9 +79,9 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -77,9 +79,9 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
dtype=hidden_states.dtype dtype=hidden_states.dtype
) )
torch.ops._C.silu_and_mul(down_input, torch.ops._C.silu_and_mul(down_input,
gateup_output.view(-1, w1.shape[0])) gateup_output.view(-1, w1.shape[1]))
expert_out = torch.matmul(down_input, w2.T) expert_out = torch.matmul(down_input, w2)
outputs.append(expert_out) outputs.append(expert_out)
start_idx = end_idx start_idx = end_idx
...@@ -89,6 +91,25 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -89,6 +91,25 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
else: else:
assert hidden_states.numel() == 0, f"sorted_tokens: should be empty, but got {hidden_states.shape}" assert hidden_states.numel() == 0, f"sorted_tokens: should be empty, but got {hidden_states.shape}"
expert_output = hidden_states expert_output = hidden_states
else:
if self.zero_token_count is None:
self.zero_token_count = torch.zeros(1, dtype=torch.int64, device=hidden_states.device)
total_tokens = tokens_per_expert.sum()
if total_tokens > self.zero_token_count:
gateup_output = groupgemm(hidden_states, layer.w13_weight, tokens_per_expert, False)
# Act
down_input = torch.zeros(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=hidden_states.dtype
)
torch.ops._C.silu_and_mul(down_input,
gateup_output.view(-1, layer.w13_weight.shape[2]))
expert_output = groupgemm(down_input, layer.w2_weight, tokens_per_expert, False)
else :
expert_output = hidden_states
return expert_output return expert_output
...@@ -157,6 +178,8 @@ class EPMoE(FusedMoE): ...@@ -157,6 +178,8 @@ class EPMoE(FusedMoE):
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
enable_eplb: bool = False,
num_redundant_experts: int = 0,
moe_permute_fusion: bool = True, moe_permute_fusion: bool = True,
moe_shared_expert_overlap: bool = False moe_shared_expert_overlap: bool = False
): ):
...@@ -170,7 +193,9 @@ class EPMoE(FusedMoE): ...@@ -170,7 +193,9 @@ class EPMoE(FusedMoE):
e_score_correction_bias, e_score_correction_bias,
apply_router_weight_on_input, apply_router_weight_on_input,
activation, activation,
routed_scaling_factor=routed_scaling_factor routed_scaling_factor=routed_scaling_factor,
enable_eplb=enable_eplb,
num_redundant_experts=num_redundant_experts,
) )
self.ep_moe_config: EpMoeConfig = EpMoeConfig.make( self.ep_moe_config: EpMoeConfig = EpMoeConfig.make(
......
...@@ -24,6 +24,7 @@ from vllm.platforms import current_platform ...@@ -24,6 +24,7 @@ from vllm.platforms import current_platform
cuda_dtoh_stream = torch.cuda.Stream() cuda_dtoh_stream = torch.cuda.Stream()
cuda_dtoh_sync_event = torch.cuda.Event(enable_timing=False)
class MoETokenDispatcher: class MoETokenDispatcher:
""" """
...@@ -137,7 +138,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -137,7 +138,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.output_splits = None self.output_splits = None
# [tp_size]. Represents the number of tokens received by the current rank from # [tp_size]. Represents the number of tokens received by the current rank from
# other TP ranks. # other TP ranks.
self.output_splits_tp = None #self.output_splits_tp = None
self.permute_idx_device = torch.device("cuda") if self.config.moe_permute_fusion else None self.permute_idx_device = torch.device("cuda") if self.config.moe_permute_fusion else None
input_chunk_idxs = torch.arange( input_chunk_idxs = torch.arange(
self.num_experts * self.tp_size, device=self.permute_idx_device self.num_experts * self.tp_size, device=self.permute_idx_device
...@@ -211,6 +212,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -211,6 +212,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
if self.use_all_gather: if self.use_all_gather:
# Gather is not supported for some devices such as TPUs. # Gather is not supported for some devices such as TPUs.
# Use all-gather instead. # Use all-gather instead.
num_global_tokens_per_expert = expert_parallel_all_gather(num_local_tokens_per_expert) \ num_global_tokens_per_expert = expert_parallel_all_gather(num_local_tokens_per_expert) \
.reshape(self.ep_size, self.tp_size, self.num_experts) \ .reshape(self.ep_size, self.tp_size, self.num_experts) \
.transpose(0, 1) .transpose(0, 1)
...@@ -233,7 +235,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -233,7 +235,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
# [tp_size, ep_size] -> [tp_size] # [tp_size, ep_size] -> [tp_size]
# self.output_splits_tp represents the number of tokens received by the current # self.output_splits_tp represents the number of tokens received by the current
# rank from other TP rank. # rank from other TP rank.
self.output_splits_tp = num_global_tokens_per_rank.sum(axis=1) #self.output_splits_tp = num_global_tokens_per_rank.sum(axis=1)
# [tp_size, ep_size, num_local_experts] -> [num_local_experts] # [tp_size, ep_size, num_local_experts] -> [num_local_experts]
num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(dim=(0, 1)) num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(dim=(0, 1))
...@@ -319,9 +321,14 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -319,9 +321,14 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
) )
# Perform expert parallel AlltoAll communication # Perform expert parallel AlltoAll communication
tokens_per_expert = self._maybe_dtoh_and_synchronize( # tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_ep_alltoall", tokens_per_expert # "before_ep_alltoall", tokens_per_expert
) # )
###test##############
#cuda_dtoh_stream.synchronize()
#cuda_dtoh_sync_event.synchronize()
###test##############
global_input_tokens = all_to_all( global_input_tokens = all_to_all(
self.ep_group.device_group, permutated_local_input_tokens, self.output_splits, self.input_splits self.ep_group.device_group, permutated_local_input_tokens, self.output_splits, self.input_splits
...@@ -331,9 +338,9 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -331,9 +338,9 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.shared_experts.linear_fc1_forward_and_act(global_input_tokens) self.shared_experts.linear_fc1_forward_and_act(global_input_tokens)
# Permutation 2: Sort tokens by local expert. # Permutation 2: Sort tokens by local expert.
tokens_per_expert = self._maybe_dtoh_and_synchronize( # tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_permutation_2", tokens_per_expert # "before_permutation_2", tokens_per_expert
) # )
if self.num_local_experts > 1: if self.num_local_experts > 1:
global_input_tokens = sort_chunks_by_idxs( global_input_tokens = sort_chunks_by_idxs(
global_input_tokens, global_input_tokens,
...@@ -342,7 +349,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -342,7 +349,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
fused=self.config.moe_permute_fusion, fused=self.config.moe_permute_fusion,
) )
tokens_per_expert = self._maybe_dtoh_and_synchronize("before_finish", tokens_per_expert) #tokens_per_expert = self._maybe_dtoh_and_synchronize("before_finish", tokens_per_expert)
return global_input_tokens, tokens_per_expert return global_input_tokens, tokens_per_expert
...@@ -444,9 +451,9 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -444,9 +451,9 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.output_splits = maybe_move_tensor_to_cpu( self.output_splits = maybe_move_tensor_to_cpu(
self.output_splits, as_numpy=True, record_stream=on_side_stream self.output_splits, as_numpy=True, record_stream=on_side_stream
) )
self.output_splits_tp = maybe_move_tensor_to_cpu( # self.output_splits_tp = maybe_move_tensor_to_cpu(
self.output_splits_tp, as_numpy=True, record_stream=on_side_stream # self.output_splits_tp, as_numpy=True, record_stream=on_side_stream
) # )
self.num_out_tokens = maybe_move_tensor_to_cpu( self.num_out_tokens = maybe_move_tensor_to_cpu(
self.num_out_tokens, record_stream=on_side_stream self.num_out_tokens, record_stream=on_side_stream
) )
...@@ -455,6 +462,8 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -455,6 +462,8 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.num_global_tokens_per_local_expert, record_stream=on_side_stream self.num_global_tokens_per_local_expert, record_stream=on_side_stream
) )
#cuda_dtoh_sync_event.record()
if point == self.cuda_sync_point: if point == self.cuda_sync_point:
# Synchronize with the dtoh stream at self.cuda_sync_point. # Synchronize with the dtoh stream at self.cuda_sync_point.
cuda_dtoh_stream.synchronize() cuda_dtoh_stream.synchronize()
......
...@@ -772,12 +772,16 @@ class FusedMoE(torch.nn.Module): ...@@ -772,12 +772,16 @@ class FusedMoE(torch.nn.Module):
self.moe_config = moe self.moe_config = moe
self.quant_config = quant_config self.quant_config = quant_config
self.quant_method = self.create_quant_method(moe, quant_config, prefix) quant_method = self.create_quant_method(moe, quant_config, prefix)
assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method
if self.enable_eplb: if self.enable_eplb:
from vllm.model_executor.layers.quantization.fp8 import ( from vllm.model_executor.layers.quantization.fp8 import (
Fp8MoEMethod) Fp8MoEMethod)
if not isinstance(self.quant_method, Fp8MoEMethod): if not isinstance(quant_method, Fp8MoEMethod):
# TODO: Add support for additional quantization methods. # TODO: Add support for additional quantization methods.
# The implementation for other quantization methods does not # The implementation for other quantization methods does not
# contain essential differences, but the current quant API # contain essential differences, but the current quant API
......
...@@ -11,6 +11,7 @@ import torch ...@@ -11,6 +11,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -24,6 +25,7 @@ from vllm.sequence import IntermediateTensors ...@@ -24,6 +25,7 @@ from vllm.sequence import IntermediateTensors
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from .deepseek_v2 import (DeepseekV2DecoderLayer, from .deepseek_v2 import (DeepseekV2DecoderLayer,
get_spec_layer_idx_from_weight_name) get_spec_layer_idx_from_weight_name)
from vllm.distributed import get_dp_group
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import maybe_prefix from .utils import maybe_prefix
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -174,6 +176,10 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -174,6 +176,10 @@ class DeepSeekMTP(nn.Module, SupportsPP):
prefix, "model")) prefix, "model"))
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
parallel_config = vllm_config.parallel_config
dp_size = get_dp_group().world_size
self.use_all2all_ep = envs.VLLM_USE_ALLTOALL_EP and dp_size > 1 and parallel_config.enable_expert_parallel
def forward( def forward(
self, self,
...@@ -205,6 +211,10 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -205,6 +211,10 @@ class DeepSeekMTP(nn.Module, SupportsPP):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
if self.use_all2all_ep:
ep_moe_shared_experts_keys = "mlp.shared_experts"
ep_moe_shared_experts_mapping = {ep_moe_shared_experts_keys:"mlp.experts.shared_experts"}
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
...@@ -233,6 +243,9 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -233,6 +243,9 @@ class DeepSeekMTP(nn.Module, SupportsPP):
if (("mlp.experts." in name) and name not in params_dict): if (("mlp.experts." in name) and name not in params_dict):
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
if self.use_all2all_ep:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
...@@ -248,6 +261,9 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -248,6 +261,9 @@ class DeepSeekMTP(nn.Module, SupportsPP):
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
if self.use_all2all_ep:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, weight_loader(param,
...@@ -257,6 +273,8 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -257,6 +273,8 @@ class DeepSeekMTP(nn.Module, SupportsPP):
expert_id=expert_id) expert_id=expert_id)
break break
else: else:
if self.use_all2all_ep:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
......
...@@ -155,9 +155,9 @@ class DeepseekV2MoE(nn.Module): ...@@ -155,9 +155,9 @@ class DeepseekV2MoE(nn.Module):
self.n_local_physical_experts) self.n_local_physical_experts)
dp_size = get_dp_group().world_size dp_size = get_dp_group().world_size
self.use_ep_opt = dp_size > 1 and parallel_config.enable_expert_parallel self.use_all2all_ep = envs.VLLM_USE_ALLTOALL_EP and dp_size > 1 and parallel_config.enable_expert_parallel
moe_cls = FusedMoE if not self.use_ep_opt else EPMoE moe_cls = FusedMoE if not self.use_all2all_ep else EPMoE
self.experts = moe_cls( self.experts = moe_cls(
num_experts=config.n_routed_experts, num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
...@@ -172,12 +172,14 @@ class DeepseekV2MoE(nn.Module): ...@@ -172,12 +172,14 @@ class DeepseekV2MoE(nn.Module):
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
scoring_func=config.scoring_func, scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias, e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor) routed_scaling_factor=self.routed_scaling_factor)
if config.n_shared_experts is not None: if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size * intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts) config.n_shared_experts)
shared_expert_cls = DeepseekV2MLP if not self.use_ep_opt else EPSharedExperts shared_expert_cls = DeepseekV2MLP if not self.use_all2all_ep else EPSharedExperts
self.shared_experts = shared_expert_cls( self.shared_experts = shared_expert_cls(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
...@@ -187,7 +189,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -187,7 +189,7 @@ class DeepseekV2MoE(nn.Module):
), ),
prefix=f"{prefix}.shared_experts", prefix=f"{prefix}.shared_experts",
) )
if self.use_all2all_ep:
self.experts.set_shared_experts(self.shared_experts) self.experts.set_shared_experts(self.shared_experts)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
...@@ -196,13 +198,13 @@ class DeepseekV2MoE(nn.Module): ...@@ -196,13 +198,13 @@ class DeepseekV2MoE(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
if not self.use_ep_opt: if not self.use_all2all_ep:
if self.n_shared_experts is not None: if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
if not self.use_ep_opt: if not self.use_all2all_ep:
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
...@@ -216,7 +218,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -216,7 +218,7 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
if not self.use_ep_opt: if not self.use_all2all_ep:
if shared_output is not None: if shared_output is not None:
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick: if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
...@@ -637,8 +639,6 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -637,8 +639,6 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
) )
#ops.print_tensor(hidden_states)
if hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick: if hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick:
# Fix FP16 overflow # Fix FP16 overflow
# We scale both hidden_states and residual before # We scale both hidden_states and residual before
...@@ -808,7 +808,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -808,7 +808,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
dp_size = get_dp_group().world_size dp_size = get_dp_group().world_size
self.use_ep_opt = dp_size > 1 and parallel_config.enable_expert_parallel self.use_all2all_ep = envs.VLLM_USE_ALLTOALL_EP and dp_size > 1 and parallel_config.enable_expert_parallel
def set_eplb_state( def set_eplb_state(
self, self,
...@@ -891,7 +891,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -891,7 +891,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
if self.use_ep_opt: if self.use_all2all_ep:
ep_moe_shared_experts_keys = "mlp.shared_experts" ep_moe_shared_experts_keys = "mlp.shared_experts"
ep_moe_shared_experts_mapping = {ep_moe_shared_experts_keys:"mlp.experts.shared_experts"} ep_moe_shared_experts_mapping = {ep_moe_shared_experts_keys:"mlp.experts.shared_experts"}
...@@ -928,7 +928,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -928,7 +928,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
if self.use_ep_opt: if self.use_all2all_ep:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys]) name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
...@@ -957,7 +957,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -957,7 +957,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# Instead, create a new variable # Instead, create a new variable
name_mapped = name.replace(weight_name, param_name) name_mapped = name.replace(weight_name, param_name)
if self.use_ep_opt: if self.use_all2all_ep:
name_mapped = name_mapped.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys]) name_mapped = name_mapped.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
if is_pp_missing_parameter(name_mapped, self): if is_pp_missing_parameter(name_mapped, self):
...@@ -985,7 +985,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -985,7 +985,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# So we simply skip it # So we simply skip it
continue continue
if self.use_ep_opt: if self.use_all2all_ep:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys]) name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
......
...@@ -87,6 +87,9 @@ class EagleProposer: ...@@ -87,6 +87,9 @@ class EagleProposer:
device=device, device=device,
dtype=torch.int32) dtype=torch.int32)
self.dp_size = self.vllm_config.parallel_config.data_parallel_size
self.enable_expert_parallel = vllm_config.parallel_config.enable_expert_parallel
def propose( def propose(
self, self,
# [num_tokens] # [num_tokens]
...@@ -510,6 +513,14 @@ class EagleProposer: ...@@ -510,6 +513,14 @@ class EagleProposer:
self.hidden_states[:num_tokens], self.hidden_states[:num_tokens],
) )
if self.dp_size > 1 and self.enable_expert_parallel and self.num_speculative_tokens > 1:
for _ in range(self.num_speculative_tokens - 1):
self.model(
self.input_ids[:num_tokens],
self.positions[:num_tokens],
self.hidden_states[:num_tokens],
)
def validate_same_kv_cache_group(self, def validate_same_kv_cache_group(self,
kv_cache_config: KVCacheConfig) -> None: kv_cache_config: KVCacheConfig) -> None:
""" """
......
...@@ -319,6 +319,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -319,6 +319,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# from the KV cache of `shared_kv_cache_layers[layer_name]`. # from the KV cache of `shared_kv_cache_layers[layer_name]`.
self.shared_kv_cache_layers: dict[str, str] = {} self.shared_kv_cache_layers: dict[str, str] = {}
dp_size = self.vllm_config.parallel_config.data_parallel_size
self.use_all2all_ep = envs.VLLM_USE_ALLTOALL_EP and dp_size > 1 and parallel_config.enable_expert_parallel
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
""" """
Update the order of requests in the batch based on the attention Update the order of requests in the batch based on the attention
...@@ -1229,7 +1232,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1229,7 +1232,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(tms) : There are many cases where padding is enabled for # TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations. # prefills, causing unnecessary and excessive padding of activations.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager: if dp_size == 1 or self.vllm_config.model_config.enforce_eager or self.use_all2all_ep:
# Early exit. # Early exit.
return 0, None return 0, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment