Commit fe6d3b05 authored by zhuwenwen's avatar zhuwenwen
Browse files

remove fused_moe of quantization

parent 68826ce6
...@@ -14,255 +14,17 @@ from vllm import _custom_ops as ops ...@@ -14,255 +14,17 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8) per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
device_name = current_platform.get_device_name().replace(" ", "_")
if device_name=='K100_AI' and torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120:
stage1_best_config=[
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #0
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"kpack": 1,"num_stages": 0,"num_warps": 4}, #1
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"kpack": 1,"num_stages": 0,"num_warps": 4}, #2
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"kpack": 1,"num_stages": 0,"num_warps": 4},#3
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"kpack": 1,"num_stages": 0,"num_warps": 4}, #4
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 4,"kpack": 1,"num_stages": 0,"num_warps": 4},#5
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#6
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 2,"kpack": 1,"num_stages": 0,"num_warps": 8},#7
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#8
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#9
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#10
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#11
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#12
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#13
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"kpack": 1,"num_stages": 0,"num_warps": 4}, #14
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #15
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #32
]
stage2_best_config=[
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #0
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #1
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #2
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#3
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #4
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#5
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#6
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#7
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#8
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#9
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#10
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#11
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#12
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#13
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #14
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #15
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #16
]
else:
stage1_best_config=[
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"num_stages": 0,"num_warps": 4}, #0
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #1
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #2
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#3
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #4
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"num_stages": 0,"num_warps": 4},#5
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"num_stages": 0,"num_warps": 4},#6
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"num_stages": 0,"num_warps": 4},#7
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#8
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"num_stages": 0,"num_warps": 4},#9
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#10
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 8},#11
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 2,"num_stages": 0,"num_warps": 2},#12
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 4,"num_stages": 0,"num_warps": 2},#13
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 2,"num_stages": 0,"num_warps": 2}, #14
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"num_stages": 0,"num_warps": 2}, #15
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #32
]
stage2_best_config=[
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #0
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #1
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #2
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#3
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #4
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#5
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#6
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#7
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#8
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#9
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#10
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4},#11
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 8},#12
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2},#13
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2}, #14
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2}, #15
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2}, #16
]
@triton.jit @triton.jit
def fused_moe_kernel_awq(
# Pointers to matrices
a_ptr, # [4, 7168]
b_ptr, # [256, 512, 3584]
c_ptr, # (8, 8, 512)
b_scale_ptr, # (256, 512, 56)
b_zp_ptr, # (256, 256, 56)
topk_weights_ptr,
sorted_token_ids_ptr, # [0, 1, 2, 3, 4]
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N: tl.constexpr,
K: tl.constexpr,
EM, # pading后的总索引长度
num_valid_tokens, # 有效索引的上限
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk, #1
stride_bn,
stride_cm,
stride_cn,
stride_bse,
stride_bsk,#1
stride_bsn,
stride_bze,
stride_bzk,
stride_bzn,
block_k_diviable: tl.constexpr,
group_size: tl.constexpr, # 128
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
has_zp: tl.constexpr,
use_int4_w4a16: tl.constexpr,
use_int8_w8a16: tl.constexpr):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) # [block_m]
token_mask = offs_token < num_valid_tokens
offs_bn = (pid_n * BLOCK_SIZE_N +
tl.arange(0, BLOCK_SIZE_N)) % N # [block_n]
offs_k = tl.arange(0, BLOCK_SIZE_K) # 0, 1, 2, ...... , 127 # # [block_k]
offs_k2 = tl.arange(0, BLOCK_SIZE_K // 2) # 0, 1, 2, ...... , 127 # # [block_k]
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak) # [block_m, block_k]
off_experts = tl.load(expert_ids_ptr + pid_m)
if use_int4_w4a16:
# [0, 1, 2, ...... , 126, 127] --> [0, 0, 1, 1 ...... , 63, 63]
# [128, 129, 130, ...... , 254, 255] --> [64, 64, 65, 65 ...... , 127, 127]
# b_ptrs = b_ptr + off_experts * stride_be + \
# (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
b_ptrs = b_ptr + off_experts * stride_be + \
offs_bn[:, None] * stride_bn + (offs_k2[None, :]) * stride_bk
# tl.device_print("stride_bn",stride_bsn)>1
# tl.device_print("stride_bk",stride_bk)=1
b_shifter = (offs_k[:, None] % 2) * 4 # 0, 4
elif use_int8_w8a16:
b_ptrs = b_ptr + off_experts * stride_be + \
offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
if not has_zp and use_int4_w4a16:
b_zp_num = 8
if not has_zp and use_int8_w8a16:
b_zp_num = 128
elif has_zp and use_int4_w4a16:
b_zp_shifter = (offs_bn[None, :] % 2) * 4 # 0, 4
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
if not block_k_diviable:
k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
k_other = 0.0
else:
k_mask = None
k_other = None
a = tl.load(a_ptrs,
mask=token_mask[:, None] &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0)
b = tl.load(b_ptrs)
if use_int4_w4a16:
b = tl.interleave(b, b)
b= b.trans()
b = (b >> b_shifter) & 0xF
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \
offs_bn[None, :] * stride_bsk + \
((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsn
qzeros_scles = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
scales_int16 = tl.cast(qzeros_scles,tl.uint16)
b_scale = tl.cast(scales_int16,tl.float16,bitcast=True)
# tl.device_print("b_scale dequant",b_scale)
mid = qzeros_scles >> 16
# b_zp = tl.cast(mid,tl.float16,bitcast=False)
b_zp = tl.cast(mid,tl.float16)
# b_zp = tl.cast(zeros_int16,tl.float16,bitcast=False)
# tl.device_print("bzp",b_zp)
# We accumulate along the K dimension.
b = ((b - b_zp) * b_scale).to(tl.float16)
accumulator = tl.dot(a, b, acc=accumulator)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
if use_int4_w4a16:
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
else:
b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token,
mask=token_mask,
other=0)
accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token,
token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N,
compute_type): compute_type):
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
None, :] None, :]
...@@ -525,7 +287,6 @@ def fused_moe_kernel( ...@@ -525,7 +287,6 @@ def fused_moe_kernel(
top_k: tl.constexpr, top_k: tl.constexpr,
compute_type: tl.constexpr, compute_type: tl.constexpr,
use_fp8_w8a8: tl.constexpr, use_fp8_w8a8: tl.constexpr,
use_int8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr): use_int8_w8a16: tl.constexpr):
""" """
Implements the fused computation for a Mixture of Experts (MOE) using Implements the fused computation for a Mixture of Experts (MOE) using
...@@ -579,6 +340,7 @@ def fused_moe_kernel( ...@@ -579,6 +340,7 @@ def fused_moe_kernel(
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m pid_n = (pid % num_pid_in_group) // group_size_m
# ---------------------------------------------------------- # ----------------------------------------------------------
# Create pointers for the first blocks of A and B. # Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction # We will advance this pointer as we move in the K direction
...@@ -616,7 +378,7 @@ def fused_moe_kernel( ...@@ -616,7 +378,7 @@ def fused_moe_kernel(
None, :] * stride_bsn None, :] * stride_bsn
b_scale = tl.load(b_scale_ptrs) b_scale = tl.load(b_scale_ptrs)
if use_fp8_w8a8 or use_int8_w8a8: if use_fp8_w8a8:
if group_k > 0 and group_n > 0: if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n offs_bsn = offs_bn // group_n
...@@ -645,7 +407,7 @@ def fused_moe_kernel( ...@@ -645,7 +407,7 @@ def fused_moe_kernel(
# We accumulate along the K dimension. # We accumulate along the K dimension.
if use_int8_w8a16: if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_fp8_w8a8 or use_int8_w8a8: elif use_fp8_w8a8:
if group_k > 0 and group_n > 0: if group_k > 0 and group_n > 0:
k_start = k * BLOCK_SIZE_K k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k offs_ks = k_start // group_k
...@@ -671,7 +433,7 @@ def fused_moe_kernel( ...@@ -671,7 +433,7 @@ def fused_moe_kernel(
accumulator = accumulator * moe_weight[:, None] accumulator = accumulator * moe_weight[:, None]
if use_int8_w8a16: if use_int8_w8a16:
accumulator = (accumulator * b_scale).to(compute_type) accumulator = (accumulator * b_scale).to(compute_type)
elif use_fp8_w8a8 or use_int8_w8a8: elif use_fp8_w8a8:
if group_k > 0 and group_n > 0: if group_k > 0 and group_n > 0:
accumulator = accumulator.to(compute_type) accumulator = accumulator.to(compute_type)
else: else:
...@@ -829,8 +591,7 @@ def moe_align_block_size( ...@@ -829,8 +591,7 @@ def moe_align_block_size(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
block_size: int, block_size: int,
num_experts: int, num_experts: int,
expert_map: torch.Tensor = None, expert_map: torch.Tensor = None
num_token: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Aligns the token distribution across experts to be compatible with block Aligns the token distribution across experts to be compatible with block
...@@ -873,18 +634,11 @@ def moe_align_block_size( ...@@ -873,18 +634,11 @@ def moe_align_block_size(
- The padding ensures that the total number of tokens is now divisible - The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations. by block_size for proper block matrix operations.
""" """
if num_token: max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
if num_token < block_size: sorted_ids = torch.empty((max_num_tokens_padded, ),
max_num_tokens_padded = min(topk_ids.numel() * block_size, topk_ids.numel() + num_experts * (block_size - 1)) dtype=torch.int32,
else: device=topk_ids.device)
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) sorted_ids.fill_(topk_ids.numel())
sorted_ids = torch.full((max_num_tokens_padded,), fill_value=topk_ids.numel(), dtype=torch.int32, device=topk_ids.device)
else:
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be zeroed out to prevent index out of bounds error while # Expert ids must be zeroed out to prevent index out of bounds error while
# mapping global expert ids to local expert ids in expert parallelism. # mapping global expert ids to local expert ids in expert parallelism.
...@@ -939,7 +693,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -939,7 +693,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
config: Dict[str, Any], config: Dict[str, Any],
compute_type: tl.dtype, compute_type: tl.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool, use_int4_w4a16: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
...@@ -958,19 +711,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -958,19 +711,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a8:
assert B_scale is not None
if block_shape is None:
A, A_scale = ops.scaled_int8_quant(A, A_scale)
else:
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_int8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a16 or use_int4_w4a16: elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None assert B_scale is not None
assert block_shape is None or block_shape[0] == 0 assert block_shape is None or block_shape[0] == 0
...@@ -1021,82 +761,43 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -1021,82 +761,43 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
config["BLOCK_SIZE_K"], bit) config["BLOCK_SIZE_K"], bit)
return return
if os.environ.get('AWQ_MOE_SZ') == '1': fused_moe_kernel_gptq_awq[grid](
fused_moe_kernel_awq[grid]( A,
A, B,
B, C,
C, B_scale,
B_scale, B_zp,
B_zp, topk_weights,
topk_weights, sorted_token_ids,
sorted_token_ids, expert_ids,
expert_ids, num_tokens_post_padded,
num_tokens_post_padded, B.shape[1],
B.shape[1], A.shape[1],
A.shape[1], EM,
EM, topk_ids.numel(),
topk_ids.numel(), A.stride(0),
A.stride(0), A.stride(1),
A.stride(1), B.stride(0),
B.stride(0), B.stride(2),
B.stride(2), B.stride(1),
B.stride(1), C.stride(1),
C.stride(1), C.stride(2),
C.stride(2), B_scale.stride(0),
B_scale.stride(0), B_scale.stride(2),
B_scale.stride(2), B_scale.stride(1),
B_scale.stride(1), B_zp.stride(0) if B_zp is not None else 0,
B_zp.stride(0) if B_zp is not None else 0, B_zp.stride(2) if B_zp is not None else 0,
B_zp.stride(2) if B_zp is not None else 0, B_zp.stride(1) if B_zp is not None else 0,
B_zp.stride(1) if B_zp is not None else 0, block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0, group_size=block_shape[1],
group_size=block_shape[1], MUL_ROUTED_WEIGHT=mul_routed_weight,
MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k,
top_k=top_k, compute_type=compute_type,
compute_type=compute_type, has_zp=B_zp is not None,
has_zp=B_zp is not None, use_int4_w4a16=use_int4_w4a16,
use_int4_w4a16=use_int4_w4a16, use_int8_w8a16=use_int8_w8a16,
use_int8_w8a16=use_int8_w8a16, **config,
**config, )
)
else:
fused_moe_kernel_gptq_awq[grid](
A,
B,
C,
B_scale,
B_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.shape[1],
A.shape[1],
EM,
topk_ids.numel(),
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(1),
C.stride(2),
B_scale.stride(0),
B_scale.stride(2),
B_scale.stride(1),
B_zp.stride(0) if B_zp is not None else 0,
B_zp.stride(2) if B_zp is not None else 0,
B_zp.stride(1) if B_zp is not None else 0,
block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
group_size=block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=compute_type,
has_zp=B_zp is not None,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16,
**config,
)
else: else:
config = config.copy() config = config.copy()
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
...@@ -1140,7 +841,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -1140,7 +841,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
top_k=top_k, top_k=top_k,
compute_type=compute_type, compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_K=BLOCK_SIZE_K,
**config, **config,
...@@ -1161,7 +861,6 @@ def get_config_file_name(E: int, ...@@ -1161,7 +861,6 @@ def get_config_file_name(E: int,
else: else:
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}_nn.json" return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}_nn.json"
# Adapted from: https://github.com/sgl-project/sglang/pull/2628 # Adapted from: https://github.com/sgl-project/sglang/pull/2628
@functools.lru_cache @functools.lru_cache
def get_moe_configs( def get_moe_configs(
...@@ -1170,7 +869,7 @@ def get_moe_configs( ...@@ -1170,7 +869,7 @@ def get_moe_configs(
dtype: Optional[str], dtype: Optional[str],
block_n: Optional[int] = None, block_n: Optional[int] = None,
block_k: Optional[int] = None, block_k: Optional[int] = None,
use_nn_moe: Optional[bool] = False use_nn_moe: Optional[bool] = False,
) -> Optional[Dict[int, Any]]: ) -> Optional[Dict[int, Any]]:
""" """
Return optimized configurations for the fused MoE kernel. Return optimized configurations for the fused MoE kernel.
...@@ -1188,15 +887,6 @@ def get_moe_configs( ...@@ -1188,15 +887,6 @@ def get_moe_configs(
config_file_path = os.path.join( config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
if torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120:
config_file_path_120 = config_file_path.replace(".json","_120.json")
if os.path.exists(config_file_path_120):
with open(config_file_path_120) as f:
logger.info("Using configuration from %s for MoE layer.",
config_file_path_120)
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
if os.path.exists(config_file_path): if os.path.exists(config_file_path):
with open(config_file_path) as f: with open(config_file_path) as f:
logger.info("Using configuration from %s for MoE layer.", logger.info("Using configuration from %s for MoE layer.",
...@@ -1285,7 +975,7 @@ def get_default_config( ...@@ -1285,7 +975,7 @@ def get_default_config(
dtype: Optional[str], dtype: Optional[str],
is_marlin: bool, is_marlin: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool]=False use_nn_moe: Optional[bool]=False,
) -> Dict[str, int]: ) -> Dict[str, int]:
if dtype == "fp8_w8a8" and block_shape is not None: if dtype == "fp8_w8a8" and block_shape is not None:
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
...@@ -1341,7 +1031,7 @@ def try_get_optimal_moe_config( ...@@ -1341,7 +1031,7 @@ def try_get_optimal_moe_config(
M: int, M: int,
is_marlin: bool = False, is_marlin: bool = False,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False use_nn_moe: Optional[bool] = False,
): ):
from vllm.model_executor.layers.fused_moe import get_config from vllm.model_executor.layers.fused_moe import get_config
override_config = get_config() override_config = get_config()
...@@ -1469,12 +1159,9 @@ def grouped_topk(hidden_states: torch.Tensor, ...@@ -1469,12 +1159,9 @@ def grouped_topk(hidden_states: torch.Tensor,
def get_config_dtype_str(dtype: torch.dtype, def get_config_dtype_str(dtype: torch.dtype,
use_int4_w4a16: Optional[bool] = False, use_int4_w4a16: Optional[bool] = False,
use_int8_w8a16: Optional[bool] = False, use_int8_w8a16: Optional[bool] = False,
use_fp8_w8a8: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False):
use_int8_w8a8: Optional[bool] = False):
if use_fp8_w8a8: if use_fp8_w8a8:
return "fp8_w8a8" return "fp8_w8a8"
elif use_int8_w8a8:
return "int8_w8a8"
elif use_int8_w8a16: elif use_int8_w8a16:
return "int8_w8a16" return "int8_w8a16"
elif use_int4_w4a16: elif use_int4_w4a16:
...@@ -1493,7 +1180,6 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -1493,7 +1180,6 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: Optional[str] = None, activation: Optional[str] = None,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -1505,12 +1191,12 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -1505,12 +1191,12 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,) -> None: use_nn_moe: Optional[bool] = False) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, activation, use_fp8_w8a8, use_int8_w8a16,
use_int4_w4a16, global_num_experts, expert_map, use_int4_w4a16, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe,) block_shape, use_nn_moe)
def inplace_fused_experts_fake( def inplace_fused_experts_fake(
...@@ -1521,7 +1207,6 @@ def inplace_fused_experts_fake( ...@@ -1521,7 +1207,6 @@ def inplace_fused_experts_fake(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: Optional[str] = None, activation: Optional[str] = None,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -1533,7 +1218,7 @@ def inplace_fused_experts_fake( ...@@ -1533,7 +1218,7 @@ def inplace_fused_experts_fake(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,) -> None: use_nn_moe: Optional[bool] = False) -> None:
pass pass
...@@ -1553,7 +1238,6 @@ def outplace_fused_experts( ...@@ -1553,7 +1238,6 @@ def outplace_fused_experts(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: Optional[str] = None, activation: Optional[str] = None,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -1565,13 +1249,12 @@ def outplace_fused_experts( ...@@ -1565,13 +1249,12 @@ def outplace_fused_experts(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,) -> torch.Tensor: use_nn_moe: Optional[bool] = False) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, False, activation, use_fp8_w8a8, use_int8_w8a16,
use_int4_w4a16, global_num_experts, expert_map, use_int4_w4a16, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape, a2_scale, block_shape, use_nn_moe)
use_nn_moe,)
def outplace_fused_experts_fake( def outplace_fused_experts_fake(
...@@ -1582,7 +1265,6 @@ def outplace_fused_experts_fake( ...@@ -1582,7 +1265,6 @@ def outplace_fused_experts_fake(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: Optional[str] = None, activation: Optional[str] = None,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -1594,7 +1276,7 @@ def outplace_fused_experts_fake( ...@@ -1594,7 +1276,7 @@ def outplace_fused_experts_fake(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,) -> torch.Tensor: use_nn_moe: Optional[bool] = False) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -1614,7 +1296,6 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -1614,7 +1296,6 @@ def fused_experts(hidden_states: torch.Tensor,
inplace: bool = False, inplace: bool = False,
activation: str = "silu", activation: str = "silu",
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -1626,23 +1307,21 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -1626,23 +1307,21 @@ def fused_experts(hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,) -> torch.Tensor: use_nn_moe: Optional[bool] = False) -> torch.Tensor:
if inplace: if inplace:
torch.ops.vllm.inplace_fused_experts( torch.ops.vllm.inplace_fused_experts(
hidden_states, w1, w2, topk_weights, topk_ids, activation, hidden_states, w1, w2, topk_weights, topk_ids, activation,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts, use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, block_shape, use_nn_moe)
use_nn_moe,)
return hidden_states return hidden_states
else: else:
return torch.ops.vllm.outplace_fused_experts( return torch.ops.vllm.outplace_fused_experts(
hidden_states, w1, w2, topk_weights, topk_ids, activation, hidden_states, w1, w2, topk_weights, topk_ids, activation,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts, use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, block_shape, use_nn_moe)
use_nn_moe,)
def fused_experts_impl(hidden_states: torch.Tensor, def fused_experts_impl(hidden_states: torch.Tensor,
...@@ -1653,7 +1332,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1653,7 +1332,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
inplace: bool = False, inplace: bool = False,
activation: str = "silu", activation: str = "silu",
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -1665,7 +1343,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1665,7 +1343,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,): use_nn_moe: Optional[bool] = False):
# Check constraints. # Check constraints.
if use_int4_w4a16: if use_int4_w4a16:
assert hidden_states.shape[1] // 2 == w1.shape[ assert hidden_states.shape[1] // 2 == w1.shape[
...@@ -1684,12 +1362,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1684,12 +1362,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
] ]
num_tokens, _ = hidden_states.shape num_tokens, _ = hidden_states.shape
if use_nn_moe: if use_nn_moe:
E, _, N = w1.shape E, _, N = w1.shape
else: else:
E, N, _ = w1.shape E, N, _ = w1.shape
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = E global_num_experts = E
top_k_num = topk_ids.shape[1] top_k_num = topk_ids.shape[1]
...@@ -1697,34 +1373,32 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1697,34 +1373,32 @@ def fused_experts_impl(hidden_states: torch.Tensor,
# https://github.com/vllm-project/vllm/issues/5938 # https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE) M = min(num_tokens, CHUNK_SIZE)
if not use_int8_w8a8: config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16,
use_int8_w8a8=use_int8_w8a8, use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16, dtype=hidden_states.dtype)
use_int4_w4a16=use_int4_w4a16,
dtype=hidden_states.dtype) get_config_func = functools.partial(
try_get_optimal_moe_config,
get_config_func = functools.partial( w1.shape,
try_get_optimal_moe_config, w2.shape,
w1.shape, top_k_num,
w2.shape, config_dtype,
topk_ids.shape[1], block_shape=block_shape,
config_dtype, use_nn_moe=use_nn_moe,
block_shape=block_shape, )
use_nn_moe=use_nn_moe,
)
config = get_config_func(M) config = get_config_func(M)
# We can reuse the memory between these because by the time we need # We can reuse the memory between these because by the time we need
# cache3, we're done with cache1 # cache3, we're done with cache1
cache13 = torch.empty(M * top_k_num * max(N, w2.shape[1]), cache13 = torch.empty(M * top_k_num * max(N, w2.shape[1] if not use_nn_moe else w2.shape[2]),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype) dtype=hidden_states.dtype)
intermediate_cache1 = cache13[:M * top_k_num * N].view( intermediate_cache1 = cache13[:M * top_k_num * N].view(
(M, topk_ids.shape[1], N)) (M, topk_ids.shape[1], N))
intermediate_cache3 = cache13[:M * top_k_num * (w2.shape[1] if not use_nn_moe else w2.shape[2])].view( intermediate_cache3 = cache13[:M * top_k_num * (w2.shape[1] if not use_nn_moe else w2.shape[2])].view(
(M, topk_ids.shape[1], w2.shape[1])) (M, topk_ids.shape[1], w2.shape[1] if not use_nn_moe else w2.shape[2]))
# This needs separate memory since it's used concurrently with cache1 # This needs separate memory since it's used concurrently with cache1
intermediate_cache2 = torch.empty((M * top_k_num, N // 2), intermediate_cache2 = torch.empty((M * top_k_num, N // 2),
...@@ -1768,40 +1442,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1768,40 +1442,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
if use_int8_w8a8: sorted_token_ids, expert_ids, num_tokens_post_padded = (
m=curr_hidden_states.shape[0] moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'],
if m<=16: global_num_experts, expert_map))
config =stage1_best_config[m-1]
elif m<=32:
config =stage1_best_config[15]
elif m<=64:
config =stage1_best_config[16]
elif m<256:
config ={
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0,
"num_warps": 4
}
else:
config ={
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_stages": 0,
"num_warps": 4
}
if use_int4_w4a16:
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], global_num_experts, expert_map, curr_hidden_states.shape[0]))
else:
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], global_num_experts, expert_map))
invoke_fused_moe_kernel(curr_hidden_states, invoke_fused_moe_kernel(curr_hidden_states,
w1, w1,
...@@ -1819,7 +1463,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1819,7 +1463,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config, config,
compute_type=compute_type, compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape, block_shape=block_shape,
...@@ -1833,33 +1476,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1833,33 +1476,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
intermediate_cache1.view(-1, N)) intermediate_cache1.view(-1, N))
else: else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}") raise ValueError(f"Unsupported FusedMoe activation: {activation}")
if use_int8_w8a8:
m=curr_hidden_states.shape[0]
if m<=16:
config =stage2_best_config[m-1]
elif m<=32:
config =stage2_best_config[15]
elif m<=64:
config =stage2_best_config[16]
elif m<256:
config ={
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0,
"num_warps": 4
}
else:
config ={
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_stages": 0,
"num_warps": 4
}
invoke_fused_moe_kernel(intermediate_cache2, invoke_fused_moe_kernel(intermediate_cache2,
w2, w2,
...@@ -1877,7 +1493,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1877,7 +1493,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config, config,
compute_type=compute_type, compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape, block_shape=block_shape,
...@@ -1902,7 +1517,6 @@ def fused_moe( ...@@ -1902,7 +1517,6 @@ def fused_moe(
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -1984,7 +1598,6 @@ def fused_moe( ...@@ -1984,7 +1598,6 @@ def fused_moe(
inplace=inplace, inplace=inplace,
activation=activation, activation=activation,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
...@@ -1996,4 +1609,4 @@ def fused_moe( ...@@ -1996,4 +1609,4 @@ def fused_moe(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe,) use_nn_moe=use_nn_moe)
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment