Commit d04683a4 authored by 王敏's avatar 王敏
Browse files

[feat]上传初版基于all2all通信的大EP代码

parent cfabf125
...@@ -40,6 +40,8 @@ except ImportError: ...@@ -40,6 +40,8 @@ except ImportError:
HAVE_TE = False HAVE_TE = False
shared_experts_overlap_stream = torch.cuda.Stream()
@dataclass @dataclass
class EpMoeConfig: class EpMoeConfig:
...@@ -48,18 +50,25 @@ class EpMoeConfig: ...@@ -48,18 +50,25 @@ class EpMoeConfig:
moe_shared_expert_overlap: bool = False moe_shared_expert_overlap: bool = False
ep_size: int = 1 ep_size: int = 1
num_moe_experts: int = 256 num_moe_experts: int = 256
apply_router_weight_on_input: bool = False
routed_scaling_factor: float = 1.0
@staticmethod @staticmethod
def make(moe_router_topk: int = 2, def make(moe_router_topk: int = 2,
moe_permute_fusion: bool = False, moe_permute_fusion: bool = False,
moe_shared_expert_overlap: bool = False, moe_shared_expert_overlap: bool = False,
ep_size: int = 1, ep_size: int = 1,
num_moe_experts: int = 256) -> "EpMoeConfig": num_moe_experts: int = 256,
routed_scaling_factor: float = 1.0,
apply_router_weight_on_input: bool = False) -> "EpMoeConfig":
return EpMoeConfig(moe_router_topk=moe_router_topk, return EpMoeConfig(moe_router_topk=moe_router_topk,
moe_permute_fusion=moe_permute_fusion, moe_permute_fusion=moe_permute_fusion,
moe_shared_expert_overlap=moe_shared_expert_overlap, moe_shared_expert_overlap=moe_shared_expert_overlap,
ep_size=ep_size, ep_size=ep_size,
num_moe_experts=num_moe_experts) num_moe_experts=num_moe_experts,
routed_scaling_factor=routed_scaling_factor,
apply_router_weight_on_input=apply_router_weight_on_input)
class EPSharedExperts(nn.Module): class EPSharedExperts(nn.Module):
...@@ -99,7 +108,7 @@ class EPSharedExperts(nn.Module): ...@@ -99,7 +108,7 @@ class EPSharedExperts(nn.Module):
self.cached_output = None self.cached_output = None
self.gate_score = None self.gate_score = None
self.stream = torch.cuda.Stream() self.stream = shared_experts_overlap_stream
def forward(self, x): def forward(self, x):
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
...@@ -215,47 +224,27 @@ def permute( ...@@ -215,47 +224,27 @@ def permute(
routing_map, routing_map,
num_out_tokens: Optional[int] = None, num_out_tokens: Optional[int] = None,
fused: bool = False, fused: bool = False,
drop_and_pad: bool = False,
): ):
"""Permute the tokens and probs based on the mask. """Permute the tokens and probs based on the mask.
Tokens with the same designated expert will be grouped together. Tokens with the same designated expert will be grouped together.
The shape of mask is [tokens, num_experts], it indicates which experts were selected The shape of mask is [tokens, num_experts], it indicates which experts were selected
by each token. by each token.
When drop_and_pad=True, in routing_map, the number of non-zeros in each column equals to
expert capacity. This function exploits this feature to use ops that support cuda graph.
Args: Args:
tokens (torch.Tensor): The input token tensor, [num_tokens, hidden]. tokens (torch.Tensor): The input token tensor, [num_tokens, hidden].
routing_map (torch.Tensor): The sparse token to expert mapping, [num_tokens, num_experts]. routing_map (torch.Tensor): The sparse token to expert mapping, [num_tokens, num_experts].
num_out_tokens (int, optional): The number of output tokens. If None, it's set to num_out_tokens (int, optional): The number of output tokens. If None, it's set to
the number of input tokens. the number of input tokens.
fused (bool, optional): Whether use the fused permute function. fused (bool, optional): Whether use the fused permute function.
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
and pads the number of tokens to the expert capacity.
If set to true, routing_map has a fixed number of non-zeros
in each column.
""" """
if fused: if fused:
if not HAVE_TE or fused_permute is None: if not HAVE_TE or fused_permute is None:
raise ValueError("fused_permute is not available. Please install TE >= 2.1.0.") raise ValueError("fused_permute is not available. Please install TE >= 2.1.0.")
return fused_permute(tokens, routing_map, num_out_tokens) return fused_permute(tokens, routing_map, num_out_tokens)
num_tokens, hidden = tokens.shape num_tokens, hidden = tokens.shape
num_experts = routing_map.shape[1] num_experts = routing_map.shape[1]
if drop_and_pad and not (num_out_tokens is None):
capacity = num_out_tokens // num_experts
assert not routing_map.requires_grad
# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
routing_map = routing_map.to(dtype=torch.int8).T.contiguous()
# use argsort to put indices of all non-zeros in the beginning of list
# and keep the first `capacity` number of indices
sorted_indices = routing_map.argsort(dim=-1, descending=True, stable=True)[
:, :capacity
].contiguous()
# flatten from [num_experts, capacity] to 1D
sorted_indices = sorted_indices.view(-1)
else:
# mask [num_tokens, num_experts] -> [num_experts, num_tokens] # mask [num_tokens, num_experts] -> [num_experts, num_tokens]
routing_map = routing_map.bool().T.contiguous() routing_map = routing_map.bool().T.contiguous()
...@@ -278,7 +267,6 @@ def unpermute( ...@@ -278,7 +267,6 @@ def unpermute(
probs: torch.Tensor = None, probs: torch.Tensor = None,
routing_map: torch.Tensor = None, routing_map: torch.Tensor = None,
fused: bool = False, fused: bool = False,
drop_and_pad: bool = False,
): ):
""" """
Restore the original order of tokens after permutation. If probs are provided, it Restore the original order of tokens after permutation. If probs are provided, it
...@@ -294,8 +282,6 @@ def unpermute( ...@@ -294,8 +282,6 @@ def unpermute(
routing_map (torch.Tensor, optional): Token to expert mapping, shape routing_map (torch.Tensor, optional): Token to expert mapping, shape
[num_tokens, num_experts]. [num_tokens, num_experts].
fused (bool, optional): Whether use the fused unpermute function. fused (bool, optional): Whether use the fused unpermute function.
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
and pads the number of tokens to the expert capacity.
Returns: Returns:
torch.Tensor: The tokens restored to their original order. torch.Tensor: The tokens restored to their original order.
...@@ -310,23 +296,6 @@ def unpermute( ...@@ -310,23 +296,6 @@ def unpermute(
if probs is not None: if probs is not None:
assert routing_map is not None, "Mask must be provided to permute the probs." assert routing_map is not None, "Mask must be provided to permute the probs."
if drop_and_pad:
num_experts = routing_map.size(1)
num_permuted_tokens = sorted_indices.size(0)
capacity = num_permuted_tokens // num_experts
num_unpermuted_tokens = probs.size(0)
# [num_unpermuted_tokens, num_experts] -> num_experts * num_unpermuted_tokens
probs_T_1D = probs.T.contiguous().view(-1)
# get 1D indices of the probs selected by routing_map
indices_dim0 = torch.arange(num_experts, device=routing_map.device).unsqueeze(-1)
indices_dim1 = sorted_indices.view(num_experts, capacity)
indices_1D = (indices_dim0 * num_unpermuted_tokens + indices_dim1).view(-1)
# get probs from indices
permuted_probs = probs_T_1D.index_select(0, indices_1D)
else:
permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous()) permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous())
# Here may promote permuted_tokens to higher precision (fp32/fp64) if probs is in # Here may promote permuted_tokens to higher precision (fp32/fp64) if probs is in
# higher precision due to moe_router_dtype being enabled. This can lead to # higher precision due to moe_router_dtype being enabled. This can lead to
...@@ -344,11 +313,6 @@ def unpermute( ...@@ -344,11 +313,6 @@ def unpermute(
def all_to_all(group, input, output_split_sizes, input_split_sizes): def all_to_all(group, input, output_split_sizes, input_split_sizes):
# torch.cuda.synchronize()
# import sys
# sys.stderr.write(f"############all_to_all input_split_sizes:{input_split_sizes}\n output_split_sizes:{output_split_sizes}")
# sys.stderr.flush()
world_size = torch.distributed.get_world_size(group=group) world_size = torch.distributed.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if world_size == 1: if world_size == 1:
......
This diff is collapsed.
import os
import logging import logging
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.model_executor.custom_op import CustomOp
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAlltoAllTokenDispatcher from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EPSharedExperts, EpMoeConfig from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EpMoeConfig
from vllm.model_executor.layers.fused_moe.ep_moe.kernels import grouped_gemm_triton from vllm.utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
@CustomOp.register("unquantized_ep_moe")
class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
"""MoE method without quantization."""
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.topk_indices_dtype = None
self.moe = moe
self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
def apply(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
return self.forward(
hidden_states=hidden_states,
layer=layer,
tokens_per_expert=tokens_per_expert)
def forward_cuda(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
# process MoE
def custom_forward(layer, hidden_states, tokens_per_expert):
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
w1 = layer.w13_weight[i]
w2 = layer.w2_weight[i]
tokens_for_this_expert = hidden_states[start_idx:end_idx]
gateup_output = torch.matmul(tokens_for_this_expert, w1.T)
# Act
down_input = torch.zeros(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=hidden_states.dtype
)
torch.ops._C.silu_and_mul(down_input,
gateup_output.view(-1, w1.shape[0]))
expert_out = torch.matmul(down_input, w2.T)
outputs.append(expert_out)
start_idx = end_idx
if len(outputs) > 0:
expert_output = torch.cat(outputs, dim=0)
else:
assert hidden_states.numel() == 0, f"sorted_tokens: should be empty, but got {hidden_states.shape}"
expert_output = hidden_states
return expert_output
output = custom_forward(layer, hidden_states, tokens_per_expert)
return output
def forward_cpu(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
**kwargs,
):
raise NotImplementedError
def forward_hpu(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
def forward_tpu(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
if current_platform.is_tpu():
forward_native = forward_tpu
elif current_platform.is_cpu():
forward_native = forward_cpu
else:
forward_native = forward_cuda
class EPMoE(FusedMoE): class EPMoE(FusedMoE):
""" """
dp+ep MoE Expert Parallel Impl dp+ep MoE Expert Parallel Impl
...@@ -46,7 +157,7 @@ class EPMoE(FusedMoE): ...@@ -46,7 +157,7 @@ class EPMoE(FusedMoE):
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
moe_permute_fusion: bool = False, moe_permute_fusion: bool = True,
moe_shared_expert_overlap: bool = False moe_shared_expert_overlap: bool = False
): ):
super().__init__(num_experts, top_k, hidden_size, super().__init__(num_experts, top_k, hidden_size,
...@@ -68,7 +179,9 @@ class EPMoE(FusedMoE): ...@@ -68,7 +179,9 @@ class EPMoE(FusedMoE):
moe_permute_fusion=moe_permute_fusion, moe_permute_fusion=moe_permute_fusion,
moe_shared_expert_overlap=moe_shared_expert_overlap, moe_shared_expert_overlap=moe_shared_expert_overlap,
ep_size=self.ep_size, ep_size=self.ep_size,
num_moe_experts=self.global_num_experts num_moe_experts=self.global_num_experts,
routed_scaling_factor=self.routed_scaling_factor,
apply_router_weight_on_input=self.apply_router_weight_on_input
) )
local_expert_indices_offset = ( local_expert_indices_offset = (
...@@ -78,148 +191,40 @@ class EPMoE(FusedMoE): ...@@ -78,148 +191,40 @@ class EPMoE(FusedMoE):
local_expert_indices_offset + i for i in range(self.local_num_experts) local_expert_indices_offset + i for i in range(self.local_num_experts)
] ]
self.shared_experts = None
self.use_shared_expert = False self.use_shared_expert = False
self.token_dispatcher = MoEAlltoAllTokenDispatcher( self.token_dispatcher = MoEAlltoAllTokenDispatcher(
self.local_num_experts, self.local_expert_indices, config=self.ep_moe_config self.local_num_experts, self.local_expert_indices, config=self.ep_moe_config
) )
self.shared_expert_overlap = moe_shared_expert_overlap self.shared_expert_overlap = moe_shared_expert_overlap
self.seg_indptr = None self.shared_experts = None
if quant_config is None:
self.use_fp8_w8a8 = False
self.use_block_quant = False
self.block_shape = None
self.activation_scheme = None
self.w13_weight_scale = None
self.w2_weight_scale = None
else:
self.use_fp8_w8a8 = True
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.block_shape = (
self.quant_method.quant_config.weight_block_size
if self.use_block_quant
else None
)
self.fp8_dtype = torch.float8_e4m3fn
self.activation_scheme = quant_config.activation_scheme
def set_shared_experts(self, shared_experts): self.dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
def set_shared_experts(self, shared_experts: torch.nn.Module):
if self.shared_experts is None:
self.shared_experts = shared_experts self.shared_experts = shared_experts
self.use_shared_expert = shared_experts is not None
if self.shared_expert_overlap: if self.shared_expert_overlap:
self.token_dispatcher.set_shared_experts(shared_experts) self.token_dispatcher.set_shared_experts(self.shared_experts)
def triton_grouped_gemm_impl(self, hidden_states, tokens_per_expert, use_nn_moe):
torch.cumsum(tokens_per_expert,
dim=0,
out=self.seg_indptr[1:])
_, N, _ = self.w13_weight.shape
gateup_input = hidden_states
weight_indices_cur_rank = torch.arange(
0,
self.local_num_experts,
device=hidden_states.device,
dtype=torch.int64,
)
# GroupGemm-0
gateup_output = torch.empty(
gateup_input.shape[0],
self.w13_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
gateup_output = grouped_gemm_triton(
a=gateup_input,
b=self.w13_weight,
c=gateup_output,
batch_size=self.local_num_experts,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=self.use_fp8_w8a8,
scale_a=self.w13_input_scale if self.quant_config is not None else None,
scale_b=(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
) if self.quant_config is not None else None,
block_shape=self.block_shape,
)
# Act
down_input = torch.empty(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=(
self.fp8_dtype
if (self.use_fp8_w8a8 and not self.use_block_quant)
else hidden_states.dtype
),
)
if self.quant_config is not None and self.w2_input_scale is None and not self.use_block_quant:
self.w2_input_scale = torch.ones(
self.local_num_experts,
dtype=torch.float32,
device=hidden_states.device,
)
if self.activation == "silu":
torch.ops._C.silu_and_mul(down_input,
gateup_output.view(-1, N))
elif self.activation == "gelu":
torch.ops._C.gelu_and_mul(down_input,
gateup_output.view(-1, N))
else:
raise ValueError(f"Unsupported FusedMoe activation: {self.activation}")
# GroupGemm-1
down_output = torch.empty(
down_input.shape[0],
self.w2_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
down_output = grouped_gemm_triton(
a=down_input,
b=self.w2_weight,
c=down_output,
batch_size=self.local_num_experts,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=self.use_fp8_w8a8,
scale_a=self.w2_input_scale if self.quant_config is not None else None,
scale_b=(
self.w2_weight_scale_inv
if self.use_block_quant
else self.w2_weight_scale
) if self.quant_config is not None else None,
block_shape=self.block_shape,
)
return down_output
def create_quant_method(self, moe, quant_config, prefix):
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None
quant_method = (UnquantizedEPGroupedGemmMethod(moe) if quant_config is None
else quant_config.get_quant_method(self, prefix))
assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase)
return quant_method
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): def forward(self, hidden_states: torch.Tensor,
if ( router_logits: torch.Tensor):
self.training return torch.ops.vllm.ep_moe_forward(hidden_states, router_logits,
and self.config.tensor_model_parallel_size > 1 self.layer_name)
and not self.config.sequence_parallel
):
raise ValueError(
"During training, performance may degrade if MoE and tensor parallelism"
"are enabled without also enabling sequence parallelism."
)
if self.seg_indptr is None: def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
self.seg_indptr = torch.zeros(self.local_num_experts+1, device=hidden_states. device, dtype=torch.int64)
# process MoE
def custom_forward(hidden_states, router_logits):
topk_weights, topk_ids = self.select_experts( topk_weights, topk_ids = self.select_experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
...@@ -234,20 +239,60 @@ class EPMoE(FusedMoE): ...@@ -234,20 +239,60 @@ class EPMoE(FusedMoE):
indices_type=torch.int64, indices_type=torch.int64,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate) use_fused_gate=self.use_fused_gate)
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
probs = torch.zeros_like(router_logits, dtype=topk_weights.dtype).scatter(1, topk_ids, topk_weights) probs = torch.zeros_like(router_logits, dtype=topk_weights.dtype).scatter(1, topk_ids, topk_weights)
routing_map = torch.zeros_like(router_logits).int().scatter(1, topk_ids, 1).bool() routing_map = torch.zeros_like(router_logits).int().scatter(1, topk_ids, 1).bool()
(dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation( (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
hidden_states, probs, routing_map hidden_states, probs, routing_map
) )
expert_output = self.triton_grouped_gemm_impl(dispatched_input, tokens_per_expert, self.use_nn_moe)
output = self.token_dispatcher.token_unpermutation(expert_output) # Matrix multiply.
if self.use_shared_expert and not self.shared_expert_overlap: expert_output = self.quant_method.apply(
layer=self,
hidden_states=dispatched_input,
tokens_per_expert=tokens_per_expert
)
final_hidden_states = self.token_dispatcher.token_unpermutation(expert_output)
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
# if shared_expert_overlap is True, the expert calculation happens in # if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations # the token_dispatcher to overlap communications and computations
output = output + self.shared_experts(hidden_states) shared_output = (
return output self.maybe_all_reduce_tensor_model_parallel(
shared_output))
output = custom_forward(hidden_states, router_logits) if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
return output return final_hidden_states
\ No newline at end of file
def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None
return self.forward_impl(hidden_states, router_logits)
def ep_moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor:
return torch.empty_like(hidden_states)
direct_register_custom_op(
op_name="ep_moe_forward",
op_func=ep_moe_forward,
mutates_args=["hidden_states"],
fake_impl=ep_moe_forward_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
\ No newline at end of file
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
...@@ -21,6 +22,9 @@ from vllm.distributed import (tensor_model_parallel_all_gather, ...@@ -21,6 +22,9 @@ from vllm.distributed import (tensor_model_parallel_all_gather,
expert_parallel_gather) expert_parallel_gather)
from vllm.platforms import current_platform from vllm.platforms import current_platform
cuda_dtoh_stream = torch.cuda.Stream()
class MoETokenDispatcher: class MoETokenDispatcher:
""" """
MoE Token Dispatcher MoE Token Dispatcher
...@@ -31,7 +35,6 @@ class MoETokenDispatcher: ...@@ -31,7 +35,6 @@ class MoETokenDispatcher:
Initialize the MoE Token Dispatcher. Initialize the MoE Token Dispatcher.
""" """
self.config = config self.config = config
self.shared_experts: Optional[EPSharedExperts] = None
self.tp_size = 1 self.tp_size = 1
self.ep_size = config.ep_size self.ep_size = config.ep_size
...@@ -162,13 +165,14 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -162,13 +165,14 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
"no_sync": 4, "no_sync": 4,
} }
self.cuda_dtoh_point = "before_permutation_1" self.cuda_dtoh_point = "before_permutation_1"
self.cuda_dtoh_stream = torch.cuda.Stream() #self.cuda_dtoh_stream = torch.cuda.Stream()
self.shared_experts = None
# Whether to use gather or all-gather to gather the logits. # Whether to use gather or all-gather to gather the logits.
self.use_all_gather = current_platform.use_all_gather() self.use_all_gather = current_platform.use_all_gather()
self.probs = None
self.dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor: def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor:
""" """
Preprocess token routing map for AlltoAll communication and token permutation. Preprocess token routing map for AlltoAll communication and token permutation.
...@@ -264,7 +268,9 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -264,7 +268,9 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
return num_tokens_per_local_expert return num_tokens_per_local_expert
def token_permutation( def token_permutation(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor self, hidden_states: torch.Tensor,
probs: torch.Tensor,
routing_map: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Dispatch tokens to local experts using AlltoAll communication. Dispatch tokens to local experts using AlltoAll communication.
...@@ -287,6 +293,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -287,6 +293,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
""" """
# Preprocess: Get the metadata for communication, permutation and computation operations. # Preprocess: Get the metadata for communication, permutation and computation operations.
self.hidden_shape = hidden_states.shape self.hidden_shape = hidden_states.shape
if self.config.apply_router_weight_on_input:
self.probs = probs self.probs = probs
self.routing_map = routing_map self.routing_map = routing_map
assert probs.dim() == 2, "Expected 2D tensor for probs" assert probs.dim() == 2, "Expected 2D tensor for probs"
...@@ -295,50 +302,32 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -295,50 +302,32 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
tokens_per_expert = self.preprocess(self.routing_map) tokens_per_expert = self.preprocess(self.routing_map)
if self.shared_experts is not None: if self.config.moe_shared_expert_overlap and self.shared_experts is not None:
self.shared_experts.pre_forward_comm(hidden_states.view(self.hidden_shape)) self.shared_experts.pre_forward_comm(hidden_states.view(self.hidden_shape))
import sys
# torch.cuda.synchronize()
# sys.stderr.write(f"token_permutation===============================================")
# sys.stderr.flush()
# Permutation 1: input to AlltoAll input # Permutation 1: input to AlltoAll input
tokens_per_expert = self._maybe_dtoh_and_synchronize( tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_permutation_1", tokens_per_expert "before_permutation_1", tokens_per_expert
) )
# torch.cuda.synchronize()
# sys.stderr.write(f"before permute===============================================")
# sys.stderr.flush()
self.hidden_shape_before_permute = hidden_states.shape self.hidden_shape_before_permute = hidden_states.shape
permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute( permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute(
hidden_states, hidden_states,
routing_map, routing_map,
num_out_tokens=self.num_out_tokens, num_out_tokens=self.num_out_tokens,
fused=self.config.moe_permute_fusion, fused=self.config.moe_permute_fusion
drop_and_pad=False,
) )
# torch.cuda.synchronize()
# sys.stderr.write(f"after permute===============================================")
# sys.stderr.flush()
# Perform expert parallel AlltoAll communication # Perform expert parallel AlltoAll communication
tokens_per_expert = self._maybe_dtoh_and_synchronize( tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_ep_alltoall", tokens_per_expert "before_ep_alltoall", tokens_per_expert
) )
#torch.cuda.synchronize()
#print("###########################before permutation all_to_all output_splits:{} input_splits:{}".format(self.output_splits, self.input_splits))
global_input_tokens = all_to_all( global_input_tokens = all_to_all(
self.ep_group.device_group, permutated_local_input_tokens, self.output_splits, self.input_splits self.ep_group.device_group, permutated_local_input_tokens, self.output_splits, self.input_splits
) )
#torch.cuda.synchronize()
#print("#######################permutation all_to_all end")
if self.shared_experts is not None: if self.config.moe_shared_expert_overlap and self.shared_experts is not None:
self.shared_experts.linear_fc1_forward_and_act(global_input_tokens) self.shared_experts.linear_fc1_forward_and_act(global_input_tokens)
# Permutation 2: Sort tokens by local expert. # Permutation 2: Sort tokens by local expert.
...@@ -358,7 +347,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -358,7 +347,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
return global_input_tokens, tokens_per_expert return global_input_tokens, tokens_per_expert
def token_unpermutation( def token_unpermutation(
self, hidden_states: torch.Tensor self, hidden_states: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
""" """
Reverse the token permutation to restore the original order. Reverse the token permutation to restore the original order.
...@@ -392,7 +381,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -392,7 +381,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.ep_group.device_group, hidden_states, self.input_splits, self.output_splits self.ep_group.device_group, hidden_states, self.input_splits, self.output_splits
) )
if self.shared_experts is not None: if self.config.moe_shared_expert_overlap and self.shared_experts is not None:
self.shared_experts.linear_fc2_forward(permutated_local_input_tokens) self.shared_experts.linear_fc2_forward(permutated_local_input_tokens)
self.shared_experts.post_forward_comm() self.shared_experts.post_forward_comm()
...@@ -404,16 +393,22 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -404,16 +393,22 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
probs=self.probs, probs=self.probs,
routing_map=self.routing_map, routing_map=self.routing_map,
fused=self.config.moe_permute_fusion, fused=self.config.moe_permute_fusion,
drop_and_pad=False,
) )
# Reshape the output tensor # Reshape the output tensor
output = output.view(self.hidden_shape) output = output.view(self.hidden_shape)
# Add shared experts output # Add shared experts output
if self.shared_experts is not None: if self.config.moe_shared_expert_overlap and self.shared_experts is not None:
shared_expert_output = self.shared_experts.get_output() shared_output = self.shared_experts.get_output()
output += shared_expert_output
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
output = output + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
output = output + shared_output \
* (1. / self.config.routed_scaling_factor)
return output return output
def _maybe_update_cuda_sync_point(self, point: str): def _maybe_update_cuda_sync_point(self, point: str):
...@@ -435,10 +430,10 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -435,10 +430,10 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
""" """
if point == self.cuda_dtoh_point: if point == self.cuda_dtoh_point:
# Move all possible GPU tensors to CPU at self.cuda_dtoh_point. # Move all possible GPU tensors to CPU at self.cuda_dtoh_point.
on_side_stream = torch.cuda.current_stream() != self.cuda_dtoh_stream on_side_stream = torch.cuda.current_stream() != cuda_dtoh_stream
if on_side_stream: if on_side_stream:
self.cuda_dtoh_stream.wait_stream(torch.cuda.current_stream()) cuda_dtoh_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.cuda_dtoh_stream): with torch.cuda.stream(cuda_dtoh_stream):
# TODO: use MemcpyBatchAsync instead. # TODO: use MemcpyBatchAsync instead.
# tokens_per_expert = maybe_move_tensor_to_cpu( # tokens_per_expert = maybe_move_tensor_to_cpu(
# tokens_per_expert, record_stream=on_side_stream # tokens_per_expert, record_stream=on_side_stream
...@@ -462,6 +457,6 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -462,6 +457,6 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
if point == self.cuda_sync_point: if point == self.cuda_sync_point:
# Synchronize with the dtoh stream at self.cuda_sync_point. # Synchronize with the dtoh stream at self.cuda_sync_point.
self.cuda_dtoh_stream.synchronize() cuda_dtoh_stream.synchronize()
return tokens_per_expert return tokens_per_expert
\ No newline at end of file
...@@ -772,20 +772,12 @@ class FusedMoE(torch.nn.Module): ...@@ -772,20 +772,12 @@ class FusedMoE(torch.nn.Module):
self.moe_config = moe self.moe_config = moe
self.quant_config = quant_config self.quant_config = quant_config
# Note: get_quant_method will look at the layer's local_num_experts self.quant_method = self.create_quant_method(moe, quant_config, prefix)
# for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None
quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
else quant_config.get_quant_method(self, prefix))
assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method
if self.enable_eplb: if self.enable_eplb:
from vllm.model_executor.layers.quantization.fp8 import ( from vllm.model_executor.layers.quantization.fp8 import (
Fp8MoEMethod) Fp8MoEMethod)
if not isinstance(quant_method, Fp8MoEMethod): if not isinstance(self.quant_method, Fp8MoEMethod):
# TODO: Add support for additional quantization methods. # TODO: Add support for additional quantization methods.
# The implementation for other quantization methods does not # The implementation for other quantization methods does not
# contain essential differences, but the current quant API # contain essential differences, but the current quant API
...@@ -852,6 +844,17 @@ class FusedMoE(torch.nn.Module): ...@@ -852,6 +844,17 @@ class FusedMoE(torch.nn.Module):
dtype=moe.in_dtype, dtype=moe.in_dtype,
device=torch.cuda.current_device()) device=torch.cuda.current_device())
def create_quant_method(self, moe, quant_config, prefix):
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None
quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
else quant_config.get_quant_method(self, prefix))
assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase)
return quant_method
@property @property
def tp_size(self): def tp_size(self):
return self.moe_parallel_config.tp_size return self.moe_parallel_config.tp_size
......
...@@ -156,23 +156,9 @@ class DeepseekV2MoE(nn.Module): ...@@ -156,23 +156,9 @@ class DeepseekV2MoE(nn.Module):
dp_size = get_dp_group().world_size dp_size = get_dp_group().world_size
self.use_ep_opt = dp_size > 1 and parallel_config.enable_expert_parallel self.use_ep_opt = dp_size > 1 and parallel_config.enable_expert_parallel
self.shared_experts = None
if config.n_shared_experts is not None: moe_cls = FusedMoE if not self.use_ep_opt else EPMoE
intermediate_size = (config.moe_intermediate_size * self.experts = moe_cls(
config.n_shared_experts)
shared_expert_cls = DeepseekV2MLP if not self.use_ep_opt else EPSharedExperts
self.shared_experts = shared_expert_cls(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
prefix=f"{prefix}.shared_experts",
)
if not self.use_ep_opt:
self.experts = FusedMoE(
num_experts=config.n_routed_experts, num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
...@@ -186,29 +172,24 @@ class DeepseekV2MoE(nn.Module): ...@@ -186,29 +172,24 @@ class DeepseekV2MoE(nn.Module):
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
scoring_func=config.scoring_func, scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias, e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor) routed_scaling_factor=self.routed_scaling_factor)
else:
self.experts = EPMoE( if config.n_shared_experts is not None:
num_experts=config.n_routed_experts, intermediate_size = (config.moe_intermediate_size *
top_k=config.num_experts_per_tok, config.n_shared_experts)
shared_expert_cls = DeepseekV2MLP if not self.use_ep_opt else EPSharedExperts
self.shared_experts = shared_expert_cls(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=intermediate_size,
reduce_results=False, hidden_act=config.hidden_act,
renormalize=config.norm_topk_prob,
quant_config=quant_config, quant_config=quant_config,
use_grouped_topk=True, reduce_results=self.experts.must_reduce_shared_expert_outputs(
num_expert_group=config.n_group, ),
topk_group=config.topk_group, prefix=f"{prefix}.shared_experts",
prefix=f"{prefix}.experts", )
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor)
if self.use_ep_opt:
self.experts.set_shared_experts(self.shared_experts) self.experts.set_shared_experts(self.shared_experts)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce self.tbo_all_reduce = tbo_all_reduce
...@@ -218,10 +199,11 @@ class DeepseekV2MoE(nn.Module): ...@@ -218,10 +199,11 @@ class DeepseekV2MoE(nn.Module):
if not self.use_ep_opt: if not self.use_ep_opt:
if self.n_shared_experts is not None: if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick: if not self.use_ep_opt:
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor router_logits=router_logits) * self.routed_scaling_factor
...@@ -230,6 +212,9 @@ class DeepseekV2MoE(nn.Module): ...@@ -230,6 +212,9 @@ class DeepseekV2MoE(nn.Module):
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
else:
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if not self.use_ep_opt: if not self.use_ep_opt:
if shared_output is not None: if shared_output is not None:
...@@ -745,9 +730,7 @@ class DeepseekV2Model(nn.Module): ...@@ -745,9 +730,7 @@ class DeepseekV2Model(nn.Module):
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(positions, hidden_states, residual)\ hidden_states, residual = layer(positions, hidden_states, residual)
#ops.print_tensor(hidden_states)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
...@@ -816,6 +799,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -816,6 +799,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
self.tritonsingleton.topk = config.num_experts_per_tok self.tritonsingleton.topk = config.num_experts_per_tok
self.tritonsingleton.quant_method=self.quant_method self.tritonsingleton.quant_method=self.quant_method
parallel_config = vllm_config.parallel_config
dp_size = get_dp_group().world_size
self.use_ep_opt = dp_size > 1 and parallel_config.enable_expert_parallel
def set_eplb_state( def set_eplb_state(
self, self,
expert_load_view: torch.Tensor, expert_load_view: torch.Tensor,
...@@ -897,6 +884,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -897,6 +884,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
if self.use_ep_opt:
ep_moe_shared_experts_keys = "mlp.shared_experts"
ep_moe_shared_experts_mapping = {ep_moe_shared_experts_keys:"mlp.experts.shared_experts"}
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
...@@ -929,6 +920,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -929,6 +920,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
if (("mlp.experts." in name) and name not in params_dict): if (("mlp.experts." in name) and name not in params_dict):
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
if self.use_ep_opt:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
...@@ -955,6 +950,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -955,6 +950,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# Instead, create a new variable # Instead, create a new variable
name_mapped = name.replace(weight_name, param_name) name_mapped = name.replace(weight_name, param_name)
if self.use_ep_opt:
name_mapped = name_mapped.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
if is_pp_missing_parameter(name_mapped, self): if is_pp_missing_parameter(name_mapped, self):
continue continue
...@@ -980,6 +978,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -980,6 +978,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# So we simply skip it # So we simply skip it
continue continue
if self.use_ep_opt:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
......
...@@ -2052,7 +2052,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2052,7 +2052,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
inputs_embeds = self.inputs_embeds[:num_tokens] inputs_embeds = self.inputs_embeds[:num_tokens]
else: else:
#self.input_ids[:num_tokens] = torch.randint(0, 120000, (num_tokens,), dtype=torch.int32) #self.input_ids[:num_tokens] = torch.randint(0, 120000, (num_tokens,), dtype=torch.int32)
self.input_ids[:num_tokens] = torch.arange(num_tokens, dtype=torch.int32, device=self.input_ids.device) #self.input_ids[:num_tokens] = torch.arange(num_tokens, dtype=torch.int32, device=self.input_ids.device)
input_ids = self.input_ids[:num_tokens] input_ids = self.input_ids[:num_tokens]
inputs_embeds = None inputs_embeds = None
if self.uses_mrope: if self.uses_mrope:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment