"tests/vscode:/vscode.git/clone" did not exist on "fb35feea6eef33d8287c52b8600313ffcb7139f1"
Commit d997afc4 authored by 王敏's avatar 王敏
Browse files

1.优化大EP,合入grouped gemm

2.解决mtp >1 大EP推理all gather卡住问题
parent 6f5d76dc
......@@ -164,6 +164,7 @@ if TYPE_CHECKING:
VLLM_USE_FLASH_ATTN_PA: bool = False
VLLM_USE_APEX_RN: bool = False
VLLM_USE_GLOBAL_CACHE13: bool = False
VLLM_USE_ALLTOALL_EP: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1089,7 +1090,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_GLOBAL_CACHE13":
lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in
("true", "1")),
# vLLM will use all_to_all ep mode
"VLLM_USE_ALLTOALL_EP":
lambda: (os.environ.get("VLLM_USE_ALLTOALL_EP", "True").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -135,8 +135,10 @@ def set_forward_context(
if need_to_track_batchsize:
forward_start_time = time.perf_counter()
dp_metadata: Optional[DPMetadata] = None
if vllm_config.parallel_config.data_parallel_size > 1 and (
attn_metadata is not None or num_tokens is not None):
dp_size = vllm_config.parallel_config.data_parallel_size
use_all2all_ep = envs.VLLM_USE_ALLTOALL_EP and dp_size > 1 and vllm_config.parallel_config.enable_expert_parallel
if not use_all2all_ep and dp_size > 1 and (
attn_metadata is not None or num_tokens is not None) :
dp_metadata = DPMetadata.make(vllm_config.parallel_config,
attn_metadata, num_tokens or 0,
num_tokens_across_dp)
......
......@@ -327,7 +327,8 @@ def all_to_all(group, input, output_split_sizes, input_split_sizes):
output = input.new_empty(
size=[sum(output_split_sizes)] + list(input.size()[1:]),
dtype=input.dtype,
device=torch.cuda.current_device(),
#device=torch.cuda.current_device(),
device=input.device,
)
torch.distributed.all_to_all_single(
......@@ -336,6 +337,7 @@ def all_to_all(group, input, output_split_sizes, input_split_sizes):
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
async_op=True
)
return output
......@@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, Unqua
from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EpMoeConfig
from vllm.utils import direct_register_custom_op
from lightop import groupgemm
logger = init_logger(__name__)
......@@ -33,6 +34,7 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self.moe = moe
self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
self.zero_token_count = None
def apply(
self,
......@@ -55,8 +57,8 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
# process MoE
def custom_forward(layer, hidden_states, tokens_per_expert):
if False:
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
......@@ -67,7 +69,7 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
w2 = layer.w2_weight[i]
tokens_for_this_expert = hidden_states[start_idx:end_idx]
gateup_output = torch.matmul(tokens_for_this_expert, w1.T)
gateup_output = torch.matmul(tokens_for_this_expert, w1)
# Act
down_input = torch.zeros(
......@@ -77,9 +79,9 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
dtype=hidden_states.dtype
)
torch.ops._C.silu_and_mul(down_input,
gateup_output.view(-1, w1.shape[0]))
gateup_output.view(-1, w1.shape[1]))
expert_out = torch.matmul(down_input, w2.T)
expert_out = torch.matmul(down_input, w2)
outputs.append(expert_out)
start_idx = end_idx
......@@ -89,6 +91,25 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
else:
assert hidden_states.numel() == 0, f"sorted_tokens: should be empty, but got {hidden_states.shape}"
expert_output = hidden_states
else:
if self.zero_token_count is None:
self.zero_token_count = torch.zeros(1, dtype=torch.int64, device=hidden_states.device)
total_tokens = tokens_per_expert.sum()
if total_tokens > self.zero_token_count:
gateup_output = groupgemm(hidden_states, layer.w13_weight, tokens_per_expert, False)
# Act
down_input = torch.zeros(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=hidden_states.dtype
)
torch.ops._C.silu_and_mul(down_input,
gateup_output.view(-1, layer.w13_weight.shape[2]))
expert_output = groupgemm(down_input, layer.w2_weight, tokens_per_expert, False)
else :
expert_output = hidden_states
return expert_output
......@@ -157,6 +178,8 @@ class EPMoE(FusedMoE):
apply_router_weight_on_input: bool = False,
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
enable_eplb: bool = False,
num_redundant_experts: int = 0,
moe_permute_fusion: bool = True,
moe_shared_expert_overlap: bool = False
):
......@@ -170,7 +193,9 @@ class EPMoE(FusedMoE):
e_score_correction_bias,
apply_router_weight_on_input,
activation,
routed_scaling_factor=routed_scaling_factor
routed_scaling_factor=routed_scaling_factor,
enable_eplb=enable_eplb,
num_redundant_experts=num_redundant_experts,
)
self.ep_moe_config: EpMoeConfig = EpMoeConfig.make(
......
......@@ -24,6 +24,7 @@ from vllm.platforms import current_platform
cuda_dtoh_stream = torch.cuda.Stream()
cuda_dtoh_sync_event = torch.cuda.Event(enable_timing=False)
class MoETokenDispatcher:
"""
......@@ -137,7 +138,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.output_splits = None
# [tp_size]. Represents the number of tokens received by the current rank from
# other TP ranks.
self.output_splits_tp = None
#self.output_splits_tp = None
self.permute_idx_device = torch.device("cuda") if self.config.moe_permute_fusion else None
input_chunk_idxs = torch.arange(
self.num_experts * self.tp_size, device=self.permute_idx_device
......@@ -211,6 +212,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
if self.use_all_gather:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
num_global_tokens_per_expert = expert_parallel_all_gather(num_local_tokens_per_expert) \
.reshape(self.ep_size, self.tp_size, self.num_experts) \
.transpose(0, 1)
......@@ -233,7 +235,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
# [tp_size, ep_size] -> [tp_size]
# self.output_splits_tp represents the number of tokens received by the current
# rank from other TP rank.
self.output_splits_tp = num_global_tokens_per_rank.sum(axis=1)
#self.output_splits_tp = num_global_tokens_per_rank.sum(axis=1)
# [tp_size, ep_size, num_local_experts] -> [num_local_experts]
num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(dim=(0, 1))
......@@ -319,9 +321,14 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
)
# Perform expert parallel AlltoAll communication
tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_ep_alltoall", tokens_per_expert
)
# tokens_per_expert = self._maybe_dtoh_and_synchronize(
# "before_ep_alltoall", tokens_per_expert
# )
###test##############
#cuda_dtoh_stream.synchronize()
#cuda_dtoh_sync_event.synchronize()
###test##############
global_input_tokens = all_to_all(
self.ep_group.device_group, permutated_local_input_tokens, self.output_splits, self.input_splits
......@@ -331,9 +338,9 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.shared_experts.linear_fc1_forward_and_act(global_input_tokens)
# Permutation 2: Sort tokens by local expert.
tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_permutation_2", tokens_per_expert
)
# tokens_per_expert = self._maybe_dtoh_and_synchronize(
# "before_permutation_2", tokens_per_expert
# )
if self.num_local_experts > 1:
global_input_tokens = sort_chunks_by_idxs(
global_input_tokens,
......@@ -342,7 +349,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
fused=self.config.moe_permute_fusion,
)
tokens_per_expert = self._maybe_dtoh_and_synchronize("before_finish", tokens_per_expert)
#tokens_per_expert = self._maybe_dtoh_and_synchronize("before_finish", tokens_per_expert)
return global_input_tokens, tokens_per_expert
......@@ -444,9 +451,9 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.output_splits = maybe_move_tensor_to_cpu(
self.output_splits, as_numpy=True, record_stream=on_side_stream
)
self.output_splits_tp = maybe_move_tensor_to_cpu(
self.output_splits_tp, as_numpy=True, record_stream=on_side_stream
)
# self.output_splits_tp = maybe_move_tensor_to_cpu(
# self.output_splits_tp, as_numpy=True, record_stream=on_side_stream
# )
self.num_out_tokens = maybe_move_tensor_to_cpu(
self.num_out_tokens, record_stream=on_side_stream
)
......@@ -455,6 +462,8 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.num_global_tokens_per_local_expert, record_stream=on_side_stream
)
#cuda_dtoh_sync_event.record()
if point == self.cuda_sync_point:
# Synchronize with the dtoh stream at self.cuda_sync_point.
cuda_dtoh_stream.synchronize()
......
......@@ -772,12 +772,16 @@ class FusedMoE(torch.nn.Module):
self.moe_config = moe
self.quant_config = quant_config
self.quant_method = self.create_quant_method(moe, quant_config, prefix)
quant_method = self.create_quant_method(moe, quant_config, prefix)
assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method
if self.enable_eplb:
from vllm.model_executor.layers.quantization.fp8 import (
Fp8MoEMethod)
if not isinstance(self.quant_method, Fp8MoEMethod):
if not isinstance(quant_method, Fp8MoEMethod):
# TODO: Add support for additional quantization methods.
# The implementation for other quantization methods does not
# contain essential differences, but the current quant API
......
......@@ -11,6 +11,7 @@ import torch
import torch.nn as nn
from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
......@@ -24,6 +25,7 @@ from vllm.sequence import IntermediateTensors
from vllm.compilation.decorators import support_torch_compile
from .deepseek_v2 import (DeepseekV2DecoderLayer,
get_spec_layer_idx_from_weight_name)
from vllm.distributed import get_dp_group
from .interfaces import SupportsPP
from .utils import maybe_prefix
from vllm import _custom_ops as ops
......@@ -174,6 +176,10 @@ class DeepSeekMTP(nn.Module, SupportsPP):
prefix, "model"))
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
parallel_config = vllm_config.parallel_config
dp_size = get_dp_group().world_size
self.use_all2all_ep = envs.VLLM_USE_ALLTOALL_EP and dp_size > 1 and parallel_config.enable_expert_parallel
def forward(
self,
......@@ -205,6 +211,10 @@ class DeepSeekMTP(nn.Module, SupportsPP):
("gate_up_proj", "up_proj", 1),
]
if self.use_all2all_ep:
ep_moe_shared_experts_keys = "mlp.shared_experts"
ep_moe_shared_experts_mapping = {ep_moe_shared_experts_keys:"mlp.experts.shared_experts"}
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
......@@ -233,6 +243,9 @@ class DeepSeekMTP(nn.Module, SupportsPP):
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
if self.use_all2all_ep:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
......@@ -248,6 +261,9 @@ class DeepSeekMTP(nn.Module, SupportsPP):
continue
name = name.replace(weight_name, param_name)
if self.use_all2all_ep:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
......@@ -257,6 +273,8 @@ class DeepSeekMTP(nn.Module, SupportsPP):
expert_id=expert_id)
break
else:
if self.use_all2all_ep:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
......
......@@ -155,9 +155,9 @@ class DeepseekV2MoE(nn.Module):
self.n_local_physical_experts)
dp_size = get_dp_group().world_size
self.use_ep_opt = dp_size > 1 and parallel_config.enable_expert_parallel
self.use_all2all_ep = envs.VLLM_USE_ALLTOALL_EP and dp_size > 1 and parallel_config.enable_expert_parallel
moe_cls = FusedMoE if not self.use_ep_opt else EPMoE
moe_cls = FusedMoE if not self.use_all2all_ep else EPMoE
self.experts = moe_cls(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
......@@ -172,12 +172,14 @@ class DeepseekV2MoE(nn.Module):
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor)
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
shared_expert_cls = DeepseekV2MLP if not self.use_ep_opt else EPSharedExperts
shared_expert_cls = DeepseekV2MLP if not self.use_all2all_ep else EPSharedExperts
self.shared_experts = shared_expert_cls(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
......@@ -187,7 +189,7 @@ class DeepseekV2MoE(nn.Module):
),
prefix=f"{prefix}.shared_experts",
)
if self.use_all2all_ep:
self.experts.set_shared_experts(self.shared_experts)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
......@@ -196,13 +198,13 @@ class DeepseekV2MoE(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if not self.use_ep_opt:
if not self.use_all2all_ep:
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states)
if not self.use_ep_opt:
if not self.use_all2all_ep:
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states,
......@@ -216,7 +218,7 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if not self.use_ep_opt:
if not self.use_all2all_ep:
if shared_output is not None:
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
final_hidden_states = final_hidden_states + shared_output
......@@ -637,8 +639,6 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states=hidden_states,
)
#ops.print_tensor(hidden_states)
if hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick:
# Fix FP16 overflow
# We scale both hidden_states and residual before
......@@ -808,7 +808,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
parallel_config = vllm_config.parallel_config
dp_size = get_dp_group().world_size
self.use_ep_opt = dp_size > 1 and parallel_config.enable_expert_parallel
self.use_all2all_ep = envs.VLLM_USE_ALLTOALL_EP and dp_size > 1 and parallel_config.enable_expert_parallel
def set_eplb_state(
self,
......@@ -891,7 +891,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
("gate_up_proj", "up_proj", 1),
]
if self.use_ep_opt:
if self.use_all2all_ep:
ep_moe_shared_experts_keys = "mlp.shared_experts"
ep_moe_shared_experts_mapping = {ep_moe_shared_experts_keys:"mlp.experts.shared_experts"}
......@@ -928,7 +928,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
continue
name = name.replace(weight_name, param_name)
if self.use_ep_opt:
if self.use_all2all_ep:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
# Skip loading extra bias for GPTQ models.
......@@ -957,7 +957,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# Instead, create a new variable
name_mapped = name.replace(weight_name, param_name)
if self.use_ep_opt:
if self.use_all2all_ep:
name_mapped = name_mapped.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
if is_pp_missing_parameter(name_mapped, self):
......@@ -985,7 +985,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# So we simply skip it
continue
if self.use_ep_opt:
if self.use_all2all_ep:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
......
......@@ -87,6 +87,9 @@ class EagleProposer:
device=device,
dtype=torch.int32)
self.dp_size = self.vllm_config.parallel_config.data_parallel_size
self.enable_expert_parallel = vllm_config.parallel_config.enable_expert_parallel
def propose(
self,
# [num_tokens]
......@@ -510,6 +513,14 @@ class EagleProposer:
self.hidden_states[:num_tokens],
)
if self.dp_size > 1 and self.enable_expert_parallel and self.num_speculative_tokens > 1:
for _ in range(self.num_speculative_tokens - 1):
self.model(
self.input_ids[:num_tokens],
self.positions[:num_tokens],
self.hidden_states[:num_tokens],
)
def validate_same_kv_cache_group(self,
kv_cache_config: KVCacheConfig) -> None:
"""
......
......@@ -319,6 +319,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
self.shared_kv_cache_layers: dict[str, str] = {}
dp_size = self.vllm_config.parallel_config.data_parallel_size
self.use_all2all_ep = envs.VLLM_USE_ALLTOALL_EP and dp_size > 1 and parallel_config.enable_expert_parallel
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
"""
Update the order of requests in the batch based on the attention
......@@ -1229,7 +1232,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
if dp_size == 1 or self.vllm_config.model_config.enforce_eager or self.use_all2all_ep:
# Early exit.
return 0, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment