Commit d698d6f2 authored by 王敏's avatar 王敏
Browse files

[feat]整合mori和deepep相关代码

parent 7293a072
File added
...@@ -171,16 +171,16 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): ...@@ -171,16 +171,16 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_qps_per_rank = None num_qps_per_rank = None
if self.internode: if self.internode:
num_rdma_bytes = int(1e9/2)#1024 * 1024 * 1024 num_rdma_bytes = int(1e9/2) #1024 * 1024 * 1024
num_qps_per_rank = 30#self.num_sms // 2 num_qps_per_rank = 30 #self.num_sms // 2
import deep_ep # import deep_ep
num_nvl_bytes, num_rdma_bytes = 0, 0 # num_nvl_bytes, num_rdma_bytes = 0, 0
hidden_size = 7168 # hidden_size = 7168
hidden_bytes = hidden_size * 2 # hidden_bytes = hidden_size * 2
for config in (deep_ep.Buffer.get_dispatch_config(self.cpu_group.size()), deep_ep.Buffer.get_combine_config(self.cpu_group.size())): # for config in (deep_ep.Buffer.get_dispatch_config(self.cpu_group.size()), deep_ep.Buffer.get_combine_config(self.cpu_group.size())):
num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, self.cpu_group.size()), num_nvl_bytes) # num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, self.cpu_group.size()), num_nvl_bytes)
num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, self.cpu_group.size()), num_rdma_bytes) # num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, self.cpu_group.size()), num_rdma_bytes)
else: else:
num_rdma_bytes = 0 num_rdma_bytes = 0
num_qps_per_rank = 1 num_qps_per_rank = 1
......
...@@ -175,6 +175,7 @@ if TYPE_CHECKING: ...@@ -175,6 +175,7 @@ if TYPE_CHECKING:
USE_FUSED_SILU_MUL_QUANT: bool = False USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_P2P_ASYNC: bool = False VLLM_P2P_ASYNC: bool = False
VLLM_P2P_BUF_TOKENS: int = 30000 VLLM_P2P_BUF_TOKENS: int = 30000
VLLM_ENABLE_MOE_GROUP_GEMM: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1151,6 +1152,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1151,6 +1152,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# pd separation p2p async buf tokens # pd separation p2p async buf tokens
"VLLM_P2P_BUF_TOKENS": "VLLM_P2P_BUF_TOKENS":
lambda: int(os.getenv("VLLM_P2P_BUF_TOKENS", "30000")), lambda: int(os.getenv("VLLM_P2P_BUF_TOKENS", "30000")),
# pd separation p2p async buf tokens
"VLLM_ENABLE_MOE_GROUP_GEMM":
lambda: (os.environ.get("VLLM_ENABLE_MOE_GROUP_GEMM", "False").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
import os
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.distributed.parallel_state import (get_dp_group,
get_tp_group,
get_ep_group,
get_tensor_model_parallel_rank)
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import (EPSharedExperts,
maybe_move_tensor_to_cpu,
maybe_move_tensor_to_cpu_block,
permute,
sort_chunks_by_idxs,
unpermute,
all_to_all,
EpMoeConfig)
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_gather,
expert_parallel_all_gather,
expert_parallel_gather)
from vllm.platforms import current_platform
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import direct_register_custom_op
from vllm.config import get_current_vllm_config
from lightop import groupgemm_permute, groupgemm_unpermute
cuda_dtoh_stream = torch.cuda.Stream()
cuda_dtoh_sync_event = torch.cuda.Event(enable_timing=False)
class MoETokenDispatcher(nn.Module):
"""
MoE Token Dispatcher
"""
def __init__(self, config: EpMoeConfig) -> None:
"""
Initialize the MoE Token Dispatcher.
"""
super().__init__()
self.config = config
self.tp_size = 1
self.ep_size = config.ep_size
@property
def ep_group(self):
"""Get expert model parallel group."""
return get_ep_group()
@property
def tp_group(self):
"""Get expert tensor parallel group."""
return get_tp_group()
@property
def tp_rank(self):
"""Get expert tensor parallel rank."""
return 0#get_tensor_model_parallel_rank()
@property
def tp_ep_group(self):
"""Get expert tensor and model parallel group."""
return get_ep_group()
@abstractmethod
def token_permutation(
self, tokens: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
):
"""Dispatch tokens to experts.
Args:
tokens (torch.Tensor): Input tokens.
probs (torch.Tensor): The routing probability tensor [num_tokens, num_experts].
routing_map (torch.Tensor): Token to expert mapping tensor.
Returns:
torch.Tensor: Tokens tensor.
"""
raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod
def token_unpermutation(self, expert_output: torch.Tensor, bias: torch.Tensor = None):
"""Restores the expert output to its original ordering.
Args:
expert_output (torch.Tensor): The output tensor from the expert models.
bias (torch.Tensor): The bias tensor.
Returns:
(torch.Tensor, torch.Tensor): Unpermuted activation and optional bias.
"""
raise NotImplementedError("Restore function not implemented.")
def set_shared_experts(self, shared_experts):
"""Set shared expert to the dispatcher."""
assert self.config.moe_shared_expert_overlap
self.shared_experts = shared_experts
class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
"""
AlltoAll-based token dispatcher.
The workflow of AlltoAll token dispatcher is as follows:
(1) preprocess(): calculate necessary metadata for communication and permute
(2) token_permutation(): permute->A2A(EP)->AG(TP)->sort_chunk(if num_local_experts>1)
(3) token_unpermutation(): sort_chunk(if num_local_experts>1)->RS(TP)->A2A(EP)->unpermute
"""
def __init__(
self, num_local_experts: int, local_expert_indices: List[int], config: EpMoeConfig, layer_name: str=""
) -> None:
"""
Initialize the AlltoAll token dispatcher.
Args:
num_local_experts (int): Number of local experts on the current device.
local_expert_indices (List[int]): Indices of local experts on the current device.
config (TransformerConfig): Configuration for the transformer model.
"""
super().__init__(config=config)
self.num_local_experts = num_local_experts
assert config.num_moe_experts is not None
self.num_experts = config.num_moe_experts
assert self.num_local_experts > 0, "Expected at least one expert"
self.local_expert_indices = local_expert_indices
assert (
len(self.local_expert_indices) == self.num_local_experts
), "Invalid local expert indices"
for i in range(len(self.local_expert_indices) - 1):
assert (
self.local_expert_indices[i] == self.local_expert_indices[i + 1] - 1
), "local_expert_indices must be continous"
self.layer_name = layer_name
# [ep_size]. Represents the number of tokens sent by the current rank to other
# EP ranks.
self.input_splits = None
# [ep_size]. Represents the number of tokens received by the current rank from
# other EP ranks.
self.output_splits = None
# [tp_size]. Represents the number of tokens received by the current rank from
# other TP ranks.
#self.output_splits_tp = None
self.permute_idx_device = torch.device("cuda") if self.config.moe_permute_fusion else None
input_chunk_idxs = torch.arange(
self.num_experts * self.tp_size, device=self.permute_idx_device
)
# [num_local_experts, tp_size * ep_size]. Sort the input chunks by local experts.
self.sort_input_by_local_experts = input_chunk_idxs.reshape(
-1, self.num_local_experts
).T.ravel()
# [tp_size * ep_size, num_local_experts]. Restore the output chunks by local experts.
self.restore_output_by_local_experts = input_chunk_idxs.reshape(
self.num_local_experts, -1
).T.ravel()
# A cuda stream synchronization is needed in self.token_permutation() in some cases,
# because there are several non-blocking DtoH data transfers called at
# `self.cuda_dtoh_point`. The synchronization happens at `self.cuda_sync_point`, which is
# decided based on the MoE and parallel settings. Valid points are "before_permutation_1",
# "before_ep_alltoall", "before_permutation_2", "before_finish", and "no_sync".
self.cuda_sync_point = "no_sync"
self.cuda_sync_point_priority = {
"before_permutation_1": 0,
"before_ep_alltoall": 1,
"before_permutation_2": 2,
"before_finish": 3,
"no_sync": 4,
}
self.cuda_dtoh_point = "before_permutation_1"
#self.cuda_dtoh_stream = torch.cuda.Stream()
# Whether to use gather or all-gather to gather the logits.
self.use_all_gather = current_platform.use_all_gather()
self.probs = None
# For smuggling this layer into the fused moe custom op
vllm_config = get_current_vllm_config()
compilation_config = vllm_config.compilation_config
if layer_name in compilation_config.static_forward_context:
raise ValueError("Duplicate layer name: {}".format(layer_name))
compilation_config.static_forward_context[layer_name] = self
def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor:
"""
Preprocess token routing map for AlltoAll communication and token permutation.
This method computes the number of tokens assigned to each expert based on the routing_map.
It also initializes the necessary data structures for AlltoAll communication, such as input
and output splits, and the mapping between global tokens and local experts. This method
should not call any DtoH data copying due to performance consideration. The necessary DtoH
copies are made on the `self.cuda_dtoh_stream` at `self.cuda_dtoh_point`.
Args:
routing_map (torch.Tensor): The mapping of tokens to experts, with shape
[num_tokens, num_experts].
Returns:
torch.Tensor: Tensor containing the number of tokens assigned to local expert.
"""
# [num_experts], number of tokens assigned to each expert from the current rank's input.
num_local_tokens_per_expert = routing_map.sum(dim=0).long()
self.num_out_tokens = routing_map.size(0) * self.config.moe_router_topk
# ===================================================
# Calculate input_splits, output_splits for alltoall/allgather in variable size.
# ===================================================
# [ep_size]. Represents the number of tokens sent by the current rank to other
# EP ranks.
self.input_splits = num_local_tokens_per_expert.reshape(
self.ep_size, self.num_local_experts
).sum(axis=1)
# Gather the global distribution of tokens across ranks.
# num_global_tokens_per_expert represents the number of tokens sent to each
# expert by all ranks.
# [tp_size, ep_size, num_experts]
if self.use_all_gather:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
num_global_tokens_per_expert = expert_parallel_all_gather(num_local_tokens_per_expert) \
.reshape(self.ep_size, self.tp_size, self.num_experts) \
.transpose(0, 1)
else:
# None may be returned for rank > 0
num_global_tokens_per_expert = expert_parallel_gather(num_local_tokens_per_expert) \
.reshape(self.ep_size, self.tp_size, self.num_experts) \
.transpose(0, 1)
# [tp_size, ep_size, num_experts] -> [tp_size, ep_size, num_local_experts]
num_global_tokens_per_local_expert = num_global_tokens_per_expert[
:, :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1
].contiguous()
# [tp_size, ep_size, num_local_experts] -> [tp_size, ep_size]
num_global_tokens_per_rank = num_global_tokens_per_local_expert.sum(axis=2)
# [tp_size, ep_size] -> [ep_size]
# self.output_splits represents the number of tokens received by the current rank
# from other EP rank.
self.output_splits = num_global_tokens_per_rank[self.tp_rank]
# [tp_size, ep_size] -> [tp_size]
# self.output_splits_tp represents the number of tokens received by the current
# rank from other TP rank.
#self.output_splits_tp = num_global_tokens_per_rank.sum(axis=1)
# [tp_size, ep_size, num_local_experts] -> [num_local_experts]
num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(dim=(0, 1))
# A synchronization is needed before expert parallel AlltoAll communication
# to get the `input_splits` and `output_splits` CPU values.
#self._maybe_update_cuda_sync_point("before_ep_alltoall")
if self.num_local_experts > 1:
# [tp_size * ep_size, num_local_experts]. Represents the number of tokens sent
# to each local expert by all ranks.
self.num_global_tokens_per_local_expert = num_global_tokens_per_local_expert.view(
-1, self.num_local_experts
)
# if not self.config.moe_permute_fusion:
# # A synchronization is needed before permutation 2
# # to get the `num_global_tokens_per_local_expert` CPU value.
# self._maybe_update_cuda_sync_point("before_permutation_2")
# assert (
# self.cuda_sync_point_priority[self.cuda_dtoh_point]
# <= self.cuda_sync_point_priority[self.cuda_sync_point]
# ), "cuda_sync_point must be after cuda_dtoh_point."
return num_tokens_per_local_expert
def token_permutation(
self, hidden_states: torch.Tensor,
probs: torch.Tensor,
routing_map: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.routing_map = routing_map
assert routing_map.dim() == 2, "Expected 2D tensor for token2expert mask"
assert routing_map.dtype == torch.bool, "Expected bool tensor for mask"
tokens_per_expert = self.preprocess(self.routing_map)
if self.config.moe_shared_expert_overlap and self.shared_experts is not None:
self.shared_experts.pre_forward_comm(hidden_states.view(self.hidden_shape))
global_input_tokens = torch.ops.vllm.token_permutation_forward(tokens_per_expert, hidden_states,
probs, routing_map, self.layer_name)
return global_input_tokens, tokens_per_expert
def token_permutation_impl(
self,
tokens_per_expert: torch.Tensor,
hidden_states: torch.Tensor,
probs: torch.Tensor,
routing_map: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Dispatch tokens to local experts using AlltoAll communication.
This method performs the following steps:
1. Preprocess the routing map to get metadata for communication and permutation.
2. Permute input tokens for AlltoAll communication.
3. Perform expert parallel AlltoAll communication.
4. Sort tokens by local expert (if multiple local experts exist).
Args:
hidden_states (torch.Tensor): Input token embeddings.
probs (torch.Tensor): The probabilities of token to experts assignment.
routing_map (torch.Tensor): The mapping of token to experts assignment.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- Permuted token embeddings for local experts.
- Number of tokens per expert.
"""
# Preprocess: Get the metadata for communication, permutation and computation operations.
# Permutation 1: input to AlltoAll input
tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_permutation_1", tokens_per_expert
)
self.hidden_shape = hidden_states.shape
if self.config.apply_router_weight_on_input:
self.probs = probs
assert probs.dim() == 2, "Expected 2D tensor for probs"
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
self.hidden_shape_before_permute = hidden_states.shape
if False:
permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute(
hidden_states,
routing_map,
num_out_tokens=self.num_out_tokens,
fused=self.config.moe_permute_fusion
)
else:
cuda_permute_result = groupgemm_permute(hidden_states, routing_map)
permutated_local_input_tokens, self.reversed_local_input_permutation_mapping, \
self.expert_m_count = cuda_permute_result
# Perform expert parallel AlltoAll communication
# tokens_per_expert = self._maybe_dtoh_and_synchronize(
# "before_ep_alltoall", tokens_per_expert
# )
###test##############
#cuda_dtoh_stream.synchronize()
cuda_dtoh_sync_event.synchronize()
###test##############
global_input_tokens = all_to_all(
self.ep_group.device_group, permutated_local_input_tokens, self.output_splits, self.input_splits
)
if self.config.moe_shared_expert_overlap and self.shared_experts is not None:
self.shared_experts.linear_fc1_forward_and_act(global_input_tokens)
# Permutation 2: Sort tokens by local expert.
# tokens_per_expert = self._maybe_dtoh_and_synchronize(
# "before_permutation_2", tokens_per_expert
# )
if self.num_local_experts > 1:
global_input_tokens = sort_chunks_by_idxs(
global_input_tokens,
self.num_global_tokens_per_local_expert.ravel(),
self.sort_input_by_local_experts,
fused=self.config.moe_permute_fusion,
)
#tokens_per_expert = self._maybe_dtoh_and_synchronize("before_finish", tokens_per_expert)
return global_input_tokens
def token_unpermutation(
self, hidden_states: torch.Tensor,
) -> torch.Tensor:
return torch.ops.vllm.token_unpermutation_forward(hidden_states, self.layer_name)
def token_unpermutation_impl(
self, hidden_states: torch.Tensor,
) -> torch.Tensor:
"""
Reverse the token permutation to restore the original order.
This method performs the following steps:
1. Unsort tokens by local expert (if multiple local experts exist).
2. Perform expert parallel AlltoAll communication to restore the original order.
3. Unpermute tokens to restore the original order.
Args:
hidden_states (torch.Tensor): Output from local experts.
bias (torch.Tensor, optional): Bias tensor (not supported).
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- Unpermuted token embeddings in the original order.
- None (bias is not supported).
"""
# Unpermutation 2: Unsort tokens by local expert.
if self.num_local_experts > 1:
hidden_states = sort_chunks_by_idxs(
hidden_states,
self.num_global_tokens_per_local_expert.T.ravel(),
self.restore_output_by_local_experts,
fused=self.config.moe_permute_fusion,
)
# Perform expert parallel AlltoAll communication
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
permutated_local_input_tokens = all_to_all(
self.ep_group.device_group, hidden_states, self.input_splits, self.output_splits
)
if self.config.moe_shared_expert_overlap and self.shared_experts is not None:
self.shared_experts.linear_fc2_forward(permutated_local_input_tokens)
self.shared_experts.post_forward_comm()
# Unpermutation 1: AlltoAll output to output
if False:
output = unpermute(
permutated_local_input_tokens,
self.reversed_local_input_permutation_mapping,
restore_shape=self.hidden_shape_before_permute,
probs=self.probs,
routing_map=self.routing_map,
fused=self.config.moe_permute_fusion,
)
else:
output = groupgemm_unpermute(permutated_local_input_tokens,
self.reversed_local_input_permutation_mapping,
list(self.hidden_shape_before_permute),
self.probs,
self.routing_map,
self.expert_m_count)
# Reshape the output tensor
output = output.view(self.hidden_shape)
# Add shared experts output
if self.config.moe_shared_expert_overlap and self.shared_experts is not None:
shared_output = self.shared_experts.get_output()
if hidden_states.dtype != torch.float16:
output = output + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
output = output + shared_output \
* (1. / self.config.routed_scaling_factor)
return output
def _maybe_update_cuda_sync_point(self, point: str):
"""
Update the CUDA sync point if the priority of the new point is higher than the current
sync point, which means the new point is reached earlier than the current sync point.
"""
if (
self.cuda_sync_point_priority[point]
< self.cuda_sync_point_priority[self.cuda_sync_point]
):
self.cuda_sync_point = point
def _maybe_dtoh_and_synchronize(
self, point: str, tokens_per_expert: torch.Tensor = None
) -> torch.Tensor:
"""
Move all possible GPU tensors to CPU and make a synchronization at the expected point.
"""
if point == self.cuda_dtoh_point:
# Move all possible GPU tensors to CPU at self.cuda_dtoh_point.
on_side_stream = torch.cuda.current_stream() != cuda_dtoh_stream
if on_side_stream:
cuda_dtoh_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(cuda_dtoh_stream):
# TODO: use MemcpyBatchAsync instead.
# tokens_per_expert = maybe_move_tensor_to_cpu(
# tokens_per_expert, record_stream=on_side_stream
# )
self.input_splits = maybe_move_tensor_to_cpu(
self.input_splits, as_numpy=True, record_stream=on_side_stream
)
self.output_splits = maybe_move_tensor_to_cpu(
self.output_splits, as_numpy=True, record_stream=on_side_stream
)
# self.output_splits_tp = maybe_move_tensor_to_cpu(
# self.output_splits_tp, as_numpy=True, record_stream=on_side_stream
# )
self.num_out_tokens = maybe_move_tensor_to_cpu(
self.num_out_tokens, record_stream=on_side_stream
)
if self.num_local_experts > 1 and not self.config.moe_permute_fusion:
self.num_global_tokens_per_local_expert = maybe_move_tensor_to_cpu(
self.num_global_tokens_per_local_expert, record_stream=on_side_stream
)
cuda_dtoh_sync_event.record()
# if point == self.cuda_sync_point:
# # Synchronize with the dtoh stream at self.cuda_sync_point.
# cuda_dtoh_stream.synchronize()
return tokens_per_expert
def token_permutation_forward(tokens_per_expert: torch.Tensor,
hidden_states: torch.Tensor,
probs: torch.Tensor,
routing_map: torch.Tensor,
layer_name: str) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
return self.token_permutation_impl(tokens_per_expert, hidden_states, probs, routing_map)
def token_permutation_forward_fake(tokens_per_expert: torch.Tensor,
hidden_states: torch.Tensor,
probs: torch.Tensor,
routing_map: torch.Tensor,
layer_name: str) -> torch.Tensor:
return torch.empty_like(hidden_states)
direct_register_custom_op(
op_name="token_permutation_forward",
op_func=token_permutation_forward,
mutates_args=["tokens_per_expert", "hidden_states", "probs", "routing_map"],
fake_impl=token_permutation_forward_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
def token_unpermutation_forward(hidden_states: torch.Tensor,
layer_name: str) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
return self.token_unpermutation_impl(hidden_states)
def token_unpermutation_forward_fake(hidden_states: torch.Tensor,
layer_name: str) -> torch.Tensor:
return torch.empty_like(hidden_states)
direct_register_custom_op(
op_name="token_unpermutation_forward",
op_func=token_unpermutation_forward,
mutates_args=["hidden_states"],
fake_impl=token_unpermutation_forward_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
\ No newline at end of file
import os from typing import Callable, Optional
import logging
from typing import Callable, List, Optional, Tuple
from dataclasses import dataclass
from collections.abc import Iterable from collections.abc import Iterable
import torch import torch
import torch.nn.functional as F import torch.distributed as dist
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -18,10 +15,8 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -18,10 +15,8 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, UnquantizedFusedMoEMethod from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAlltoAllTokenDispatcher from vllm.model_executor.layers.fused_moe.mori_moe.ep_moe_utlis import EpMoeConfig
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EpMoeConfig
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
import torch.distributed as dist
try: try:
import mori import mori
...@@ -35,8 +30,8 @@ logger = init_logger(__name__) ...@@ -35,8 +30,8 @@ logger = init_logger(__name__)
_MORI_OP = None _MORI_OP = None
@CustomOp.register("unquantized_ep_moe") @CustomOp.register("unquantized_mori_moe")
class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): class UnquantizedMoriMoeMethod(UnquantizedFusedMoEMethod):
"""MoE method without quantization.""" """MoE method without quantization."""
def __init__(self, moe: FusedMoEConfig): def __init__(self, moe: FusedMoEConfig):
...@@ -44,9 +39,9 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -44,9 +39,9 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self.topk_indices_dtype = None self.topk_indices_dtype = None
self.moe = moe self.moe = moe
self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled() self.rocm_aiter_moe_enabled = False
def apply_ep( def apply_mori_ep(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -162,7 +157,7 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -162,7 +157,7 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
forward_native = forward_cuda forward_native = forward_cuda
class EPMoE(FusedMoE): class MoriMoE(FusedMoE):
""" """
dp+ep MoE Expert Parallel Impl dp+ep MoE Expert Parallel Impl
...@@ -194,7 +189,6 @@ class EPMoE(FusedMoE): ...@@ -194,7 +189,6 @@ class EPMoE(FusedMoE):
enable_eplb: bool = False, enable_eplb: bool = False,
num_redundant_experts: int = 0, num_redundant_experts: int = 0,
moe_permute_fusion: bool = False, moe_permute_fusion: bool = False,
moe_shared_expert_overlap: bool = False
): ):
super().__init__(num_experts, top_k, hidden_size, super().__init__(num_experts, top_k, hidden_size,
intermediate_size, params_dtype, intermediate_size, params_dtype,
...@@ -215,7 +209,6 @@ class EPMoE(FusedMoE): ...@@ -215,7 +209,6 @@ class EPMoE(FusedMoE):
moe_router_topk=self.top_k, moe_router_topk=self.top_k,
# TODO: support fusion permute # TODO: support fusion permute
moe_permute_fusion=moe_permute_fusion, moe_permute_fusion=moe_permute_fusion,
moe_shared_expert_overlap=moe_shared_expert_overlap,
ep_size=self.ep_size, ep_size=self.ep_size,
num_moe_experts=self.global_num_experts, num_moe_experts=self.global_num_experts,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
...@@ -229,20 +222,13 @@ class EPMoE(FusedMoE): ...@@ -229,20 +222,13 @@ class EPMoE(FusedMoE):
local_expert_indices_offset + i for i in range(self.local_num_experts) local_expert_indices_offset + i for i in range(self.local_num_experts)
] ]
self.use_shared_expert = False
# self.token_dispatcher = MoEAlltoAllTokenDispatcher(
# self.local_num_experts, self.local_expert_indices,
# config=self.ep_moe_config, layer_name=f"{self.layer_name}.token_dispatcher",
# )
self.shared_expert_overlap = moe_shared_expert_overlap
self.shared_experts = None self.shared_experts = None
self.scales = None self.scales = None
self.use_int8_dispatch = True self.use_int8_dispatch = True
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.max_num_inp_token_per_rank = 1024#vllm_config.scheduler_config.max_num_seqs self.max_num_inp_token_per_rank = 1024 #vllm_config.scheduler_config.max_num_seqs
self.mori_op = self.get_mori_op() self.mori_op = self.get_mori_op()
def get_mori_op(self): def get_mori_op(self):
...@@ -252,10 +238,6 @@ class EPMoE(FusedMoE): ...@@ -252,10 +238,6 @@ class EPMoE(FusedMoE):
assert world_group is not None assert world_group is not None
torch._C._distributed_c10d._register_process_group("mori_ep", get_ep_group().device_group) torch._C._distributed_c10d._register_process_group("mori_ep", get_ep_group().device_group)
mori.shmem.shmem_torch_process_group_init("mori_ep") mori.shmem.shmem_torch_process_group_init("mori_ep")
# world_group = torch.distributed.group.WORLD
# assert world_group is not None
# torch._C._distributed_c10d._register_process_group("default", world_group)
# mori.shmem.shmem_torch_process_group_init("default")
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
multi_node = self.ep_size / 8 > 1 multi_node = self.ep_size / 8 > 1
...@@ -278,7 +260,6 @@ class EPMoE(FusedMoE): ...@@ -278,7 +260,6 @@ class EPMoE(FusedMoE):
max_token_type_size=2, max_token_type_size=2,
block_num=80, block_num=80,
warp_num_per_block=4, warp_num_per_block=4,
# kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode
kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode if multi_node else \ kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode if multi_node else \
mori.ops.EpDispatchCombineKernelType.IntraNode mori.ops.EpDispatchCombineKernelType.IntraNode
) )
...@@ -290,14 +271,11 @@ class EPMoE(FusedMoE): ...@@ -290,14 +271,11 @@ class EPMoE(FusedMoE):
if self.shared_experts is None: if self.shared_experts is None:
self.shared_experts = shared_experts self.shared_experts = shared_experts
# if self.shared_expert_overlap:
# self.token_dispatcher.set_shared_experts(self.shared_experts)
def create_quant_method(self, moe, quant_config, prefix): def create_quant_method(self, moe, quant_config, prefix):
# Note: get_quant_method will look at the layer's local_num_experts # Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first. # for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None quant_method: Optional[QuantizeMethodBase] = None
quant_method = (UnquantizedEPGroupedGemmMethod(moe) if quant_config is None quant_method = (UnquantizedMoriMoeMethod(moe) if quant_config is None
else quant_config.get_quant_method(self, prefix)) else quant_config.get_quant_method(self, prefix))
assert quant_method is not None assert quant_method is not None
...@@ -310,7 +288,7 @@ class EPMoE(FusedMoE): ...@@ -310,7 +288,7 @@ class EPMoE(FusedMoE):
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor): router_logits: torch.Tensor):
return torch.ops.vllm.ep_moe_forward(hidden_states, router_logits, return torch.ops.vllm.mori_moe_forward(hidden_states, router_logits,
self.layer_name) self.layer_name)
def get_expert_weights(self) -> Iterable[torch.Tensor]: def get_expert_weights(self) -> Iterable[torch.Tensor]:
...@@ -350,7 +328,7 @@ class EPMoE(FusedMoE): ...@@ -350,7 +328,7 @@ class EPMoE(FusedMoE):
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate) use_fused_gate=self.use_fused_gate)
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None: if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
if self.use_int8_dispatch: if self.use_int8_dispatch:
...@@ -377,11 +355,10 @@ class EPMoE(FusedMoE): ...@@ -377,11 +355,10 @@ class EPMoE(FusedMoE):
hidden_states, hidden_states,
topk_weights, topk_weights,
scales, scales,
topk_ids, topk_ids
#layer_idx=int(self.layer_name.split('.')[2])
) )
expert_output = self.quant_method.apply_ep( expert_output = self.quant_method.apply_mori_ep(
layer=self, layer=self,
x=dispatch_output, x=dispatch_output,
topk_weights=dispatch_weights, topk_weights=dispatch_weights,
...@@ -394,7 +371,6 @@ class EPMoE(FusedMoE): ...@@ -394,7 +371,6 @@ class EPMoE(FusedMoE):
num_local_tokens=dispatch_recv_num_token, num_local_tokens=dispatch_recv_num_token,
config_select_bs=hidden_states.shape[0]*self.ep_size/self.dp_size, config_select_bs=hidden_states.shape[0]*self.ep_size/self.dp_size,
scales=dispatch_scales if self.use_int8_dispatch else None scales=dispatch_scales if self.use_int8_dispatch else None
# routed_scaling_factor=self.routed_scaling_factor,
) )
# self.sync() # self.sync()
...@@ -404,11 +380,7 @@ class EPMoE(FusedMoE): ...@@ -404,11 +380,7 @@ class EPMoE(FusedMoE):
# self.sync() # self.sync()
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None: if self.shared_experts is not None:
# shared_output = (
# self.maybe_all_reduce_tensor_model_parallel(
# shared_output))
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
else: else:
...@@ -420,7 +392,7 @@ class EPMoE(FusedMoE): ...@@ -420,7 +392,7 @@ class EPMoE(FusedMoE):
return final_hidden_states return final_hidden_states
def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, def mori_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor: layer_name: str) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
...@@ -429,16 +401,16 @@ def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, ...@@ -429,16 +401,16 @@ def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(hidden_states, router_logits)
def ep_moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, def mori_moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor: layer_name: str) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
direct_register_custom_op( direct_register_custom_op(
op_name="ep_moe_forward", op_name="mori_moe_forward",
op_func=ep_moe_forward, op_func=mori_moe_forward,
mutates_args=["hidden_states", "router_logits"], mutates_args=["hidden_states", "router_logits"],
fake_impl=ep_moe_forward_fake, fake_impl=mori_moe_forward_fake,
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order,), tags=(torch.Tag.needs_fixed_stride_order,),
) )
\ No newline at end of file
...@@ -167,6 +167,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -167,6 +167,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \ (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
self.enable_moe_group_gemm = parallel_config.enable_expert_parallel and envs.VLLM_ENABLE_MOE_GROUP_GEMM
def create_weights( def create_weights(
self, self,
...@@ -250,6 +252,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -250,6 +252,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
**_ ): **_ ):
if not self.enable_moe_group_gemm:
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers() workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
return fused_experts_impl_w4a8_marlin( return fused_experts_impl_w4a8_marlin(
x, x,
...@@ -274,8 +277,11 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -274,8 +277,11 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
shared_output=shared_output, shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
) )
else:
# TODO:
return None
def apply_ep( #dp+ep def apply_mori_ep(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
...@@ -310,12 +316,11 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -310,12 +316,11 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
w1_scale=(layer.w13_weight_scale), w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale), w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale, a1_scale=scales,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
num_local_tokens=num_local_tokens, num_local_tokens=num_local_tokens,
config_select_bs=config_select_bs, config_select_bs=config_select_bs,
q_scales=scales
) )
def apply( def apply(
......
...@@ -43,8 +43,8 @@ from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group, ...@@ -43,8 +43,8 @@ from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.ep_moe.layer import EPMoE from vllm.model_executor.layers.fused_moe.mori_moe.layer import MoriMoE
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EPSharedExperts from vllm.model_executor.layers.fused_moe.mori_moe.ep_moe_utlis import EPSharedExperts
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -167,10 +167,10 @@ class DeepseekV2MoE(nn.Module): ...@@ -167,10 +167,10 @@ class DeepseekV2MoE(nn.Module):
self.n_local_physical_experts) self.n_local_physical_experts)
dp_size = get_dp_group().world_size dp_size = get_dp_group().world_size
self.use_mori_ep = envs.VLLM_ALL2ALL_BACKEND == 'mori' and dp_size > 1 and parallel_config.enable_expert_parallel self.use_mori_ep = parallel_config.enable_expert_parallel and dp_size > 1 and envs.VLLM_ALL2ALL_BACKEND == 'mori'
self.enable_expert_parallel = parallel_config.enable_expert_parallel self.enable_expert_parallel = parallel_config.enable_expert_parallel
moe_cls = FusedMoE if not self.use_mori_ep else EPMoE moe_cls = FusedMoE if not self.use_mori_ep else MoriMoE
self.experts = moe_cls( self.experts = moe_cls(
num_experts=config.n_routed_experts, num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
...@@ -225,7 +225,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -225,7 +225,7 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
if not self.use_mori_ep: if not self.enable_expert_parallel:
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD: if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
...@@ -249,6 +249,20 @@ class DeepseekV2MoE(nn.Module): ...@@ -249,6 +249,20 @@ class DeepseekV2MoE(nn.Module):
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \ final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor) * (1. / self.routed_scaling_factor)
else:
if not self.use_mori_ep:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
else: else:
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment