Commit d698d6f2 authored by 王敏's avatar 王敏
Browse files

[feat]整合mori和deepep相关代码

parent 7293a072
File added
......@@ -171,16 +171,16 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_qps_per_rank = None
if self.internode:
num_rdma_bytes = int(1e9/2)#1024 * 1024 * 1024
num_qps_per_rank = 30#self.num_sms // 2
import deep_ep
num_nvl_bytes, num_rdma_bytes = 0, 0
hidden_size = 7168
hidden_bytes = hidden_size * 2
for config in (deep_ep.Buffer.get_dispatch_config(self.cpu_group.size()), deep_ep.Buffer.get_combine_config(self.cpu_group.size())):
num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, self.cpu_group.size()), num_nvl_bytes)
num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, self.cpu_group.size()), num_rdma_bytes)
num_rdma_bytes = int(1e9/2) #1024 * 1024 * 1024
num_qps_per_rank = 30 #self.num_sms // 2
# import deep_ep
# num_nvl_bytes, num_rdma_bytes = 0, 0
# hidden_size = 7168
# hidden_bytes = hidden_size * 2
# for config in (deep_ep.Buffer.get_dispatch_config(self.cpu_group.size()), deep_ep.Buffer.get_combine_config(self.cpu_group.size())):
# num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, self.cpu_group.size()), num_nvl_bytes)
# num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, self.cpu_group.size()), num_rdma_bytes)
else:
num_rdma_bytes = 0
num_qps_per_rank = 1
......
......@@ -175,6 +175,7 @@ if TYPE_CHECKING:
USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_P2P_ASYNC: bool = False
VLLM_P2P_BUF_TOKENS: int = 30000
VLLM_ENABLE_MOE_GROUP_GEMM: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1151,6 +1152,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# pd separation p2p async buf tokens
"VLLM_P2P_BUF_TOKENS":
lambda: int(os.getenv("VLLM_P2P_BUF_TOKENS", "30000")),
# pd separation p2p async buf tokens
"VLLM_ENABLE_MOE_GROUP_GEMM":
lambda: (os.environ.get("VLLM_ENABLE_MOE_GROUP_GEMM", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
import os
import logging
from typing import Callable, List, Optional, Tuple
from dataclasses import dataclass
from typing import Callable, Optional
from collections.abc import Iterable
import torch
import torch.nn.functional as F
import torch.distributed as dist
from vllm.logger import init_logger
from vllm.platforms import current_platform
......@@ -18,10 +15,8 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EpMoeConfig
from vllm.model_executor.layers.fused_moe.mori_moe.ep_moe_utlis import EpMoeConfig
from vllm.utils import direct_register_custom_op
import torch.distributed as dist
try:
import mori
......@@ -35,8 +30,8 @@ logger = init_logger(__name__)
_MORI_OP = None
@CustomOp.register("unquantized_ep_moe")
class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
@CustomOp.register("unquantized_mori_moe")
class UnquantizedMoriMoeMethod(UnquantizedFusedMoEMethod):
"""MoE method without quantization."""
def __init__(self, moe: FusedMoEConfig):
......@@ -44,9 +39,9 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self.topk_indices_dtype = None
self.moe = moe
self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
self.rocm_aiter_moe_enabled = False
def apply_ep(
def apply_mori_ep(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
......@@ -162,7 +157,7 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
forward_native = forward_cuda
class EPMoE(FusedMoE):
class MoriMoE(FusedMoE):
"""
dp+ep MoE Expert Parallel Impl
......@@ -194,7 +189,6 @@ class EPMoE(FusedMoE):
enable_eplb: bool = False,
num_redundant_experts: int = 0,
moe_permute_fusion: bool = False,
moe_shared_expert_overlap: bool = False
):
super().__init__(num_experts, top_k, hidden_size,
intermediate_size, params_dtype,
......@@ -215,7 +209,6 @@ class EPMoE(FusedMoE):
moe_router_topk=self.top_k,
# TODO: support fusion permute
moe_permute_fusion=moe_permute_fusion,
moe_shared_expert_overlap=moe_shared_expert_overlap,
ep_size=self.ep_size,
num_moe_experts=self.global_num_experts,
routed_scaling_factor=self.routed_scaling_factor,
......@@ -228,21 +221,14 @@ class EPMoE(FusedMoE):
self.local_expert_indices = [
local_expert_indices_offset + i for i in range(self.local_num_experts)
]
self.use_shared_expert = False
# self.token_dispatcher = MoEAlltoAllTokenDispatcher(
# self.local_num_experts, self.local_expert_indices,
# config=self.ep_moe_config, layer_name=f"{self.layer_name}.token_dispatcher",
# )
self.shared_expert_overlap = moe_shared_expert_overlap
self.shared_experts = None
self.scales = None
self.use_int8_dispatch = True
vllm_config = get_current_vllm_config()
self.max_num_inp_token_per_rank = 1024#vllm_config.scheduler_config.max_num_seqs
self.max_num_inp_token_per_rank = 1024 #vllm_config.scheduler_config.max_num_seqs
self.mori_op = self.get_mori_op()
def get_mori_op(self):
......@@ -252,10 +238,6 @@ class EPMoE(FusedMoE):
assert world_group is not None
torch._C._distributed_c10d._register_process_group("mori_ep", get_ep_group().device_group)
mori.shmem.shmem_torch_process_group_init("mori_ep")
# world_group = torch.distributed.group.WORLD
# assert world_group is not None
# torch._C._distributed_c10d._register_process_group("default", world_group)
# mori.shmem.shmem_torch_process_group_init("default")
vllm_config = get_current_vllm_config()
multi_node = self.ep_size / 8 > 1
......@@ -278,7 +260,6 @@ class EPMoE(FusedMoE):
max_token_type_size=2,
block_num=80,
warp_num_per_block=4,
# kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode
kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode if multi_node else \
mori.ops.EpDispatchCombineKernelType.IntraNode
)
......@@ -290,14 +271,11 @@ class EPMoE(FusedMoE):
if self.shared_experts is None:
self.shared_experts = shared_experts
# if self.shared_expert_overlap:
# self.token_dispatcher.set_shared_experts(self.shared_experts)
def create_quant_method(self, moe, quant_config, prefix):
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None
quant_method = (UnquantizedEPGroupedGemmMethod(moe) if quant_config is None
quant_method = (UnquantizedMoriMoeMethod(moe) if quant_config is None
else quant_config.get_quant_method(self, prefix))
assert quant_method is not None
......@@ -310,7 +288,7 @@ class EPMoE(FusedMoE):
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
return torch.ops.vllm.ep_moe_forward(hidden_states, router_logits,
return torch.ops.vllm.mori_moe_forward(hidden_states, router_logits,
self.layer_name)
def get_expert_weights(self) -> Iterable[torch.Tensor]:
......@@ -350,7 +328,7 @@ class EPMoE(FusedMoE):
routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate)
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
if self.use_int8_dispatch:
......@@ -377,11 +355,10 @@ class EPMoE(FusedMoE):
hidden_states,
topk_weights,
scales,
topk_ids,
#layer_idx=int(self.layer_name.split('.')[2])
topk_ids
)
expert_output = self.quant_method.apply_ep(
expert_output = self.quant_method.apply_mori_ep(
layer=self,
x=dispatch_output,
topk_weights=dispatch_weights,
......@@ -394,7 +371,6 @@ class EPMoE(FusedMoE):
num_local_tokens=dispatch_recv_num_token,
config_select_bs=hidden_states.shape[0]*self.ep_size/self.dp_size,
scales=dispatch_scales if self.use_int8_dispatch else None
# routed_scaling_factor=self.routed_scaling_factor,
)
# self.sync()
......@@ -404,11 +380,7 @@ class EPMoE(FusedMoE):
# self.sync()
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
# shared_output = (
# self.maybe_all_reduce_tensor_model_parallel(
# shared_output))
if self.shared_experts is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
......@@ -420,7 +392,7 @@ class EPMoE(FusedMoE):
return final_hidden_states
def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
def mori_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
......@@ -429,16 +401,16 @@ def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
return self.forward_impl(hidden_states, router_logits)
def ep_moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
def mori_moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor:
return torch.empty_like(hidden_states)
direct_register_custom_op(
op_name="ep_moe_forward",
op_func=ep_moe_forward,
op_name="mori_moe_forward",
op_func=mori_moe_forward,
mutates_args=["hidden_states", "router_logits"],
fake_impl=ep_moe_forward_fake,
fake_impl=mori_moe_forward_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order,),
)
\ No newline at end of file
......@@ -166,6 +166,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
self.use_deepep = parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
self.enable_moe_group_gemm = parallel_config.enable_expert_parallel and envs.VLLM_ENABLE_MOE_GROUP_GEMM
def create_weights(
......@@ -250,32 +252,36 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
**_ ):
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
return fused_experts_impl_w4a8_marlin(
x,
w1,
w2,
topk_ids=topk_ids,
topk_weights=topk_weights,
workspace=workspace,
global_reduce_buffer=global_reduce_buffer,
inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
activation=activation,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
def apply_ep( #dp+ep
if not self.enable_moe_group_gemm:
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
return fused_experts_impl_w4a8_marlin(
x,
w1,
w2,
topk_ids=topk_ids,
topk_weights=topk_weights,
workspace=workspace,
global_reduce_buffer=global_reduce_buffer,
inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
activation=activation,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
else:
# TODO:
return None
def apply_mori_ep(
self,
layer: torch.nn.Module,
x: torch.Tensor,
......@@ -310,12 +316,11 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
global_num_experts=global_num_experts,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a1_scale=scales,
a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe,
num_local_tokens=num_local_tokens,
config_select_bs=config_select_bs,
q_scales=scales
)
def apply(
......
......@@ -43,8 +43,8 @@ from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.ep_moe.layer import EPMoE
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EPSharedExperts
from vllm.model_executor.layers.fused_moe.mori_moe.layer import MoriMoE
from vllm.model_executor.layers.fused_moe.mori_moe.ep_moe_utlis import EPSharedExperts
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
......@@ -167,10 +167,10 @@ class DeepseekV2MoE(nn.Module):
self.n_local_physical_experts)
dp_size = get_dp_group().world_size
self.use_mori_ep = envs.VLLM_ALL2ALL_BACKEND == 'mori' and dp_size > 1 and parallel_config.enable_expert_parallel
self.use_mori_ep = parallel_config.enable_expert_parallel and dp_size > 1 and envs.VLLM_ALL2ALL_BACKEND == 'mori'
self.enable_expert_parallel = parallel_config.enable_expert_parallel
moe_cls = FusedMoE if not self.use_mori_ep else EPMoE
moe_cls = FusedMoE if not self.use_mori_ep else MoriMoE
self.experts = moe_cls(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
......@@ -225,12 +225,12 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
if not self.use_mori_ep:
if not self.enable_expert_parallel:
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
shared_output=shared_output)
hidden_states=hidden_states,
router_logits=router_logits,
shared_output=shared_output)
else:
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
......@@ -249,8 +249,22 @@ class DeepseekV2MoE(nn.Module):
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
else:
final_hidden_states = self.experts(hidden_states=hidden_states,
else:
if not self.use_mori_ep:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
else:
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if not self.use_mori_ep:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment