Commit d04683a4 authored by 王敏's avatar 王敏
Browse files

[feat]上传初版基于all2all通信的大EP代码

parent cfabf125
......@@ -40,6 +40,8 @@ except ImportError:
HAVE_TE = False
shared_experts_overlap_stream = torch.cuda.Stream()
@dataclass
class EpMoeConfig:
......@@ -48,18 +50,25 @@ class EpMoeConfig:
moe_shared_expert_overlap: bool = False
ep_size: int = 1
num_moe_experts: int = 256
apply_router_weight_on_input: bool = False
routed_scaling_factor: float = 1.0
@staticmethod
def make(moe_router_topk: int = 2,
moe_permute_fusion: bool = False,
moe_shared_expert_overlap: bool = False,
ep_size: int = 1,
num_moe_experts: int = 256) -> "EpMoeConfig":
num_moe_experts: int = 256,
routed_scaling_factor: float = 1.0,
apply_router_weight_on_input: bool = False) -> "EpMoeConfig":
return EpMoeConfig(moe_router_topk=moe_router_topk,
moe_permute_fusion=moe_permute_fusion,
moe_shared_expert_overlap=moe_shared_expert_overlap,
ep_size=ep_size,
num_moe_experts=num_moe_experts)
num_moe_experts=num_moe_experts,
routed_scaling_factor=routed_scaling_factor,
apply_router_weight_on_input=apply_router_weight_on_input)
class EPSharedExperts(nn.Module):
......@@ -99,7 +108,7 @@ class EPSharedExperts(nn.Module):
self.cached_output = None
self.gate_score = None
self.stream = torch.cuda.Stream()
self.stream = shared_experts_overlap_stream
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
......@@ -215,55 +224,35 @@ def permute(
routing_map,
num_out_tokens: Optional[int] = None,
fused: bool = False,
drop_and_pad: bool = False,
):
"""Permute the tokens and probs based on the mask.
Tokens with the same designated expert will be grouped together.
The shape of mask is [tokens, num_experts], it indicates which experts were selected
by each token.
When drop_and_pad=True, in routing_map, the number of non-zeros in each column equals to
expert capacity. This function exploits this feature to use ops that support cuda graph.
Args:
tokens (torch.Tensor): The input token tensor, [num_tokens, hidden].
routing_map (torch.Tensor): The sparse token to expert mapping, [num_tokens, num_experts].
num_out_tokens (int, optional): The number of output tokens. If None, it's set to
the number of input tokens.
fused (bool, optional): Whether use the fused permute function.
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
and pads the number of tokens to the expert capacity.
If set to true, routing_map has a fixed number of non-zeros
in each column.
"""
if fused:
if not HAVE_TE or fused_permute is None:
raise ValueError("fused_permute is not available. Please install TE >= 2.1.0.")
return fused_permute(tokens, routing_map, num_out_tokens)
num_tokens, hidden = tokens.shape
num_experts = routing_map.shape[1]
if drop_and_pad and not (num_out_tokens is None):
capacity = num_out_tokens // num_experts
assert not routing_map.requires_grad
# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
routing_map = routing_map.to(dtype=torch.int8).T.contiguous()
# use argsort to put indices of all non-zeros in the beginning of list
# and keep the first `capacity` number of indices
sorted_indices = routing_map.argsort(dim=-1, descending=True, stable=True)[
:, :capacity
].contiguous()
# flatten from [num_experts, capacity] to 1D
sorted_indices = sorted_indices.view(-1)
else:
# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
routing_map = routing_map.bool().T.contiguous()
# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
routing_map = routing_map.bool().T.contiguous()
# Create a dense expert-to-token mapping from the sparse token-to-expert mapping
token_indices = (
torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1)
)
sorted_indices = token_indices.masked_select(routing_map)
# Create a dense expert-to-token mapping from the sparse token-to-expert mapping
token_indices = (
torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1)
)
sorted_indices = token_indices.masked_select(routing_map)
# use the mapping to permute the tokens
permuted_input = tokens.index_select(0, sorted_indices)
......@@ -278,7 +267,6 @@ def unpermute(
probs: torch.Tensor = None,
routing_map: torch.Tensor = None,
fused: bool = False,
drop_and_pad: bool = False,
):
"""
Restore the original order of tokens after permutation. If probs are provided, it
......@@ -294,8 +282,6 @@ def unpermute(
routing_map (torch.Tensor, optional): Token to expert mapping, shape
[num_tokens, num_experts].
fused (bool, optional): Whether use the fused unpermute function.
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
and pads the number of tokens to the expert capacity.
Returns:
torch.Tensor: The tokens restored to their original order.
......@@ -310,24 +296,7 @@ def unpermute(
if probs is not None:
assert routing_map is not None, "Mask must be provided to permute the probs."
if drop_and_pad:
num_experts = routing_map.size(1)
num_permuted_tokens = sorted_indices.size(0)
capacity = num_permuted_tokens // num_experts
num_unpermuted_tokens = probs.size(0)
# [num_unpermuted_tokens, num_experts] -> num_experts * num_unpermuted_tokens
probs_T_1D = probs.T.contiguous().view(-1)
# get 1D indices of the probs selected by routing_map
indices_dim0 = torch.arange(num_experts, device=routing_map.device).unsqueeze(-1)
indices_dim1 = sorted_indices.view(num_experts, capacity)
indices_1D = (indices_dim0 * num_unpermuted_tokens + indices_dim1).view(-1)
# get probs from indices
permuted_probs = probs_T_1D.index_select(0, indices_1D)
else:
permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous())
permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous())
# Here may promote permuted_tokens to higher precision (fp32/fp64) if probs is in
# higher precision due to moe_router_dtype being enabled. This can lead to
# additional GPU memory usage. Use --moe-permute-fusion flag to avoid this extra memory
......@@ -344,11 +313,6 @@ def unpermute(
def all_to_all(group, input, output_split_sizes, input_split_sizes):
# torch.cuda.synchronize()
# import sys
# sys.stderr.write(f"############all_to_all input_split_sizes:{input_split_sizes}\n output_split_sizes:{output_split_sizes}")
# sys.stderr.flush()
world_size = torch.distributed.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
......
import logging
from typing import List, Optional
import torch
import triton
import triton.language as tl
logger = logging.getLogger(__name__)
@triton.jit
def compute_src2dst_triton_kernel(
reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr
):
pid = tl.program_id(axis=0)
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = dst_id < num_toks
src_id = tl.load(reorder_ids + dst_id, mask=mask)
tl.store(src2dst + src_id, dst_id, mask=mask)
@triton.jit
def deepep_compute_src2dst_triton_kernel(
reorder_ids, src2dst, num_toks, num_minus_one, BLOCK_SIZE: tl.constexpr
):
pid = tl.program_id(axis=0)
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = dst_id < num_toks
src_id = tl.load(reorder_ids + dst_id, mask=mask)
num_invalid = tl.load(num_minus_one)
tl.store(src2dst + src_id, dst_id - num_invalid, mask=mask)
def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
seg_indptr = torch.empty(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int64)
# Find offet
expert_ids = torch.arange(
num_experts + 1, device=topk_ids.device, dtype=reorder_topk_ids.dtype
)
torch.searchsorted(reorder_topk_ids, expert_ids, out=seg_indptr)
num_minus_one = seg_indptr[0]
seg_indptr = seg_indptr - num_minus_one
BLOCK_SIZE = 512
grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),)
deepep_compute_src2dst_triton_kernel[grid](
reorder_ids, src2dst, topk_ids.numel(), num_minus_one, BLOCK_SIZE
)
reorder_topk_ids = reorder_topk_ids[num_minus_one:]
return reorder_topk_ids, src2dst, seg_indptr
@triton.jit
def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
expert = tl.program_id(0)
low = 0
high = num_toks - 1
target_location = -1
while low <= high:
mid = (low + high) // 2
if tl.load(reorder_topk_ids + mid) > expert:
high = mid - 1
else:
low = mid + 1
target_location = mid
tl.store(seg_indptr + expert + 1, target_location + 1)
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
compute_seg_indptr_triton_kernel[(num_experts,)](
reorder_topk_ids, seg_indptr, topk_ids.numel()
)
BLOCK_SIZE = 512
grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),)
compute_src2dst_triton_kernel[grid](
reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE
)
return reorder_topk_ids, src2dst, seg_indptr
@triton.jit
def pre_reorder_triton_kernel(
input_ptr,
gateup_input_ptr,
src2dst_ptr,
topk_ids_ptr,
a1_scales_ptr,
start_expert_id,
end_expert_id,
topk,
hidden_size,
BLOCK_SIZE: tl.constexpr,
):
OutDtype = gateup_input_ptr.dtype.element_ty
src_idx = tl.program_id(0)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk
src_ptr = input_ptr + src_idx * hidden_size
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
if a1_scales_ptr is not None:
scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id)
else:
scale = 1.0
dst_idx = tl.load(src2dst_ptr + idx)
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
out_data = (in_data * scale).to(OutDtype)
tl.store(dst_ptr + offset, out_data, mask=mask)
@triton.jit
def silu_and_mul_triton_kernel(
gateup_output,
down_input,
hidden_size,
reorder_topk_ids,
scales,
start_expert_id,
end_expert_id,
BLOCK_SIZE: tl.constexpr,
):
InDtype = gateup_output.dtype.element_ty
OutDtype = down_input.dtype.element_ty
half_hidden_size = hidden_size // 2
pid = tl.program_id(0)
expert_id = tl.load(reorder_topk_ids + pid)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
gateup_output_ptr = gateup_output + pid * hidden_size
gate_output_ptr = gateup_output_ptr
up_output_ptr = gateup_output_ptr + half_hidden_size
down_input_ptr = down_input + pid * half_hidden_size
if scales is not None:
scale = tl.load(scales + expert_id - start_expert_id)
scale = (1 / scale).to(InDtype)
else:
scale = 1
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < half_hidden_size
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
up_output = tl.load(up_output_ptr + offset, mask=mask)
# silu & mul & quantize
gate_output = gate_output * tl.sigmoid(gate_output)
gate_output = gate_output.to(InDtype)
silu_mul_output = gate_output * up_output * scale
silu_mul_output = silu_mul_output.to(OutDtype)
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
@triton.jit
def _silu_and_mul_post_quant_kernel(
input_ptr,
stride_input_0,
stride_input_1,
stride_input_2,
output_ptr,
stride_output_0,
stride_output_1,
stride_output_2,
output_scale_ptr,
stride_output_scale_0,
stride_output_scale_1,
stride_output_scale_2,
masked_m_ptr,
size_n,
fp8_max,
fp8_min,
BLOCK_N: tl.constexpr,
NUM_STAGE: tl.constexpr,
):
expert_id = tl.program_id(2)
token_id = tl.program_id(1)
hidden_dim_block_index = tl.program_id(0)
block_num_per_expert = tl.num_programs(1)
token_num_cur_expert = tl.load(masked_m_ptr + expert_id)
stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64)
stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64)
stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64)
stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64)
offs_in_d = hidden_dim_block_index * BLOCK_N + tl.arange(0, BLOCK_N)
input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d
output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d
output_scale_offs = (
output_scale_ptr
+ expert_id * stride_output_scale_0
+ hidden_dim_block_index * stride_output_scale_2
)
for token_index in tl.range(
token_id, token_num_cur_expert, block_num_per_expert, num_stages=NUM_STAGE
):
gate = tl.load(
input_ptr_offs + token_index * stride_input_1,
mask=offs_in_d < size_n,
other=0.0,
).to(tl.float32)
up = tl.load(
input_ptr_offs + token_index * stride_input_1 + size_n,
mask=offs_in_d < size_n,
other=0.0,
)
gate = gate / (1 + tl.exp(-gate))
gate = gate.to(input_ptr.dtype.element_ty)
gate_up = up * gate
_absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
output_s = _absmax / fp8_max
output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to(
output_ptr.dtype.element_ty
)
tl.store(
output_ptr_offs + token_index * stride_output_1,
output_q,
mask=offs_in_d < size_n,
)
tl.store(
output_scale_offs + token_index * stride_output_scale_1,
output_s,
)
def silu_and_mul_masked_post_quant_fwd(
input: torch.Tensor,
output: torch.Tensor,
output_scale: torch.Tensor,
quant_group_size: int,
masked_m: torch.Tensor,
):
"""
input shape [expert_num, token_num_padded, hidden_dim]
output shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8
output_scale [expert_num token_num_paddded, hidden_dim // 2 // 128] dtype float32
quant_group_size int,
masked_m shape [expert_num],
"""
assert input.is_contiguous()
assert output.dtype == torch.float8_e4m3fn
assert output.is_contiguous()
assert len(input.shape) == 3
assert input.shape[0] == masked_m.shape[0]
assert input.shape[-1] % 2 == 0
size_n = input.shape[-1] // 2
assert size_n % quant_group_size == 0
expert_num = len(masked_m)
if expert_num < 4:
BLOCK_NUM_PER_EXPERT = 64
else:
BLOCK_NUM_PER_EXPERT = 32
BLOCK_N = quant_group_size
num_warps = 1
NUM_STAGES = 6
hidden_dim_split_block_num = triton.cdiv(size_n, BLOCK_N)
assert BLOCK_N % quant_group_size == 0
grid = (
hidden_dim_split_block_num,
BLOCK_NUM_PER_EXPERT,
expert_num,
)
finfo = torch.finfo(torch.float8_e4m3fn)
fp8_max = finfo.max
fp8_min = -fp8_max
_silu_and_mul_post_quant_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
output_scale,
*output_scale.stride(),
masked_m,
size_n,
fp8_max,
fp8_min,
BLOCK_N=BLOCK_N,
NUM_STAGE=NUM_STAGES,
num_warps=num_warps,
)
return
@triton.jit
def tanh(x):
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def gelu_and_mul_triton_kernel(
gateup_output,
down_input,
hidden_size,
reorder_topk_ids,
scales,
start_expert_id,
end_expert_id,
BLOCK_SIZE: tl.constexpr,
):
InDtype = gateup_output.dtype.element_ty
OutDtype = down_input.dtype.element_ty
half_hidden_size = hidden_size // 2
pid = tl.program_id(0)
expert_id = tl.load(reorder_topk_ids + pid)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
gateup_output_ptr = gateup_output + pid * hidden_size
gate_output_ptr = gateup_output_ptr
up_output_ptr = gateup_output_ptr + half_hidden_size
down_input_ptr = down_input + pid * half_hidden_size
if scales is not None:
scale = tl.load(scales + expert_id - start_expert_id)
scale = (1 / scale).to(InDtype)
else:
scale = 1
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < half_hidden_size
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
up_output = tl.load(up_output_ptr + offset, mask=mask)
# gelu & mul & quantize
# https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
# sqrt(2/pi)
kAlpha = 0.7978845608028654
gate_output = (
0.5
* gate_output
* (
1
+ tanh(
kAlpha
* (
gate_output
+ 0.044715 * gate_output * gate_output * gate_output
)
)
)
)
gate_output = gate_output.to(InDtype)
gelu_mul_output = gate_output * up_output * scale
gelu_mul_output = gelu_mul_output.to(OutDtype)
tl.store(down_input_ptr + offset, gelu_mul_output, mask=mask)
@triton.jit
def post_reorder_triton_kernel(
down_output_ptr,
output_ptr,
src2dst_ptr,
topk_ids_ptr,
topk_weights_ptr,
start_expert_id,
end_expert_id,
topk,
hidden_size,
BLOCK_SIZE: tl.constexpr,
):
InDtype = down_output_ptr.dtype.element_ty
src_idx = tl.program_id(0)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk
topk_weights_ptr = topk_weights_ptr + src_idx * topk
computed = False
store_ptr = output_ptr + src_idx * hidden_size
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
computed = True
dst_idx = tl.load(src2dst_ptr + idx)
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
load_ptr = down_output_ptr + dst_idx * hidden_size
in_data = tl.load(load_ptr + offset, mask=mask)
sum_vec += in_data * weigh_scale
tl.store(store_ptr + offset, sum_vec, mask=mask)
if computed == False:
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
tl.store(
store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask
)
@triton.jit
def compute_m_range(
pid,
batch_size,
seg_indptr,
weight_indices,
m_num_tiles_indptr,
BLOCK_SIZE_M: tl.constexpr,
):
idx = 0
for bs in range(batch_size):
tiles = tl.load(m_num_tiles_indptr + bs)
if pid >= tiles:
idx = bs
idx_start = tl.load(m_num_tiles_indptr + idx)
m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M
m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M)
expert_id = tl.load(weight_indices + idx)
return m_range_start, m_range_end, expert_id
@triton.jit
def grouped_gemm_triton_kernel(
a,
b,
c,
batch_size,
N,
K,
seg_indptr,
weight_indices,
m_num_tiles_indptr,
scale_a,
scale_b,
use_fp8_w8a8: tl.constexpr,
group_n: tl.constexpr,
group_k: tl.constexpr,
a_stride_0: tl.constexpr,
b_stride_0: tl.constexpr,
b_stride_1: tl.constexpr,
as_stride_0: tl.constexpr,
as_stride_1: tl.constexpr,
bs_stride_0: tl.constexpr,
bs_stride_2: tl.constexpr,
bs_stride_1: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
c_dtype = c.dtype.element_ty
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
total_m_block = tl.load(m_num_tiles_indptr + batch_size)
if pid_m >= total_m_block:
return
m_range_start, m_range_end, expert_id = compute_m_range(
pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M
)
if m_range_end - m_range_start == 0:
return
n_range_start = pid_n * BLOCK_SIZE_N
n_range_end = min(n_range_start + BLOCK_SIZE_N, N)
offs_am = tl.arange(0, BLOCK_SIZE_M)
offs_bn = tl.arange(0, BLOCK_SIZE_N)
offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0)
offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0)
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :]
# [blcok_n, block_k]
b_ptr = b + (
(expert_id * b_stride_0)
+ (n_range_start + offs_bn[:, None]) * b_stride_1
+ offs_k[None, :]
)
if group_k > 0 and group_n > 0:
a_scale_ptrs = scale_a + (m_range_start + offs_am[:, None]) * as_stride_0
offs_bsn = (n_range_start + offs_bn) // group_n
b_scale_ptrs = scale_b + (expert_id * bs_stride_0) + offs_bsn * bs_stride_1
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a_tile = tl.load(
a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
)
# [block_n, blcok_k]
b_tile = tl.load(
b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
)
if group_k > 0 and group_n > 0:
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_scale = tl.load(a_scale_ptrs + offs_ks * as_stride_1)
b_scale = tl.load(b_scale_ptrs + offs_ks * bs_stride_2)
accumulator += tl.dot(a_tile, b_tile.T) * a_scale * b_scale[None, :]
else:
accumulator = tl.dot(a_tile, b_tile.T, accumulator)
a_ptr += BLOCK_SIZE_K
b_ptr += BLOCK_SIZE_K
if use_fp8_w8a8 and not (group_k > 0 and group_n > 0):
scale_a_value = tl.load(scale_a + expert_id)
scale_b_value = tl.load(scale_b + expert_id)
accumulator *= scale_a_value * scale_b_value
c_tile = accumulator.to(c_dtype)
offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M)
offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N)
c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :]
c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end)
tl.store(c_ptr, c_tile, mask=c_mask)
@triton.jit
def compute_m_num_tiles_indptr(
m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr
):
for bs in range(batch_size):
m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs)
cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M)
pre_num_tiles = tl.load(m_num_tiles_indptr + bs)
tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles)
def grouped_gemm_triton(
a: torch.Tensor,
b: torch.Tensor,
c: torch.Tensor,
batch_size: int,
weight_column_major: bool,
seg_indptr: Optional[torch.Tensor] = None,
weight_indices: Optional[torch.Tensor] = None,
use_fp8_w8a8: bool = False,
scale_a: torch.Tensor = None,
scale_b: torch.Tensor = None,
block_shape: Optional[List[int]] = None,
):
assert weight_column_major == True # TODO: more
if use_fp8_w8a8 and block_shape is None:
assert scale_a is not None and scale_b is not None
# if block_shape is not None:
# assert len(block_shape) == 2
# block_n, block_k = block_shape[0], block_shape[1]
# if _is_cuda:
# a, scale_a = sglang_per_token_group_quant_fp8(a, block_k)
# else:
# a, scale_a = per_token_group_quant_fp8(a, block_k)
# assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1]
# assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
# assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1]
# TODO: adjust config or tune kernel
# Reduce block size to prevent L40 shared memory overflow.
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
}
m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64)
compute_m_num_tiles_indptr[(1,)](
m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"]
)
grid = lambda META: (
triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size,
triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
)
grouped_gemm_triton_kernel[grid](
a,
b,
c,
batch_size,
b.size(1),
b.size(2),
seg_indptr,
weight_indices,
m_num_tiles_indptr,
scale_a,
scale_b,
use_fp8_w8a8,
0 if block_shape is None else block_shape[0],
0 if block_shape is None else block_shape[1],
a.stride(0),
b.stride(0),
b.stride(1),
scale_a.stride(0) if scale_a is not None and scale_a.ndim == 2 else 0,
scale_a.stride(1) if scale_a is not None and scale_a.ndim == 2 else 0,
scale_b.stride(0) if scale_b is not None and scale_b.ndim >= 2 else 0,
scale_b.stride(2) if scale_b is not None and scale_b.ndim == 3 else 0,
scale_b.stride(1) if scale_b is not None and scale_b.ndim >= 2 else 0,
**config,
)
return c
import os
import logging
from typing import Callable, List, Optional, Tuple
from dataclasses import dataclass
import torch
from torch import nn
import torch.nn.functional as F
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.model_executor.custom_op import CustomOp
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EPSharedExperts, EpMoeConfig
from vllm.model_executor.layers.fused_moe.ep_moe.kernels import grouped_gemm_triton
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EpMoeConfig
from vllm.utils import direct_register_custom_op
logger = init_logger(__name__)
@CustomOp.register("unquantized_ep_moe")
class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
"""MoE method without quantization."""
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.topk_indices_dtype = None
self.moe = moe
self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
def apply(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
return self.forward(
hidden_states=hidden_states,
layer=layer,
tokens_per_expert=tokens_per_expert)
def forward_cuda(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
# process MoE
def custom_forward(layer, hidden_states, tokens_per_expert):
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
w1 = layer.w13_weight[i]
w2 = layer.w2_weight[i]
tokens_for_this_expert = hidden_states[start_idx:end_idx]
gateup_output = torch.matmul(tokens_for_this_expert, w1.T)
# Act
down_input = torch.zeros(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=hidden_states.dtype
)
torch.ops._C.silu_and_mul(down_input,
gateup_output.view(-1, w1.shape[0]))
expert_out = torch.matmul(down_input, w2.T)
outputs.append(expert_out)
start_idx = end_idx
if len(outputs) > 0:
expert_output = torch.cat(outputs, dim=0)
else:
assert hidden_states.numel() == 0, f"sorted_tokens: should be empty, but got {hidden_states.shape}"
expert_output = hidden_states
return expert_output
output = custom_forward(layer, hidden_states, tokens_per_expert)
return output
def forward_cpu(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
**kwargs,
):
raise NotImplementedError
def forward_hpu(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
def forward_tpu(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
if current_platform.is_tpu():
forward_native = forward_tpu
elif current_platform.is_cpu():
forward_native = forward_cpu
else:
forward_native = forward_cuda
class EPMoE(FusedMoE):
"""
dp+ep MoE Expert Parallel Impl
......@@ -46,7 +157,7 @@ class EPMoE(FusedMoE):
apply_router_weight_on_input: bool = False,
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
moe_permute_fusion: bool = False,
moe_permute_fusion: bool = True,
moe_shared_expert_overlap: bool = False
):
super().__init__(num_experts, top_k, hidden_size,
......@@ -68,7 +179,9 @@ class EPMoE(FusedMoE):
moe_permute_fusion=moe_permute_fusion,
moe_shared_expert_overlap=moe_shared_expert_overlap,
ep_size=self.ep_size,
num_moe_experts=self.global_num_experts
num_moe_experts=self.global_num_experts,
routed_scaling_factor=self.routed_scaling_factor,
apply_router_weight_on_input=self.apply_router_weight_on_input
)
local_expert_indices_offset = (
......@@ -78,149 +191,41 @@ class EPMoE(FusedMoE):
local_expert_indices_offset + i for i in range(self.local_num_experts)
]
self.shared_experts = None
self.use_shared_expert = False
self.token_dispatcher = MoEAlltoAllTokenDispatcher(
self.local_num_experts, self.local_expert_indices, config=self.ep_moe_config
)
self.shared_expert_overlap = moe_shared_expert_overlap
self.seg_indptr = None
if quant_config is None:
self.use_fp8_w8a8 = False
self.use_block_quant = False
self.block_shape = None
self.activation_scheme = None
self.w13_weight_scale = None
self.w2_weight_scale = None
else:
self.use_fp8_w8a8 = True
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.block_shape = (
self.quant_method.quant_config.weight_block_size
if self.use_block_quant
else None
)
self.fp8_dtype = torch.float8_e4m3fn
self.activation_scheme = quant_config.activation_scheme
def set_shared_experts(self, shared_experts):
self.shared_experts = shared_experts
self.use_shared_expert = shared_experts is not None
if self.shared_expert_overlap:
self.token_dispatcher.set_shared_experts(shared_experts)
def triton_grouped_gemm_impl(self, hidden_states, tokens_per_expert, use_nn_moe):
torch.cumsum(tokens_per_expert,
dim=0,
out=self.seg_indptr[1:])
_, N, _ = self.w13_weight.shape
gateup_input = hidden_states
weight_indices_cur_rank = torch.arange(
0,
self.local_num_experts,
device=hidden_states.device,
dtype=torch.int64,
)
# GroupGemm-0
gateup_output = torch.empty(
gateup_input.shape[0],
self.w13_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
self.shared_experts = None
gateup_output = grouped_gemm_triton(
a=gateup_input,
b=self.w13_weight,
c=gateup_output,
batch_size=self.local_num_experts,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=self.use_fp8_w8a8,
scale_a=self.w13_input_scale if self.quant_config is not None else None,
scale_b=(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
) if self.quant_config is not None else None,
block_shape=self.block_shape,
)
self.dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
# Act
down_input = torch.empty(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=(
self.fp8_dtype
if (self.use_fp8_w8a8 and not self.use_block_quant)
else hidden_states.dtype
),
)
if self.quant_config is not None and self.w2_input_scale is None and not self.use_block_quant:
self.w2_input_scale = torch.ones(
self.local_num_experts,
dtype=torch.float32,
device=hidden_states.device,
)
def set_shared_experts(self, shared_experts: torch.nn.Module):
if self.shared_experts is None:
self.shared_experts = shared_experts
if self.shared_expert_overlap:
self.token_dispatcher.set_shared_experts(self.shared_experts)
if self.activation == "silu":
torch.ops._C.silu_and_mul(down_input,
gateup_output.view(-1, N))
elif self.activation == "gelu":
torch.ops._C.gelu_and_mul(down_input,
gateup_output.view(-1, N))
else:
raise ValueError(f"Unsupported FusedMoe activation: {self.activation}")
# GroupGemm-1
down_output = torch.empty(
down_input.shape[0],
self.w2_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
down_output = grouped_gemm_triton(
a=down_input,
b=self.w2_weight,
c=down_output,
batch_size=self.local_num_experts,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=self.use_fp8_w8a8,
scale_a=self.w2_input_scale if self.quant_config is not None else None,
scale_b=(
self.w2_weight_scale_inv
if self.use_block_quant
else self.w2_weight_scale
) if self.quant_config is not None else None,
block_shape=self.block_shape,
)
return down_output
def create_quant_method(self, moe, quant_config, prefix):
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None
quant_method = (UnquantizedEPGroupedGemmMethod(moe) if quant_config is None
else quant_config.get_quant_method(self, prefix))
assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase)
return quant_method
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
if (
self.training
and self.config.tensor_model_parallel_size > 1
and not self.config.sequence_parallel
):
raise ValueError(
"During training, performance may degrade if MoE and tensor parallelism"
"are enabled without also enabling sequence parallelism."
)
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
return torch.ops.vllm.ep_moe_forward(hidden_states, router_logits,
self.layer_name)
def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
if self.seg_indptr is None:
self.seg_indptr = torch.zeros(self.local_num_experts+1, device=hidden_states. device, dtype=torch.int64)
# process MoE
def custom_forward(hidden_states, router_logits):
topk_weights, topk_ids = self.select_experts(
topk_weights, topk_ids = self.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
use_grouped_topk=self.use_grouped_topk,
......@@ -234,20 +239,60 @@ class EPMoE(FusedMoE):
indices_type=torch.int64,
routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate)
probs = torch.zeros_like(router_logits, dtype=topk_weights.dtype).scatter(1, topk_ids, topk_weights)
routing_map = torch.zeros_like(router_logits).int().scatter(1, topk_ids, 1).bool()
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
(dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
hidden_states, probs, routing_map
)
expert_output = self.triton_grouped_gemm_impl(dispatched_input, tokens_per_expert, self.use_nn_moe)
output = self.token_dispatcher.token_unpermutation(expert_output)
if self.use_shared_expert and not self.shared_expert_overlap:
# if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations
output = output + self.shared_experts(hidden_states)
return output
probs = torch.zeros_like(router_logits, dtype=topk_weights.dtype).scatter(1, topk_ids, topk_weights)
routing_map = torch.zeros_like(router_logits).int().scatter(1, topk_ids, 1).bool()
(dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
hidden_states, probs, routing_map
)
# Matrix multiply.
expert_output = self.quant_method.apply(
layer=self,
hidden_states=dispatched_input,
tokens_per_expert=tokens_per_expert
)
final_hidden_states = self.token_dispatcher.token_unpermutation(expert_output)
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
# if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations
shared_output = (
self.maybe_all_reduce_tensor_model_parallel(
shared_output))
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
return final_hidden_states
def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None
return self.forward_impl(hidden_states, router_logits)
def ep_moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor:
return torch.empty_like(hidden_states)
output = custom_forward(hidden_states, router_logits)
return output
\ No newline at end of file
direct_register_custom_op(
op_name="ep_moe_forward",
op_func=ep_moe_forward,
mutates_args=["hidden_states"],
fake_impl=ep_moe_forward_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
\ No newline at end of file
import os
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
......@@ -21,6 +22,9 @@ from vllm.distributed import (tensor_model_parallel_all_gather,
expert_parallel_gather)
from vllm.platforms import current_platform
cuda_dtoh_stream = torch.cuda.Stream()
class MoETokenDispatcher:
"""
MoE Token Dispatcher
......@@ -31,7 +35,6 @@ class MoETokenDispatcher:
Initialize the MoE Token Dispatcher.
"""
self.config = config
self.shared_experts: Optional[EPSharedExperts] = None
self.tp_size = 1
self.ep_size = config.ep_size
......@@ -162,13 +165,14 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
"no_sync": 4,
}
self.cuda_dtoh_point = "before_permutation_1"
self.cuda_dtoh_stream = torch.cuda.Stream()
self.shared_experts = None
#self.cuda_dtoh_stream = torch.cuda.Stream()
# Whether to use gather or all-gather to gather the logits.
self.use_all_gather = current_platform.use_all_gather()
self.probs = None
self.dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor:
"""
Preprocess token routing map for AlltoAll communication and token permutation.
......@@ -264,7 +268,9 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
return num_tokens_per_local_expert
def token_permutation(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
self, hidden_states: torch.Tensor,
probs: torch.Tensor,
routing_map: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Dispatch tokens to local experts using AlltoAll communication.
......@@ -287,7 +293,8 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
"""
# Preprocess: Get the metadata for communication, permutation and computation operations.
self.hidden_shape = hidden_states.shape
self.probs = probs
if self.config.apply_router_weight_on_input:
self.probs = probs
self.routing_map = routing_map
assert probs.dim() == 2, "Expected 2D tensor for probs"
assert routing_map.dim() == 2, "Expected 2D tensor for token2expert mask"
......@@ -295,50 +302,32 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
tokens_per_expert = self.preprocess(self.routing_map)
if self.shared_experts is not None:
if self.config.moe_shared_expert_overlap and self.shared_experts is not None:
self.shared_experts.pre_forward_comm(hidden_states.view(self.hidden_shape))
import sys
# torch.cuda.synchronize()
# sys.stderr.write(f"token_permutation===============================================")
# sys.stderr.flush()
# Permutation 1: input to AlltoAll input
tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_permutation_1", tokens_per_expert
)
# torch.cuda.synchronize()
# sys.stderr.write(f"before permute===============================================")
# sys.stderr.flush()
self.hidden_shape_before_permute = hidden_states.shape
permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute(
hidden_states,
routing_map,
num_out_tokens=self.num_out_tokens,
fused=self.config.moe_permute_fusion,
drop_and_pad=False,
fused=self.config.moe_permute_fusion
)
# torch.cuda.synchronize()
# sys.stderr.write(f"after permute===============================================")
# sys.stderr.flush()
# Perform expert parallel AlltoAll communication
tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_ep_alltoall", tokens_per_expert
)
#torch.cuda.synchronize()
#print("###########################before permutation all_to_all output_splits:{} input_splits:{}".format(self.output_splits, self.input_splits))
global_input_tokens = all_to_all(
self.ep_group.device_group, permutated_local_input_tokens, self.output_splits, self.input_splits
)
#torch.cuda.synchronize()
#print("#######################permutation all_to_all end")
if self.shared_experts is not None:
if self.config.moe_shared_expert_overlap and self.shared_experts is not None:
self.shared_experts.linear_fc1_forward_and_act(global_input_tokens)
# Permutation 2: Sort tokens by local expert.
......@@ -358,7 +347,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
return global_input_tokens, tokens_per_expert
def token_unpermutation(
self, hidden_states: torch.Tensor
self, hidden_states: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Reverse the token permutation to restore the original order.
......@@ -392,7 +381,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.ep_group.device_group, hidden_states, self.input_splits, self.output_splits
)
if self.shared_experts is not None:
if self.config.moe_shared_expert_overlap and self.shared_experts is not None:
self.shared_experts.linear_fc2_forward(permutated_local_input_tokens)
self.shared_experts.post_forward_comm()
......@@ -404,16 +393,22 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
probs=self.probs,
routing_map=self.routing_map,
fused=self.config.moe_permute_fusion,
drop_and_pad=False,
)
# Reshape the output tensor
output = output.view(self.hidden_shape)
# Add shared experts output
if self.shared_experts is not None:
shared_expert_output = self.shared_experts.get_output()
output += shared_expert_output
if self.config.moe_shared_expert_overlap and self.shared_experts is not None:
shared_output = self.shared_experts.get_output()
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
output = output + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
output = output + shared_output \
* (1. / self.config.routed_scaling_factor)
return output
def _maybe_update_cuda_sync_point(self, point: str):
......@@ -435,10 +430,10 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
"""
if point == self.cuda_dtoh_point:
# Move all possible GPU tensors to CPU at self.cuda_dtoh_point.
on_side_stream = torch.cuda.current_stream() != self.cuda_dtoh_stream
on_side_stream = torch.cuda.current_stream() != cuda_dtoh_stream
if on_side_stream:
self.cuda_dtoh_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.cuda_dtoh_stream):
cuda_dtoh_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(cuda_dtoh_stream):
# TODO: use MemcpyBatchAsync instead.
# tokens_per_expert = maybe_move_tensor_to_cpu(
# tokens_per_expert, record_stream=on_side_stream
......@@ -462,6 +457,6 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
if point == self.cuda_sync_point:
# Synchronize with the dtoh stream at self.cuda_sync_point.
self.cuda_dtoh_stream.synchronize()
cuda_dtoh_stream.synchronize()
return tokens_per_expert
\ No newline at end of file
......@@ -772,20 +772,12 @@ class FusedMoE(torch.nn.Module):
self.moe_config = moe
self.quant_config = quant_config
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None
quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
else quant_config.get_quant_method(self, prefix))
assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method
self.quant_method = self.create_quant_method(moe, quant_config, prefix)
if self.enable_eplb:
from vllm.model_executor.layers.quantization.fp8 import (
Fp8MoEMethod)
if not isinstance(quant_method, Fp8MoEMethod):
if not isinstance(self.quant_method, Fp8MoEMethod):
# TODO: Add support for additional quantization methods.
# The implementation for other quantization methods does not
# contain essential differences, but the current quant API
......@@ -852,6 +844,17 @@ class FusedMoE(torch.nn.Module):
dtype=moe.in_dtype,
device=torch.cuda.current_device())
def create_quant_method(self, moe, quant_config, prefix):
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None
quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
else quant_config.get_quant_method(self, prefix))
assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase)
return quant_method
@property
def tp_size(self):
return self.moe_parallel_config.tp_size
......
......@@ -156,7 +156,23 @@ class DeepseekV2MoE(nn.Module):
dp_size = get_dp_group().world_size
self.use_ep_opt = dp_size > 1 and parallel_config.enable_expert_parallel
self.shared_experts = None
moe_cls = FusedMoE if not self.use_ep_opt else EPMoE
self.experts = moe_cls(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor)
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
......@@ -167,48 +183,13 @@ class DeepseekV2MoE(nn.Module):
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
reduce_results=self.experts.must_reduce_shared_expert_outputs(
),
prefix=f"{prefix}.shared_experts",
)
if not self.use_ep_opt:
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor)
else:
self.experts = EPMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor)
if self.use_ep_opt:
self.experts.set_shared_experts(self.shared_experts)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce
......@@ -218,18 +199,22 @@ class DeepseekV2MoE(nn.Module):
if not self.use_ep_opt:
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if not self.use_ep_opt:
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
else:
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
router_logits=router_logits)
if not self.use_ep_opt:
if shared_output is not None:
......@@ -745,9 +730,7 @@ class DeepseekV2Model(nn.Module):
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(positions, hidden_states, residual)\
#ops.print_tensor(hidden_states)
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
......@@ -816,6 +799,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
self.tritonsingleton.topk = config.num_experts_per_tok
self.tritonsingleton.quant_method=self.quant_method
parallel_config = vllm_config.parallel_config
dp_size = get_dp_group().world_size
self.use_ep_opt = dp_size > 1 and parallel_config.enable_expert_parallel
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
......@@ -897,6 +884,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
("gate_up_proj", "up_proj", 1),
]
if self.use_ep_opt:
ep_moe_shared_experts_keys = "mlp.shared_experts"
ep_moe_shared_experts_mapping = {ep_moe_shared_experts_keys:"mlp.experts.shared_experts"}
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
......@@ -929,6 +920,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
if self.use_ep_opt:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
......@@ -955,6 +950,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# Instead, create a new variable
name_mapped = name.replace(weight_name, param_name)
if self.use_ep_opt:
name_mapped = name_mapped.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
if is_pp_missing_parameter(name_mapped, self):
continue
......@@ -979,7 +977,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# However it's not mapped locally to this rank
# So we simply skip it
continue
if self.use_ep_opt:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
......
......@@ -2052,7 +2052,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
inputs_embeds = self.inputs_embeds[:num_tokens]
else:
#self.input_ids[:num_tokens] = torch.randint(0, 120000, (num_tokens,), dtype=torch.int32)
self.input_ids[:num_tokens] = torch.arange(num_tokens, dtype=torch.int32, device=self.input_ids.device)
#self.input_ids[:num_tokens] = torch.arange(num_tokens, dtype=torch.int32, device=self.input_ids.device)
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
if self.uses_mrope:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment