Commit fe70dcb2 authored by 王敏's avatar 王敏
Browse files

临时添加cudagraph代码,目前还有问题

parent 121db653
......@@ -4320,6 +4320,9 @@ class CompilationConfig:
self.splitting_ops = [] if self.full_cuda_graph else [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.token_permutation_forward",
"vllm.token_unpermutation_forward",
"vllm.ep_moe_forward",
]
......
......@@ -948,6 +948,7 @@ def init_distributed_environment(
"Fallback Gloo backend is not available.")
backend = "gloo"
# this backend is used for WORLD
backend="cpu:gloo,cuda:nccl"
torch.distributed.init_process_group(
backend=backend,
init_method=distributed_init_method,
......
......@@ -204,6 +204,24 @@ def maybe_move_tensor_to_cpu(tensor, as_numpy=False, record_stream=False):
return tensor
def maybe_move_tensor_to_cpu_block(tensor, as_numpy=False, record_stream=False):
"""Move a tensor to CPU if it is on GPU.
Args:
tensor (torch.Tensor or None): The tensor to move to CPU.
as_numpy (bool): Whether to convert the tensor to a numpy array.
record_stream (bool): Whether to record the stream of the tensor, to prevent memory leak
when the DtoH data transfer is on a side stream.
"""
if torch.is_tensor(tensor) and tensor.is_cuda:
cpu_tensor = tensor.to(torch.device("cpu"))
if as_numpy:
cpu_tensor = cpu_tensor.numpy()
if record_stream:
tensor.record_stream(torch.cuda.current_stream())
tensor = cpu_tensor
return tensor
def sort_chunks_by_idxs(
input: torch.Tensor, split_sizes: torch.Tensor, sorted_idxs: torch.Tensor, fused: bool = False
):
......
......@@ -8,8 +8,10 @@ import torch.nn.functional as F
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.config import get_current_vllm_config
from vllm.model_executor.custom_op import CustomOp
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.distributed.parallel_state import get_ep_group, get_node_count
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
......@@ -19,10 +21,13 @@ from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAllt
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EpMoeConfig
from vllm.utils import direct_register_custom_op
from lightop import groupgemm
#import mori
import torch.distributed as dist
logger = init_logger(__name__)
_MORI_OP = None
@CustomOp.register("unquantized_ep_moe")
class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
......@@ -36,7 +41,7 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
self.zero_token_count = None
def apply(
def apply_ep(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
......@@ -218,7 +223,8 @@ class EPMoE(FusedMoE):
self.use_shared_expert = False
self.token_dispatcher = MoEAlltoAllTokenDispatcher(
self.local_num_experts, self.local_expert_indices, config=self.ep_moe_config
self.local_num_experts, self.local_expert_indices,
config=self.ep_moe_config, #layer_name=f"{self.layer_name}.token_dispatcher",
)
self.shared_expert_overlap = moe_shared_expert_overlap
......@@ -226,6 +232,36 @@ class EPMoE(FusedMoE):
self.dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
if False:
self.mori_op = self.get_mori_op()
def get_mori_op(self):
global _MORI_OP
if _MORI_OP is None:
# world_group = torch.distributed.group.WORLD
# assert world_group is not None
torch._C._distributed_c10d._register_process_group("mori_ep", get_ep_group().device_group)
mori.shmem.shmem_torch_process_group_init("mori_ep")
vllm_config = get_current_vllm_config()
config = mori.ops.EpDispatchCombineConfig(
data_type=vllm_config.model_config.dtype,
rank=self.ep_rank,
world_size=self.ep_size,
hidden_dim=self.hidden_size,
scale_dim=0,
scale_type_size=vllm_config.model_config.dtype.itemsize,
max_num_inp_token_per_rank=10000,
num_experts_per_rank=self.local_num_experts,
num_experts_per_token=self.top_k,
max_token_type_size=4,
# block_num=40,
# warp_num_per_block=8,
kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode
)
_MORI_OP = mori.ops.EpDispatchCombineOp(config)
return _MORI_OP
def set_shared_experts(self, shared_experts: torch.nn.Module):
if self.shared_experts is None:
self.shared_experts = shared_experts
......@@ -243,6 +279,10 @@ class EPMoE(FusedMoE):
assert isinstance(quant_method, FusedMoEMethodBase)
return quant_method
def sync(self):
torch.cuda.synchronize()
dist.barrier()
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
return torch.ops.vllm.ep_moe_forward(hidden_states, router_logits,
......@@ -267,21 +307,57 @@ class EPMoE(FusedMoE):
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
probs = torch.zeros_like(router_logits, dtype=topk_weights.dtype).scatter(1, topk_ids, topk_weights)
routing_map = torch.zeros_like(router_logits).int().scatter(1, topk_ids, 1).bool()
if True:
probs = None
if self.apply_router_weight_on_input:
probs = torch.zeros_like(router_logits, dtype=topk_weights.dtype).scatter(1, topk_ids, topk_weights)
(dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
hidden_states, probs, routing_map
)
routing_map = torch.zeros_like(router_logits).int().scatter(1, topk_ids, 1).bool()
(dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
hidden_states, probs, routing_map
)
else:
topk_ids = topk_ids.to(torch.int32)
scales = torch.rand(
hidden_states.shape[0],
0,
dtype=torch.float32,
device=hidden_states.device,
)
(
dispatched_input,
dispatch_weights,
dispatch_scales,
dispatch_indices,
dispatch_recv_num_token,
) = self.mori_op.dispatch(
hidden_states,
topk_weights,
scales,
topk_ids,
)
tokens_per_expert = dispatch_recv_num_token
self.sync()
print("######################dispatched_input shape:{} dispatch_weights.shape:{} dispatch_indices shape:{}".format(dispatched_input.shape,
dispatch_weights.shape, dispatch_indices.shape))
print("####################dispatch_recv_num_token:", dispatch_recv_num_token.tolist())
#print("####################dispatch_weights:", dispatch_weights.tolist())
#print("####################dispatch_indices:", dispatch_indices.tolist())
# Matrix multiply.
expert_output = self.quant_method.apply(
expert_output = self.quant_method.apply_ep(
layer=self,
hidden_states=dispatched_input,
tokens_per_expert=tokens_per_expert
)
final_hidden_states = self.token_dispatcher.token_unpermutation(expert_output)
if True:
final_hidden_states = self.token_dispatcher.token_unpermutation(expert_output)
else:
combine_output, _ = self.mori_op.combine(expert_output, dispatch_weights, topk_ids)
final_hidden_states = combine_output[:hidden_states.shape[0], :]
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
# if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations
......
......@@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.distributed.parallel_state import (get_dp_group,
get_tp_group,
......@@ -11,6 +12,7 @@ from vllm.distributed.parallel_state import (get_dp_group,
get_tensor_model_parallel_rank)
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import (EPSharedExperts,
maybe_move_tensor_to_cpu,
maybe_move_tensor_to_cpu_block,
permute,
sort_chunks_by_idxs,
unpermute,
......@@ -21,12 +23,16 @@ from vllm.distributed import (tensor_model_parallel_all_gather,
expert_parallel_all_gather,
expert_parallel_gather)
from vllm.platforms import current_platform
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import direct_register_custom_op
from vllm.config import get_current_vllm_config
cuda_dtoh_stream = torch.cuda.Stream()
cuda_dtoh_sync_event = torch.cuda.Event(enable_timing=False)
class MoETokenDispatcher:
class MoETokenDispatcher(nn.Module):
"""
MoE Token Dispatcher
"""
......@@ -35,6 +41,7 @@ class MoETokenDispatcher:
"""
Initialize the MoE Token Dispatcher.
"""
super().__init__()
self.config = config
self.tp_size = 1
......@@ -106,7 +113,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
"""
def __init__(
self, num_local_experts: int, local_expert_indices: List[int], config: EpMoeConfig
self, num_local_experts: int, local_expert_indices: List[int], config: EpMoeConfig, layer_name: str=""
) -> None:
"""
Initialize the AlltoAll token dispatcher.
......@@ -130,6 +137,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.local_expert_indices[i] == self.local_expert_indices[i + 1] - 1
), "local_expert_indices must be continous"
self.layer_name = layer_name
# [ep_size]. Represents the number of tokens sent by the current rank to other
# EP ranks.
self.input_splits = None
......@@ -174,6 +182,13 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.probs = None
self.dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
# For smuggling this layer into the fused moe custom op
vllm_config = get_current_vllm_config()
compilation_config = vllm_config.compilation_config
if layer_name in compilation_config.static_forward_context:
raise ValueError("Duplicate layer name: {}".format(layer_name))
compilation_config.static_forward_context[layer_name] = self
def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor:
"""
Preprocess token routing map for AlltoAll communication and token permutation.
......@@ -196,60 +211,50 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.num_out_tokens = routing_map.size(0) * self.config.moe_router_topk
if self.ep_size > 1 or self.tp_size > 1:
# ===================================================
# Calculate input_splits, output_splits for alltoall/allgather in variable size.
# ===================================================
# [ep_size]. Represents the number of tokens sent by the current rank to other
# EP ranks.
self.input_splits = num_local_tokens_per_expert.reshape(
self.ep_size, self.num_local_experts
).sum(axis=1)
# Gather the global distribution of tokens across ranks.
# num_global_tokens_per_expert represents the number of tokens sent to each
# expert by all ranks.
# [tp_size, ep_size, num_experts]
if self.use_all_gather:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
num_global_tokens_per_expert = expert_parallel_all_gather(num_local_tokens_per_expert) \
.reshape(self.ep_size, self.tp_size, self.num_experts) \
.transpose(0, 1)
else:
# None may be returned for rank > 0
num_global_tokens_per_expert = expert_parallel_gather(num_local_tokens_per_expert) \
.reshape(self.ep_size, self.tp_size, self.num_experts) \
.transpose(0, 1)
# [tp_size, ep_size, num_experts] -> [tp_size, ep_size, num_local_experts]
num_global_tokens_per_local_expert = num_global_tokens_per_expert[
:, :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1
].contiguous()
# [tp_size, ep_size, num_local_experts] -> [tp_size, ep_size]
num_global_tokens_per_rank = num_global_tokens_per_local_expert.sum(axis=2)
# [tp_size, ep_size] -> [ep_size]
# self.output_splits represents the number of tokens received by the current rank
# from other EP rank.
self.output_splits = num_global_tokens_per_rank[self.tp_rank]
# [tp_size, ep_size] -> [tp_size]
# self.output_splits_tp represents the number of tokens received by the current
# rank from other TP rank.
#self.output_splits_tp = num_global_tokens_per_rank.sum(axis=1)
# [tp_size, ep_size, num_local_experts] -> [num_local_experts]
num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(dim=(0, 1))
# A synchronization is needed before expert parallel AlltoAll communication
# to get the `input_splits` and `output_splits` CPU values.
self._maybe_update_cuda_sync_point("before_ep_alltoall")
# ===================================================
# Calculate input_splits, output_splits for alltoall/allgather in variable size.
# ===================================================
# [ep_size]. Represents the number of tokens sent by the current rank to other
# EP ranks.
self.input_splits = num_local_tokens_per_expert.reshape(
self.ep_size, self.num_local_experts
).sum(axis=1)
# Gather the global distribution of tokens across ranks.
# num_global_tokens_per_expert represents the number of tokens sent to each
# expert by all ranks.
# [tp_size, ep_size, num_experts]
if self.use_all_gather:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
num_global_tokens_per_expert = expert_parallel_all_gather(num_local_tokens_per_expert) \
.reshape(self.ep_size, self.tp_size, self.num_experts) \
.transpose(0, 1)
else:
num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape(
self.num_experts
)
num_tokens_per_local_expert = num_local_tokens_per_expert
# A synchronization is needed before the returns
# to get the `num_tokens_per_local_expert` CPU value.
self._maybe_update_cuda_sync_point("before_finish")
# None may be returned for rank > 0
num_global_tokens_per_expert = expert_parallel_gather(num_local_tokens_per_expert) \
.reshape(self.ep_size, self.tp_size, self.num_experts) \
.transpose(0, 1)
# [tp_size, ep_size, num_experts] -> [tp_size, ep_size, num_local_experts]
num_global_tokens_per_local_expert = num_global_tokens_per_expert[
:, :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1
].contiguous()
# [tp_size, ep_size, num_local_experts] -> [tp_size, ep_size]
num_global_tokens_per_rank = num_global_tokens_per_local_expert.sum(axis=2)
# [tp_size, ep_size] -> [ep_size]
# self.output_splits represents the number of tokens received by the current rank
# from other EP rank.
self.output_splits = num_global_tokens_per_rank[self.tp_rank]
# [tp_size, ep_size] -> [tp_size]
# self.output_splits_tp represents the number of tokens received by the current
# rank from other TP rank.
#self.output_splits_tp = num_global_tokens_per_rank.sum(axis=1)
# [tp_size, ep_size, num_local_experts] -> [num_local_experts]
num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(dim=(0, 1))
# A synchronization is needed before expert parallel AlltoAll communication
# to get the `input_splits` and `output_splits` CPU values.
#self._maybe_update_cuda_sync_point("before_ep_alltoall")
if self.num_local_experts > 1:
# [tp_size * ep_size, num_local_experts]. Represents the number of tokens sent
......@@ -257,21 +262,40 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.num_global_tokens_per_local_expert = num_global_tokens_per_local_expert.view(
-1, self.num_local_experts
)
if not self.config.moe_permute_fusion:
# A synchronization is needed before permutation 2
# to get the `num_global_tokens_per_local_expert` CPU value.
self._maybe_update_cuda_sync_point("before_permutation_2")
assert (
self.cuda_sync_point_priority[self.cuda_dtoh_point]
<= self.cuda_sync_point_priority[self.cuda_sync_point]
), "cuda_sync_point must be after cuda_dtoh_point."
# if not self.config.moe_permute_fusion:
# # A synchronization is needed before permutation 2
# # to get the `num_global_tokens_per_local_expert` CPU value.
# self._maybe_update_cuda_sync_point("before_permutation_2")
# assert (
# self.cuda_sync_point_priority[self.cuda_dtoh_point]
# <= self.cuda_sync_point_priority[self.cuda_sync_point]
# ), "cuda_sync_point must be after cuda_dtoh_point."
return num_tokens_per_local_expert
def token_permutation(
self, hidden_states: torch.Tensor,
probs: torch.Tensor,
routing_map: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.routing_map = routing_map
assert routing_map.dim() == 2, "Expected 2D tensor for token2expert mask"
assert routing_map.dtype == torch.bool, "Expected bool tensor for mask"
tokens_per_expert = self.preprocess(self.routing_map)
if self.config.moe_shared_expert_overlap and self.shared_experts is not None:
self.shared_experts.pre_forward_comm(hidden_states.view(self.hidden_shape))
global_input_tokens = torch.ops.vllm.token_permutation_forward(tokens_per_expert, hidden_states,
probs, routing_map, self.layer_name)
return global_input_tokens, tokens_per_expert
def token_permutation_impl(
self,
tokens_per_expert: torch.Tensor,
hidden_states: torch.Tensor,
probs: torch.Tensor,
routing_map: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Dispatch tokens to local experts using AlltoAll communication.
......@@ -293,23 +317,16 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
- Number of tokens per expert.
"""
# Preprocess: Get the metadata for communication, permutation and computation operations.
self.hidden_shape = hidden_states.shape
if self.config.apply_router_weight_on_input:
self.probs = probs
self.routing_map = routing_map
assert probs.dim() == 2, "Expected 2D tensor for probs"
assert routing_map.dim() == 2, "Expected 2D tensor for token2expert mask"
assert routing_map.dtype == torch.bool, "Expected bool tensor for mask"
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
tokens_per_expert = self.preprocess(self.routing_map)
if self.config.moe_shared_expert_overlap and self.shared_experts is not None:
self.shared_experts.pre_forward_comm(hidden_states.view(self.hidden_shape))
# Permutation 1: input to AlltoAll input
tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_permutation_1", tokens_per_expert
)
self.hidden_shape = hidden_states.shape
if self.config.apply_router_weight_on_input:
self.probs = probs
assert probs.dim() == 2, "Expected 2D tensor for probs"
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
self.hidden_shape_before_permute = hidden_states.shape
permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute(
......@@ -350,11 +367,16 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
#tokens_per_expert = self._maybe_dtoh_and_synchronize("before_finish", tokens_per_expert)
return global_input_tokens, tokens_per_expert
return global_input_tokens
def token_unpermutation(
self, hidden_states: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> torch.Tensor:
return torch.ops.vllm.token_unpermutation_forward(hidden_states, self.layer_name)
def token_unpermutation_impl(
self, hidden_states: torch.Tensor,
) -> torch.Tensor:
"""
Reverse the token permutation to restore the original order.
......@@ -463,8 +485,60 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
cuda_dtoh_sync_event.record()
# if point == self.cuda_sync_point:
# # Synchronize with the dtoh stream at self.cuda_sync_point.
# cuda_dtoh_stream.synchronize()
# if point == self.cuda_sync_point:
# # Synchronize with the dtoh stream at self.cuda_sync_point.
# cuda_dtoh_stream.synchronize()
return tokens_per_expert
def token_permutation_forward(tokens_per_expert: torch.Tensor,
hidden_states: torch.Tensor,
probs: torch.Tensor,
routing_map: torch.Tensor,
layer_name: str) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
return self.token_permutation_impl(tokens_per_expert, hidden_states, probs, routing_map)
def token_permutation_forward_fake(tokens_per_expert: torch.Tensor,
hidden_states: torch.Tensor,
probs: torch.Tensor,
routing_map: torch.Tensor,
layer_name: str) -> torch.Tensor:
return torch.empty_like(hidden_states)
direct_register_custom_op(
op_name="token_permutation_forward",
op_func=token_permutation_forward,
mutates_args=["tokens_per_expert", "hidden_states", "probs", "routing_map"],
fake_impl=token_permutation_forward_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
def token_unpermutation_forward(hidden_states: torch.Tensor,
layer_name: str) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
return self.token_unpermutation_impl(hidden_states)
def token_unpermutation_forward_fake(hidden_states: torch.Tensor,
layer_name: str) -> torch.Tensor:
return torch.empty_like(hidden_states)
return tokens_per_expert
\ No newline at end of file
direct_register_custom_op(
op_name="token_unpermutation_forward",
op_func=token_unpermutation_forward,
mutates_args=["hidden_states"],
fake_impl=token_unpermutation_forward_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment