Commit 3cb11400 authored by 王敏's avatar 王敏
Browse files

临时添加mori代码

parent 22a4e07b
...@@ -21,7 +21,7 @@ from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAllt ...@@ -21,7 +21,7 @@ from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAllt
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EpMoeConfig from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EpMoeConfig
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from lightop import groupgemm from lightop import groupgemm
#import mori import mori
import torch.distributed as dist import torch.distributed as dist
...@@ -46,18 +46,42 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -46,18 +46,42 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor, tokens_per_expert: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace=True,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
return self.forward( return self.forward(
hidden_states=hidden_states, hidden_states=hidden_states,
layer=layer, layer=layer,
tokens_per_expert=tokens_per_expert) tokens_per_expert=tokens_per_expert,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=use_nn_moe)
def forward_cuda( def forward_cuda(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor, tokens_per_expert: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace=True,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
# process MoE # process MoE
...@@ -97,23 +121,39 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -97,23 +121,39 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
assert hidden_states.numel() == 0, f"sorted_tokens: should be empty, but got {hidden_states.shape}" assert hidden_states.numel() == 0, f"sorted_tokens: should be empty, but got {hidden_states.shape}"
expert_output = hidden_states expert_output = hidden_states
else: else:
if self.zero_token_count is None: if topk_ids is None:
self.zero_token_count = torch.zeros(1, dtype=torch.int64, device=hidden_states.device) if self.zero_token_count is None:
total_tokens = tokens_per_expert.sum() self.zero_token_count = torch.zeros(1, dtype=torch.int64, device=hidden_states.device)
if total_tokens > self.zero_token_count: total_tokens = tokens_per_expert.sum()
gateup_output = groupgemm(hidden_states, layer.w13_weight, tokens_per_expert, False) print("#################total_tokens:", total_tokens.tolist())
# Act if total_tokens > self.zero_token_count:
down_input = torch.zeros( gateup_output = groupgemm(hidden_states, layer.w13_weight, tokens_per_expert, False)
gateup_output.shape[0], # Act
gateup_output.shape[1] // 2, down_input = torch.zeros(
device=gateup_output.device, gateup_output.shape[0],
dtype=hidden_states.dtype gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=hidden_states.dtype
)
torch.ops._C.silu_and_mul(down_input,
gateup_output.view(-1, layer.w13_weight.shape[2]))
expert_output = groupgemm(down_input, layer.w2_weight, tokens_per_expert, False)
else :
expert_output = hidden_states
else:
expert_output = self.fused_experts(
hidden_states=hidden_states,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=use_nn_moe
) )
torch.ops._C.silu_and_mul(down_input,
gateup_output.view(-1, layer.w13_weight.shape[2]))
expert_output = groupgemm(down_input, layer.w2_weight, tokens_per_expert, False)
else :
expert_output = hidden_states
return expert_output return expert_output
...@@ -127,6 +167,14 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -127,6 +167,14 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor, tokens_per_expert: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace=True,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
**kwargs, **kwargs,
): ):
raise NotImplementedError raise NotImplementedError
...@@ -136,6 +184,14 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -136,6 +184,14 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor, tokens_per_expert: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace=True,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
...@@ -144,6 +200,14 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -144,6 +200,14 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor, tokens_per_expert: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace=True,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
...@@ -224,7 +288,7 @@ class EPMoE(FusedMoE): ...@@ -224,7 +288,7 @@ class EPMoE(FusedMoE):
self.use_shared_expert = False self.use_shared_expert = False
self.token_dispatcher = MoEAlltoAllTokenDispatcher( self.token_dispatcher = MoEAlltoAllTokenDispatcher(
self.local_num_experts, self.local_expert_indices, self.local_num_experts, self.local_expert_indices,
config=self.ep_moe_config, #layer_name=f"{self.layer_name}.token_dispatcher", config=self.ep_moe_config, layer_name=f"{self.layer_name}.token_dispatcher",
) )
self.shared_expert_overlap = moe_shared_expert_overlap self.shared_expert_overlap = moe_shared_expert_overlap
...@@ -232,7 +296,7 @@ class EPMoE(FusedMoE): ...@@ -232,7 +296,7 @@ class EPMoE(FusedMoE):
self.dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1' self.dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
if False: if True:
self.mori_op = self.get_mori_op() self.mori_op = self.get_mori_op()
def get_mori_op(self): def get_mori_op(self):
...@@ -240,8 +304,13 @@ class EPMoE(FusedMoE): ...@@ -240,8 +304,13 @@ class EPMoE(FusedMoE):
if _MORI_OP is None: if _MORI_OP is None:
# world_group = torch.distributed.group.WORLD # world_group = torch.distributed.group.WORLD
# assert world_group is not None # assert world_group is not None
torch._C._distributed_c10d._register_process_group("mori_ep", get_ep_group().device_group) #torch._C._distributed_c10d._register_process_group("mori_ep", get_ep_group().device_group)
mori.shmem.shmem_torch_process_group_init("mori_ep") #mori.shmem.shmem_torch_process_group_init("mori_ep")
world_group = torch.distributed.group.WORLD
assert world_group is not None
torch._C._distributed_c10d._register_process_group("default", world_group)
mori.shmem.shmem_torch_process_group_init("default")
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
config = mori.ops.EpDispatchCombineConfig( config = mori.ops.EpDispatchCombineConfig(
data_type=vllm_config.model_config.dtype, data_type=vllm_config.model_config.dtype,
...@@ -250,12 +319,12 @@ class EPMoE(FusedMoE): ...@@ -250,12 +319,12 @@ class EPMoE(FusedMoE):
hidden_dim=self.hidden_size, hidden_dim=self.hidden_size,
scale_dim=0, scale_dim=0,
scale_type_size=vllm_config.model_config.dtype.itemsize, scale_type_size=vllm_config.model_config.dtype.itemsize,
max_num_inp_token_per_rank=10000, max_num_inp_token_per_rank=4096,
num_experts_per_rank=self.local_num_experts, num_experts_per_rank=self.local_num_experts,
num_experts_per_token=self.top_k, num_experts_per_token=self.top_k,
max_token_type_size=4, max_token_type_size=2,
# block_num=40, block_num=64,
# warp_num_per_block=8, warp_num_per_block=16,
kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode
) )
_MORI_OP = mori.ops.EpDispatchCombineOp(config) _MORI_OP = mori.ops.EpDispatchCombineOp(config)
...@@ -307,14 +376,50 @@ class EPMoE(FusedMoE): ...@@ -307,14 +376,50 @@ class EPMoE(FusedMoE):
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None: if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
if True:
########################test#########################
# probs = None
# if self.apply_router_weight_on_input:
# probs = torch.zeros_like(router_logits, dtype=topk_weights.dtype).scatter(1, topk_ids, topk_weights)
# routing_map = torch.zeros_like(router_logits).int().scatter(1, topk_ids, 1).bool()
# (dispatch_output, tokens_per_expert) = self.token_dispatcher.token_permutation(
# hidden_states, probs, routing_map
# )
# torch.cuda.synchronize()
# print("###########################all2all dispatch_output shape:", dispatch_output.shape)
# print("###########################all2all dispatch_output:", dispatch_output[:10, :10])
# expert_output = self.quant_method.apply_ep(
# layer=self,
# hidden_states=dispatch_output,
# tokens_per_expert=tokens_per_expert,
# topk_weights=None,
# topk_ids=None,
# global_num_experts=self.global_num_experts,
# expert_map=self.expert_map,
# activation=self.activation,
# apply_router_weight_on_input=self.apply_router_weight_on_input,
# use_nn_moe=self.use_nn_moe,
# )
# torch.cuda.synchronize()
# print("###########################grouped gemm out:", expert_output[:10, :10])
# final_hidden_states = self.token_dispatcher.token_unpermutation(expert_output)
# final_hidden_states_all2all = final_hidden_states
# torch.cuda.synchronize()
# print("####################all2all unpermute output:", final_hidden_states[:10, :10].tolist())
########################test##########################
if False:
probs = None probs = None
if self.apply_router_weight_on_input: if self.apply_router_weight_on_input:
probs = torch.zeros_like(router_logits, dtype=topk_weights.dtype).scatter(1, topk_ids, topk_weights) probs = torch.zeros_like(router_logits, dtype=topk_weights.dtype).scatter(1, topk_ids, topk_weights)
routing_map = torch.zeros_like(router_logits).int().scatter(1, topk_ids, 1).bool() routing_map = torch.zeros_like(router_logits).int().scatter(1, topk_ids, 1).bool()
(dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation( (dispatch_output, tokens_per_expert) = self.token_dispatcher.token_permutation(
hidden_states, probs, routing_map hidden_states, probs, routing_map
) )
else: else:
...@@ -325,39 +430,76 @@ class EPMoE(FusedMoE): ...@@ -325,39 +430,76 @@ class EPMoE(FusedMoE):
dtype=torch.float32, dtype=torch.float32,
device=hidden_states.device, device=hidden_states.device,
) )
self.sync()
print("##########################topk_weights shape:{} topk_ids shape:{}".format(topk_weights.shape, topk_ids.shape))
( (
dispatched_input, dispatch_output,
dispatch_weights, dispatch_weights,
dispatch_scales, dispatch_scales,
dispatch_indices, dispatch_indices,
dispatch_recv_num_token, dispatch_recv_num_token,
) = self.mori_op.dispatch( ) = self.mori_op.dispatch(
hidden_states, hidden_states.contiguous(),
topk_weights, topk_weights.contiguous(),
scales, scales.contiguous(),
topk_ids, topk_ids.contiguous(),
) )
tokens_per_expert = dispatch_recv_num_token
self.sync() self.sync()
print("######################dispatched_input shape:{} dispatch_weights.shape:{} dispatch_indices shape:{}".format(dispatched_input.shape, # with torch.inference_mode():
# src_token_pos = self.mori_op.get_dispatch_src_token_pos().tolist()
# print("##################src_token_pos:", src_token_pos[:10].tolist())
tokens_per_expert = dispatch_recv_num_token
print("######################dispatched_input shape:{} dispatch_weights.shape:{} dispatch_indices shape:{}".format(dispatch_output.shape,
dispatch_weights.shape, dispatch_indices.shape)) dispatch_weights.shape, dispatch_indices.shape))
print("####################dispatch_recv_num_token:", dispatch_recv_num_token.tolist()) print("####################dispatch_recv_num_token:", dispatch_recv_num_token)
#print("####################dispatch_weights:", dispatch_weights.tolist())
#print("####################dispatch_indices:", dispatch_indices.tolist()) dispatch_recv_num_token = dispatch_recv_num_token.cpu()[0]
print("########################dispatch_output:", dispatch_output[:10, :10].tolist())
# Matrix multiply. print("########################dispatch_indices:", dispatch_indices[:10, :].tolist())
expert_output = self.quant_method.apply_ep( print("#########################start fused_moe")
layer=self, has_greater_than_255 = torch.any(dispatch_indices > 255).item()
hidden_states=dispatched_input, has_less_than_0 = torch.any(dispatch_indices < 0).item()
tokens_per_expert=tokens_per_expert print("##################################has_greater_than_255:{} has_less_than_0:{}".format(has_greater_than_255, has_less_than_0))
)
if dispatch_recv_num_token > 0:
# Matrix multiply.
#expert_output = self.quant_method.apply_ep(
expert_output = self.quant_method.apply(
layer=self,
x=dispatch_output[:dispatch_recv_num_token].contiguous(),
tokens_per_expert=tokens_per_expert,
topk_weights=dispatch_weights[:dispatch_recv_num_token].contiguous(),
topk_ids=dispatch_indices[:dispatch_recv_num_token].contiguous(),
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
use_nn_moe=self.use_nn_moe,
)
else:
expert_output = dispatch_output[:dispatch_recv_num_token]
self.sync()
print("####################fused_moe expert_output:", expert_output[:10, :10].tolist())
if True:
if False:
final_hidden_states = self.token_dispatcher.token_unpermutation(expert_output) final_hidden_states = self.token_dispatcher.token_unpermutation(expert_output)
else: else:
combine_output, _ = self.mori_op.combine(expert_output, dispatch_weights, topk_ids) combine_output, _ = self.mori_op.combine(expert_output.contiguous(), dispatch_weights.contiguous(), topk_ids.contiguous())
final_hidden_states = combine_output[:hidden_states.shape[0], :] final_hidden_states = combine_output[:hidden_states.shape[0], :]
torch.cuda.synchronize()
print("####################mori combine_output:", combine_output[:10, :10].tolist())
self.sync()
####################test#################
# final_hidden_states_close = torch.allclose(final_hidden_states, final_hidden_states_all2all, rtol=1e-2, atol=1e-2)
# print(f"final_hidden_states_close: {final_hidden_states_close}")
#####################test################
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None: if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
# if shared_expert_overlap is True, the expert calculation happens in # if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations # the token_dispatcher to overlap communications and computations
......
...@@ -26,6 +26,7 @@ from vllm.platforms import current_platform ...@@ -26,6 +26,7 @@ from vllm.platforms import current_platform
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from lightop import groupgemm_permute, groupgemm_unpermute
cuda_dtoh_stream = torch.cuda.Stream() cuda_dtoh_stream = torch.cuda.Stream()
...@@ -329,12 +330,24 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -329,12 +330,24 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
self.hidden_shape_before_permute = hidden_states.shape self.hidden_shape_before_permute = hidden_states.shape
permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute(
hidden_states, if True:
routing_map, permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute(
num_out_tokens=self.num_out_tokens, hidden_states,
fused=self.config.moe_permute_fusion routing_map,
) num_out_tokens=self.num_out_tokens,
fused=self.config.moe_permute_fusion
)
else:
torch.cuda.synchronize()
print("########################hidden_states shape:{} \n #####################routing_map shape:{}\n".format(hidden_states.shape,
routing_map.shape))
print("########################hidden_states:{} \n #####################routing_map:{}\n".format(hidden_states[0, :10].tolist(),
routing_map[0, :10].tolist()))
cuda_permute_result = groupgemm_permute(hidden_states, routing_map)
permutated_local_input_tokens, self.reversed_local_input_permutation_mapping, \
expert_m, self.expert_m_count, expert_m_max = cuda_permute_result
# Perform expert parallel AlltoAll communication # Perform expert parallel AlltoAll communication
# tokens_per_expert = self._maybe_dtoh_and_synchronize( # tokens_per_expert = self._maybe_dtoh_and_synchronize(
...@@ -414,14 +427,22 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -414,14 +427,22 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.shared_experts.post_forward_comm() self.shared_experts.post_forward_comm()
# Unpermutation 1: AlltoAll output to output # Unpermutation 1: AlltoAll output to output
output = unpermute( if True:
permutated_local_input_tokens, output = unpermute(
self.reversed_local_input_permutation_mapping, permutated_local_input_tokens,
restore_shape=self.hidden_shape_before_permute, self.reversed_local_input_permutation_mapping,
probs=self.probs, restore_shape=self.hidden_shape_before_permute,
routing_map=self.routing_map, probs=self.probs,
fused=self.config.moe_permute_fusion, routing_map=self.routing_map,
) fused=self.config.moe_permute_fusion,
)
else:
output = groupgemm_unpermute(permutated_local_input_tokens,
self.reversed_local_input_permutation_mapping,
list(self.hidden_shape_before_permute),
self.probs,
self.routing_map,
self.expert_m_count)
# Reshape the output tensor # Reshape the output tensor
output = output.view(self.hidden_shape) output = output.view(self.hidden_shape)
......
...@@ -349,23 +349,14 @@ class SlimQuantW4A8Int8MoEMethod: ...@@ -349,23 +349,14 @@ class SlimQuantW4A8Int8MoEMethod:
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, tokens_per_expert: torch.Tensor,
top_k: int, topk_weights: torch.Tensor,
renormalize: bool, topk_ids: torch.Tensor,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
enable_eplb: bool = False, enable_eplb: bool = False,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
**_ **_
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -373,20 +364,6 @@ class SlimQuantW4A8Int8MoEMethod: ...@@ -373,20 +364,6 @@ class SlimQuantW4A8Int8MoEMethod:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet.") "EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet.")
# Expert selection # Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate
)
return fused_experts( return fused_experts(
x, x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment