Commit bb94d2e5 authored by yangql's avatar yangql
Browse files

增加fused-moe int4/int8的支持,以及deepseek精度问题的修复

parent 087254b9
...@@ -14,17 +14,248 @@ from vllm import _custom_ops as ops ...@@ -14,17 +14,248 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8) per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
device_name = current_platform.get_device_name().replace(" ", "_")
if device_name=='K100_AI' and torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120:
stage1_best_config=[
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #0
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"kpack": 1,"num_stages": 0,"num_warps": 4}, #1
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"kpack": 1,"num_stages": 0,"num_warps": 4}, #2
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"kpack": 1,"num_stages": 0,"num_warps": 4},#3
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"kpack": 1,"num_stages": 0,"num_warps": 4}, #4
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 4,"kpack": 1,"num_stages": 0,"num_warps": 4},#5
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#6
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 2,"kpack": 1,"num_stages": 0,"num_warps": 8},#7
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#8
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#9
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#10
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#11
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#12
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#13
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"kpack": 1,"num_stages": 0,"num_warps": 4}, #14
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #15
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #32
]
stage2_best_config=[
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #0
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #1
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #2
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#3
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #4
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#5
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#6
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#7
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#8
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#9
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#10
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#11
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#12
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#13
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #14
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #15
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #16
]
else:
stage1_best_config=[
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"num_stages": 0,"num_warps": 4}, #0
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #1
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #2
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#3
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #4
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"num_stages": 0,"num_warps": 4},#5
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"num_stages": 0,"num_warps": 4},#6
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"num_stages": 0,"num_warps": 4},#7
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#8
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"num_stages": 0,"num_warps": 4},#9
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#10
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 8},#11
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 2,"num_stages": 0,"num_warps": 2},#12
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 4,"num_stages": 0,"num_warps": 2},#13
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 2,"num_stages": 0,"num_warps": 2}, #14
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"num_stages": 0,"num_warps": 2}, #15
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #32
]
stage2_best_config=[
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #0
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #1
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #2
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#3
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #4
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#5
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#6
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#7
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#8
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#9
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#10
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#11
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 8},#12
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2},#13
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2}, #14
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2}, #15
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2}, #16
]
@triton.jit @triton.jit
def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, def fused_moe_kernel_awq(
token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, # Pointers to matrices
compute_type): a_ptr, # [4, 7168]
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) b_ptr, # [256, 512, 3584]
c_ptr, # (8, 8, 512)
b_scale_ptr, # (256, 512, 56)
b_zp_ptr, # (256, 256, 56)
topk_weights_ptr,
sorted_token_ids_ptr, # [0, 1, 2, 3, 4]
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N: tl.constexpr,
K: tl.constexpr,
EM, # pading后的总索引长度
num_valid_tokens, # 有效索引的上限
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk, #1
stride_bn,
stride_cm,
stride_cn,
stride_bse,
stride_bsk,#1
stride_bsn,
stride_bze,
stride_bzk,
stride_bzn,
block_k_diviable: tl.constexpr,
group_size: tl.constexpr, # 128
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
has_zp: tl.constexpr,
use_int4_w4a16: tl.constexpr,
use_int8_w8a16: tl.constexpr):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) # [block_m]
token_mask = offs_token < num_valid_tokens
offs_bn = (pid_n * BLOCK_SIZE_N +
tl.arange(0, BLOCK_SIZE_N)) % N # [block_n]
offs_k = tl.arange(0, BLOCK_SIZE_K) # 0, 1, 2, ...... , 127 # # [block_k]
offs_k2 = tl.arange(0, BLOCK_SIZE_K // 2) # 0, 1, 2, ...... , 127 # # [block_k]
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak) # [block_m, block_k]
off_experts = tl.load(expert_ids_ptr + pid_m)
if use_int4_w4a16:
# [0, 1, 2, ...... , 126, 127] --> [0, 0, 1, 1 ...... , 63, 63]
# [128, 129, 130, ...... , 254, 255] --> [64, 64, 65, 65 ...... , 127, 127]
# b_ptrs = b_ptr + off_experts * stride_be + \
# (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
b_ptrs = b_ptr + off_experts * stride_be + \
offs_bn[:, None] * stride_bn + (offs_k2[None, :]) * stride_bk
# tl.device_print("stride_bn",stride_bsn)>1
# tl.device_print("stride_bk",stride_bk)=1
b_shifter = (offs_k[:, None] % 2) * 4 # 0, 4
elif use_int8_w8a16:
b_ptrs = b_ptr + off_experts * stride_be + \
offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
if not has_zp and use_int4_w4a16:
b_zp_num = 8
if not has_zp and use_int8_w8a16:
b_zp_num = 128
elif has_zp and use_int4_w4a16:
b_zp_shifter = (offs_bn[None, :] % 2) * 4 # 0, 4
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
if not block_k_diviable:
k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
k_other = 0.0
else:
k_mask = None
k_other = None
a = tl.load(a_ptrs,
mask=token_mask[:, None] &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0)
b = tl.load(b_ptrs)
if use_int4_w4a16:
b = tl.interleave(b, b)
b= b.trans()
b = (b >> b_shifter) & 0xF
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \
offs_bn[None, :] * stride_bsk + \
((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsn
qzeros_scles = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
scales_int16 = tl.cast(qzeros_scles,tl.uint16)
b_scale = tl.cast(scales_int16,tl.float16,bitcast=True)
# tl.device_print("b_scale dequant",b_scale)
mid = qzeros_scles >> 16
# b_zp = tl.cast(mid,tl.float16,bitcast=False)
b_zp = tl.cast(mid,tl.float16)
# b_zp = tl.cast(zeros_int16,tl.float16,bitcast=False)
# tl.device_print("bzp",b_zp)
# We accumulate along the K dimension.
b = ((b - b_zp) * b_scale).to(tl.float16)
accumulator = tl.dot(a, b, acc=accumulator)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
if use_int4_w4a16:
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
else:
b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token,
mask=token_mask,
other=0)
accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
None, :] None, :]
...@@ -287,6 +518,7 @@ def fused_moe_kernel( ...@@ -287,6 +518,7 @@ def fused_moe_kernel(
top_k: tl.constexpr, top_k: tl.constexpr,
compute_type: tl.constexpr, compute_type: tl.constexpr,
use_fp8_w8a8: tl.constexpr, use_fp8_w8a8: tl.constexpr,
use_int8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr): use_int8_w8a16: tl.constexpr):
""" """
Implements the fused computation for a Mixture of Experts (MOE) using Implements the fused computation for a Mixture of Experts (MOE) using
...@@ -340,7 +572,6 @@ def fused_moe_kernel( ...@@ -340,7 +572,6 @@ def fused_moe_kernel(
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m pid_n = (pid % num_pid_in_group) // group_size_m
# ---------------------------------------------------------- # ----------------------------------------------------------
# Create pointers for the first blocks of A and B. # Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction # We will advance this pointer as we move in the K direction
...@@ -355,22 +586,13 @@ def fused_moe_kernel( ...@@ -355,22 +586,13 @@ def fused_moe_kernel(
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens token_mask = offs_token < num_valid_tokens
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
if off_experts == -1:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N,
offs_token, token_mask, BLOCK_SIZE_M,
BLOCK_SIZE_N, compute_type)
return
offs_bn = (pid_n * BLOCK_SIZE_N + offs_bn = (pid_n * BLOCK_SIZE_N +
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K) offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak) offs_k[None, :] * stride_ak)
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
offs_bn[None, :] * stride_bn) offs_bn[None, :] * stride_bn)
if use_int8_w8a16: if use_int8_w8a16:
...@@ -378,7 +600,7 @@ def fused_moe_kernel( ...@@ -378,7 +600,7 @@ def fused_moe_kernel(
None, :] * stride_bsn None, :] * stride_bsn
b_scale = tl.load(b_scale_ptrs) b_scale = tl.load(b_scale_ptrs)
if use_fp8_w8a8: if use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0: if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n offs_bsn = offs_bn // group_n
...@@ -407,7 +629,7 @@ def fused_moe_kernel( ...@@ -407,7 +629,7 @@ def fused_moe_kernel(
# We accumulate along the K dimension. # We accumulate along the K dimension.
if use_int8_w8a16: if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_fp8_w8a8: elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0: if group_k > 0 and group_n > 0:
k_start = k * BLOCK_SIZE_K k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k offs_ks = k_start // group_k
...@@ -433,7 +655,7 @@ def fused_moe_kernel( ...@@ -433,7 +655,7 @@ def fused_moe_kernel(
accumulator = accumulator * moe_weight[:, None] accumulator = accumulator * moe_weight[:, None]
if use_int8_w8a16: if use_int8_w8a16:
accumulator = (accumulator * b_scale).to(compute_type) accumulator = (accumulator * b_scale).to(compute_type)
elif use_fp8_w8a8: elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0: if group_k > 0 and group_n > 0:
accumulator = accumulator.to(compute_type) accumulator = accumulator.to(compute_type)
else: else:
...@@ -591,7 +813,8 @@ def moe_align_block_size( ...@@ -591,7 +813,8 @@ def moe_align_block_size(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
block_size: int, block_size: int,
num_experts: int, num_experts: int,
expert_map: torch.Tensor = None expert_map: torch.Tensor = None,
num_token: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Aligns the token distribution across experts to be compatible with block Aligns the token distribution across experts to be compatible with block
...@@ -634,15 +857,20 @@ def moe_align_block_size( ...@@ -634,15 +857,20 @@ def moe_align_block_size(
- The padding ensures that the total number of tokens is now divisible - The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations. by block_size for proper block matrix operations.
""" """
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) if num_token:
sorted_ids = torch.empty((max_num_tokens_padded, ), if num_token < block_size:
dtype=torch.int32, max_num_tokens_padded = min(topk_ids.numel() * block_size, topk_ids.numel() + num_experts * (block_size - 1))
device=topk_ids.device) else:
sorted_ids.fill_(topk_ids.numel()) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.full((max_num_tokens_padded,), fill_value=topk_ids.numel(), dtype=torch.int32, device=topk_ids.device)
else:
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be zeroed out to prevent index out of bounds error while expert_ids = torch.empty((max_num_m_blocks, ),
# mapping global expert ids to local expert ids in expert parallelism.
expert_ids = torch.zeros((max_num_m_blocks, ),
dtype=torch.int32, dtype=torch.int32,
device=topk_ids.device) device=topk_ids.device)
num_tokens_post_pad = torch.empty((1), num_tokens_post_pad = torch.empty((1),
...@@ -693,6 +921,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -693,6 +921,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
config: Dict[str, Any], config: Dict[str, Any],
compute_type: tl.dtype, compute_type: tl.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool, use_int4_w4a16: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
...@@ -711,6 +940,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -711,6 +940,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a8:
assert B_scale is not None
if block_shape is None:
A, A_scale = ops.scaled_int8_quant(A, A_scale)
else:
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_int8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a16 or use_int4_w4a16: elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None assert B_scale is not None
assert block_shape is None or block_shape[0] == 0 assert block_shape is None or block_shape[0] == 0
...@@ -733,77 +975,117 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -733,77 +975,117 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
block_shape is not None and block_shape[1] > 0: block_shape is not None and block_shape[1] > 0:
assert B_scale is not None and B_scale.ndim == 3 assert B_scale is not None and B_scale.ndim == 3
assert B_zp is None or B_zp.ndim == 3 assert B_zp is None or B_zp.ndim == 3
if os.environ.get('moe_wna16_use_cuda') == '1':
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
num_valid_tokens=topk_ids.numel(),
group_size=block_shape[1],
num_experts=B.shape[0],
bit=4 if use_int4_w4a16 else 8)
config = config.copy()
config.update(
get_moe_wna16_block_config(config=config,
use_moe_wna16_cuda=use_moe_wna16_cuda,
num_valid_tokens=topk_ids.numel(),
size_k=A.shape[1],
size_n=B.shape[1],
num_experts=B.shape[1],
group_size=block_shape[1],
real_top_k=topk_ids.shape[1],
block_size_m=config["BLOCK_SIZE_M"]))
if use_moe_wna16_cuda:
bit = 4 if use_int4_w4a16 else 8
ops.moe_wna16_gemm(A, C, B, B_scale, B_zp,
topk_weights if mul_routed_weight else None,
sorted_token_ids, expert_ids,
num_tokens_post_padded, top_k,
config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"],
config["BLOCK_SIZE_K"], bit)
return
if os.environ.get('AWQ_MOE_SZ') == '1':
fused_moe_kernel_awq[grid](
A,
B,
C,
B_scale,
B_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.shape[1],
A.shape[1],
EM,
topk_ids.numel(),
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(1),
C.stride(2),
B_scale.stride(0),
B_scale.stride(2),
B_scale.stride(1),
B_zp.stride(0) if B_zp is not None else 0,
B_zp.stride(2) if B_zp is not None else 0,
B_zp.stride(1) if B_zp is not None else 0,
block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
group_size=block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=compute_type,
has_zp=B_zp is not None,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16,
**config,
)
else:
fused_moe_kernel_gptq_awq[grid](
A,
B,
C,
B_scale,
B_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.shape[1],
A.shape[1],
EM,
topk_ids.numel(),
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(1),
C.stride(2),
B_scale.stride(0),
B_scale.stride(2),
B_scale.stride(1),
B_zp.stride(0) if B_zp is not None else 0,
B_zp.stride(2) if B_zp is not None else 0,
B_zp.stride(1) if B_zp is not None else 0,
block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
group_size=block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=compute_type,
has_zp=B_zp is not None,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16,
**config,
)
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
num_valid_tokens=topk_ids.numel(),
group_size=block_shape[1],
num_experts=B.shape[0],
bit=4 if use_int4_w4a16 else 8)
config = config.copy()
config.update(
get_moe_wna16_block_config(config=config,
use_moe_wna16_cuda=use_moe_wna16_cuda,
num_valid_tokens=topk_ids.numel(),
size_k=A.shape[1],
size_n=B.shape[1],
num_experts=B.shape[1],
group_size=block_shape[1],
real_top_k=topk_ids.shape[1],
block_size_m=config["BLOCK_SIZE_M"]))
if use_moe_wna16_cuda:
bit = 4 if use_int4_w4a16 else 8
ops.moe_wna16_gemm(A, C, B, B_scale, B_zp,
topk_weights if mul_routed_weight else None,
sorted_token_ids, expert_ids,
num_tokens_post_padded, top_k,
config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"],
config["BLOCK_SIZE_K"], bit)
return
fused_moe_kernel_gptq_awq[grid](
A,
B,
C,
B_scale,
B_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.shape[1],
A.shape[1],
EM,
topk_ids.numel(),
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(1),
C.stride(2),
B_scale.stride(0),
B_scale.stride(2),
B_scale.stride(1),
B_zp.stride(0) if B_zp is not None else 0,
B_zp.stride(2) if B_zp is not None else 0,
B_zp.stride(1) if B_zp is not None else 0,
block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
group_size=block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=compute_type,
has_zp=B_zp is not None,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16,
**config,
)
else: else:
config = config.copy() # config = config.copy()
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") # BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
if block_shape is not None: # if block_shape is not None:
BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], # BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0],
block_shape[1])) # block_shape[1]))
fused_moe_kernel[grid]( fused_moe_kernel[grid](
A, A,
B, B,
...@@ -841,8 +1123,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -841,8 +1123,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
top_k=top_k, top_k=top_k,
compute_type=compute_type, compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
BLOCK_SIZE_K=BLOCK_SIZE_K, #BLOCK_SIZE_K=BLOCK_SIZE_K,
**config, **config,
) )
...@@ -869,7 +1152,7 @@ def get_moe_configs( ...@@ -869,7 +1152,7 @@ def get_moe_configs(
dtype: Optional[str], dtype: Optional[str],
block_n: Optional[int] = None, block_n: Optional[int] = None,
block_k: Optional[int] = None, block_k: Optional[int] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False
) -> Optional[Dict[int, Any]]: ) -> Optional[Dict[int, Any]]:
""" """
Return optimized configurations for the fused MoE kernel. Return optimized configurations for the fused MoE kernel.
...@@ -887,6 +1170,15 @@ def get_moe_configs( ...@@ -887,6 +1170,15 @@ def get_moe_configs(
config_file_path = os.path.join( config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
if torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120:
config_file_path_120 = config_file_path.replace(".json","_120.json")
if os.path.exists(config_file_path_120):
with open(config_file_path_120) as f:
logger.info("Using configuration from %s for MoE layer.",
config_file_path_120)
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
if os.path.exists(config_file_path): if os.path.exists(config_file_path):
with open(config_file_path) as f: with open(config_file_path) as f:
logger.info("Using configuration from %s for MoE layer.", logger.info("Using configuration from %s for MoE layer.",
...@@ -975,7 +1267,7 @@ def get_default_config( ...@@ -975,7 +1267,7 @@ def get_default_config(
dtype: Optional[str], dtype: Optional[str],
is_marlin: bool, is_marlin: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool]=False, use_nn_moe: Optional[bool]=False
) -> Dict[str, int]: ) -> Dict[str, int]:
if dtype == "fp8_w8a8" and block_shape is not None: if dtype == "fp8_w8a8" and block_shape is not None:
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
...@@ -988,21 +1280,21 @@ def get_default_config( ...@@ -988,21 +1280,21 @@ def get_default_config(
"num_warps": 4, "num_warps": 4,
"num_stages": 3, "num_stages": 3,
} }
elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None: # elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None:
# moe wna16 kernels # # moe wna16 kernels
# only set BLOCK_SIZE_M # # only set BLOCK_SIZE_M
# BLOCK_SIZE_N and BLOCK_SIZE_K would be set later # # BLOCK_SIZE_N and BLOCK_SIZE_K would be set later
bit = 4 if dtype == "int4_w4a16" else 8 # bit = 4 if dtype == "int4_w4a16" else 8
use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk, # use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk,
block_shape[1], E, bit) # block_shape[1], E, bit)
if use_moe_wna16_cuda: # if use_moe_wna16_cuda:
config = {"BLOCK_SIZE_M": min(16, M)} # config = {"BLOCK_SIZE_M": min(16, M)}
elif M <= 20: # elif M <= 20:
config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1} # config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1}
elif M <= 40: # elif M <= 40:
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1} # config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1}
else: # else:
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1} # config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1}
else: else:
config = { config = {
"BLOCK_SIZE_M": 64, "BLOCK_SIZE_M": 64,
...@@ -1043,8 +1335,8 @@ def try_get_optimal_moe_config( ...@@ -1043,8 +1335,8 @@ def try_get_optimal_moe_config(
E, _, N = w2_shape E, _, N = w2_shape
else: else:
E, N, _ = w2_shape E, N, _ = w2_shape
if dtype == "int4_w4a16": # if dtype == "int4_w4a16":
N = N * 2 # N = N * 2
block_n = block_shape[0] if block_shape else 0 block_n = block_shape[0] if block_shape else 0
block_k = block_shape[1] if block_shape else 0 block_k = block_shape[1] if block_shape else 0
configs = get_moe_configs(E, N, dtype, block_n, block_k, use_nn_moe=use_nn_moe) configs = get_moe_configs(E, N, dtype, block_n, block_k, use_nn_moe=use_nn_moe)
...@@ -1159,9 +1451,12 @@ def grouped_topk(hidden_states: torch.Tensor, ...@@ -1159,9 +1451,12 @@ def grouped_topk(hidden_states: torch.Tensor,
def get_config_dtype_str(dtype: torch.dtype, def get_config_dtype_str(dtype: torch.dtype,
use_int4_w4a16: Optional[bool] = False, use_int4_w4a16: Optional[bool] = False,
use_int8_w8a16: Optional[bool] = False, use_int8_w8a16: Optional[bool] = False,
use_fp8_w8a8: Optional[bool] = False): use_fp8_w8a8: Optional[bool] = False,
use_int8_w8a8: Optional[bool] = False):
if use_fp8_w8a8: if use_fp8_w8a8:
return "fp8_w8a8" return "fp8_w8a8"
elif use_int8_w8a8:
return "int8_w8a8"
elif use_int8_w8a16: elif use_int8_w8a16:
return "int8_w8a16" return "int8_w8a16"
elif use_int4_w4a16: elif use_int4_w4a16:
...@@ -1180,6 +1475,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -1180,6 +1475,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: Optional[str] = None, activation: Optional[str] = None,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -1193,7 +1489,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -1193,7 +1489,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None: use_nn_moe: Optional[bool] = False) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, use_fp8_w8a8, use_int8_w8a16, activation, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
use_int4_w4a16, global_num_experts, expert_map, use_int4_w4a16, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe) block_shape, use_nn_moe)
...@@ -1207,6 +1503,7 @@ def inplace_fused_experts_fake( ...@@ -1207,6 +1503,7 @@ def inplace_fused_experts_fake(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: Optional[str] = None, activation: Optional[str] = None,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -1238,6 +1535,7 @@ def outplace_fused_experts( ...@@ -1238,6 +1535,7 @@ def outplace_fused_experts(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: Optional[str] = None, activation: Optional[str] = None,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -1251,7 +1549,7 @@ def outplace_fused_experts( ...@@ -1251,7 +1549,7 @@ def outplace_fused_experts(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor: use_nn_moe: Optional[bool] = False) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, use_fp8_w8a8, use_int8_w8a16, False, activation, use_fp8_w8a8,use_int8_w8a8, use_int8_w8a16,
use_int4_w4a16, global_num_experts, expert_map, use_int4_w4a16, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape, use_nn_moe) a2_scale, block_shape, use_nn_moe)
...@@ -1265,6 +1563,7 @@ def outplace_fused_experts_fake( ...@@ -1265,6 +1563,7 @@ def outplace_fused_experts_fake(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: Optional[str] = None, activation: Optional[str] = None,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -1296,6 +1595,7 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -1296,6 +1595,7 @@ def fused_experts(hidden_states: torch.Tensor,
inplace: bool = False, inplace: bool = False,
activation: str = "silu", activation: str = "silu",
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -1312,14 +1612,14 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -1312,14 +1612,14 @@ def fused_experts(hidden_states: torch.Tensor,
if inplace: if inplace:
torch.ops.vllm.inplace_fused_experts( torch.ops.vllm.inplace_fused_experts(
hidden_states, w1, w2, topk_weights, topk_ids, activation, hidden_states, w1, w2, topk_weights, topk_ids, activation,
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe) block_shape, use_nn_moe)
return hidden_states return hidden_states
else: else:
return torch.ops.vllm.outplace_fused_experts( return torch.ops.vllm.outplace_fused_experts(
hidden_states, w1, w2, topk_weights, topk_ids, activation, hidden_states, w1, w2, topk_weights, topk_ids, activation,
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe) block_shape, use_nn_moe)
...@@ -1332,6 +1632,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1332,6 +1632,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
inplace: bool = False, inplace: bool = False,
activation: str = "silu", activation: str = "silu",
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -1373,22 +1674,24 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1373,22 +1674,24 @@ def fused_experts_impl(hidden_states: torch.Tensor,
# https://github.com/vllm-project/vllm/issues/5938 # https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE) M = min(num_tokens, CHUNK_SIZE)
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, if not use_int8_w8a8:
use_int8_w8a16=use_int8_w8a16, config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
use_int4_w4a16=use_int4_w4a16, use_int8_w8a8=use_int8_w8a8,
dtype=hidden_states.dtype) use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
get_config_func = functools.partial( dtype=hidden_states.dtype)
try_get_optimal_moe_config,
w1.shape, get_config_func = functools.partial(
w2.shape, try_get_optimal_moe_config,
top_k_num, w1.shape,
config_dtype, w2.shape,
block_shape=block_shape, top_k_num,
use_nn_moe=use_nn_moe, config_dtype,
) block_shape=block_shape,
use_nn_moe=use_nn_moe,
)
config = get_config_func(M) config = get_config_func(M)
# We can reuse the memory between these because by the time we need # We can reuse the memory between these because by the time we need
# cache3, we're done with cache1 # cache3, we're done with cache1
...@@ -1442,10 +1745,43 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1442,10 +1745,43 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
if use_int8_w8a8:
m=curr_hidden_states.shape[0]
if m<=16:
config =stage1_best_config[m-1]
elif m<=32:
config =stage1_best_config[15]
elif m<=64:
config =stage1_best_config[16]
elif m<256:
config ={
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0,
"num_warps": 4
}
else:
config ={
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_stages": 0,
"num_warps": 4
}
# sorted_token_ids, expert_ids, num_tokens_post_padded = (
# moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'],
# global_num_experts, expert_map))
if use_int4_w4a16:
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], global_num_experts, expert_map, curr_hidden_states.shape[0]))
else:
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], global_num_experts, expert_map))
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'],
global_num_experts, expert_map))
invoke_fused_moe_kernel(curr_hidden_states, invoke_fused_moe_kernel(curr_hidden_states,
w1, w1,
...@@ -1463,6 +1799,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1463,6 +1799,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config, config,
compute_type=compute_type, compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape, block_shape=block_shape,
...@@ -1476,7 +1813,33 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1476,7 +1813,33 @@ def fused_experts_impl(hidden_states: torch.Tensor,
intermediate_cache1.view(-1, N)) intermediate_cache1.view(-1, N))
else: else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}") raise ValueError(f"Unsupported FusedMoe activation: {activation}")
if use_int8_w8a8:
m=curr_hidden_states.shape[0]
if m<=16:
config =stage2_best_config[m-1]
elif m<=32:
config =stage2_best_config[15]
elif m<=64:
config =stage2_best_config[16]
elif m<256:
config ={
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0,
"num_warps": 4
}
else:
config ={
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_stages": 0,
"num_warps": 4
}
invoke_fused_moe_kernel(intermediate_cache2, invoke_fused_moe_kernel(intermediate_cache2,
w2, w2,
intermediate_cache3, intermediate_cache3,
...@@ -1493,6 +1856,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1493,6 +1856,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config, config,
compute_type=compute_type, compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape, block_shape=block_shape,
...@@ -1517,6 +1881,7 @@ def fused_moe( ...@@ -1517,6 +1881,7 @@ def fused_moe(
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -1598,6 +1963,7 @@ def fused_moe( ...@@ -1598,6 +1963,7 @@ def fused_moe(
inplace=inplace, inplace=inplace,
activation=activation, activation=activation,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
......
...@@ -164,21 +164,28 @@ class DeepseekV2MoE(nn.Module): ...@@ -164,21 +164,28 @@ class DeepseekV2MoE(nn.Module):
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
if hidden_states.dtype != torch.float16: # if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts( # final_hidden_states = self.experts(
hidden_states=hidden_states, # hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor # router_logits=router_logits) * self.routed_scaling_factor
else: # else:
# This is a special case to avoid FP16 overflow # # This is a special case to avoid FP16 overflow
final_hidden_states = self.experts(hidden_states=hidden_states, # final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) # router_logits=router_logits)
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
if shared_output is not None: if shared_output is not None:
if hidden_states.dtype != torch.float16: final_hidden_states = final_hidden_states + shared_output
final_hidden_states = final_hidden_states + shared_output
else: # if shared_output is not None:
# This is a special case to avoid FP16 overflow # if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output \ # final_hidden_states = final_hidden_states + shared_output
* (1. / self.routed_scaling_factor) # else:
# # This is a special case to avoid FP16 overflow
# final_hidden_states = final_hidden_states + shared_output \
# * (1. / self.routed_scaling_factor)
if self.tp_size > 1: if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states) final_hidden_states)
...@@ -571,18 +578,18 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -571,18 +578,18 @@ class DeepseekV2DecoderLayer(nn.Module):
) )
# Fully Connected # Fully Connected
if isinstance(self.mlp, DeepseekV2MoE) and \ # if isinstance(self.mlp, DeepseekV2MoE) and \
hidden_states.dtype == torch.float16: # hidden_states.dtype == torch.float16:
# This is a special case to avoid FP16 overflow # # This is a special case to avoid FP16 overflow
hidden_states *= 1. / self.routed_scaling_factor # hidden_states *= 1. / self.routed_scaling_factor
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual) hidden_states, residual)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp, DeepseekV2MLP) and \ # if isinstance(self.mlp, DeepseekV2MLP) and \
hidden_states.dtype == torch.float16: # hidden_states.dtype == torch.float16:
# This is a special case to avoid FP16 overflow # # This is a special case to avoid FP16 overflow
hidden_states *= 1. / self.routed_scaling_factor # hidden_states *= 1. / self.routed_scaling_factor
residual *= 1. / self.routed_scaling_factor # residual *= 1. / self.routed_scaling_factor
return hidden_states, residual return hidden_states, residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment