Commit 5dcc5cb8 authored by 王敏's avatar 王敏
Browse files

优化mori ep

parent e0ba23b5
......@@ -948,7 +948,7 @@ def init_distributed_environment(
"Fallback Gloo backend is not available.")
backend = "gloo"
# this backend is used for WORLD
parallel_config = config.parallel_config
data_parallel_size = parallel_config.data_parallel_size
use_mori_ep = envs.VLLM_USE_MORI_EP and data_parallel_size > 1 and parallel_config.enable_expert_parallel
if use_mori_ep:
......
......@@ -10,30 +10,31 @@ import torch.nn.functional as F
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.config import get_current_vllm_config
from vllm.distributed.parallel_state import get_ep_group
from vllm.model_executor.custom_op import CustomOp
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.distributed.parallel_state import get_ep_group, get_node_count
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.distributed import expert_parallel_all_gather, expert_parallel_all_reduce
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EpMoeConfig
from vllm.utils import direct_register_custom_op
import mori
import torch.distributed as dist
from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8,
try:
import mori
from lmslim.layers.gemm.int8_utils import (
per_token_quant_int8)
except ImportError:
is_mori_available = False
logger = init_logger(__name__)
_MORI_OP = None
@CustomOp.register("unquantized_ep_moe")
class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
"""MoE method without quantization."""
......@@ -166,6 +167,7 @@ class EPMoE(FusedMoE):
dp+ep MoE Expert Parallel Impl
"""
def __init__(
self,
num_experts: int, # Global number of experts
......@@ -241,26 +243,27 @@ class EPMoE(FusedMoE):
self.scales = None
self.use_int8_dispatch = True
vllm_config = get_current_vllm_config()
self.max_num_inp_token_per_rank = vllm_config.scheduler_config.max_num_seqs
self.mori_op = self.get_mori_op()
self.first = True
def get_mori_op(self):
global _MORI_OP
if _MORI_OP is None:
# world_group = torch.distributed.group.WORLD
# assert world_group is not None
#torch._C._distributed_c10d._register_process_group("mori_ep", get_ep_group().device_group)
#mori.shmem.shmem_torch_process_group_init("mori_ep")
world_group = torch.distributed.group.WORLD
assert world_group is not None
torch._C._distributed_c10d._register_process_group("default", world_group)
mori.shmem.shmem_torch_process_group_init("default")
torch._C._distributed_c10d._register_process_group("mori_ep", get_ep_group().device_group)
mori.shmem.shmem_torch_process_group_init("mori_ep")
# world_group = torch.distributed.group.WORLD
# assert world_group is not None
# torch._C._distributed_c10d._register_process_group("default", world_group)
# mori.shmem.shmem_torch_process_group_init("default")
vllm_config = get_current_vllm_config()
multi_node = self.ep_size / 8 > 1
mori_data_type=vllm_config.model_config.dtype
mori_data_type = vllm_config.model_config.dtype
mori_scale_type_size = vllm_config.model_config.dtype.itemsize
if self.use_int8_dispatch:
mori_scale_type_size = 4
......@@ -272,12 +275,13 @@ class EPMoE(FusedMoE):
hidden_dim=self.hidden_size,
scale_dim=1 if self.use_int8_dispatch else 0,
scale_type_size=mori_scale_type_size,
max_num_inp_token_per_rank=2048,
max_num_inp_token_per_rank=self.max_num_inp_token_per_rank,
num_experts_per_rank=self.local_num_experts,
num_experts_per_token=self.top_k,
max_token_type_size=2,
block_num=80,
warp_num_per_block=16,
# kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode
kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode if multi_node else \
mori.ops.EpDispatchCombineKernelType.IntraNode
)
......@@ -304,7 +308,7 @@ class EPMoE(FusedMoE):
return quant_method
def sync(self):
#torch.cuda.synchronize()
# torch.cuda.synchronize()
dist.barrier()
def forward(self, hidden_states: torch.Tensor,
......@@ -331,7 +335,6 @@ class EPMoE(FusedMoE):
if name not in NON_EXPERT_WEIGHTS
]
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
......@@ -365,8 +368,7 @@ class EPMoE(FusedMoE):
)
scales = self.scales
#self.sync()
# self.sync()
(
dispatch_output,
......@@ -380,60 +382,50 @@ class EPMoE(FusedMoE):
scales,
topk_ids,
)
#self.sync()
expect_m = hidden_states.shape[0] * self.ep_size
dispatch_output_clip = dispatch_output[:expect_m]
dispatch_weights_clip = dispatch_weights[:expect_m]
dispatch_indices_clip = dispatch_indices[:expect_m]
dispatch_scales_clip = dispatch_scales[:expect_m]
expert_output = self.quant_method.apply_ep(
layer=self,
x=dispatch_output_clip,
topk_weights=dispatch_weights_clip,
topk_ids=dispatch_indices_clip,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
use_nn_moe=self.use_nn_moe,
num_local_tokens=dispatch_recv_num_token,
config_select_bs=hidden_states.shape[0],
scales=dispatch_scales_clip if self.use_int8_dispatch else None
#routed_scaling_factor=self.routed_scaling_factor,
)
# if self.first and hidden_states.shape[0] == 2:
# self.first = False
# import numpy as np
# np.save(f'/work/vllm_profile/ep{self.ep_rank}_topk_ids.npy', dispatch_indices_clip.cpu().numpy())
# print("##################config_select_bs:{} topk_ids shape:{} num_local_tokens:{}".format(hidden_states.shape[0],
# topk_ids.shape,
# dispatch_recv_num_token))
# self.sync()
# expect_m = topk_ids.shape[0] * self.ep_size
# dispatch_output_clip = dispatch_output[:expect_m]
# dispatch_weights_clip = dispatch_weights[:expect_m]
# dispatch_indices_clip = dispatch_indices[:expect_m]
# dispatch_scales_clip = dispatch_scales[:expect_m]
# expert_output = self.quant_method.apply_ep(
# layer=self,
# x=dispatch_output,
# topk_weights=dispatch_weights,
# topk_ids=dispatch_indices,
# x=dispatch_output_clip,
# topk_weights=dispatch_weights_clip,
# topk_ids=dispatch_indices_clip,
# global_num_experts=self.global_num_experts,
# expert_map=self.expert_map,
# activation=self.activation,
# apply_router_weight_on_input=self.apply_router_weight_on_input,
# use_nn_moe=self.use_nn_moe,
# num_local_tokens=dispatch_recv_num_token,
# config_select_bs=hidden_states.shape[0]*2,
# scales=dispatch_scales if self.use_int8_dispatch else None
# config_select_bs=hidden_states.shape[0],
# scales=dispatch_scales_clip if self.use_int8_dispatch else None
# #routed_scaling_factor=self.routed_scaling_factor,
# )
#self.sync()
expert_output = self.quant_method.apply_ep(
layer=self,
x=dispatch_output,
topk_weights=dispatch_weights,
topk_ids=dispatch_indices,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
use_nn_moe=self.use_nn_moe,
num_local_tokens=dispatch_recv_num_token,
config_select_bs=hidden_states.shape[0],
scales=dispatch_scales if self.use_int8_dispatch else None
# routed_scaling_factor=self.routed_scaling_factor,
)
# self.sync()
combine_output, _ = self.mori_op.combine(expert_output, dispatch_weights, topk_ids)
final_hidden_states = combine_output[:hidden_states.shape[0], :]
#self.sync()
# self.sync()
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
# if shared_expert_overlap is True, the expert calculation happens in
......@@ -452,6 +444,7 @@ class EPMoE(FusedMoE):
return final_hidden_states
def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
......@@ -472,5 +465,5 @@ direct_register_custom_op(
mutates_args=["hidden_states", "router_logits"],
fake_impl=ep_moe_forward_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
tags=(torch.Tag.needs_fixed_stride_order,),
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment