Commit 1a56d6cb authored by 王敏's avatar 王敏
Browse files

添加mori ep

parent 3cb11400
......@@ -2004,10 +2004,10 @@ class ParallelConfig:
logger.info("Disabling V1 multiprocessing for external launcher.")
if self.enable_eplb:
if not current_platform.is_cuda():
raise ValueError(
"Expert parallelism load balancing is only supported on "
"CUDA devices now.")
# if not current_platform.is_cuda():
# raise ValueError(
# "Expert parallelism load balancing is only supported on "
# "CUDA devices now.")
if self.num_redundant_experts < 0:
raise ValueError(
"num_redundant_experts must be non-negative, but got "
......
......@@ -2,6 +2,7 @@ import os
import logging
from typing import Callable, List, Optional, Tuple
from dataclasses import dataclass
from collections.abc import Iterable
import torch
import torch.nn.functional as F
......@@ -11,7 +12,7 @@ from vllm.platforms import current_platform
from vllm.config import get_current_vllm_config
from vllm.model_executor.custom_op import CustomOp
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.distributed.parallel_state import get_ep_group, get_node_count
from vllm.distributed.parallel_state import get_ep_group, get_node_count, is_use_cuda_graph
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
......@@ -20,7 +21,6 @@ from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, Unqua
from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EpMoeConfig
from vllm.utils import direct_register_custom_op
from lightop import groupgemm
import mori
import torch.distributed as dist
......@@ -45,7 +45,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace=True,
......@@ -59,7 +58,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
return self.forward(
hidden_states=hidden_states,
layer=layer,
tokens_per_expert=tokens_per_expert,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
......@@ -73,7 +71,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace=True,
......@@ -85,80 +82,24 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
) -> torch.Tensor:
# process MoE
def custom_forward(layer, hidden_states, tokens_per_expert):
if False:
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
w1 = layer.w13_weight[i]
w2 = layer.w2_weight[i]
tokens_for_this_expert = hidden_states[start_idx:end_idx]
gateup_output = torch.matmul(tokens_for_this_expert, w1)
# Act
down_input = torch.zeros(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=hidden_states.dtype
)
torch.ops._C.silu_and_mul(down_input,
gateup_output.view(-1, w1.shape[1]))
expert_out = torch.matmul(down_input, w2)
outputs.append(expert_out)
start_idx = end_idx
if len(outputs) > 0:
expert_output = torch.cat(outputs, dim=0)
else:
assert hidden_states.numel() == 0, f"sorted_tokens: should be empty, but got {hidden_states.shape}"
expert_output = hidden_states
else:
if topk_ids is None:
if self.zero_token_count is None:
self.zero_token_count = torch.zeros(1, dtype=torch.int64, device=hidden_states.device)
total_tokens = tokens_per_expert.sum()
print("#################total_tokens:", total_tokens.tolist())
if total_tokens > self.zero_token_count:
gateup_output = groupgemm(hidden_states, layer.w13_weight, tokens_per_expert, False)
# Act
down_input = torch.zeros(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=hidden_states.dtype
)
torch.ops._C.silu_and_mul(down_input,
gateup_output.view(-1, layer.w13_weight.shape[2]))
expert_output = groupgemm(down_input, layer.w2_weight, tokens_per_expert, False)
else :
expert_output = hidden_states
else:
expert_output = self.fused_experts(
hidden_states=hidden_states,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=use_nn_moe
)
def custom_forward(layer, hidden_states):
expert_output = self.fused_experts(
hidden_states=hidden_states,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=use_nn_moe
)
return expert_output
output = custom_forward(layer, hidden_states, tokens_per_expert)
output = custom_forward(layer, hidden_states)
return output
......@@ -166,7 +107,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace=True,
......@@ -183,7 +123,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace=True,
......@@ -199,7 +138,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace=True,
......@@ -249,7 +187,7 @@ class EPMoE(FusedMoE):
routed_scaling_factor: Optional[float] = None,
enable_eplb: bool = False,
num_redundant_experts: int = 0,
moe_permute_fusion: bool = True,
moe_permute_fusion: bool = False,
moe_shared_expert_overlap: bool = False
):
super().__init__(num_experts, top_k, hidden_size,
......@@ -296,8 +234,10 @@ class EPMoE(FusedMoE):
self.dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
if True:
self.mori_op = self.get_mori_op()
self.mori_op = self.get_mori_op()
self.zero_token_count = None
def get_mori_op(self):
global _MORI_OP
......@@ -319,7 +259,7 @@ class EPMoE(FusedMoE):
hidden_dim=self.hidden_size,
scale_dim=0,
scale_type_size=vllm_config.model_config.dtype.itemsize,
max_num_inp_token_per_rank=4096,
max_num_inp_token_per_rank=20480,
num_experts_per_rank=self.local_num_experts,
num_experts_per_token=self.top_k,
max_token_type_size=2,
......@@ -334,6 +274,7 @@ class EPMoE(FusedMoE):
def set_shared_experts(self, shared_experts: torch.nn.Module):
if self.shared_experts is None:
self.shared_experts = shared_experts
if self.shared_expert_overlap:
self.token_dispatcher.set_shared_experts(self.shared_experts)
......@@ -355,8 +296,28 @@ class EPMoE(FusedMoE):
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
return torch.ops.vllm.ep_moe_forward(hidden_states, router_logits,
self.layer_name)
self.layer_name)
def get_expert_weights(self) -> Iterable[torch.Tensor]:
weights = list(self.named_parameters())
# Filter out the non-expert weights.
# `e_score_correction_bias` is a bias for each logical expert,
# with shape (num_logical_experts,), not an expert weight.
NON_EXPERT_WEIGHTS = {
"e_score_correction_bias",
"shared_experts.gate_up_proj.weight",
"shared_experts.gate_up_proj.weight_scale",
"shared_experts.down_proj.weight",
"shared_experts.down_proj.weight_scale"
}
return [
weight.view(self.local_num_experts, -1) for name, weight in weights
if name not in NON_EXPERT_WEIGHTS
]
def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
topk_weights, topk_ids = self.select_experts(
......@@ -376,129 +337,71 @@ class EPMoE(FusedMoE):
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
topk_ids = topk_ids.to(torch.int32)
scales = torch.rand(
hidden_states.shape[0],
0,
dtype=torch.float32,
device=hidden_states.device,
)
########################test#########################
# probs = None
# if self.apply_router_weight_on_input:
# probs = torch.zeros_like(router_logits, dtype=topk_weights.dtype).scatter(1, topk_ids, topk_weights)
# routing_map = torch.zeros_like(router_logits).int().scatter(1, topk_ids, 1).bool()
# (dispatch_output, tokens_per_expert) = self.token_dispatcher.token_permutation(
# hidden_states, probs, routing_map
# )
# torch.cuda.synchronize()
# print("###########################all2all dispatch_output shape:", dispatch_output.shape)
# print("###########################all2all dispatch_output:", dispatch_output[:10, :10])
# expert_output = self.quant_method.apply_ep(
# layer=self,
# hidden_states=dispatch_output,
# tokens_per_expert=tokens_per_expert,
# topk_weights=None,
# topk_ids=None,
# global_num_experts=self.global_num_experts,
# expert_map=self.expert_map,
# activation=self.activation,
# apply_router_weight_on_input=self.apply_router_weight_on_input,
# use_nn_moe=self.use_nn_moe,
# )
# torch.cuda.synchronize()
# print("###########################grouped gemm out:", expert_output[:10, :10])
# final_hidden_states = self.token_dispatcher.token_unpermutation(expert_output)
# final_hidden_states_all2all = final_hidden_states
# torch.cuda.synchronize()
# print("####################all2all unpermute output:", final_hidden_states[:10, :10].tolist())
########################test##########################
if False:
probs = None
if self.apply_router_weight_on_input:
probs = torch.zeros_like(router_logits, dtype=topk_weights.dtype).scatter(1, topk_ids, topk_weights)
routing_map = torch.zeros_like(router_logits).int().scatter(1, topk_ids, 1).bool()
(dispatch_output, tokens_per_expert) = self.token_dispatcher.token_permutation(
hidden_states, probs, routing_map
)
else:
topk_ids = topk_ids.to(torch.int32)
scales = torch.rand(
hidden_states.shape[0],
0,
dtype=torch.float32,
device=hidden_states.device,
)
self.sync()
print("##########################topk_weights shape:{} topk_ids shape:{}".format(topk_weights.shape, topk_ids.shape))
(
dispatch_output,
dispatch_weights,
dispatch_scales,
dispatch_indices,
dispatch_recv_num_token,
) = self.mori_op.dispatch(
hidden_states.contiguous(),
topk_weights.contiguous(),
scales.contiguous(),
topk_ids.contiguous(),
)
self.sync()
# with torch.inference_mode():
# src_token_pos = self.mori_op.get_dispatch_src_token_pos().tolist()
# print("##################src_token_pos:", src_token_pos[:10].tolist())
tokens_per_expert = dispatch_recv_num_token
print("######################dispatched_input shape:{} dispatch_weights.shape:{} dispatch_indices shape:{}".format(dispatch_output.shape,
dispatch_weights.shape, dispatch_indices.shape))
print("####################dispatch_recv_num_token:", dispatch_recv_num_token)
dispatch_recv_num_token = dispatch_recv_num_token.cpu()[0]
print("########################dispatch_output:", dispatch_output[:10, :10].tolist())
print("########################dispatch_indices:", dispatch_indices[:10, :].tolist())
print("#########################start fused_moe")
has_greater_than_255 = torch.any(dispatch_indices > 255).item()
has_less_than_0 = torch.any(dispatch_indices < 0).item()
print("##################################has_greater_than_255:{} has_less_than_0:{}".format(has_greater_than_255, has_less_than_0))
if dispatch_recv_num_token > 0:
# Matrix multiply.
#expert_output = self.quant_method.apply_ep(
expert_output = self.quant_method.apply(
layer=self,
x=dispatch_output[:dispatch_recv_num_token].contiguous(),
tokens_per_expert=tokens_per_expert,
topk_weights=dispatch_weights[:dispatch_recv_num_token].contiguous(),
topk_ids=dispatch_indices[:dispatch_recv_num_token].contiguous(),
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
use_nn_moe=self.use_nn_moe,
)
else:
expert_output = dispatch_output[:dispatch_recv_num_token]
self.sync()
print("####################fused_moe expert_output:", expert_output[:10, :10].tolist())
(
dispatch_output,
dispatch_weights,
dispatch_scales,
dispatch_indices,
dispatch_recv_num_token,
) = self.mori_op.dispatch(
hidden_states,
topk_weights,
scales,
topk_ids,
)
if False:
final_hidden_states = self.token_dispatcher.token_unpermutation(expert_output)
#self.sync()
#dispatch_recv_num_token = dispatch_recv_num_token[0].item()
dispatch_recv_num_token = dispatch_recv_num_token.cpu()[0]
dispatch_output = dispatch_output[:dispatch_recv_num_token]
dispatch_weights = dispatch_weights[:dispatch_recv_num_token]
dispatch_indices = dispatch_indices[:dispatch_recv_num_token]
valid_mask = ((dispatch_indices <= 255) & (dispatch_indices >= 0)).all(dim=1)
dispatch_output = dispatch_output[valid_mask]
dispatch_indices = dispatch_indices[valid_mask]
dispatch_weights = dispatch_weights[valid_mask]
dispatch_recv_num_token = dispatch_indices.shape[0]
# dispatch_recv_num_token = dispatch_recv_num_token.cpu()[0]
# has_greater_than_255 = torch.any(dispatch_indices > 255).item()
# has_less_than_0 = torch.any(dispatch_indices < 0).item()
# print("##################################has_greater_than_255:{} has_less_than_0:{}".format(has_greater_than_255, has_less_than_0))
# if has_greater_than_255 or has_less_than_0:
# print("###################dispatch_indices:", dispatch_indices.tolist())
if dispatch_recv_num_token > 0:
# Matrix multiply.
expert_output = self.quant_method.apply_ep(
layer=self,
x=dispatch_output,
topk_weights=dispatch_weights,
topk_ids=dispatch_indices,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
use_nn_moe=self.use_nn_moe,
)
else:
combine_output, _ = self.mori_op.combine(expert_output.contiguous(), dispatch_weights.contiguous(), topk_ids.contiguous())
final_hidden_states = combine_output[:hidden_states.shape[0], :]
torch.cuda.synchronize()
print("####################mori combine_output:", combine_output[:10, :10].tolist())
expert_output = dispatch_output#[:dispatch_recv_num_token]
#self.sync()
self.sync()
combine_output, _ = self.mori_op.combine(expert_output, dispatch_weights, topk_ids)
final_hidden_states = combine_output[:hidden_states.shape[0], :]
####################test#################
# final_hidden_states_close = torch.allclose(final_hidden_states, final_hidden_states_all2all, rtol=1e-2, atol=1e-2)
# print(f"final_hidden_states_close: {final_hidden_states_close}")
#####################test################
#self.sync()
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
# if shared_expert_overlap is True, the expert calculation happens in
......
......@@ -331,7 +331,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.hidden_shape_before_permute = hidden_states.shape
if True:
if False:
permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute(
hidden_states,
routing_map,
......@@ -339,15 +339,10 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
fused=self.config.moe_permute_fusion
)
else:
torch.cuda.synchronize()
print("########################hidden_states shape:{} \n #####################routing_map shape:{}\n".format(hidden_states.shape,
routing_map.shape))
print("########################hidden_states:{} \n #####################routing_map:{}\n".format(hidden_states[0, :10].tolist(),
routing_map[0, :10].tolist()))
cuda_permute_result = groupgemm_permute(hidden_states, routing_map)
permutated_local_input_tokens, self.reversed_local_input_permutation_mapping, \
expert_m, self.expert_m_count, expert_m_max = cuda_permute_result
self.expert_m_count = cuda_permute_result
# Perform expert parallel AlltoAll communication
# tokens_per_expert = self._maybe_dtoh_and_synchronize(
......@@ -427,7 +422,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.shared_experts.post_forward_comm()
# Unpermutation 1: AlltoAll output to output
if True:
if False:
output = unpermute(
permutated_local_input_tokens,
self.reversed_local_input_permutation_mapping,
......
......@@ -778,19 +778,19 @@ class FusedMoE(torch.nn.Module):
assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method
if self.enable_eplb:
from vllm.model_executor.layers.quantization.fp8 import (
Fp8MoEMethod)
if not isinstance(quant_method, Fp8MoEMethod):
# TODO: Add support for additional quantization methods.
# The implementation for other quantization methods does not
# contain essential differences, but the current quant API
# design causes duplicated work when extending to new
# quantization methods, so I'm leaving it for now.
# If you plan to add support for more quantization methods,
# please refer to the implementation in `Fp8MoEMethod`.
raise NotImplementedError("EPLB is only supported for FP8 "
"quantization for now.")
# if self.enable_eplb:
# from vllm.model_executor.layers.quantization.fp8 import (
# Fp8MoEMethod)
# if not isinstance(quant_method, Fp8MoEMethod):
# # TODO: Add support for additional quantization methods.
# # The implementation for other quantization methods does not
# # contain essential differences, but the current quant API
# # design causes duplicated work when extending to new
# # quantization methods, so I'm leaving it for now.
# # If you plan to add support for more quantization methods,
# # please refer to the implementation in `Fp8MoEMethod`.
# raise NotImplementedError("EPLB is only supported for FP8 "
# "quantization for now.")
if quant_config is None:
# Not considering quant for now, temporarily
......
......@@ -334,29 +334,59 @@ class SlimQuantW4A8Int8MoEMethod:
def apply_ep( #dp+ep
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
use_nn_moe: Optional[bool] = False,
**_
) -> torch.Tensor:
return fused_experts_impl_w4a8_ep(hidden_states,
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
tokens_per_expert)
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
activation=activation,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe,
)
def apply(# tp
self,
layer: torch.nn.Module,
x: torch.Tensor,
tokens_per_expert: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
**_
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
......@@ -364,6 +394,20 @@ class SlimQuantW4A8Int8MoEMethod:
raise NotImplementedError(
"EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet.")
# Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate
)
return fused_experts(
x,
......
......@@ -102,12 +102,7 @@ def with_amdsmi_context(fn):
def device_id_to_physical_device_id(device_id: int) -> int:
if "CUDA_VISIBLE_DEVICES" in os.environ:
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
physical_device_id = device_ids[device_id]
return int(physical_device_id)
else:
return device_id
return device_id
@cache
......
......@@ -441,7 +441,7 @@ class EagleProposer:
# [batch_size]
num_accepted_tokens_tensor: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
cu_num_tokens = torch.arange(cu_target_query_lens.shape[0], device=cu_target_query_lens.device)
cu_num_tokens = torch.arange(cu_target_query_lens.shape[0], device=cu_target_query_lens.device, dtype=torch.int32)
token_indices = num_accepted_tokens_tensor + cu_target_query_lens[:-1]
return cu_num_tokens, token_indices
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment