"tests/vscode:/vscode.git/clone" did not exist on "bc9d7b5595887d4a1358926579b638b42368efd7"
Commit dbd0bda6 authored by 王敏's avatar 王敏
Browse files

临时上传大ep代码

parent 15347448
...@@ -6,7 +6,7 @@ from typing import Any, Optional, Union ...@@ -6,7 +6,7 @@ from typing import Any, Optional, Union
import torch import torch
import torch.distributed import torch.distributed
from .parallel_state import get_tp_group from .parallel_state import get_tp_group, get_ep_group
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
...@@ -32,6 +32,17 @@ def tensor_model_parallel_gather(input_: torch.Tensor, ...@@ -32,6 +32,17 @@ def tensor_model_parallel_gather(input_: torch.Tensor,
"""Gather the input tensor across model parallel group.""" """Gather the input tensor across model parallel group."""
return get_tp_group().gather(input_, dst, dim) return get_tp_group().gather(input_, dst, dim)
def expert_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_ep_group().all_gather(input_, dim)
def expert_parallel_gather(input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> Optional[torch.Tensor]:
"""Gather the input tensor across model parallel group."""
return get_ep_group().gather(input_, dst, dim)
def broadcast_tensor_dict(tensor_dict: Optional[dict[Any, Union[torch.Tensor, def broadcast_tensor_dict(tensor_dict: Optional[dict[Any, Union[torch.Tensor,
Any]]] = None, Any]]] = None,
......
import math
from typing import Callable, List, Optional, Tuple, Union
from dataclasses import dataclass
import torch
from torch import nn
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.distributed import (get_dp_group, get_ep_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
try:
from transformer_engine.pytorch.permutation import (
moe_permute,
moe_sort_chunks_by_index,
moe_unpermute,
)
fused_permute = moe_permute
fused_unpermute = moe_unpermute
fused_sort_chunks_by_index = moe_sort_chunks_by_index
HAVE_TE = True
except ImportError:
fused_permute = None
fused_unpermute = None
fused_sort_chunks_by_index = None
HAVE_TE = False
@dataclass
class EpMoeConfig:
moe_router_topk: int = 2
moe_permute_fusion: bool = False
moe_shared_expert_overlap: bool = False
ep_size: int = 1
num_moe_experts: int = 256
@staticmethod
def make(moe_router_topk: int = 2,
moe_permute_fusion: bool = False,
moe_shared_expert_overlap: bool = False,
ep_size: int = 1,
num_moe_experts: int = 256) -> "EpMoeConfig":
return EpMoeConfig(moe_router_topk=moe_router_topk,
moe_permute_fusion=moe_permute_fusion,
moe_shared_expert_overlap=moe_shared_expert_overlap,
ep_size=ep_size,
num_moe_experts=num_moe_experts)
class EPSharedExperts(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
moe_shared_expert_overlap: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj")
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
self.moe_shared_expert_overlap = moe_shared_expert_overlap
if self.moe_shared_expert_overlap:
self.cached_fc1_input = None
self.cached_fc2_input = None
self.cached_fc2_output = None
self.cached_output = None
self.gate_score = None
self.stream = torch.cuda.Stream()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
def linear_fc1_forward_and_act(self, overlapped_comm_output=None):
"""
Do Linear FC1 and activation function forward.
This function is used to overlap shared experts with the dispatcher.
It is only useful when --moe-shared-expert-overlap is set and may be changed.
"""
assert self.moe_shared_expert_overlap
with torch.cuda.stream(self.stream):
# [s, b, 4 * h/p]
intermediate_parallel, bias_parallel = self.gate_up_proj(self.cached_fc1_input)
self.cached_fc1_input = None
if bias_parallel is not None:
intermediate_parallel = intermediate_parallel + bias_parallel
intermediate_parallel = self.act_fn(intermediate_parallel)
self.cached_fc2_input = intermediate_parallel
def linear_fc2_forward(self, overlapped_comm_output=None):
"""
Do Linear FC2 forward.
This function is used to overlap shared experts with the dispatcher.
It is only useful when --moe-shared-expert-overlap is set and may be changed.
"""
assert self.moe_shared_expert_overlap
assert self.cached_fc2_input is not None
with torch.cuda.stream(self.stream):
# [s, b, h]
self.cached_fc2_output, _ = self.down_proj(self.cached_fc2_input)
self.cached_fc2_input = None
def pre_forward_comm(self, input):
"""
All Gather for SP before forward.
This function is used to overlap shared experts with the dispatcher.
It is only useful when --moe-shared-expert-overlap is set and may be changed.
"""
assert self.cached_output is None
self.stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.stream):
self.cached_fc1_input = input
def post_forward_comm(self):
"""
Reduce scatter for SP after forward.
This function is used to overlap shared experts with the dispatcher.
It is only useful when --moe-shared-expert-overlap is set and may be changed.
"""
assert self.moe_shared_expert_overlap
assert self.cached_fc2_output is not None
with torch.cuda.stream(self.stream):
self.cached_output = tensor_model_parallel_all_reduce(
self.cached_fc2_output
)
self.cached_fc2_output = None
def get_output(self):
"""
Gets the module forward output.
This function is used to overlap shared experts with the dispatcher.
It is only useful when --moe-shared-expert-overlap is set and may be changed.
"""
assert self.moe_shared_expert_overlap
assert self.cached_output is not None
with torch.cuda.stream(self.stream):
output = self.cached_output
self.cached_output = None
torch.cuda.current_stream().wait_stream(self.stream)
return output
def maybe_move_tensor_to_cpu(tensor, as_numpy=False, record_stream=False):
"""Move a tensor to CPU if it is on GPU.
Args:
tensor (torch.Tensor or None): The tensor to move to CPU.
as_numpy (bool): Whether to convert the tensor to a numpy array.
record_stream (bool): Whether to record the stream of the tensor, to prevent memory leak
when the DtoH data transfer is on a side stream.
"""
if torch.is_tensor(tensor) and tensor.is_cuda:
cpu_tensor = tensor.to(torch.device("cpu"), non_blocking=True)
if as_numpy:
cpu_tensor = cpu_tensor.numpy()
if record_stream:
tensor.record_stream(torch.cuda.current_stream())
tensor = cpu_tensor
return tensor
def sort_chunks_by_idxs(
input: torch.Tensor, split_sizes: torch.Tensor, sorted_idxs: torch.Tensor, fused: bool = False
):
"""Split and sort the input tensor based on the split_sizes and sorted indices."""
if fused:
if not HAVE_TE or fused_sort_chunks_by_index is None:
raise ValueError(
"fused_sort_chunks_by_index is not available. Please install TE >= 2.1.0."
)
return fused_sort_chunks_by_index(input, split_sizes, sorted_idxs)
input = torch.split(input, split_sizes.tolist(), dim=0)
output = torch.cat([input[i] for i in sorted_idxs.tolist()], dim=0)
return output
def permute(
tokens,
routing_map,
num_out_tokens: Optional[int] = None,
fused: bool = False,
drop_and_pad: bool = False,
):
"""Permute the tokens and probs based on the mask.
Tokens with the same designated expert will be grouped together.
The shape of mask is [tokens, num_experts], it indicates which experts were selected
by each token.
When drop_and_pad=True, in routing_map, the number of non-zeros in each column equals to
expert capacity. This function exploits this feature to use ops that support cuda graph.
Args:
tokens (torch.Tensor): The input token tensor, [num_tokens, hidden].
routing_map (torch.Tensor): The sparse token to expert mapping, [num_tokens, num_experts].
num_out_tokens (int, optional): The number of output tokens. If None, it's set to
the number of input tokens.
fused (bool, optional): Whether use the fused permute function.
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
and pads the number of tokens to the expert capacity.
If set to true, routing_map has a fixed number of non-zeros
in each column.
"""
if fused:
if not HAVE_TE or fused_permute is None:
raise ValueError("fused_permute is not available. Please install TE >= 2.1.0.")
return fused_permute(tokens, routing_map, num_out_tokens)
num_tokens, hidden = tokens.shape
num_experts = routing_map.shape[1]
if drop_and_pad and not (num_out_tokens is None):
capacity = num_out_tokens // num_experts
assert not routing_map.requires_grad
# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
routing_map = routing_map.to(dtype=torch.int8).T.contiguous()
# use argsort to put indices of all non-zeros in the beginning of list
# and keep the first `capacity` number of indices
sorted_indices = routing_map.argsort(dim=-1, descending=True, stable=True)[
:, :capacity
].contiguous()
# flatten from [num_experts, capacity] to 1D
sorted_indices = sorted_indices.view(-1)
else:
# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
routing_map = routing_map.bool().T.contiguous()
# Create a dense expert-to-token mapping from the sparse token-to-expert mapping
token_indices = (
torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1)
)
sorted_indices = token_indices.masked_select(routing_map)
# use the mapping to permute the tokens
permuted_input = tokens.index_select(0, sorted_indices)
return permuted_input, sorted_indices
def unpermute(
permuted_tokens: torch.Tensor,
sorted_indices: torch.Tensor,
restore_shape: torch.Size,
probs: torch.Tensor = None,
routing_map: torch.Tensor = None,
fused: bool = False,
drop_and_pad: bool = False,
):
"""
Restore the original order of tokens after permutation. If probs are provided, it
will also apply them to the tokens before restoring the order.
This function exploits these features to use ops that support cuda graph.
Args:
permuted_tokens (torch.Tensor): The permuted token tensor.
sorted_indices (torch.Tensor): The indices used to sort the tokens.
restore_shape (torch.Size): The shape of the unpermuted tensor.
probs (torch.Tensor, optional): The unpermuted probs tensor,
routing_map (torch.Tensor, optional): Token to expert mapping, shape
[num_tokens, num_experts].
fused (bool, optional): Whether use the fused unpermute function.
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
and pads the number of tokens to the expert capacity.
Returns:
torch.Tensor: The tokens restored to their original order.
"""
if fused:
if not HAVE_TE or fused_unpermute is None:
raise ValueError("fused_unpermute is not available. Please install TE >= 2.1.0.")
return fused_unpermute(permuted_tokens, sorted_indices, probs, restore_shape)
_, hidden = restore_shape
input_dtype = permuted_tokens.dtype
if probs is not None:
assert routing_map is not None, "Mask must be provided to permute the probs."
if drop_and_pad:
num_experts = routing_map.size(1)
num_permuted_tokens = sorted_indices.size(0)
capacity = num_permuted_tokens // num_experts
num_unpermuted_tokens = probs.size(0)
# [num_unpermuted_tokens, num_experts] -> num_experts * num_unpermuted_tokens
probs_T_1D = probs.T.contiguous().view(-1)
# get 1D indices of the probs selected by routing_map
indices_dim0 = torch.arange(num_experts, device=routing_map.device).unsqueeze(-1)
indices_dim1 = sorted_indices.view(num_experts, capacity)
indices_1D = (indices_dim0 * num_unpermuted_tokens + indices_dim1).view(-1)
# get probs from indices
permuted_probs = probs_T_1D.index_select(0, indices_1D)
else:
permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous())
# Here may promote permuted_tokens to higher precision (fp32/fp64) if probs is in
# higher precision due to moe_router_dtype being enabled. This can lead to
# additional GPU memory usage. Use --moe-permute-fusion flag to avoid this extra memory
# allocation.
permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1)
# Create an output tensor filled with zeros
output_tokens = torch.zeros(
restore_shape, dtype=permuted_tokens.dtype, device=permuted_tokens.device
)
# Scatter add the permuted_input back to the original positions
output_tokens.scatter_add_(0, sorted_indices.unsqueeze(1).expand(-1, hidden), permuted_tokens)
return output_tokens.to(dtype=input_dtype)
def all_to_all(group, input, output_split_sizes, input_split_sizes):
# torch.cuda.synchronize()
# import sys
# sys.stderr.write(f"############all_to_all input_split_sizes:{input_split_sizes}\n output_split_sizes:{output_split_sizes}")
# sys.stderr.flush()
world_size = torch.distributed.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input
input = input.contiguous()
if output_split_sizes is None:
# Equal split (all2all)
output = torch.empty_like(input)
else:
# Unequal split (all2all-v)
output = input.new_empty(
size=[sum(output_split_sizes)] + list(input.size()[1:]),
dtype=input.dtype,
device=torch.cuda.current_device(),
)
torch.distributed.all_to_all_single(
output,
input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
)
return output
import logging
from typing import List, Optional
import torch
import triton
import triton.language as tl
logger = logging.getLogger(__name__)
@triton.jit
def compute_src2dst_triton_kernel(
reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr
):
pid = tl.program_id(axis=0)
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = dst_id < num_toks
src_id = tl.load(reorder_ids + dst_id, mask=mask)
tl.store(src2dst + src_id, dst_id, mask=mask)
@triton.jit
def deepep_compute_src2dst_triton_kernel(
reorder_ids, src2dst, num_toks, num_minus_one, BLOCK_SIZE: tl.constexpr
):
pid = tl.program_id(axis=0)
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = dst_id < num_toks
src_id = tl.load(reorder_ids + dst_id, mask=mask)
num_invalid = tl.load(num_minus_one)
tl.store(src2dst + src_id, dst_id - num_invalid, mask=mask)
def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
seg_indptr = torch.empty(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int64)
# Find offet
expert_ids = torch.arange(
num_experts + 1, device=topk_ids.device, dtype=reorder_topk_ids.dtype
)
torch.searchsorted(reorder_topk_ids, expert_ids, out=seg_indptr)
num_minus_one = seg_indptr[0]
seg_indptr = seg_indptr - num_minus_one
BLOCK_SIZE = 512
grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),)
deepep_compute_src2dst_triton_kernel[grid](
reorder_ids, src2dst, topk_ids.numel(), num_minus_one, BLOCK_SIZE
)
reorder_topk_ids = reorder_topk_ids[num_minus_one:]
return reorder_topk_ids, src2dst, seg_indptr
@triton.jit
def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
expert = tl.program_id(0)
low = 0
high = num_toks - 1
target_location = -1
while low <= high:
mid = (low + high) // 2
if tl.load(reorder_topk_ids + mid) > expert:
high = mid - 1
else:
low = mid + 1
target_location = mid
tl.store(seg_indptr + expert + 1, target_location + 1)
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
compute_seg_indptr_triton_kernel[(num_experts,)](
reorder_topk_ids, seg_indptr, topk_ids.numel()
)
BLOCK_SIZE = 512
grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),)
compute_src2dst_triton_kernel[grid](
reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE
)
return reorder_topk_ids, src2dst, seg_indptr
@triton.jit
def pre_reorder_triton_kernel(
input_ptr,
gateup_input_ptr,
src2dst_ptr,
topk_ids_ptr,
a1_scales_ptr,
start_expert_id,
end_expert_id,
topk,
hidden_size,
BLOCK_SIZE: tl.constexpr,
):
OutDtype = gateup_input_ptr.dtype.element_ty
src_idx = tl.program_id(0)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk
src_ptr = input_ptr + src_idx * hidden_size
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
if a1_scales_ptr is not None:
scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id)
else:
scale = 1.0
dst_idx = tl.load(src2dst_ptr + idx)
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
out_data = (in_data * scale).to(OutDtype)
tl.store(dst_ptr + offset, out_data, mask=mask)
@triton.jit
def silu_and_mul_triton_kernel(
gateup_output,
down_input,
hidden_size,
reorder_topk_ids,
scales,
start_expert_id,
end_expert_id,
BLOCK_SIZE: tl.constexpr,
):
InDtype = gateup_output.dtype.element_ty
OutDtype = down_input.dtype.element_ty
half_hidden_size = hidden_size // 2
pid = tl.program_id(0)
expert_id = tl.load(reorder_topk_ids + pid)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
gateup_output_ptr = gateup_output + pid * hidden_size
gate_output_ptr = gateup_output_ptr
up_output_ptr = gateup_output_ptr + half_hidden_size
down_input_ptr = down_input + pid * half_hidden_size
if scales is not None:
scale = tl.load(scales + expert_id - start_expert_id)
scale = (1 / scale).to(InDtype)
else:
scale = 1
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < half_hidden_size
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
up_output = tl.load(up_output_ptr + offset, mask=mask)
# silu & mul & quantize
gate_output = gate_output * tl.sigmoid(gate_output)
gate_output = gate_output.to(InDtype)
silu_mul_output = gate_output * up_output * scale
silu_mul_output = silu_mul_output.to(OutDtype)
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
@triton.jit
def _silu_and_mul_post_quant_kernel(
input_ptr,
stride_input_0,
stride_input_1,
stride_input_2,
output_ptr,
stride_output_0,
stride_output_1,
stride_output_2,
output_scale_ptr,
stride_output_scale_0,
stride_output_scale_1,
stride_output_scale_2,
masked_m_ptr,
size_n,
fp8_max,
fp8_min,
BLOCK_N: tl.constexpr,
NUM_STAGE: tl.constexpr,
):
expert_id = tl.program_id(2)
token_id = tl.program_id(1)
hidden_dim_block_index = tl.program_id(0)
block_num_per_expert = tl.num_programs(1)
token_num_cur_expert = tl.load(masked_m_ptr + expert_id)
stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64)
stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64)
stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64)
stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64)
offs_in_d = hidden_dim_block_index * BLOCK_N + tl.arange(0, BLOCK_N)
input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d
output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d
output_scale_offs = (
output_scale_ptr
+ expert_id * stride_output_scale_0
+ hidden_dim_block_index * stride_output_scale_2
)
for token_index in tl.range(
token_id, token_num_cur_expert, block_num_per_expert, num_stages=NUM_STAGE
):
gate = tl.load(
input_ptr_offs + token_index * stride_input_1,
mask=offs_in_d < size_n,
other=0.0,
).to(tl.float32)
up = tl.load(
input_ptr_offs + token_index * stride_input_1 + size_n,
mask=offs_in_d < size_n,
other=0.0,
)
gate = gate / (1 + tl.exp(-gate))
gate = gate.to(input_ptr.dtype.element_ty)
gate_up = up * gate
_absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
output_s = _absmax / fp8_max
output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to(
output_ptr.dtype.element_ty
)
tl.store(
output_ptr_offs + token_index * stride_output_1,
output_q,
mask=offs_in_d < size_n,
)
tl.store(
output_scale_offs + token_index * stride_output_scale_1,
output_s,
)
def silu_and_mul_masked_post_quant_fwd(
input: torch.Tensor,
output: torch.Tensor,
output_scale: torch.Tensor,
quant_group_size: int,
masked_m: torch.Tensor,
):
"""
input shape [expert_num, token_num_padded, hidden_dim]
output shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8
output_scale [expert_num token_num_paddded, hidden_dim // 2 // 128] dtype float32
quant_group_size int,
masked_m shape [expert_num],
"""
assert input.is_contiguous()
assert output.dtype == torch.float8_e4m3fn
assert output.is_contiguous()
assert len(input.shape) == 3
assert input.shape[0] == masked_m.shape[0]
assert input.shape[-1] % 2 == 0
size_n = input.shape[-1] // 2
assert size_n % quant_group_size == 0
expert_num = len(masked_m)
if expert_num < 4:
BLOCK_NUM_PER_EXPERT = 64
else:
BLOCK_NUM_PER_EXPERT = 32
BLOCK_N = quant_group_size
num_warps = 1
NUM_STAGES = 6
hidden_dim_split_block_num = triton.cdiv(size_n, BLOCK_N)
assert BLOCK_N % quant_group_size == 0
grid = (
hidden_dim_split_block_num,
BLOCK_NUM_PER_EXPERT,
expert_num,
)
finfo = torch.finfo(torch.float8_e4m3fn)
fp8_max = finfo.max
fp8_min = -fp8_max
_silu_and_mul_post_quant_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
output_scale,
*output_scale.stride(),
masked_m,
size_n,
fp8_max,
fp8_min,
BLOCK_N=BLOCK_N,
NUM_STAGE=NUM_STAGES,
num_warps=num_warps,
)
return
@triton.jit
def tanh(x):
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def gelu_and_mul_triton_kernel(
gateup_output,
down_input,
hidden_size,
reorder_topk_ids,
scales,
start_expert_id,
end_expert_id,
BLOCK_SIZE: tl.constexpr,
):
InDtype = gateup_output.dtype.element_ty
OutDtype = down_input.dtype.element_ty
half_hidden_size = hidden_size // 2
pid = tl.program_id(0)
expert_id = tl.load(reorder_topk_ids + pid)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
gateup_output_ptr = gateup_output + pid * hidden_size
gate_output_ptr = gateup_output_ptr
up_output_ptr = gateup_output_ptr + half_hidden_size
down_input_ptr = down_input + pid * half_hidden_size
if scales is not None:
scale = tl.load(scales + expert_id - start_expert_id)
scale = (1 / scale).to(InDtype)
else:
scale = 1
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < half_hidden_size
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
up_output = tl.load(up_output_ptr + offset, mask=mask)
# gelu & mul & quantize
# https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
# sqrt(2/pi)
kAlpha = 0.7978845608028654
gate_output = (
0.5
* gate_output
* (
1
+ tanh(
kAlpha
* (
gate_output
+ 0.044715 * gate_output * gate_output * gate_output
)
)
)
)
gate_output = gate_output.to(InDtype)
gelu_mul_output = gate_output * up_output * scale
gelu_mul_output = gelu_mul_output.to(OutDtype)
tl.store(down_input_ptr + offset, gelu_mul_output, mask=mask)
@triton.jit
def post_reorder_triton_kernel(
down_output_ptr,
output_ptr,
src2dst_ptr,
topk_ids_ptr,
topk_weights_ptr,
start_expert_id,
end_expert_id,
topk,
hidden_size,
BLOCK_SIZE: tl.constexpr,
):
InDtype = down_output_ptr.dtype.element_ty
src_idx = tl.program_id(0)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk
topk_weights_ptr = topk_weights_ptr + src_idx * topk
computed = False
store_ptr = output_ptr + src_idx * hidden_size
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
computed = True
dst_idx = tl.load(src2dst_ptr + idx)
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
load_ptr = down_output_ptr + dst_idx * hidden_size
in_data = tl.load(load_ptr + offset, mask=mask)
sum_vec += in_data * weigh_scale
tl.store(store_ptr + offset, sum_vec, mask=mask)
if computed == False:
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
tl.store(
store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask
)
@triton.jit
def compute_m_range(
pid,
batch_size,
seg_indptr,
weight_indices,
m_num_tiles_indptr,
BLOCK_SIZE_M: tl.constexpr,
):
idx = 0
for bs in range(batch_size):
tiles = tl.load(m_num_tiles_indptr + bs)
if pid >= tiles:
idx = bs
idx_start = tl.load(m_num_tiles_indptr + idx)
m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M
m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M)
expert_id = tl.load(weight_indices + idx)
return m_range_start, m_range_end, expert_id
@triton.jit
def grouped_gemm_triton_kernel(
a,
b,
c,
batch_size,
N,
K,
seg_indptr,
weight_indices,
m_num_tiles_indptr,
scale_a,
scale_b,
use_fp8_w8a8: tl.constexpr,
group_n: tl.constexpr,
group_k: tl.constexpr,
a_stride_0: tl.constexpr,
b_stride_0: tl.constexpr,
b_stride_1: tl.constexpr,
as_stride_0: tl.constexpr,
as_stride_1: tl.constexpr,
bs_stride_0: tl.constexpr,
bs_stride_2: tl.constexpr,
bs_stride_1: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
c_dtype = c.dtype.element_ty
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
total_m_block = tl.load(m_num_tiles_indptr + batch_size)
if pid_m >= total_m_block:
return
m_range_start, m_range_end, expert_id = compute_m_range(
pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M
)
if m_range_end - m_range_start == 0:
return
n_range_start = pid_n * BLOCK_SIZE_N
n_range_end = min(n_range_start + BLOCK_SIZE_N, N)
offs_am = tl.arange(0, BLOCK_SIZE_M)
offs_bn = tl.arange(0, BLOCK_SIZE_N)
offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0)
offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0)
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :]
# [blcok_n, block_k]
b_ptr = b + (
(expert_id * b_stride_0)
+ (n_range_start + offs_bn[:, None]) * b_stride_1
+ offs_k[None, :]
)
if group_k > 0 and group_n > 0:
a_scale_ptrs = scale_a + (m_range_start + offs_am[:, None]) * as_stride_0
offs_bsn = (n_range_start + offs_bn) // group_n
b_scale_ptrs = scale_b + (expert_id * bs_stride_0) + offs_bsn * bs_stride_1
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a_tile = tl.load(
a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
)
# [block_n, blcok_k]
b_tile = tl.load(
b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
)
if group_k > 0 and group_n > 0:
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_scale = tl.load(a_scale_ptrs + offs_ks * as_stride_1)
b_scale = tl.load(b_scale_ptrs + offs_ks * bs_stride_2)
accumulator += tl.dot(a_tile, b_tile.T) * a_scale * b_scale[None, :]
else:
accumulator = tl.dot(a_tile, b_tile.T, accumulator)
a_ptr += BLOCK_SIZE_K
b_ptr += BLOCK_SIZE_K
if use_fp8_w8a8 and not (group_k > 0 and group_n > 0):
scale_a_value = tl.load(scale_a + expert_id)
scale_b_value = tl.load(scale_b + expert_id)
accumulator *= scale_a_value * scale_b_value
c_tile = accumulator.to(c_dtype)
offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M)
offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N)
c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :]
c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end)
tl.store(c_ptr, c_tile, mask=c_mask)
@triton.jit
def compute_m_num_tiles_indptr(
m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr
):
for bs in range(batch_size):
m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs)
cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M)
pre_num_tiles = tl.load(m_num_tiles_indptr + bs)
tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles)
def grouped_gemm_triton(
a: torch.Tensor,
b: torch.Tensor,
c: torch.Tensor,
batch_size: int,
weight_column_major: bool,
seg_indptr: Optional[torch.Tensor] = None,
weight_indices: Optional[torch.Tensor] = None,
use_fp8_w8a8: bool = False,
scale_a: torch.Tensor = None,
scale_b: torch.Tensor = None,
block_shape: Optional[List[int]] = None,
):
assert weight_column_major == True # TODO: more
if use_fp8_w8a8 and block_shape is None:
assert scale_a is not None and scale_b is not None
# if block_shape is not None:
# assert len(block_shape) == 2
# block_n, block_k = block_shape[0], block_shape[1]
# if _is_cuda:
# a, scale_a = sglang_per_token_group_quant_fp8(a, block_k)
# else:
# a, scale_a = per_token_group_quant_fp8(a, block_k)
# assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1]
# assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
# assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1]
# TODO: adjust config or tune kernel
# Reduce block size to prevent L40 shared memory overflow.
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
}
m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64)
compute_m_num_tiles_indptr[(1,)](
m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"]
)
grid = lambda META: (
triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size,
triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
)
grouped_gemm_triton_kernel[grid](
a,
b,
c,
batch_size,
b.size(1),
b.size(2),
seg_indptr,
weight_indices,
m_num_tiles_indptr,
scale_a,
scale_b,
use_fp8_w8a8,
0 if block_shape is None else block_shape[0],
0 if block_shape is None else block_shape[1],
a.stride(0),
b.stride(0),
b.stride(1),
scale_a.stride(0) if scale_a is not None and scale_a.ndim == 2 else 0,
scale_a.stride(1) if scale_a is not None and scale_a.ndim == 2 else 0,
scale_b.stride(0) if scale_b is not None and scale_b.ndim >= 2 else 0,
scale_b.stride(2) if scale_b is not None and scale_b.ndim == 3 else 0,
scale_b.stride(1) if scale_b is not None and scale_b.ndim >= 2 else 0,
**config,
)
return c
import logging
from typing import Callable, List, Optional, Tuple
from dataclasses import dataclass
import torch
from torch import nn
import torch.nn.functional as F
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EPSharedExperts, EpMoeConfig
from vllm.model_executor.layers.fused_moe.ep_moe.kernels import grouped_gemm_triton
logger = init_logger(__name__)
class EPMoE(FusedMoE):
"""
dp+ep MoE Expert Parallel Impl
"""
def __init__(
self,
num_experts: int, # Global number of experts
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
ep_size: Optional[int] = None,
dp_size: Optional[int] = None,
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
moe_permute_fusion: bool = False,
moe_shared_expert_overlap: bool = False
):
super().__init__(num_experts, top_k, hidden_size,
intermediate_size, params_dtype,
reduce_results, renormalize,
use_grouped_topk, num_expert_group,
topk_group, quant_config, tp_size,
ep_size, dp_size, prefix,
custom_routing_function, scoring_func,
e_score_correction_bias,
apply_router_weight_on_input,
activation,
routed_scaling_factor=routed_scaling_factor
)
self.ep_moe_config: EpMoeConfig = EpMoeConfig.make(
moe_router_topk=self.top_k,
# TODO: support fusion permute
moe_permute_fusion=moe_permute_fusion,
moe_shared_expert_overlap=moe_shared_expert_overlap,
ep_size=self.ep_size,
num_moe_experts=self.global_num_experts
)
local_expert_indices_offset = (
self.ep_rank * self.local_num_experts
)
self.local_expert_indices = [
local_expert_indices_offset + i for i in range(self.local_num_experts)
]
self.shared_experts = None
self.use_shared_expert = False
self.token_dispatcher = MoEAlltoAllTokenDispatcher(
self.local_num_experts, self.local_expert_indices, config=self.ep_moe_config
)
self.shared_expert_overlap = moe_shared_expert_overlap
self.seg_indptr = None
if quant_config is None:
self.use_fp8_w8a8 = False
self.use_block_quant = False
self.block_shape = None
self.activation_scheme = None
self.w13_weight_scale = None
self.w2_weight_scale = None
else:
self.use_fp8_w8a8 = True
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.block_shape = (
self.quant_method.quant_config.weight_block_size
if self.use_block_quant
else None
)
self.fp8_dtype = torch.float8_e4m3fn
self.activation_scheme = quant_config.activation_scheme
def set_shared_experts(self, shared_experts):
self.shared_experts = shared_experts
self.use_shared_expert = shared_experts is not None
if self.shared_expert_overlap:
self.token_dispatcher.set_shared_experts(shared_experts)
def triton_grouped_gemm_impl(self, hidden_states, tokens_per_expert, use_nn_moe):
torch.cumsum(tokens_per_expert,
dim=0,
out=self.seg_indptr[1:])
_, N, _ = self.w13_weight.shape
gateup_input = hidden_states
weight_indices_cur_rank = torch.arange(
0,
self.local_num_experts,
device=hidden_states.device,
dtype=torch.int64,
)
# GroupGemm-0
gateup_output = torch.empty(
gateup_input.shape[0],
self.w13_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
gateup_output = grouped_gemm_triton(
a=gateup_input,
b=self.w13_weight,
c=gateup_output,
batch_size=self.local_num_experts,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=self.use_fp8_w8a8,
scale_a=self.w13_input_scale if self.quant_config is not None else None,
scale_b=(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
) if self.quant_config is not None else None,
block_shape=self.block_shape,
)
# Act
down_input = torch.empty(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=(
self.fp8_dtype
if (self.use_fp8_w8a8 and not self.use_block_quant)
else hidden_states.dtype
),
)
if self.quant_config is not None and self.w2_input_scale is None and not self.use_block_quant:
self.w2_input_scale = torch.ones(
self.local_num_experts,
dtype=torch.float32,
device=hidden_states.device,
)
if self.activation == "silu":
torch.ops._C.silu_and_mul(down_input,
gateup_output.view(-1, N))
elif self.activation == "gelu":
torch.ops._C.gelu_and_mul(down_input,
gateup_output.view(-1, N))
else:
raise ValueError(f"Unsupported FusedMoe activation: {self.activation}")
# GroupGemm-1
down_output = torch.empty(
down_input.shape[0],
self.w2_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
down_output = grouped_gemm_triton(
a=down_input,
b=self.w2_weight,
c=down_output,
batch_size=self.local_num_experts,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=self.use_fp8_w8a8,
scale_a=self.w2_input_scale if self.quant_config is not None else None,
scale_b=(
self.w2_weight_scale_inv
if self.use_block_quant
else self.w2_weight_scale
) if self.quant_config is not None else None,
block_shape=self.block_shape,
)
return down_output
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
if (
self.training
and self.config.tensor_model_parallel_size > 1
and not self.config.sequence_parallel
):
raise ValueError(
"During training, performance may degrade if MoE and tensor parallelism"
"are enabled without also enabling sequence parallelism."
)
if self.seg_indptr is None:
self.seg_indptr = torch.zeros(self.local_num_experts+1, device=hidden_states. device, dtype=torch.int64)
# process MoE
def custom_forward(hidden_states, router_logits):
topk_weights, topk_ids = self.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
use_grouped_topk=self.use_grouped_topk,
top_k=self.top_k,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
indices_type=torch.int64,
routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate)
probs = torch.zeros_like(router_logits, dtype=topk_weights.dtype).scatter(1, topk_ids, topk_weights)
routing_map = torch.zeros_like(router_logits).int().scatter(1, topk_ids, 1).bool()
(dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
hidden_states, probs, routing_map
)
expert_output = self.triton_grouped_gemm_impl(dispatched_input, tokens_per_expert, self.use_nn_moe)
output = self.token_dispatcher.token_unpermutation(expert_output)
if self.use_shared_expert and not self.shared_expert_overlap:
# if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations
output = output + self.shared_experts(hidden_states)
return output
output = custom_forward(hidden_states, router_logits)
return output
\ No newline at end of file
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
import torch
from vllm.distributed.parallel_state import (get_dp_group,
get_tp_group,
get_ep_group,
get_tensor_model_parallel_rank)
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import (EPSharedExperts,
maybe_move_tensor_to_cpu,
permute,
sort_chunks_by_idxs,
unpermute,
all_to_all,
EpMoeConfig)
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_gather,
expert_parallel_all_gather,
expert_parallel_gather)
from vllm.platforms import current_platform
class MoETokenDispatcher:
"""
MoE Token Dispatcher
"""
def __init__(self, config: EpMoeConfig) -> None:
"""
Initialize the MoE Token Dispatcher.
"""
self.config = config
self.shared_experts: Optional[EPSharedExperts] = None
self.tp_size = 1
self.ep_size = config.ep_size
@property
def ep_group(self):
"""Get expert model parallel group."""
return get_ep_group()
@property
def tp_group(self):
"""Get expert tensor parallel group."""
return get_tp_group()
@property
def tp_rank(self):
"""Get expert tensor parallel rank."""
return 0#get_tensor_model_parallel_rank()
@property
def tp_ep_group(self):
"""Get expert tensor and model parallel group."""
return get_ep_group()
@abstractmethod
def token_permutation(
self, tokens: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
):
"""Dispatch tokens to experts.
Args:
tokens (torch.Tensor): Input tokens.
probs (torch.Tensor): The routing probability tensor [num_tokens, num_experts].
routing_map (torch.Tensor): Token to expert mapping tensor.
Returns:
torch.Tensor: Tokens tensor.
"""
raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod
def token_unpermutation(self, expert_output: torch.Tensor, bias: torch.Tensor = None):
"""Restores the expert output to its original ordering.
Args:
expert_output (torch.Tensor): The output tensor from the expert models.
bias (torch.Tensor): The bias tensor.
Returns:
(torch.Tensor, torch.Tensor): Unpermuted activation and optional bias.
"""
raise NotImplementedError("Restore function not implemented.")
def set_shared_experts(self, shared_experts):
"""Set shared expert to the dispatcher."""
assert self.config.moe_shared_expert_overlap
self.shared_experts = shared_experts
class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
"""
AlltoAll-based token dispatcher.
The workflow of AlltoAll token dispatcher is as follows:
(1) preprocess(): calculate necessary metadata for communication and permute
(2) token_permutation(): permute->A2A(EP)->AG(TP)->sort_chunk(if num_local_experts>1)
(3) token_unpermutation(): sort_chunk(if num_local_experts>1)->RS(TP)->A2A(EP)->unpermute
"""
def __init__(
self, num_local_experts: int, local_expert_indices: List[int], config: EpMoeConfig
) -> None:
"""
Initialize the AlltoAll token dispatcher.
Args:
num_local_experts (int): Number of local experts on the current device.
local_expert_indices (List[int]): Indices of local experts on the current device.
config (TransformerConfig): Configuration for the transformer model.
"""
super().__init__(config=config)
self.num_local_experts = num_local_experts
assert config.num_moe_experts is not None
self.num_experts = config.num_moe_experts
assert self.num_local_experts > 0, "Expected at least one expert"
self.local_expert_indices = local_expert_indices
assert (
len(self.local_expert_indices) == self.num_local_experts
), "Invalid local expert indices"
for i in range(len(self.local_expert_indices) - 1):
assert (
self.local_expert_indices[i] == self.local_expert_indices[i + 1] - 1
), "local_expert_indices must be continous"
# [ep_size]. Represents the number of tokens sent by the current rank to other
# EP ranks.
self.input_splits = None
# [ep_size]. Represents the number of tokens received by the current rank from
# other EP ranks.
self.output_splits = None
# [tp_size]. Represents the number of tokens received by the current rank from
# other TP ranks.
self.output_splits_tp = None
self.permute_idx_device = torch.device("cuda") if self.config.moe_permute_fusion else None
input_chunk_idxs = torch.arange(
self.num_experts * self.tp_size, device=self.permute_idx_device
)
# [num_local_experts, tp_size * ep_size]. Sort the input chunks by local experts.
self.sort_input_by_local_experts = input_chunk_idxs.reshape(
-1, self.num_local_experts
).T.ravel()
# [tp_size * ep_size, num_local_experts]. Restore the output chunks by local experts.
self.restore_output_by_local_experts = input_chunk_idxs.reshape(
self.num_local_experts, -1
).T.ravel()
# A cuda stream synchronization is needed in self.token_permutation() in some cases,
# because there are several non-blocking DtoH data transfers called at
# `self.cuda_dtoh_point`. The synchronization happens at `self.cuda_sync_point`, which is
# decided based on the MoE and parallel settings. Valid points are "before_permutation_1",
# "before_ep_alltoall", "before_permutation_2", "before_finish", and "no_sync".
self.cuda_sync_point = "no_sync"
self.cuda_sync_point_priority = {
"before_permutation_1": 0,
"before_ep_alltoall": 1,
"before_permutation_2": 2,
"before_finish": 3,
"no_sync": 4,
}
self.cuda_dtoh_point = "before_permutation_1"
self.cuda_dtoh_stream = torch.cuda.Stream()
self.shared_experts = None
# Whether to use gather or all-gather to gather the logits.
self.use_all_gather = current_platform.use_all_gather()
def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor:
"""
Preprocess token routing map for AlltoAll communication and token permutation.
This method computes the number of tokens assigned to each expert based on the routing_map.
It also initializes the necessary data structures for AlltoAll communication, such as input
and output splits, and the mapping between global tokens and local experts. This method
should not call any DtoH data copying due to performance consideration. The necessary DtoH
copies are made on the `self.cuda_dtoh_stream` at `self.cuda_dtoh_point`.
Args:
routing_map (torch.Tensor): The mapping of tokens to experts, with shape
[num_tokens, num_experts].
Returns:
torch.Tensor: Tensor containing the number of tokens assigned to local expert.
"""
# [num_experts], number of tokens assigned to each expert from the current rank's input.
num_local_tokens_per_expert = routing_map.sum(dim=0).long()
self.num_out_tokens = routing_map.size(0) * self.config.moe_router_topk
if self.ep_size > 1 or self.tp_size > 1:
# ===================================================
# Calculate input_splits, output_splits for alltoall/allgather in variable size.
# ===================================================
# [ep_size]. Represents the number of tokens sent by the current rank to other
# EP ranks.
self.input_splits = num_local_tokens_per_expert.reshape(
self.ep_size, self.num_local_experts
).sum(axis=1)
# Gather the global distribution of tokens across ranks.
# num_global_tokens_per_expert represents the number of tokens sent to each
# expert by all ranks.
# [tp_size, ep_size, num_experts]
if self.use_all_gather:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
num_global_tokens_per_expert = expert_parallel_all_gather(num_local_tokens_per_expert) \
.reshape(self.ep_size, self.tp_size, self.num_experts) \
.transpose(0, 1)
else:
# None may be returned for rank > 0
num_global_tokens_per_expert = expert_parallel_gather(num_local_tokens_per_expert) \
.reshape(self.ep_size, self.tp_size, self.num_experts) \
.transpose(0, 1)
# [tp_size, ep_size, num_experts] -> [tp_size, ep_size, num_local_experts]
num_global_tokens_per_local_expert = num_global_tokens_per_expert[
:, :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1
].contiguous()
# [tp_size, ep_size, num_local_experts] -> [tp_size, ep_size]
num_global_tokens_per_rank = num_global_tokens_per_local_expert.sum(axis=2)
# [tp_size, ep_size] -> [ep_size]
# self.output_splits represents the number of tokens received by the current rank
# from other EP rank.
self.output_splits = num_global_tokens_per_rank[self.tp_rank]
# [tp_size, ep_size] -> [tp_size]
# self.output_splits_tp represents the number of tokens received by the current
# rank from other TP rank.
self.output_splits_tp = num_global_tokens_per_rank.sum(axis=1)
# [tp_size, ep_size, num_local_experts] -> [num_local_experts]
num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(dim=(0, 1))
# A synchronization is needed before expert parallel AlltoAll communication
# to get the `input_splits` and `output_splits` CPU values.
self._maybe_update_cuda_sync_point("before_ep_alltoall")
else:
num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape(
self.num_experts
)
num_tokens_per_local_expert = num_local_tokens_per_expert
# A synchronization is needed before the returns
# to get the `num_tokens_per_local_expert` CPU value.
self._maybe_update_cuda_sync_point("before_finish")
if self.num_local_experts > 1:
# [tp_size * ep_size, num_local_experts]. Represents the number of tokens sent
# to each local expert by all ranks.
self.num_global_tokens_per_local_expert = num_global_tokens_per_local_expert.view(
-1, self.num_local_experts
)
if not self.config.moe_permute_fusion:
# A synchronization is needed before permutation 2
# to get the `num_global_tokens_per_local_expert` CPU value.
self._maybe_update_cuda_sync_point("before_permutation_2")
assert (
self.cuda_sync_point_priority[self.cuda_dtoh_point]
<= self.cuda_sync_point_priority[self.cuda_sync_point]
), "cuda_sync_point must be after cuda_dtoh_point."
return num_tokens_per_local_expert
def token_permutation(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Dispatch tokens to local experts using AlltoAll communication.
This method performs the following steps:
1. Preprocess the routing map to get metadata for communication and permutation.
2. Permute input tokens for AlltoAll communication.
3. Perform expert parallel AlltoAll communication.
4. Sort tokens by local expert (if multiple local experts exist).
Args:
hidden_states (torch.Tensor): Input token embeddings.
probs (torch.Tensor): The probabilities of token to experts assignment.
routing_map (torch.Tensor): The mapping of token to experts assignment.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- Permuted token embeddings for local experts.
- Number of tokens per expert.
"""
# Preprocess: Get the metadata for communication, permutation and computation operations.
self.hidden_shape = hidden_states.shape
self.probs = probs
self.routing_map = routing_map
assert probs.dim() == 2, "Expected 2D tensor for probs"
assert routing_map.dim() == 2, "Expected 2D tensor for token2expert mask"
assert routing_map.dtype == torch.bool, "Expected bool tensor for mask"
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
tokens_per_expert = self.preprocess(self.routing_map)
if self.shared_experts is not None:
self.shared_experts.pre_forward_comm(hidden_states.view(self.hidden_shape))
import sys
# torch.cuda.synchronize()
# sys.stderr.write(f"token_permutation===============================================")
# sys.stderr.flush()
# Permutation 1: input to AlltoAll input
tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_permutation_1", tokens_per_expert
)
# torch.cuda.synchronize()
# sys.stderr.write(f"before permute===============================================")
# sys.stderr.flush()
self.hidden_shape_before_permute = hidden_states.shape
permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute(
hidden_states,
routing_map,
num_out_tokens=self.num_out_tokens,
fused=self.config.moe_permute_fusion,
drop_and_pad=False,
)
# torch.cuda.synchronize()
# sys.stderr.write(f"after permute===============================================")
# sys.stderr.flush()
# Perform expert parallel AlltoAll communication
tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_ep_alltoall", tokens_per_expert
)
#torch.cuda.synchronize()
#print("###########################before permutation all_to_all output_splits:{} input_splits:{}".format(self.output_splits, self.input_splits))
global_input_tokens = all_to_all(
self.ep_group.device_group, permutated_local_input_tokens, self.output_splits, self.input_splits
)
#torch.cuda.synchronize()
#print("#######################permutation all_to_all end")
if self.shared_experts is not None:
self.shared_experts.linear_fc1_forward_and_act(global_input_tokens)
# Permutation 2: Sort tokens by local expert.
tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_permutation_2", tokens_per_expert
)
if self.num_local_experts > 1:
global_input_tokens = sort_chunks_by_idxs(
global_input_tokens,
self.num_global_tokens_per_local_expert.ravel(),
self.sort_input_by_local_experts,
fused=self.config.moe_permute_fusion,
)
tokens_per_expert = self._maybe_dtoh_and_synchronize("before_finish", tokens_per_expert)
return global_input_tokens, tokens_per_expert
def token_unpermutation(
self, hidden_states: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Reverse the token permutation to restore the original order.
This method performs the following steps:
1. Unsort tokens by local expert (if multiple local experts exist).
2. Perform expert parallel AlltoAll communication to restore the original order.
3. Unpermute tokens to restore the original order.
Args:
hidden_states (torch.Tensor): Output from local experts.
bias (torch.Tensor, optional): Bias tensor (not supported).
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- Unpermuted token embeddings in the original order.
- None (bias is not supported).
"""
# Unpermutation 2: Unsort tokens by local expert.
if self.num_local_experts > 1:
hidden_states = sort_chunks_by_idxs(
hidden_states,
self.num_global_tokens_per_local_expert.T.ravel(),
self.restore_output_by_local_experts,
fused=self.config.moe_permute_fusion,
)
# Perform expert parallel AlltoAll communication
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
permutated_local_input_tokens = all_to_all(
self.ep_group.device_group, hidden_states, self.input_splits, self.output_splits
)
if self.shared_experts is not None:
self.shared_experts.linear_fc2_forward(permutated_local_input_tokens)
self.shared_experts.post_forward_comm()
# Unpermutation 1: AlltoAll output to output
output = unpermute(
permutated_local_input_tokens,
self.reversed_local_input_permutation_mapping,
restore_shape=self.hidden_shape_before_permute,
probs=self.probs,
routing_map=self.routing_map,
fused=self.config.moe_permute_fusion,
drop_and_pad=False,
)
# Reshape the output tensor
output = output.view(self.hidden_shape)
# Add shared experts output
if self.shared_experts is not None:
shared_expert_output = self.shared_experts.get_output()
output += shared_expert_output
return output
def _maybe_update_cuda_sync_point(self, point: str):
"""
Update the CUDA sync point if the priority of the new point is higher than the current
sync point, which means the new point is reached earlier than the current sync point.
"""
if (
self.cuda_sync_point_priority[point]
< self.cuda_sync_point_priority[self.cuda_sync_point]
):
self.cuda_sync_point = point
def _maybe_dtoh_and_synchronize(
self, point: str, tokens_per_expert: torch.Tensor = None
) -> torch.Tensor:
"""
Move all possible GPU tensors to CPU and make a synchronization at the expected point.
"""
if point == self.cuda_dtoh_point:
# Move all possible GPU tensors to CPU at self.cuda_dtoh_point.
on_side_stream = torch.cuda.current_stream() != self.cuda_dtoh_stream
if on_side_stream:
self.cuda_dtoh_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.cuda_dtoh_stream):
# TODO: use MemcpyBatchAsync instead.
# tokens_per_expert = maybe_move_tensor_to_cpu(
# tokens_per_expert, record_stream=on_side_stream
# )
self.input_splits = maybe_move_tensor_to_cpu(
self.input_splits, as_numpy=True, record_stream=on_side_stream
)
self.output_splits = maybe_move_tensor_to_cpu(
self.output_splits, as_numpy=True, record_stream=on_side_stream
)
self.output_splits_tp = maybe_move_tensor_to_cpu(
self.output_splits_tp, as_numpy=True, record_stream=on_side_stream
)
self.num_out_tokens = maybe_move_tensor_to_cpu(
self.num_out_tokens, record_stream=on_side_stream
)
if self.num_local_experts > 1 and not self.config.moe_permute_fusion:
self.num_global_tokens_per_local_expert = maybe_move_tensor_to_cpu(
self.num_global_tokens_per_local_expert, record_stream=on_side_stream
)
if point == self.cuda_sync_point:
# Synchronize with the dtoh stream at self.cuda_sync_point.
self.cuda_dtoh_stream.synchronize()
return tokens_per_expert
\ No newline at end of file
...@@ -39,10 +39,12 @@ from vllm.attention import Attention ...@@ -39,10 +39,12 @@ from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, ModelConfig, VllmConfig, from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config) get_current_vllm_config)
from vllm.distributed import (get_ep_group, get_pp_group, from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.ep_moe.layer import EPMoE
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EPSharedExperts
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -152,6 +154,24 @@ class DeepseekV2MoE(nn.Module): ...@@ -152,6 +154,24 @@ class DeepseekV2MoE(nn.Module):
self.physical_expert_end = (self.physical_expert_start + self.physical_expert_end = (self.physical_expert_start +
self.n_local_physical_experts) self.n_local_physical_experts)
dp_size = get_dp_group().world_size
self.use_ep_opt = dp_size > 1 and parallel_config.enable_expert_parallel
self.shared_experts = None
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
shared_expert_cls = DeepseekV2MLP if not self.use_ep_opt else EPSharedExperts
self.shared_experts = shared_expert_cls(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
prefix=f"{prefix}.shared_experts",
)
if not self.use_ep_opt:
self.experts = FusedMoE( self.experts = FusedMoE(
num_experts=config.n_routed_experts, num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
...@@ -169,25 +189,33 @@ class DeepseekV2MoE(nn.Module): ...@@ -169,25 +189,33 @@ class DeepseekV2MoE(nn.Module):
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts, num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor) routed_scaling_factor=self.routed_scaling_factor)
else:
if config.n_shared_experts is not None: self.experts = EPMoE(
intermediate_size = (config.moe_intermediate_size * num_experts=config.n_routed_experts,
config.n_shared_experts) top_k=config.num_experts_per_tok,
self.shared_experts = DeepseekV2MLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=intermediate_size, intermediate_size=config.moe_intermediate_size,
hidden_act=config.hidden_act, reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config, quant_config=quant_config,
reduce_results=self.experts.must_reduce_shared_expert_outputs( use_grouped_topk=True,
), num_expert_group=config.n_group,
prefix=f"{prefix}.shared_experts", topk_group=config.topk_group,
) prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor)
if self.use_ep_opt:
self.experts.set_shared_experts(self.shared_experts)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce self.tbo_all_reduce = tbo_all_reduce
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
if not self.use_ep_opt:
if self.n_shared_experts is not None: if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
...@@ -203,6 +231,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -203,6 +231,7 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
if not self.use_ep_opt:
if shared_output is not None: if shared_output is not None:
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick: if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
...@@ -619,6 +648,8 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -619,6 +648,8 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
) )
#ops.print_tensor(hidden_states)
if hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick: if hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick:
# Fix FP16 overflow # Fix FP16 overflow
# We scale both hidden_states and residual before # We scale both hidden_states and residual before
...@@ -714,7 +745,9 @@ class DeepseekV2Model(nn.Module): ...@@ -714,7 +745,9 @@ class DeepseekV2Model(nn.Module):
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]: for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)\
#ops.print_tensor(hidden_states)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
......
...@@ -244,11 +244,18 @@ class CoreEngineActorManager: ...@@ -244,11 +244,18 @@ class CoreEngineActorManager:
local_engine_count = \ local_engine_count = \
vllm_config.parallel_config.data_parallel_size_local vllm_config.parallel_config.data_parallel_size_local
nodes = sorted(list_nodes(), # nodes = sorted(list_nodes(),
key=lambda node: node.node_ip != dp_master_ip) # key=lambda node: node.node_ip != dp_master_ip)
assert nodes[0].node_ip == dp_master_ip, ( # assert nodes[0].node_ip == dp_master_ip, (
# "The first node must be the head node")
# assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, (
# "There can only be one head node")
nodes = ray.nodes()
nodes = sorted(nodes,
key=lambda node: node["NodeManagerAddress"] != dp_master_ip)
assert nodes[0]["NodeManagerAddress"] == dp_master_ip, (
"The first node must be the head node") "The first node must be the head node")
assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, ( assert len(nodes) == 1 or nodes[1]["NodeManagerAddress"] != dp_master_ip, (
"There can only be one head node") "There can only be one head node")
available_resources = available_resources_per_node() available_resources = available_resources_per_node()
...@@ -257,8 +264,11 @@ class CoreEngineActorManager: ...@@ -257,8 +264,11 @@ class CoreEngineActorManager:
local_dp_ranks: list[int] = [] local_dp_ranks: list[int] = []
for node in nodes: for node in nodes:
node_ip = node.node_ip # node_ip = node.node_ip
node_resources = available_resources[node.node_id] # node_resources = available_resources[node.node_id]
node_ip = node["NodeManagerAddress"]
node_resources = available_resources[node["NodeID"]]
# For now, each DP rank can only be assigned to one node # For now, each DP rank can only be assigned to one node
# TODO(rui): support allocating a single DP rank # TODO(rui): support allocating a single DP rank
# to multiple nodes # to multiple nodes
...@@ -428,6 +438,9 @@ def launch_core_engines( ...@@ -428,6 +438,9 @@ def launch_core_engines(
else: else:
local_engine_manager = None local_engine_manager = None
import torch
torch.cuda.synchronize()
logger.info(("launch_core_engines end==============================="))
yield local_engine_manager, coordinator, addresses yield local_engine_manager, coordinator, addresses
# Now wait for engines to start. # Now wait for engines to start.
...@@ -440,6 +453,8 @@ def launch_core_engines( ...@@ -440,6 +453,8 @@ def launch_core_engines(
local_engine_manager, local_engine_manager,
coordinator.proc if coordinator else None, coordinator.proc if coordinator else None,
) )
torch.cuda.synchronize()
logger.info(("engine startup==============================="))
def wait_for_engine_startup( def wait_for_engine_startup(
......
...@@ -2051,6 +2051,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2051,6 +2051,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens] inputs_embeds = self.inputs_embeds[:num_tokens]
else: else:
#self.input_ids[:num_tokens] = torch.randint(0, 120000, (num_tokens,), dtype=torch.int32)
self.input_ids[:num_tokens] = torch.arange(num_tokens, dtype=torch.int32, device=self.input_ids.device)
input_ids = self.input_ids[:num_tokens] input_ids = self.input_ids[:num_tokens]
inputs_embeds = None inputs_embeds = None
if self.uses_mrope: if self.uses_mrope:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment