Commit d04683a4 authored by 王敏's avatar 王敏
Browse files

[feat]上传初版基于all2all通信的大EP代码

parent cfabf125
......@@ -40,6 +40,8 @@ except ImportError:
HAVE_TE = False
shared_experts_overlap_stream = torch.cuda.Stream()
@dataclass
class EpMoeConfig:
......@@ -48,18 +50,25 @@ class EpMoeConfig:
moe_shared_expert_overlap: bool = False
ep_size: int = 1
num_moe_experts: int = 256
apply_router_weight_on_input: bool = False
routed_scaling_factor: float = 1.0
@staticmethod
def make(moe_router_topk: int = 2,
moe_permute_fusion: bool = False,
moe_shared_expert_overlap: bool = False,
ep_size: int = 1,
num_moe_experts: int = 256) -> "EpMoeConfig":
num_moe_experts: int = 256,
routed_scaling_factor: float = 1.0,
apply_router_weight_on_input: bool = False) -> "EpMoeConfig":
return EpMoeConfig(moe_router_topk=moe_router_topk,
moe_permute_fusion=moe_permute_fusion,
moe_shared_expert_overlap=moe_shared_expert_overlap,
ep_size=ep_size,
num_moe_experts=num_moe_experts)
num_moe_experts=num_moe_experts,
routed_scaling_factor=routed_scaling_factor,
apply_router_weight_on_input=apply_router_weight_on_input)
class EPSharedExperts(nn.Module):
......@@ -99,7 +108,7 @@ class EPSharedExperts(nn.Module):
self.cached_output = None
self.gate_score = None
self.stream = torch.cuda.Stream()
self.stream = shared_experts_overlap_stream
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
......@@ -215,55 +224,35 @@ def permute(
routing_map,
num_out_tokens: Optional[int] = None,
fused: bool = False,
drop_and_pad: bool = False,
):
"""Permute the tokens and probs based on the mask.
Tokens with the same designated expert will be grouped together.
The shape of mask is [tokens, num_experts], it indicates which experts were selected
by each token.
When drop_and_pad=True, in routing_map, the number of non-zeros in each column equals to
expert capacity. This function exploits this feature to use ops that support cuda graph.
Args:
tokens (torch.Tensor): The input token tensor, [num_tokens, hidden].
routing_map (torch.Tensor): The sparse token to expert mapping, [num_tokens, num_experts].
num_out_tokens (int, optional): The number of output tokens. If None, it's set to
the number of input tokens.
fused (bool, optional): Whether use the fused permute function.
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
and pads the number of tokens to the expert capacity.
If set to true, routing_map has a fixed number of non-zeros
in each column.
"""
if fused:
if not HAVE_TE or fused_permute is None:
raise ValueError("fused_permute is not available. Please install TE >= 2.1.0.")
return fused_permute(tokens, routing_map, num_out_tokens)
num_tokens, hidden = tokens.shape
num_experts = routing_map.shape[1]
if drop_and_pad and not (num_out_tokens is None):
capacity = num_out_tokens // num_experts
assert not routing_map.requires_grad
# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
routing_map = routing_map.to(dtype=torch.int8).T.contiguous()
# use argsort to put indices of all non-zeros in the beginning of list
# and keep the first `capacity` number of indices
sorted_indices = routing_map.argsort(dim=-1, descending=True, stable=True)[
:, :capacity
].contiguous()
# flatten from [num_experts, capacity] to 1D
sorted_indices = sorted_indices.view(-1)
else:
# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
routing_map = routing_map.bool().T.contiguous()
# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
routing_map = routing_map.bool().T.contiguous()
# Create a dense expert-to-token mapping from the sparse token-to-expert mapping
token_indices = (
torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1)
)
sorted_indices = token_indices.masked_select(routing_map)
# Create a dense expert-to-token mapping from the sparse token-to-expert mapping
token_indices = (
torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1)
)
sorted_indices = token_indices.masked_select(routing_map)
# use the mapping to permute the tokens
permuted_input = tokens.index_select(0, sorted_indices)
......@@ -278,7 +267,6 @@ def unpermute(
probs: torch.Tensor = None,
routing_map: torch.Tensor = None,
fused: bool = False,
drop_and_pad: bool = False,
):
"""
Restore the original order of tokens after permutation. If probs are provided, it
......@@ -294,8 +282,6 @@ def unpermute(
routing_map (torch.Tensor, optional): Token to expert mapping, shape
[num_tokens, num_experts].
fused (bool, optional): Whether use the fused unpermute function.
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
and pads the number of tokens to the expert capacity.
Returns:
torch.Tensor: The tokens restored to their original order.
......@@ -310,24 +296,7 @@ def unpermute(
if probs is not None:
assert routing_map is not None, "Mask must be provided to permute the probs."
if drop_and_pad:
num_experts = routing_map.size(1)
num_permuted_tokens = sorted_indices.size(0)
capacity = num_permuted_tokens // num_experts
num_unpermuted_tokens = probs.size(0)
# [num_unpermuted_tokens, num_experts] -> num_experts * num_unpermuted_tokens
probs_T_1D = probs.T.contiguous().view(-1)
# get 1D indices of the probs selected by routing_map
indices_dim0 = torch.arange(num_experts, device=routing_map.device).unsqueeze(-1)
indices_dim1 = sorted_indices.view(num_experts, capacity)
indices_1D = (indices_dim0 * num_unpermuted_tokens + indices_dim1).view(-1)
# get probs from indices
permuted_probs = probs_T_1D.index_select(0, indices_1D)
else:
permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous())
permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous())
# Here may promote permuted_tokens to higher precision (fp32/fp64) if probs is in
# higher precision due to moe_router_dtype being enabled. This can lead to
# additional GPU memory usage. Use --moe-permute-fusion flag to avoid this extra memory
......@@ -344,11 +313,6 @@ def unpermute(
def all_to_all(group, input, output_split_sizes, input_split_sizes):
# torch.cuda.synchronize()
# import sys
# sys.stderr.write(f"############all_to_all input_split_sizes:{input_split_sizes}\n output_split_sizes:{output_split_sizes}")
# sys.stderr.flush()
world_size = torch.distributed.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
......
This diff is collapsed.
import os
import logging
from typing import Callable, List, Optional, Tuple
from dataclasses import dataclass
import torch
from torch import nn
import torch.nn.functional as F
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.model_executor.custom_op import CustomOp
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EPSharedExperts, EpMoeConfig
from vllm.model_executor.layers.fused_moe.ep_moe.kernels import grouped_gemm_triton
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EpMoeConfig
from vllm.utils import direct_register_custom_op
logger = init_logger(__name__)
@CustomOp.register("unquantized_ep_moe")
class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
"""MoE method without quantization."""
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.topk_indices_dtype = None
self.moe = moe
self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
def apply(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
return self.forward(
hidden_states=hidden_states,
layer=layer,
tokens_per_expert=tokens_per_expert)
def forward_cuda(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
# process MoE
def custom_forward(layer, hidden_states, tokens_per_expert):
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
w1 = layer.w13_weight[i]
w2 = layer.w2_weight[i]
tokens_for_this_expert = hidden_states[start_idx:end_idx]
gateup_output = torch.matmul(tokens_for_this_expert, w1.T)
# Act
down_input = torch.zeros(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=hidden_states.dtype
)
torch.ops._C.silu_and_mul(down_input,
gateup_output.view(-1, w1.shape[0]))
expert_out = torch.matmul(down_input, w2.T)
outputs.append(expert_out)
start_idx = end_idx
if len(outputs) > 0:
expert_output = torch.cat(outputs, dim=0)
else:
assert hidden_states.numel() == 0, f"sorted_tokens: should be empty, but got {hidden_states.shape}"
expert_output = hidden_states
return expert_output
output = custom_forward(layer, hidden_states, tokens_per_expert)
return output
def forward_cpu(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
**kwargs,
):
raise NotImplementedError
def forward_hpu(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
def forward_tpu(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
if current_platform.is_tpu():
forward_native = forward_tpu
elif current_platform.is_cpu():
forward_native = forward_cpu
else:
forward_native = forward_cuda
class EPMoE(FusedMoE):
"""
dp+ep MoE Expert Parallel Impl
......@@ -46,7 +157,7 @@ class EPMoE(FusedMoE):
apply_router_weight_on_input: bool = False,
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
moe_permute_fusion: bool = False,
moe_permute_fusion: bool = True,
moe_shared_expert_overlap: bool = False
):
super().__init__(num_experts, top_k, hidden_size,
......@@ -68,7 +179,9 @@ class EPMoE(FusedMoE):
moe_permute_fusion=moe_permute_fusion,
moe_shared_expert_overlap=moe_shared_expert_overlap,
ep_size=self.ep_size,
num_moe_experts=self.global_num_experts
num_moe_experts=self.global_num_experts,
routed_scaling_factor=self.routed_scaling_factor,
apply_router_weight_on_input=self.apply_router_weight_on_input
)
local_expert_indices_offset = (
......@@ -78,149 +191,41 @@ class EPMoE(FusedMoE):
local_expert_indices_offset + i for i in range(self.local_num_experts)
]
self.shared_experts = None
self.use_shared_expert = False
self.token_dispatcher = MoEAlltoAllTokenDispatcher(
self.local_num_experts, self.local_expert_indices, config=self.ep_moe_config
)
self.shared_expert_overlap = moe_shared_expert_overlap
self.seg_indptr = None
if quant_config is None:
self.use_fp8_w8a8 = False
self.use_block_quant = False
self.block_shape = None
self.activation_scheme = None
self.w13_weight_scale = None
self.w2_weight_scale = None
else:
self.use_fp8_w8a8 = True
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.block_shape = (
self.quant_method.quant_config.weight_block_size
if self.use_block_quant
else None
)
self.fp8_dtype = torch.float8_e4m3fn
self.activation_scheme = quant_config.activation_scheme
def set_shared_experts(self, shared_experts):
self.shared_experts = shared_experts
self.use_shared_expert = shared_experts is not None
if self.shared_expert_overlap:
self.token_dispatcher.set_shared_experts(shared_experts)
def triton_grouped_gemm_impl(self, hidden_states, tokens_per_expert, use_nn_moe):
torch.cumsum(tokens_per_expert,
dim=0,
out=self.seg_indptr[1:])
_, N, _ = self.w13_weight.shape
gateup_input = hidden_states
weight_indices_cur_rank = torch.arange(
0,
self.local_num_experts,
device=hidden_states.device,
dtype=torch.int64,
)
# GroupGemm-0
gateup_output = torch.empty(
gateup_input.shape[0],
self.w13_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
self.shared_experts = None
gateup_output = grouped_gemm_triton(
a=gateup_input,
b=self.w13_weight,
c=gateup_output,
batch_size=self.local_num_experts,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=self.use_fp8_w8a8,
scale_a=self.w13_input_scale if self.quant_config is not None else None,
scale_b=(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
) if self.quant_config is not None else None,
block_shape=self.block_shape,
)
self.dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
# Act
down_input = torch.empty(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=(
self.fp8_dtype
if (self.use_fp8_w8a8 and not self.use_block_quant)
else hidden_states.dtype
),
)
if self.quant_config is not None and self.w2_input_scale is None and not self.use_block_quant:
self.w2_input_scale = torch.ones(
self.local_num_experts,
dtype=torch.float32,
device=hidden_states.device,
)
def set_shared_experts(self, shared_experts: torch.nn.Module):
if self.shared_experts is None:
self.shared_experts = shared_experts
if self.shared_expert_overlap:
self.token_dispatcher.set_shared_experts(self.shared_experts)
if self.activation == "silu":
torch.ops._C.silu_and_mul(down_input,
gateup_output.view(-1, N))
elif self.activation == "gelu":
torch.ops._C.gelu_and_mul(down_input,
gateup_output.view(-1, N))
else:
raise ValueError(f"Unsupported FusedMoe activation: {self.activation}")
# GroupGemm-1
down_output = torch.empty(
down_input.shape[0],
self.w2_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
down_output = grouped_gemm_triton(
a=down_input,
b=self.w2_weight,
c=down_output,
batch_size=self.local_num_experts,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=self.use_fp8_w8a8,
scale_a=self.w2_input_scale if self.quant_config is not None else None,
scale_b=(
self.w2_weight_scale_inv
if self.use_block_quant
else self.w2_weight_scale
) if self.quant_config is not None else None,
block_shape=self.block_shape,
)
return down_output
def create_quant_method(self, moe, quant_config, prefix):
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None
quant_method = (UnquantizedEPGroupedGemmMethod(moe) if quant_config is None
else quant_config.get_quant_method(self, prefix))
assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase)
return quant_method
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
if (
self.training
and self.config.tensor_model_parallel_size > 1
and not self.config.sequence_parallel
):
raise ValueError(
"During training, performance may degrade if MoE and tensor parallelism"
"are enabled without also enabling sequence parallelism."
)
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
return torch.ops.vllm.ep_moe_forward(hidden_states, router_logits,
self.layer_name)
def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
if self.seg_indptr is None:
self.seg_indptr = torch.zeros(self.local_num_experts+1, device=hidden_states. device, dtype=torch.int64)
# process MoE
def custom_forward(hidden_states, router_logits):
topk_weights, topk_ids = self.select_experts(
topk_weights, topk_ids = self.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
use_grouped_topk=self.use_grouped_topk,
......@@ -234,20 +239,60 @@ class EPMoE(FusedMoE):
indices_type=torch.int64,
routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate)
probs = torch.zeros_like(router_logits, dtype=topk_weights.dtype).scatter(1, topk_ids, topk_weights)
routing_map = torch.zeros_like(router_logits).int().scatter(1, topk_ids, 1).bool()
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
(dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
hidden_states, probs, routing_map
)
expert_output = self.triton_grouped_gemm_impl(dispatched_input, tokens_per_expert, self.use_nn_moe)
output = self.token_dispatcher.token_unpermutation(expert_output)
if self.use_shared_expert and not self.shared_expert_overlap:
# if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations
output = output + self.shared_experts(hidden_states)
return output
probs = torch.zeros_like(router_logits, dtype=topk_weights.dtype).scatter(1, topk_ids, topk_weights)
routing_map = torch.zeros_like(router_logits).int().scatter(1, topk_ids, 1).bool()
(dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
hidden_states, probs, routing_map
)
# Matrix multiply.
expert_output = self.quant_method.apply(
layer=self,
hidden_states=dispatched_input,
tokens_per_expert=tokens_per_expert
)
final_hidden_states = self.token_dispatcher.token_unpermutation(expert_output)
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
# if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations
shared_output = (
self.maybe_all_reduce_tensor_model_parallel(
shared_output))
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
return final_hidden_states
def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None
return self.forward_impl(hidden_states, router_logits)
def ep_moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor:
return torch.empty_like(hidden_states)
output = custom_forward(hidden_states, router_logits)
return output
\ No newline at end of file
direct_register_custom_op(
op_name="ep_moe_forward",
op_func=ep_moe_forward,
mutates_args=["hidden_states"],
fake_impl=ep_moe_forward_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
\ No newline at end of file
import os
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
......@@ -21,6 +22,9 @@ from vllm.distributed import (tensor_model_parallel_all_gather,
expert_parallel_gather)
from vllm.platforms import current_platform
cuda_dtoh_stream = torch.cuda.Stream()
class MoETokenDispatcher:
"""
MoE Token Dispatcher
......@@ -31,7 +35,6 @@ class MoETokenDispatcher:
Initialize the MoE Token Dispatcher.
"""
self.config = config
self.shared_experts: Optional[EPSharedExperts] = None
self.tp_size = 1
self.ep_size = config.ep_size
......@@ -162,13 +165,14 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
"no_sync": 4,
}
self.cuda_dtoh_point = "before_permutation_1"
self.cuda_dtoh_stream = torch.cuda.Stream()
self.shared_experts = None
#self.cuda_dtoh_stream = torch.cuda.Stream()
# Whether to use gather or all-gather to gather the logits.
self.use_all_gather = current_platform.use_all_gather()
self.probs = None
self.dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor:
"""
Preprocess token routing map for AlltoAll communication and token permutation.
......@@ -264,7 +268,9 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
return num_tokens_per_local_expert
def token_permutation(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
self, hidden_states: torch.Tensor,
probs: torch.Tensor,
routing_map: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Dispatch tokens to local experts using AlltoAll communication.
......@@ -287,7 +293,8 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
"""
# Preprocess: Get the metadata for communication, permutation and computation operations.
self.hidden_shape = hidden_states.shape
self.probs = probs
if self.config.apply_router_weight_on_input:
self.probs = probs
self.routing_map = routing_map
assert probs.dim() == 2, "Expected 2D tensor for probs"
assert routing_map.dim() == 2, "Expected 2D tensor for token2expert mask"
......@@ -295,50 +302,32 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
tokens_per_expert = self.preprocess(self.routing_map)
if self.shared_experts is not None:
if self.config.moe_shared_expert_overlap and self.shared_experts is not None:
self.shared_experts.pre_forward_comm(hidden_states.view(self.hidden_shape))
import sys
# torch.cuda.synchronize()
# sys.stderr.write(f"token_permutation===============================================")
# sys.stderr.flush()
# Permutation 1: input to AlltoAll input
tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_permutation_1", tokens_per_expert
)
# torch.cuda.synchronize()
# sys.stderr.write(f"before permute===============================================")
# sys.stderr.flush()
self.hidden_shape_before_permute = hidden_states.shape
permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute(
hidden_states,
routing_map,
num_out_tokens=self.num_out_tokens,
fused=self.config.moe_permute_fusion,
drop_and_pad=False,
fused=self.config.moe_permute_fusion
)
# torch.cuda.synchronize()
# sys.stderr.write(f"after permute===============================================")
# sys.stderr.flush()
# Perform expert parallel AlltoAll communication
tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_ep_alltoall", tokens_per_expert
)
#torch.cuda.synchronize()
#print("###########################before permutation all_to_all output_splits:{} input_splits:{}".format(self.output_splits, self.input_splits))
global_input_tokens = all_to_all(
self.ep_group.device_group, permutated_local_input_tokens, self.output_splits, self.input_splits
)
#torch.cuda.synchronize()
#print("#######################permutation all_to_all end")
if self.shared_experts is not None:
if self.config.moe_shared_expert_overlap and self.shared_experts is not None:
self.shared_experts.linear_fc1_forward_and_act(global_input_tokens)
# Permutation 2: Sort tokens by local expert.
......@@ -358,7 +347,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
return global_input_tokens, tokens_per_expert
def token_unpermutation(
self, hidden_states: torch.Tensor
self, hidden_states: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Reverse the token permutation to restore the original order.
......@@ -392,7 +381,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.ep_group.device_group, hidden_states, self.input_splits, self.output_splits
)
if self.shared_experts is not None:
if self.config.moe_shared_expert_overlap and self.shared_experts is not None:
self.shared_experts.linear_fc2_forward(permutated_local_input_tokens)
self.shared_experts.post_forward_comm()
......@@ -404,16 +393,22 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
probs=self.probs,
routing_map=self.routing_map,
fused=self.config.moe_permute_fusion,
drop_and_pad=False,
)
# Reshape the output tensor
output = output.view(self.hidden_shape)
# Add shared experts output
if self.shared_experts is not None:
shared_expert_output = self.shared_experts.get_output()
output += shared_expert_output
if self.config.moe_shared_expert_overlap and self.shared_experts is not None:
shared_output = self.shared_experts.get_output()
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
output = output + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
output = output + shared_output \
* (1. / self.config.routed_scaling_factor)
return output
def _maybe_update_cuda_sync_point(self, point: str):
......@@ -435,10 +430,10 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
"""
if point == self.cuda_dtoh_point:
# Move all possible GPU tensors to CPU at self.cuda_dtoh_point.
on_side_stream = torch.cuda.current_stream() != self.cuda_dtoh_stream
on_side_stream = torch.cuda.current_stream() != cuda_dtoh_stream
if on_side_stream:
self.cuda_dtoh_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.cuda_dtoh_stream):
cuda_dtoh_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(cuda_dtoh_stream):
# TODO: use MemcpyBatchAsync instead.
# tokens_per_expert = maybe_move_tensor_to_cpu(
# tokens_per_expert, record_stream=on_side_stream
......@@ -462,6 +457,6 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
if point == self.cuda_sync_point:
# Synchronize with the dtoh stream at self.cuda_sync_point.
self.cuda_dtoh_stream.synchronize()
cuda_dtoh_stream.synchronize()
return tokens_per_expert
\ No newline at end of file
......@@ -772,20 +772,12 @@ class FusedMoE(torch.nn.Module):
self.moe_config = moe
self.quant_config = quant_config
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None
quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
else quant_config.get_quant_method(self, prefix))
assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method
self.quant_method = self.create_quant_method(moe, quant_config, prefix)
if self.enable_eplb:
from vllm.model_executor.layers.quantization.fp8 import (
Fp8MoEMethod)
if not isinstance(quant_method, Fp8MoEMethod):
if not isinstance(self.quant_method, Fp8MoEMethod):
# TODO: Add support for additional quantization methods.
# The implementation for other quantization methods does not
# contain essential differences, but the current quant API
......@@ -852,6 +844,17 @@ class FusedMoE(torch.nn.Module):
dtype=moe.in_dtype,
device=torch.cuda.current_device())
def create_quant_method(self, moe, quant_config, prefix):
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None
quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
else quant_config.get_quant_method(self, prefix))
assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase)
return quant_method
@property
def tp_size(self):
return self.moe_parallel_config.tp_size
......
......@@ -156,7 +156,23 @@ class DeepseekV2MoE(nn.Module):
dp_size = get_dp_group().world_size
self.use_ep_opt = dp_size > 1 and parallel_config.enable_expert_parallel
self.shared_experts = None
moe_cls = FusedMoE if not self.use_ep_opt else EPMoE
self.experts = moe_cls(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor)
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
......@@ -167,48 +183,13 @@ class DeepseekV2MoE(nn.Module):
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
reduce_results=self.experts.must_reduce_shared_expert_outputs(
),
prefix=f"{prefix}.shared_experts",
)
if not self.use_ep_opt:
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor)
else:
self.experts = EPMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor)
if self.use_ep_opt:
self.experts.set_shared_experts(self.shared_experts)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce
......@@ -218,18 +199,22 @@ class DeepseekV2MoE(nn.Module):
if not self.use_ep_opt:
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if not self.use_ep_opt:
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
else:
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
router_logits=router_logits)
if not self.use_ep_opt:
if shared_output is not None:
......@@ -745,9 +730,7 @@ class DeepseekV2Model(nn.Module):
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(positions, hidden_states, residual)\
#ops.print_tensor(hidden_states)
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
......@@ -816,6 +799,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
self.tritonsingleton.topk = config.num_experts_per_tok
self.tritonsingleton.quant_method=self.quant_method
parallel_config = vllm_config.parallel_config
dp_size = get_dp_group().world_size
self.use_ep_opt = dp_size > 1 and parallel_config.enable_expert_parallel
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
......@@ -897,6 +884,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
("gate_up_proj", "up_proj", 1),
]
if self.use_ep_opt:
ep_moe_shared_experts_keys = "mlp.shared_experts"
ep_moe_shared_experts_mapping = {ep_moe_shared_experts_keys:"mlp.experts.shared_experts"}
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
......@@ -929,6 +920,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
if self.use_ep_opt:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
......@@ -955,6 +950,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# Instead, create a new variable
name_mapped = name.replace(weight_name, param_name)
if self.use_ep_opt:
name_mapped = name_mapped.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
if is_pp_missing_parameter(name_mapped, self):
continue
......@@ -979,7 +977,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# However it's not mapped locally to this rank
# So we simply skip it
continue
if self.use_ep_opt:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
......
......@@ -2052,7 +2052,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
inputs_embeds = self.inputs_embeds[:num_tokens]
else:
#self.input_ids[:num_tokens] = torch.randint(0, 120000, (num_tokens,), dtype=torch.int32)
self.input_ids[:num_tokens] = torch.arange(num_tokens, dtype=torch.int32, device=self.input_ids.device)
#self.input_ids[:num_tokens] = torch.arange(num_tokens, dtype=torch.int32, device=self.input_ids.device)
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
if self.uses_mrope:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment