Commit d2e57a90 authored by 王敏's avatar 王敏
Browse files

[feat]优化mori计算逻辑,支持cudagraph,按照bs*ep_size截断fused_moe的输入,共享专家不tp切分,去掉最后的allreduce

parent 8824ae6a
...@@ -4320,9 +4320,6 @@ class CompilationConfig: ...@@ -4320,9 +4320,6 @@ class CompilationConfig:
self.splitting_ops = [] if self.full_cuda_graph else [ self.splitting_ops = [] if self.full_cuda_graph else [
"vllm.unified_attention", "vllm.unified_attention",
"vllm.unified_attention_with_output", "vllm.unified_attention_with_output",
"vllm.token_permutation_forward",
"vllm.token_unpermutation_forward",
"vllm.ep_moe_forward",
] ]
......
...@@ -948,6 +948,10 @@ def init_distributed_environment( ...@@ -948,6 +948,10 @@ def init_distributed_environment(
"Fallback Gloo backend is not available.") "Fallback Gloo backend is not available.")
backend = "gloo" backend = "gloo"
# this backend is used for WORLD # this backend is used for WORLD
data_parallel_size = parallel_config.data_parallel_size
use_mori_ep = envs.VLLM_USE_MORI_EP and data_parallel_size > 1 and parallel_config.enable_expert_parallel
if use_mori_ep:
backend="cpu:gloo,cuda:nccl" backend="cpu:gloo,cuda:nccl"
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=backend, backend=backend,
......
...@@ -168,7 +168,7 @@ if TYPE_CHECKING: ...@@ -168,7 +168,7 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_CAT: bool = False VLLM_USE_TRITON_CAT: bool = False
USE_FUSED_RMS_QUANT: bool = False USE_FUSED_RMS_QUANT: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
VLLM_USE_ALLTOALL_EP: bool = False VLLM_USE_MORI_EP: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1112,8 +1112,8 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1112,8 +1112,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
("true", "1")), ("true", "1")),
# vLLM will use all_to_all ep mode # vLLM will use all_to_all ep mode
"VLLM_USE_ALLTOALL_EP": "VLLM_USE_MORI_EP":
lambda: (os.environ.get("VLLM_USE_ALLTOALL_EP", "True").lower() in lambda: (os.environ.get("VLLM_USE_MORI_EP", "True").lower() in
("true", "1")), ("true", "1")),
} }
......
...@@ -136,8 +136,8 @@ def set_forward_context( ...@@ -136,8 +136,8 @@ def set_forward_context(
forward_start_time = time.perf_counter() forward_start_time = time.perf_counter()
dp_metadata: Optional[DPMetadata] = None dp_metadata: Optional[DPMetadata] = None
dp_size = vllm_config.parallel_config.data_parallel_size dp_size = vllm_config.parallel_config.data_parallel_size
use_all2all_ep = envs.VLLM_USE_ALLTOALL_EP and dp_size > 1 and vllm_config.parallel_config.enable_expert_parallel use_mori_ep = envs.VLLM_USE_MORI_EP and dp_size > 1 and vllm_config.parallel_config.enable_expert_parallel
if not use_all2all_ep and dp_size > 1 and ( if not use_mori_ep and dp_size > 1 and (
attn_metadata is not None or num_tokens is not None) : attn_metadata is not None or num_tokens is not None) :
dp_metadata = DPMetadata.make(vllm_config.parallel_config, dp_metadata = DPMetadata.make(vllm_config.parallel_config,
attn_metadata, num_tokens or 0, attn_metadata, num_tokens or 0,
......
...@@ -88,13 +88,16 @@ class EPSharedExperts(nn.Module): ...@@ -88,13 +88,16 @@ class EPSharedExperts(nn.Module):
hidden_size, [intermediate_size] * 2, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj") prefix=f"{prefix}.gate_up_proj",
expect_tp_size=1)
self.down_proj = RowParallelLinear(intermediate_size, self.down_proj = RowParallelLinear(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results, reduce_results=reduce_results,
prefix=f"{prefix}.down_proj") prefix=f"{prefix}.down_proj",
expect_tp_size=1)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
......
...@@ -7,7 +7,6 @@ from collections.abc import Iterable ...@@ -7,7 +7,6 @@ from collections.abc import Iterable
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
...@@ -17,6 +16,7 @@ from vllm.distributed.parallel_state import get_ep_group, get_node_count ...@@ -17,6 +16,7 @@ from vllm.distributed.parallel_state import get_ep_group, get_node_count
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.distributed import expert_parallel_all_gather, expert_parallel_all_reduce
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, UnquantizedFusedMoEMethod from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAlltoAllTokenDispatcher from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAlltoAllTokenDispatcher
...@@ -25,6 +25,10 @@ from vllm.utils import direct_register_custom_op ...@@ -25,6 +25,10 @@ from vllm.utils import direct_register_custom_op
import mori import mori
import torch.distributed as dist import torch.distributed as dist
from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8,
per_token_quant_int8)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -40,7 +44,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -40,7 +44,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self.moe = moe self.moe = moe
self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled() self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
self.zero_token_count = None
def apply_ep( def apply_ep(
self, self,
...@@ -235,9 +238,11 @@ class EPMoE(FusedMoE): ...@@ -235,9 +238,11 @@ class EPMoE(FusedMoE):
self.dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1' self.dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
self.mori_op = self.get_mori_op() self.scales = None
self.use_int8_dispatch = True
self.zero_token_count = None self.mori_op = self.get_mori_op()
self.first = True
def get_mori_op(self): def get_mori_op(self):
...@@ -253,20 +258,28 @@ class EPMoE(FusedMoE): ...@@ -253,20 +258,28 @@ class EPMoE(FusedMoE):
mori.shmem.shmem_torch_process_group_init("default") mori.shmem.shmem_torch_process_group_init("default")
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
multi_node = self.ep_size / 8 > 1
mori_data_type=vllm_config.model_config.dtype
mori_scale_type_size = vllm_config.model_config.dtype.itemsize
if self.use_int8_dispatch:
mori_scale_type_size = 4
config = mori.ops.EpDispatchCombineConfig( config = mori.ops.EpDispatchCombineConfig(
data_type=vllm_config.model_config.dtype, data_type=mori_data_type,
rank=self.ep_rank, rank=self.ep_rank,
world_size=self.ep_size, world_size=self.ep_size,
hidden_dim=self.hidden_size, hidden_dim=self.hidden_size,
scale_dim=0, scale_dim=1 if self.use_int8_dispatch else 0,
scale_type_size=vllm_config.model_config.dtype.itemsize, scale_type_size=mori_scale_type_size,
max_num_inp_token_per_rank=4096, max_num_inp_token_per_rank=2048,
num_experts_per_rank=self.local_num_experts, num_experts_per_rank=self.local_num_experts,
num_experts_per_token=self.top_k, num_experts_per_token=self.top_k,
max_token_type_size=2, max_token_type_size=2,
block_num=64, block_num=80,
warp_num_per_block=16, warp_num_per_block=16,
kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode if multi_node else \
mori.ops.EpDispatchCombineKernelType.IntraNode
) )
_MORI_OP = mori.ops.EpDispatchCombineOp(config) _MORI_OP = mori.ops.EpDispatchCombineOp(config)
...@@ -291,13 +304,11 @@ class EPMoE(FusedMoE): ...@@ -291,13 +304,11 @@ class EPMoE(FusedMoE):
return quant_method return quant_method
def sync(self): def sync(self):
torch.cuda.synchronize() #torch.cuda.synchronize()
dist.barrier() dist.barrier()
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor):
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None):
return torch.ops.vllm.ep_moe_forward(hidden_states, router_logits, return torch.ops.vllm.ep_moe_forward(hidden_states, router_logits,
self.layer_name) self.layer_name)
...@@ -322,9 +333,7 @@ class EPMoE(FusedMoE): ...@@ -322,9 +333,7 @@ class EPMoE(FusedMoE):
def forward_impl(self, hidden_states: torch.Tensor, def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor):
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None):
topk_weights, topk_ids = self.select_experts( topk_weights, topk_ids = self.select_experts(
hidden_states=hidden_states, hidden_states=hidden_states,
...@@ -337,25 +346,27 @@ class EPMoE(FusedMoE): ...@@ -337,25 +346,27 @@ class EPMoE(FusedMoE):
custom_routing_function=self.custom_routing_function, custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func, scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias, e_score_correction_bias=self.e_score_correction_bias,
indices_type=torch.int64, indices_type=torch.int32,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate) use_fused_gate=self.use_fused_gate)
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None: if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
else:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
topk_ids = topk_ids.to(torch.int32) if self.use_int8_dispatch:
scales = torch.rand( hidden_states, scales = per_token_quant_int8(hidden_states)
else:
if self.scales is None:
self.scales = torch.rand(
hidden_states.shape[0], hidden_states.shape[0],
0, 0,
dtype=torch.float32, dtype=torch.float32,
device=hidden_states.device, device=hidden_states.device,
) )
scales = self.scales
#dist.barrier()
#self.sync()
( (
dispatch_output, dispatch_output,
...@@ -369,49 +380,54 @@ class EPMoE(FusedMoE): ...@@ -369,49 +380,54 @@ class EPMoE(FusedMoE):
scales, scales,
topk_ids, topk_ids,
) )
#self.sync() #self.sync()
# dispatch_recv_num_token = dispatch_recv_num_token.cpu()[0] expect_m = hidden_states.shape[0] * self.ep_size
# #dispatch_recv_num_token = dispatch_recv_num_token.item() dispatch_output_clip = dispatch_output[:expect_m]
# dispatch_output = dispatch_output[:dispatch_recv_num_token] dispatch_weights_clip = dispatch_weights[:expect_m]
# dispatch_weights = dispatch_weights[:dispatch_recv_num_token] dispatch_indices_clip = dispatch_indices[:expect_m]
# dispatch_indices = dispatch_indices[:dispatch_recv_num_token] dispatch_scales_clip = dispatch_scales[:expect_m]
# dispatch_recv_num_token = dispatch_recv_num_token.item()
# dispatch_output = torch.narrow(dispatch_output, dim=0, start=0, length=dispatch_recv_num_token)
# dispatch_weights = torch.narrow(dispatch_weights, dim=0, start=0, length=dispatch_recv_num_token)
# dispatch_indices = torch.narrow(dispatch_indices, dim=0, start=0, length=dispatch_recv_num_token)
# valid_mask = ((dispatch_indices <= 255) & (dispatch_indices >= 0)).all(dim=1)
# dispatch_output = dispatch_output[valid_mask]
# dispatch_indices = dispatch_indices[valid_mask]
# dispatch_weights = dispatch_weights[valid_mask]
# dispatch_recv_num_token = dispatch_indices.shape[0]
# dispatch_recv_num_token = dispatch_recv_num_token.cpu()[0]
# has_greater_than_255 = torch.any(dispatch_indices > 255).item()
# has_less_than_0 = torch.any(dispatch_indices < 0).item()
# print("##################################has_greater_than_255:{} has_less_than_0:{}".format(has_greater_than_255, has_less_than_0))
# if has_greater_than_255 or has_less_than_0:
# print("###################dispatch_indices:", dispatch_indices.tolist())
if dispatch_recv_num_token > 0:
# Matrix multiply.
expert_output = self.quant_method.apply_ep( expert_output = self.quant_method.apply_ep(
layer=self, layer=self,
x=dispatch_output, x=dispatch_output_clip,
topk_weights=dispatch_weights, topk_weights=dispatch_weights_clip,
topk_ids=dispatch_indices, topk_ids=dispatch_indices_clip,
global_num_experts=self.global_num_experts, global_num_experts=self.global_num_experts,
expert_map=self.expert_map, expert_map=self.expert_map,
activation=self.activation, activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input, apply_router_weight_on_input=self.apply_router_weight_on_input,
use_nn_moe=self.use_nn_moe, use_nn_moe=self.use_nn_moe,
num_local_tokens=dispatch_recv_num_token,
config_select_bs=hidden_states.shape[0],
scales=dispatch_scales_clip if self.use_int8_dispatch else None
#routed_scaling_factor=self.routed_scaling_factor,
) )
else:
expert_output = dispatch_output#[:dispatch_recv_num_token] # if self.first and hidden_states.shape[0] == 2:
# self.first = False
# import numpy as np
# np.save(f'/work/vllm_profile/ep{self.ep_rank}_topk_ids.npy', dispatch_indices_clip.cpu().numpy())
# print("##################config_select_bs:{} topk_ids shape:{} num_local_tokens:{}".format(hidden_states.shape[0],
# topk_ids.shape,
# dispatch_recv_num_token))
# expert_output = self.quant_method.apply_ep(
# layer=self,
# x=dispatch_output,
# topk_weights=dispatch_weights,
# topk_ids=dispatch_indices,
# global_num_experts=self.global_num_experts,
# expert_map=self.expert_map,
# activation=self.activation,
# apply_router_weight_on_input=self.apply_router_weight_on_input,
# use_nn_moe=self.use_nn_moe,
# num_local_tokens=dispatch_recv_num_token,
# config_select_bs=hidden_states.shape[0]*2,
# scales=dispatch_scales if self.use_int8_dispatch else None
# #routed_scaling_factor=self.routed_scaling_factor,
# )
#self.sync() #self.sync()
combine_output, _ = self.mori_op.combine(expert_output, dispatch_weights, topk_ids) combine_output, _ = self.mori_op.combine(expert_output, dispatch_weights, topk_ids)
...@@ -422,9 +438,9 @@ class EPMoE(FusedMoE): ...@@ -422,9 +438,9 @@ class EPMoE(FusedMoE):
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None: if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
# if shared_expert_overlap is True, the expert calculation happens in # if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations # the token_dispatcher to overlap communications and computations
shared_output = ( # shared_output = (
self.maybe_all_reduce_tensor_model_parallel( # self.maybe_all_reduce_tensor_model_parallel(
shared_output)) # shared_output))
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick: if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
...@@ -434,31 +450,26 @@ class EPMoE(FusedMoE): ...@@ -434,31 +450,26 @@ class EPMoE(FusedMoE):
final_hidden_states = final_hidden_states + shared_output \ final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor) * (1. / self.routed_scaling_factor)
if envs.USE_FUSED_RMS_QUANT: return final_hidden_states
return final_hidden_states, new_resi
else:
return final_hidden_states, None
def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, rms_weight: Optional[torch.Tensor] = None, layer_name: str) -> torch.Tensor:
residual: Optional[torch.Tensor] = None) -> tuple[torch.Tensor, torch.Tensor]:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None assert self.quant_method is not None
return self.forward_impl(hidden_states, router_logits, rms_weight, residual) return self.forward_impl(hidden_states, router_logits)
def ep_moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, def ep_moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, rms_weight: Optional[torch.Tensor] = None, layer_name: str) -> torch.Tensor:
residual: Optional[torch.Tensor] = None) -> tuple[torch.Tensor, torch.Tensor]: return torch.empty_like(hidden_states)
return torch.empty_like(hidden_states), torch.empty_like(hidden_states)
direct_register_custom_op( direct_register_custom_op(
op_name="ep_moe_forward", op_name="ep_moe_forward",
op_func=ep_moe_forward, op_func=ep_moe_forward,
mutates_args=["hidden_states", "router_logits", "rms_weight", "residual"], mutates_args=["hidden_states", "router_logits"],
fake_impl=ep_moe_forward_fake, fake_impl=ep_moe_forward_fake,
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ), tags=(torch.Tag.needs_fixed_stride_order, ),
......
...@@ -1257,13 +1257,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -1257,13 +1257,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None: use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None,
true_bs: Optional[int] = None,) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8, activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,use_int4_w4a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,use_int4_w4a8,
per_channel_quant, global_num_experts, expert_map, per_channel_quant, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe) block_shape, use_nn_moe, num_local_tokens, true_bs)
def inplace_fused_experts_fake( def inplace_fused_experts_fake(
...@@ -1289,7 +1291,9 @@ def inplace_fused_experts_fake( ...@@ -1289,7 +1291,9 @@ def inplace_fused_experts_fake(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None: use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None,
true_bs: Optional[int] = None,) -> None:
pass pass
...@@ -1325,14 +1329,16 @@ def outplace_fused_experts( ...@@ -1325,14 +1329,16 @@ def outplace_fused_experts(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor: use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None,
true_bs: Optional[int] = None,) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input, False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
use_int4_w4a16,use_int4_w4a8, per_channel_quant, use_int4_w4a16,use_int4_w4a8, per_channel_quant,
global_num_experts, expert_map, w1_scale, global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe) block_shape, use_nn_moe, num_local_tokens, true_bs)
def outplace_fused_experts_fake( def outplace_fused_experts_fake(
...@@ -1357,7 +1363,9 @@ def outplace_fused_experts_fake( ...@@ -1357,7 +1363,9 @@ def outplace_fused_experts_fake(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor: use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None,
true_bs: Optional[int] = None,) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -1414,7 +1422,9 @@ def fused_experts( ...@@ -1414,7 +1422,9 @@ def fused_experts(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
allow_deep_gemm: bool = False, allow_deep_gemm: bool = False,
allow_cutlass_block_scaled_grouped_gemm: bool = False, allow_cutlass_block_scaled_grouped_gemm: bool = False,
use_nn_moe: Optional[bool] = False) -> torch.Tensor: use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None,
true_bs: Optional[int] = None,) -> torch.Tensor:
# For now, disable DeepGemm for small N (<= 512) until better # For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available. # permute/unpermute ops are available.
N = w1.size(1) N = w1.size(1)
...@@ -1472,7 +1482,9 @@ def fused_experts( ...@@ -1472,7 +1482,9 @@ def fused_experts(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe,
num_local_tokens=num_local_tokens,
true_bs=true_bs)
def fused_experts_impl( def fused_experts_impl(
...@@ -1500,6 +1512,8 @@ def fused_experts_impl( ...@@ -1500,6 +1512,8 @@ def fused_experts_impl(
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None,
true_bs: Optional[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = hidden_states.size(0) num_tokens = hidden_states.size(0)
if use_nn_moe: if use_nn_moe:
...@@ -1544,7 +1558,9 @@ def fused_experts_impl( ...@@ -1544,7 +1558,9 @@ def fused_experts_impl(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=False use_nn_moe=False,
num_local_tokens=num_local_tokens,
true_bs=true_bs,
) )
elif use_int4_w4a8 is True: elif use_int4_w4a8 is True:
return fused_experts_impl_w4a8(hidden_states=hidden_states, return fused_experts_impl_w4a8(hidden_states=hidden_states,
......
...@@ -152,7 +152,8 @@ def moe_align_block_size( ...@@ -152,7 +152,8 @@ def moe_align_block_size(
num_experts: int, num_experts: int,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
pad_sorted_ids: bool = False, pad_sorted_ids: bool = False,
num_token: Optional[int] = None num_token: Optional[int] = None,
num_local_tokens: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Aligns the token distribution across experts to be compatible with block Aligns the token distribution across experts to be compatible with block
...@@ -234,7 +235,7 @@ def moe_align_block_size( ...@@ -234,7 +235,7 @@ def moe_align_block_size(
if envs.VLLM_USE_LIGHT_OP: if envs.VLLM_USE_LIGHT_OP:
op.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, op.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad, None) expert_ids, num_tokens_post_pad, expert_map, None, num_local_tokens)
else: else:
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad) expert_ids, num_tokens_post_pad)
......
...@@ -486,9 +486,13 @@ class ColumnParallelLinear(LinearBase): ...@@ -486,9 +486,13 @@ class ColumnParallelLinear(LinearBase):
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
expect_tp_size: Optional[int] = None,
): ):
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
if expect_tp_size is not None:
self.expect_tp_size = expect_tp_size
self.tp_size = self.expect_tp_size
self.input_size_per_partition = input_size self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size) self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition] self.output_partition_sizes = [self.output_size_per_partition]
...@@ -728,10 +732,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -728,10 +732,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
expect_tp_size: Optional[int] = None,
): ):
self.eps = eps self.eps = eps
self.output_sizes = output_sizes self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if expect_tp_size is not None:
tp_size = expect_tp_size
self.expect_tp_size = expect_tp_size
self.expect_tp_size = expect_tp_size
assert all(output_size % tp_size == 0 for output_size in output_sizes) assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size=input_size, super().__init__(input_size=input_size,
output_size=sum(output_sizes), output_size=sum(output_sizes),
...@@ -741,7 +753,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -741,7 +753,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
return_bias=return_bias) return_bias=return_bias,
expect_tp_size=expect_tp_size)
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
...@@ -838,6 +851,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -838,6 +851,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert loaded_shard_id < len(self.output_sizes) assert loaded_shard_id < len(self.output_sizes)
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if self.expect_tp_size is not None and self.expect_tp_size == 1:
tp_rank = 0
tp_size = 1
if output_dim is not None: if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size shard_size = self.output_sizes[loaded_shard_id] // tp_size
...@@ -1384,10 +1402,16 @@ class RowParallelLinear(LinearBase): ...@@ -1384,10 +1402,16 @@ class RowParallelLinear(LinearBase):
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
expect_tp_size: Optional[int] = None,
): ):
# Divide the weight matrix along the first dimension. # Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
if expect_tp_size is not None:
self.tp_rank = 0
self.tp_size = 1
self.expect_tp_size = expect_tp_size
self.input_size_per_partition = divide(input_size, self.tp_size) self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size] self.output_partition_sizes = [output_size]
...@@ -1433,6 +1457,10 @@ class RowParallelLinear(LinearBase): ...@@ -1433,6 +1457,10 @@ class RowParallelLinear(LinearBase):
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if self.expect_tp_size is not None:
tp_rank = 0
tp_size = 1
input_dim = getattr(param, "input_dim", None) input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False) is_sharded_weight = getattr(param, "is_sharded_weight", False)
......
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import os import os
import torch import torch
import vllm.envs as envs from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from torch.nn.parameter import Parameter from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase) from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase)
from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig, from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig,
...@@ -16,6 +18,10 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, ...@@ -16,6 +18,10 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
from vllm.model_executor.parameter import (ChannelQuantScaleParameter, from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter) ModelWeightParameter)
from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod
from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
FusedMoeWeightScaleSupported)
try: try:
from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin
...@@ -23,6 +29,9 @@ except Exception: ...@@ -23,6 +29,9 @@ except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n") print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
logger = init_logger(__name__)
class MarlinMoeWorkspace: class MarlinMoeWorkspace:
""" """
Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE. Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE.
...@@ -220,6 +229,10 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -220,6 +229,10 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None,
config_select_bs: Optional[int] = None,
routed_scaling_factor: Optional[float] = None,
scales: Optional[torch.Tensor] = None,
**_ **_
) -> torch.Tensor: ) -> torch.Tensor:
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers() workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
...@@ -243,6 +256,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -243,6 +256,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
num_local_tokens=num_local_tokens,
config_select_bs=config_select_bs,
q_scales=scales
) )
def apply( def apply(
...@@ -309,3 +325,43 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -309,3 +325,43 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
) )
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import (
BatchedGroupedGemmExperts, GroupedGemmGemmExperts)
assert not self.rocm_aiter_moe_enabled, (
"ROCm AITER are not supported with all2all yet.")
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
max_num_tokens_per_rank = (
prepare_finalize.max_num_tokens_per_rank())
assert max_num_tokens_per_rank is not None
logger.debug(
"BatchedGroupedGemmExperts(%s): "
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
self.__class__.__name__, max_num_tokens_per_rank,
self.quant_config.weight_block_size, False)
return BatchedGroupedGemmExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
use_fp8_w8a8=False,
block_shape=self.quant_config.weight_block_size,
per_act_token_quant=True,
allow_deep_gemm=False,
)
else:
logger.debug(
"GroupedGemmGemmExperts(%s): block_size=%s, per_act_token=%s",
self.__class__.__name__, self.quant_config.weight_block_size,
False)
return GroupedGemmGemmExperts(
use_fp8_w8a8=False,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=False,
)
...@@ -178,7 +178,7 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -178,7 +178,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
dp_size = get_dp_group().world_size dp_size = get_dp_group().world_size
self.use_all2all_ep = envs.VLLM_USE_ALLTOALL_EP and dp_size > 1 and parallel_config.enable_expert_parallel self.use_mori_ep = envs.VLLM_USE_MORI_EP and dp_size > 1 and parallel_config.enable_expert_parallel
def forward( def forward(
...@@ -211,7 +211,7 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -211,7 +211,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
if self.use_all2all_ep: if self.use_mori_ep:
ep_moe_shared_experts_keys = "mlp.shared_experts" ep_moe_shared_experts_keys = "mlp.shared_experts"
ep_moe_shared_experts_mapping = {ep_moe_shared_experts_keys:"mlp.experts.shared_experts"} ep_moe_shared_experts_mapping = {ep_moe_shared_experts_keys:"mlp.experts.shared_experts"}
...@@ -244,7 +244,7 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -244,7 +244,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
if self.use_all2all_ep: if self.use_mori_ep:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys]) name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
...@@ -261,7 +261,7 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -261,7 +261,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
if self.use_all2all_ep: if self.use_mori_ep:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys]) name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
param = params_dict[name] param = params_dict[name]
...@@ -273,7 +273,7 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -273,7 +273,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
expert_id=expert_id) expert_id=expert_id)
break break
else: else:
if self.use_all2all_ep: if self.use_mori_ep:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys]) name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
......
...@@ -165,9 +165,9 @@ class DeepseekV2MoE(nn.Module): ...@@ -165,9 +165,9 @@ class DeepseekV2MoE(nn.Module):
self.n_local_physical_experts) self.n_local_physical_experts)
dp_size = get_dp_group().world_size dp_size = get_dp_group().world_size
self.use_all2all_ep = envs.VLLM_USE_ALLTOALL_EP and dp_size > 1 and parallel_config.enable_expert_parallel self.use_mori_ep = envs.VLLM_USE_MORI_EP and dp_size > 1 and parallel_config.enable_expert_parallel
moe_cls = FusedMoE if not self.use_all2all_ep else EPMoE moe_cls = FusedMoE if not self.use_mori_ep else EPMoE
self.experts = moe_cls( self.experts = moe_cls(
num_experts=config.n_routed_experts, num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
...@@ -189,8 +189,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -189,8 +189,8 @@ class DeepseekV2MoE(nn.Module):
if config.n_shared_experts is not None: if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size * intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts) config.n_shared_experts)
#shared_expert_cls = DeepseekV2MLP if not self.use_all2all_ep else EPSharedExperts shared_expert_cls = DeepseekV2MLP if not self.use_mori_ep else EPSharedExperts
self.shared_experts = DeepseekV2MLP( self.shared_experts = shared_expert_cls(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
...@@ -199,7 +199,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -199,7 +199,7 @@ class DeepseekV2MoE(nn.Module):
), ),
prefix=f"{prefix}.shared_experts", prefix=f"{prefix}.shared_experts",
) )
if self.use_all2all_ep: if self.use_mori_ep:
self.experts.set_shared_experts(self.shared_experts) self.experts.set_shared_experts(self.shared_experts)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
...@@ -212,7 +212,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -212,7 +212,7 @@ class DeepseekV2MoE(nn.Module):
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
if not self.use_all2all_ep: if not self.use_mori_ep:
if self.n_shared_experts is not None: if self.n_shared_experts is not None:
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True) shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
...@@ -222,7 +222,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -222,7 +222,7 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
if not self.use_all2all_ep: if not self.use_mori_ep:
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
...@@ -233,10 +233,10 @@ class DeepseekV2MoE(nn.Module): ...@@ -233,10 +233,10 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
else: else:
final_hidden_states, new_resi = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
if not self.use_all2all_ep: if not self.use_mori_ep:
if shared_output is not None: if shared_output is not None:
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick: if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
...@@ -917,7 +917,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -917,7 +917,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
dp_size = get_dp_group().world_size dp_size = get_dp_group().world_size
self.use_all2all_ep = envs.VLLM_USE_ALLTOALL_EP and dp_size > 1 and parallel_config.enable_expert_parallel self.use_mori_ep = envs.VLLM_USE_MORI_EP and dp_size > 1 and parallel_config.enable_expert_parallel
def set_eplb_state( def set_eplb_state(
self, self,
...@@ -1000,7 +1000,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -1000,7 +1000,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
if self.use_all2all_ep: if self.use_mori_ep:
ep_moe_shared_experts_keys = "mlp.shared_experts" ep_moe_shared_experts_keys = "mlp.shared_experts"
ep_moe_shared_experts_mapping = {ep_moe_shared_experts_keys:"mlp.experts.shared_experts"} ep_moe_shared_experts_mapping = {ep_moe_shared_experts_keys:"mlp.experts.shared_experts"}
...@@ -1037,7 +1037,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -1037,7 +1037,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
if self.use_all2all_ep: if self.use_mori_ep:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys]) name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
...@@ -1066,7 +1066,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -1066,7 +1066,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# Instead, create a new variable # Instead, create a new variable
name_mapped = name.replace(weight_name, param_name) name_mapped = name.replace(weight_name, param_name)
if self.use_all2all_ep: if self.use_mori_ep:
name_mapped = name_mapped.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys]) name_mapped = name_mapped.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
if is_pp_missing_parameter(name_mapped, self): if is_pp_missing_parameter(name_mapped, self):
...@@ -1094,7 +1094,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -1094,7 +1094,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# So we simply skip it # So we simply skip it
continue continue
if self.use_all2all_ep: if self.use_mori_ep:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys]) name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
......
...@@ -320,7 +320,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -320,7 +320,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.shared_kv_cache_layers: dict[str, str] = {} self.shared_kv_cache_layers: dict[str, str] = {}
dp_size = self.vllm_config.parallel_config.data_parallel_size dp_size = self.vllm_config.parallel_config.data_parallel_size
self.use_all2all_ep = envs.VLLM_USE_ALLTOALL_EP and dp_size > 1 and parallel_config.enable_expert_parallel self.use_mori_ep = envs.VLLM_USE_MORI_EP and dp_size > 1 and parallel_config.enable_expert_parallel
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
""" """
...@@ -1234,7 +1234,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1234,7 +1234,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(tms) : There are many cases where padding is enabled for # TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations. # prefills, causing unnecessary and excessive padding of activations.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager or self.use_all2all_ep: if dp_size == 1 or self.vllm_config.model_config.enforce_eager or self.use_mori_ep:
# Early exit. # Early exit.
return 0, None return 0, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment