Commit 7bf7df7f authored by 王敏's avatar 王敏
Browse files

[fix]修复不开启ep报错

parent 0ff29dbf
......@@ -10,13 +10,12 @@ import torch.nn.functional as F
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.config import get_current_vllm_config
from vllm.distributed.parallel_state import get_ep_group
from vllm.model_executor.custom_op import CustomOp
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.distributed.parallel_state import get_ep_group, get_node_count
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.distributed import expert_parallel_all_gather, expert_parallel_all_reduce
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAlltoAllTokenDispatcher
......@@ -26,7 +25,6 @@ import mori
import torch.distributed as dist
from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8,
per_token_quant_int8)
......@@ -248,14 +246,14 @@ class EPMoE(FusedMoE):
def get_mori_op(self):
global _MORI_OP
if _MORI_OP is None:
# world_group = torch.distributed.group.WORLD
# assert world_group is not None
#torch._C._distributed_c10d._register_process_group("mori_ep", get_ep_group().device_group)
#mori.shmem.shmem_torch_process_group_init("mori_ep")
world_group = torch.distributed.group.WORLD
assert world_group is not None
torch._C._distributed_c10d._register_process_group("default", world_group)
mori.shmem.shmem_torch_process_group_init("default")
torch._C._distributed_c10d._register_process_group("mori_ep", get_ep_group().device_group)
mori.shmem.shmem_torch_process_group_init("mori_ep")
# world_group = torch.distributed.group.WORLD
# assert world_group is not None
# torch._C._distributed_c10d._register_process_group("default", world_group)
# mori.shmem.shmem_torch_process_group_init("default")
vllm_config = get_current_vllm_config()
multi_node = self.ep_size / 8 > 1
......@@ -272,12 +270,13 @@ class EPMoE(FusedMoE):
hidden_dim=self.hidden_size,
scale_dim=1 if self.use_int8_dispatch else 0,
scale_type_size=mori_scale_type_size,
max_num_inp_token_per_rank=2048,
max_num_inp_token_per_rank=512,
num_experts_per_rank=self.local_num_experts,
num_experts_per_token=self.top_k,
max_token_type_size=2,
block_num=80,
warp_num_per_block=16,
#kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode
kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode if multi_node else \
mori.ops.EpDispatchCombineKernelType.IntraNode
)
......@@ -403,16 +402,6 @@ class EPMoE(FusedMoE):
#routed_scaling_factor=self.routed_scaling_factor,
)
# if self.first and hidden_states.shape[0] == 2:
# self.first = False
# import numpy as np
# np.save(f'/work/vllm_profile/ep{self.ep_rank}_topk_ids.npy', dispatch_indices_clip.cpu().numpy())
# print("##################config_select_bs:{} topk_ids shape:{} num_local_tokens:{}".format(hidden_states.shape[0],
# topk_ids.shape,
# dispatch_recv_num_token))
# expert_output = self.quant_method.apply_ep(
# layer=self,
# x=dispatch_output,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment