Commit d2e57a90 authored by 王敏's avatar 王敏
Browse files

[feat]优化mori计算逻辑,支持cudagraph,按照bs*ep_size截断fused_moe的输入,共享专家不tp切分,去掉最后的allreduce

parent 8824ae6a
......@@ -4320,9 +4320,6 @@ class CompilationConfig:
self.splitting_ops = [] if self.full_cuda_graph else [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.token_permutation_forward",
"vllm.token_unpermutation_forward",
"vllm.ep_moe_forward",
]
......
......@@ -948,7 +948,11 @@ def init_distributed_environment(
"Fallback Gloo backend is not available.")
backend = "gloo"
# this backend is used for WORLD
backend="cpu:gloo,cuda:nccl"
data_parallel_size = parallel_config.data_parallel_size
use_mori_ep = envs.VLLM_USE_MORI_EP and data_parallel_size > 1 and parallel_config.enable_expert_parallel
if use_mori_ep:
backend="cpu:gloo,cuda:nccl"
torch.distributed.init_process_group(
backend=backend,
init_method=distributed_init_method,
......
......@@ -168,7 +168,7 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_CAT: bool = False
USE_FUSED_RMS_QUANT: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
VLLM_USE_ALLTOALL_EP: bool = False
VLLM_USE_MORI_EP: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1112,8 +1112,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
("true", "1")),
# vLLM will use all_to_all ep mode
"VLLM_USE_ALLTOALL_EP":
lambda: (os.environ.get("VLLM_USE_ALLTOALL_EP", "True").lower() in
"VLLM_USE_MORI_EP":
lambda: (os.environ.get("VLLM_USE_MORI_EP", "True").lower() in
("true", "1")),
}
......
......@@ -136,8 +136,8 @@ def set_forward_context(
forward_start_time = time.perf_counter()
dp_metadata: Optional[DPMetadata] = None
dp_size = vllm_config.parallel_config.data_parallel_size
use_all2all_ep = envs.VLLM_USE_ALLTOALL_EP and dp_size > 1 and vllm_config.parallel_config.enable_expert_parallel
if not use_all2all_ep and dp_size > 1 and (
use_mori_ep = envs.VLLM_USE_MORI_EP and dp_size > 1 and vllm_config.parallel_config.enable_expert_parallel
if not use_mori_ep and dp_size > 1 and (
attn_metadata is not None or num_tokens is not None) :
dp_metadata = DPMetadata.make(vllm_config.parallel_config,
attn_metadata, num_tokens or 0,
......
......@@ -88,13 +88,16 @@ class EPSharedExperts(nn.Module):
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
prefix=f"{prefix}.gate_up_proj",
expect_tp_size=1)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj")
prefix=f"{prefix}.down_proj",
expect_tp_size=1)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
......
......@@ -7,7 +7,6 @@ from collections.abc import Iterable
import torch
import torch.nn.functional as F
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.config import get_current_vllm_config
......@@ -17,6 +16,7 @@ from vllm.distributed.parallel_state import get_ep_group, get_node_count
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.distributed import expert_parallel_all_gather, expert_parallel_all_reduce
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAlltoAllTokenDispatcher
......@@ -25,6 +25,10 @@ from vllm.utils import direct_register_custom_op
import mori
import torch.distributed as dist
from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8,
per_token_quant_int8)
logger = init_logger(__name__)
......@@ -40,7 +44,6 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self.moe = moe
self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
self.zero_token_count = None
def apply_ep(
self,
......@@ -235,10 +238,12 @@ class EPMoE(FusedMoE):
self.dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
self.mori_op = self.get_mori_op()
self.zero_token_count = None
self.scales = None
self.use_int8_dispatch = True
self.mori_op = self.get_mori_op()
self.first = True
def get_mori_op(self):
global _MORI_OP
......@@ -253,20 +258,28 @@ class EPMoE(FusedMoE):
mori.shmem.shmem_torch_process_group_init("default")
vllm_config = get_current_vllm_config()
multi_node = self.ep_size / 8 > 1
mori_data_type=vllm_config.model_config.dtype
mori_scale_type_size = vllm_config.model_config.dtype.itemsize
if self.use_int8_dispatch:
mori_scale_type_size = 4
config = mori.ops.EpDispatchCombineConfig(
data_type=vllm_config.model_config.dtype,
data_type=mori_data_type,
rank=self.ep_rank,
world_size=self.ep_size,
hidden_dim=self.hidden_size,
scale_dim=0,
scale_type_size=vllm_config.model_config.dtype.itemsize,
max_num_inp_token_per_rank=4096,
scale_dim=1 if self.use_int8_dispatch else 0,
scale_type_size=mori_scale_type_size,
max_num_inp_token_per_rank=2048,
num_experts_per_rank=self.local_num_experts,
num_experts_per_token=self.top_k,
max_token_type_size=2,
block_num=64,
block_num=80,
warp_num_per_block=16,
kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode
kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode if multi_node else \
mori.ops.EpDispatchCombineKernelType.IntraNode
)
_MORI_OP = mori.ops.EpDispatchCombineOp(config)
......@@ -291,13 +304,11 @@ class EPMoE(FusedMoE):
return quant_method
def sync(self):
torch.cuda.synchronize()
#torch.cuda.synchronize()
dist.barrier()
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None):
router_logits: torch.Tensor):
return torch.ops.vllm.ep_moe_forward(hidden_states, router_logits,
self.layer_name)
......@@ -322,9 +333,7 @@ class EPMoE(FusedMoE):
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None):
router_logits: torch.Tensor):
topk_weights, topk_ids = self.select_experts(
hidden_states=hidden_states,
......@@ -337,25 +346,27 @@ class EPMoE(FusedMoE):
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
indices_type=torch.int64,
indices_type=torch.int32,
routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate)
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
else:
shared_output = self.shared_experts(hidden_states)
topk_ids = topk_ids.to(torch.int32)
scales = torch.rand(
hidden_states.shape[0],
0,
dtype=torch.float32,
device=hidden_states.device,
)
shared_output = self.shared_experts(hidden_states)
if self.use_int8_dispatch:
hidden_states, scales = per_token_quant_int8(hidden_states)
else:
if self.scales is None:
self.scales = torch.rand(
hidden_states.shape[0],
0,
dtype=torch.float32,
device=hidden_states.device,
)
scales = self.scales
#dist.barrier()
#self.sync()
(
dispatch_output,
......@@ -369,49 +380,54 @@ class EPMoE(FusedMoE):
scales,
topk_ids,
)
#self.sync()
expect_m = hidden_states.shape[0] * self.ep_size
dispatch_output_clip = dispatch_output[:expect_m]
dispatch_weights_clip = dispatch_weights[:expect_m]
dispatch_indices_clip = dispatch_indices[:expect_m]
dispatch_scales_clip = dispatch_scales[:expect_m]
expert_output = self.quant_method.apply_ep(
layer=self,
x=dispatch_output_clip,
topk_weights=dispatch_weights_clip,
topk_ids=dispatch_indices_clip,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
use_nn_moe=self.use_nn_moe,
num_local_tokens=dispatch_recv_num_token,
config_select_bs=hidden_states.shape[0],
scales=dispatch_scales_clip if self.use_int8_dispatch else None
#routed_scaling_factor=self.routed_scaling_factor,
)
# dispatch_recv_num_token = dispatch_recv_num_token.cpu()[0]
# #dispatch_recv_num_token = dispatch_recv_num_token.item()
# dispatch_output = dispatch_output[:dispatch_recv_num_token]
# dispatch_weights = dispatch_weights[:dispatch_recv_num_token]
# dispatch_indices = dispatch_indices[:dispatch_recv_num_token]
# dispatch_recv_num_token = dispatch_recv_num_token.item()
# dispatch_output = torch.narrow(dispatch_output, dim=0, start=0, length=dispatch_recv_num_token)
# dispatch_weights = torch.narrow(dispatch_weights, dim=0, start=0, length=dispatch_recv_num_token)
# dispatch_indices = torch.narrow(dispatch_indices, dim=0, start=0, length=dispatch_recv_num_token)
# valid_mask = ((dispatch_indices <= 255) & (dispatch_indices >= 0)).all(dim=1)
# dispatch_output = dispatch_output[valid_mask]
# dispatch_indices = dispatch_indices[valid_mask]
# dispatch_weights = dispatch_weights[valid_mask]
# dispatch_recv_num_token = dispatch_indices.shape[0]
# dispatch_recv_num_token = dispatch_recv_num_token.cpu()[0]
# has_greater_than_255 = torch.any(dispatch_indices > 255).item()
# has_less_than_0 = torch.any(dispatch_indices < 0).item()
# print("##################################has_greater_than_255:{} has_less_than_0:{}".format(has_greater_than_255, has_less_than_0))
# if has_greater_than_255 or has_less_than_0:
# print("###################dispatch_indices:", dispatch_indices.tolist())
if dispatch_recv_num_token > 0:
# Matrix multiply.
expert_output = self.quant_method.apply_ep(
layer=self,
x=dispatch_output,
topk_weights=dispatch_weights,
topk_ids=dispatch_indices,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
use_nn_moe=self.use_nn_moe,
)
else:
expert_output = dispatch_output#[:dispatch_recv_num_token]
# if self.first and hidden_states.shape[0] == 2:
# self.first = False
# import numpy as np
# np.save(f'/work/vllm_profile/ep{self.ep_rank}_topk_ids.npy', dispatch_indices_clip.cpu().numpy())
# print("##################config_select_bs:{} topk_ids shape:{} num_local_tokens:{}".format(hidden_states.shape[0],
# topk_ids.shape,
# dispatch_recv_num_token))
# expert_output = self.quant_method.apply_ep(
# layer=self,
# x=dispatch_output,
# topk_weights=dispatch_weights,
# topk_ids=dispatch_indices,
# global_num_experts=self.global_num_experts,
# expert_map=self.expert_map,
# activation=self.activation,
# apply_router_weight_on_input=self.apply_router_weight_on_input,
# use_nn_moe=self.use_nn_moe,
# num_local_tokens=dispatch_recv_num_token,
# config_select_bs=hidden_states.shape[0]*2,
# scales=dispatch_scales if self.use_int8_dispatch else None
# #routed_scaling_factor=self.routed_scaling_factor,
# )
#self.sync()
combine_output, _ = self.mori_op.combine(expert_output, dispatch_weights, topk_ids)
......@@ -422,9 +438,9 @@ class EPMoE(FusedMoE):
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
# if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations
shared_output = (
self.maybe_all_reduce_tensor_model_parallel(
shared_output))
# shared_output = (
# self.maybe_all_reduce_tensor_model_parallel(
# shared_output))
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
final_hidden_states = final_hidden_states + shared_output
......@@ -434,31 +450,26 @@ class EPMoE(FusedMoE):
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if envs.USE_FUSED_RMS_QUANT:
return final_hidden_states, new_resi
else:
return final_hidden_states, None
return final_hidden_states
def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None) -> tuple[torch.Tensor, torch.Tensor]:
layer_name: str) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None
return self.forward_impl(hidden_states, router_logits, rms_weight, residual)
return self.forward_impl(hidden_states, router_logits)
def ep_moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(hidden_states), torch.empty_like(hidden_states)
layer_name: str) -> torch.Tensor:
return torch.empty_like(hidden_states)
direct_register_custom_op(
op_name="ep_moe_forward",
op_func=ep_moe_forward,
mutates_args=["hidden_states", "router_logits", "rms_weight", "residual"],
mutates_args=["hidden_states", "router_logits"],
fake_impl=ep_moe_forward_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
......
......@@ -1257,13 +1257,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None:
use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None,
true_bs: Optional[int] = None,) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,use_int4_w4a8,
per_channel_quant, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe)
block_shape, use_nn_moe, num_local_tokens, true_bs)
def inplace_fused_experts_fake(
......@@ -1289,7 +1291,9 @@ def inplace_fused_experts_fake(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None:
use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None,
true_bs: Optional[int] = None,) -> None:
pass
......@@ -1325,14 +1329,16 @@ def outplace_fused_experts(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor:
use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None,
true_bs: Optional[int] = None,) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
use_int4_w4a16,use_int4_w4a8, per_channel_quant,
global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe)
block_shape, use_nn_moe, num_local_tokens, true_bs)
def outplace_fused_experts_fake(
......@@ -1357,7 +1363,9 @@ def outplace_fused_experts_fake(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor:
use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None,
true_bs: Optional[int] = None,) -> torch.Tensor:
return torch.empty_like(hidden_states)
......@@ -1414,7 +1422,9 @@ def fused_experts(
block_shape: Optional[List[int]] = None,
allow_deep_gemm: bool = False,
allow_cutlass_block_scaled_grouped_gemm: bool = False,
use_nn_moe: Optional[bool] = False) -> torch.Tensor:
use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None,
true_bs: Optional[int] = None,) -> torch.Tensor:
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
N = w1.size(1)
......@@ -1472,7 +1482,9 @@ def fused_experts(
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
use_nn_moe=use_nn_moe)
use_nn_moe=use_nn_moe,
num_local_tokens=num_local_tokens,
true_bs=true_bs)
def fused_experts_impl(
......@@ -1500,6 +1512,8 @@ def fused_experts_impl(
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None,
true_bs: Optional[int] = None,
) -> torch.Tensor:
num_tokens = hidden_states.size(0)
if use_nn_moe:
......@@ -1544,7 +1558,9 @@ def fused_experts_impl(
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
use_nn_moe=False
use_nn_moe=False,
num_local_tokens=num_local_tokens,
true_bs=true_bs,
)
elif use_int4_w4a8 is True:
return fused_experts_impl_w4a8(hidden_states=hidden_states,
......
......@@ -152,7 +152,8 @@ def moe_align_block_size(
num_experts: int,
expert_map: Optional[torch.Tensor] = None,
pad_sorted_ids: bool = False,
num_token: Optional[int] = None
num_token: Optional[int] = None,
num_local_tokens: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns the token distribution across experts to be compatible with block
......@@ -234,11 +235,11 @@ def moe_align_block_size(
if envs.VLLM_USE_LIGHT_OP:
op.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad, None)
expert_ids, num_tokens_post_pad, expert_map, None, num_local_tokens)
else:
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
if expert_map is not None:
expert_ids = expert_map[expert_ids]
if expert_map is not None:
expert_ids = expert_map[expert_ids]
return sorted_ids, expert_ids, num_tokens_post_pad
\ No newline at end of file
......@@ -486,9 +486,13 @@ class ColumnParallelLinear(LinearBase):
prefix: str = "",
*,
return_bias: bool = True,
expect_tp_size: Optional[int] = None,
):
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
if expect_tp_size is not None:
self.expect_tp_size = expect_tp_size
self.tp_size = self.expect_tp_size
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
......@@ -728,10 +732,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
prefix: str = "",
*,
return_bias: bool = True,
expect_tp_size: Optional[int] = None,
):
self.eps = eps
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
if expect_tp_size is not None:
tp_size = expect_tp_size
self.expect_tp_size = expect_tp_size
self.expect_tp_size = expect_tp_size
assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size=input_size,
output_size=sum(output_sizes),
......@@ -741,7 +753,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias)
return_bias=return_bias,
expect_tp_size=expect_tp_size)
def weight_loader(self,
param: Parameter,
......@@ -838,6 +851,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert loaded_shard_id < len(self.output_sizes)
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
if self.expect_tp_size is not None and self.expect_tp_size == 1:
tp_rank = 0
tp_size = 1
if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
......@@ -1384,10 +1402,16 @@ class RowParallelLinear(LinearBase):
prefix: str = "",
*,
return_bias: bool = True,
expect_tp_size: Optional[int] = None,
):
# Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
if expect_tp_size is not None:
self.tp_rank = 0
self.tp_size = 1
self.expect_tp_size = expect_tp_size
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
......@@ -1433,6 +1457,10 @@ class RowParallelLinear(LinearBase):
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
if self.expect_tp_size is not None:
tp_rank = 0
tp_size = 1
input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
......
from typing import Any, Callable, Dict, List, Optional
import os
import torch
import vllm.envs as envs
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.utils import set_weight_attrs
from vllm.distributed import get_tensor_model_parallel_world_size
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase)
from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig,
......@@ -16,6 +18,10 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter)
from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod
from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
FusedMoeWeightScaleSupported)
try:
from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin
......@@ -23,6 +29,9 @@ except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
logger = init_logger(__name__)
class MarlinMoeWorkspace:
"""
Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE.
......@@ -220,6 +229,10 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
apply_router_weight_on_input: bool = False,
activation: str = "silu",
use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None,
config_select_bs: Optional[int] = None,
routed_scaling_factor: Optional[float] = None,
scales: Optional[torch.Tensor] = None,
**_
) -> torch.Tensor:
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
......@@ -243,6 +256,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe,
num_local_tokens=num_local_tokens,
config_select_bs=config_select_bs,
q_scales=scales
)
def apply(
......@@ -309,3 +325,43 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe,
)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import (
BatchedGroupedGemmExperts, GroupedGemmGemmExperts)
assert not self.rocm_aiter_moe_enabled, (
"ROCm AITER are not supported with all2all yet.")
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
max_num_tokens_per_rank = (
prepare_finalize.max_num_tokens_per_rank())
assert max_num_tokens_per_rank is not None
logger.debug(
"BatchedGroupedGemmExperts(%s): "
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
self.__class__.__name__, max_num_tokens_per_rank,
self.quant_config.weight_block_size, False)
return BatchedGroupedGemmExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
use_fp8_w8a8=False,
block_shape=self.quant_config.weight_block_size,
per_act_token_quant=True,
allow_deep_gemm=False,
)
else:
logger.debug(
"GroupedGemmGemmExperts(%s): block_size=%s, per_act_token=%s",
self.__class__.__name__, self.quant_config.weight_block_size,
False)
return GroupedGemmGemmExperts(
use_fp8_w8a8=False,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=False,
)
......@@ -178,7 +178,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
parallel_config = vllm_config.parallel_config
dp_size = get_dp_group().world_size
self.use_all2all_ep = envs.VLLM_USE_ALLTOALL_EP and dp_size > 1 and parallel_config.enable_expert_parallel
self.use_mori_ep = envs.VLLM_USE_MORI_EP and dp_size > 1 and parallel_config.enable_expert_parallel
def forward(
......@@ -211,7 +211,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
("gate_up_proj", "up_proj", 1),
]
if self.use_all2all_ep:
if self.use_mori_ep:
ep_moe_shared_experts_keys = "mlp.shared_experts"
ep_moe_shared_experts_mapping = {ep_moe_shared_experts_keys:"mlp.experts.shared_experts"}
......@@ -244,7 +244,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
continue
name = name.replace(weight_name, param_name)
if self.use_all2all_ep:
if self.use_mori_ep:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
......@@ -261,7 +261,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
continue
name = name.replace(weight_name, param_name)
if self.use_all2all_ep:
if self.use_mori_ep:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
param = params_dict[name]
......@@ -273,7 +273,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
expert_id=expert_id)
break
else:
if self.use_all2all_ep:
if self.use_mori_ep:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
......
......@@ -165,9 +165,9 @@ class DeepseekV2MoE(nn.Module):
self.n_local_physical_experts)
dp_size = get_dp_group().world_size
self.use_all2all_ep = envs.VLLM_USE_ALLTOALL_EP and dp_size > 1 and parallel_config.enable_expert_parallel
self.use_mori_ep = envs.VLLM_USE_MORI_EP and dp_size > 1 and parallel_config.enable_expert_parallel
moe_cls = FusedMoE if not self.use_all2all_ep else EPMoE
moe_cls = FusedMoE if not self.use_mori_ep else EPMoE
self.experts = moe_cls(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
......@@ -189,8 +189,8 @@ class DeepseekV2MoE(nn.Module):
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
#shared_expert_cls = DeepseekV2MLP if not self.use_all2all_ep else EPSharedExperts
self.shared_experts = DeepseekV2MLP(
shared_expert_cls = DeepseekV2MLP if not self.use_mori_ep else EPSharedExperts
self.shared_experts = shared_expert_cls(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
......@@ -199,7 +199,7 @@ class DeepseekV2MoE(nn.Module):
),
prefix=f"{prefix}.shared_experts",
)
if self.use_all2all_ep:
if self.use_mori_ep:
self.experts.set_shared_experts(self.shared_experts)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
......@@ -212,7 +212,7 @@ class DeepseekV2MoE(nn.Module):
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if not self.use_all2all_ep:
if not self.use_mori_ep:
if self.n_shared_experts is not None:
if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
......@@ -222,7 +222,7 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
if not self.use_all2all_ep:
if not self.use_mori_ep:
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states,
......@@ -233,10 +233,10 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
else:
final_hidden_states, new_resi = self.experts(hidden_states=hidden_states,
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if not self.use_all2all_ep:
if not self.use_mori_ep:
if shared_output is not None:
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
final_hidden_states = final_hidden_states + shared_output
......@@ -917,7 +917,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
parallel_config = vllm_config.parallel_config
dp_size = get_dp_group().world_size
self.use_all2all_ep = envs.VLLM_USE_ALLTOALL_EP and dp_size > 1 and parallel_config.enable_expert_parallel
self.use_mori_ep = envs.VLLM_USE_MORI_EP and dp_size > 1 and parallel_config.enable_expert_parallel
def set_eplb_state(
self,
......@@ -1000,7 +1000,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
("gate_up_proj", "up_proj", 1),
]
if self.use_all2all_ep:
if self.use_mori_ep:
ep_moe_shared_experts_keys = "mlp.shared_experts"
ep_moe_shared_experts_mapping = {ep_moe_shared_experts_keys:"mlp.experts.shared_experts"}
......@@ -1037,7 +1037,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
continue
name = name.replace(weight_name, param_name)
if self.use_all2all_ep:
if self.use_mori_ep:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
# Skip loading extra bias for GPTQ models.
......@@ -1066,7 +1066,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# Instead, create a new variable
name_mapped = name.replace(weight_name, param_name)
if self.use_all2all_ep:
if self.use_mori_ep:
name_mapped = name_mapped.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
if is_pp_missing_parameter(name_mapped, self):
......@@ -1094,7 +1094,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# So we simply skip it
continue
if self.use_all2all_ep:
if self.use_mori_ep:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
......
......@@ -320,7 +320,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.shared_kv_cache_layers: dict[str, str] = {}
dp_size = self.vllm_config.parallel_config.data_parallel_size
self.use_all2all_ep = envs.VLLM_USE_ALLTOALL_EP and dp_size > 1 and parallel_config.enable_expert_parallel
self.use_mori_ep = envs.VLLM_USE_MORI_EP and dp_size > 1 and parallel_config.enable_expert_parallel
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
"""
......@@ -1234,7 +1234,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager or self.use_all2all_ep:
if dp_size == 1 or self.vllm_config.model_config.enforce_eager or self.use_mori_ep:
# Early exit.
return 0, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment