Unverified Commit a4d5d663 authored by Runkai Tao's avatar Runkai Tao Committed by GitHub
Browse files

Add unpermute-aware fused MoE path and small-batch fallback (#29354)


Signed-off-by: default avatarRunkai Tao <rt572@physics.rutgers.edu>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 657e9c0e
......@@ -64,8 +64,10 @@ from vllm.utils.torch_utils import set_random_seed
from vllm.v1.worker.workspace import init_workspace_manager
NUM_EXPERTS = [8, 64, 192]
NUM_EXPERTS_LARGE = [128, 256]
EP_SIZE = [1, 4]
TOP_KS = [2, 6]
TOP_KS_SMALL = [1, 2]
MOE_MARLIN_QUANT_TEST_CONFIGS = [
# AWQ-INT4
......@@ -133,6 +135,13 @@ FUSED_MOE_MNK_FACTORS = [
(40000, 1024, 1024),
]
FUSED_MOE_MNK_FACTORS_SMALL_M = [
(1, 128, 128),
(1, 2048, 128),
(2, 2048, 128),
(2, 2048, 511),
]
FUSED_MOE_WN16_MNK_FACTORS = [
(1, 128, 128),
(1, 1024, 1024),
......@@ -330,6 +339,111 @@ def test_fused_moe(
)
@pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS_SMALL_M)
@pytest.mark.parametrize("e", NUM_EXPERTS_LARGE)
@pytest.mark.parametrize("topk", TOP_KS_SMALL)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False])
@pytest.mark.parametrize("chunk_size", [8192])
def test_naive_block_assignment_moe(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
padding: bool,
chunk_size: int,
monkeypatch,
workspace_init,
):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
#
# Setup test data
#
#
# Setup test data
#
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
e_map = None
#
# Setup test functions
#
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
m_fused_moe_fn = modular_triton_fused_moe(quant_config)
def m_fused_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
) -> torch.Tensor:
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
return m_fused_moe_fn(
a,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
fused_moe_fn = functools.partial(fused_moe, renormalize=False)
#
# Run tests
#
runner = functools.partial(
run_moe_test,
a=a,
w1=w1,
w2=w2,
score=score,
topk=topk,
global_num_experts=e,
expert_map=e_map,
padding=padding,
)
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
use_compile = False
use_cudagraph = n >= 1024 and k >= 1024 and current_platform.is_cuda_alike()
with set_current_vllm_config(vllm_config):
baseline_output = runner(torch_moe, iterative_moe)
runner(
baseline_output,
fused_moe_fn,
use_compile=use_compile,
use_cudagraph=use_cudagraph,
)
runner(
baseline_output,
m_fused_moe,
use_compile=use_compile,
use_cudagraph=use_cudagraph,
)
@pytest.mark.parametrize("m,n,k", FUSED_MOE_WN16_MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
......
......@@ -351,6 +351,7 @@ def fused_moe_kernel(
# Block size for block-wise quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
naive_block_assignment: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
......@@ -386,6 +387,9 @@ def fused_moe_kernel(
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
- naive_block_assignment: A boolean flag indicating whether to use naive
token wise block assignment. If True, each block corresponds to a
single token.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
......@@ -411,11 +415,20 @@ def fused_moe_kernel(
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
offs = tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
if not naive_block_assignment:
offs_token_id = pid_m * BLOCK_SIZE_M + offs
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
else:
offs_token = tl.where(
offs == 0,
pid_m, # first element = pid_m
num_valid_tokens, # remaining elements = constant
)
token_mask = offs_token < num_valid_tokens
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
......@@ -557,7 +570,7 @@ def invoke_fused_moe_wna16_cuda_kernel(
B_scale: torch.Tensor | None,
B_zp: torch.Tensor | None,
topk_weights: torch.Tensor | None,
sorted_token_ids: torch.Tensor,
sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool,
......@@ -705,7 +718,7 @@ def invoke_fused_moe_triton_kernel(
A_scale: torch.Tensor | None,
B_scale: torch.Tensor | None,
topk_weights: torch.Tensor | None,
sorted_token_ids: torch.Tensor,
sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool,
......@@ -722,7 +735,7 @@ def invoke_fused_moe_triton_kernel(
):
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1
if use_fp8_w8a8 or use_int8_w8a8:
assert B_scale is not None
......@@ -741,14 +754,18 @@ def invoke_fused_moe_triton_kernel(
M = A.size(0)
num_tokens = M * top_k
EM = sorted_token_ids.size(0)
if A.size(0) < config["BLOCK_SIZE_M"]:
# optimize for small batch_size.
# We assume that top_ids of each token is unique,
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
# and we can skip some invalid blocks.
EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
if sorted_token_ids is not None:
EM = sorted_token_ids.size(0)
if A.size(0) < config["BLOCK_SIZE_M"]:
# optimize for small batch_size.
# We assume that top_ids of each token is unique,
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
# and we can skip some invalid blocks.
EM = min(
sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]
)
else:
EM = num_tokens * config["BLOCK_SIZE_M"]
grid = lambda META: (
triton.cdiv(EM, META["BLOCK_SIZE_M"])
* triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
......@@ -798,6 +815,7 @@ def invoke_fused_moe_triton_kernel(
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
per_channel_quant=per_channel_quant,
naive_block_assignment=(sorted_token_ids is None),
HAS_BIAS=HAS_BIAS,
BLOCK_SIZE_K=BLOCK_SIZE_K,
**config,
......@@ -812,7 +830,7 @@ def dispatch_fused_moe_kernel(
B_scale: torch.Tensor | None,
B_zp: torch.Tensor | None,
topk_weights: torch.Tensor | None,
sorted_token_ids: torch.Tensor,
sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool,
......@@ -829,7 +847,7 @@ def dispatch_fused_moe_kernel(
) -> None:
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1
M = A.size(0)
num_tokens = M * top_k
......@@ -2165,14 +2183,37 @@ def fused_experts_impl(
block_shape=block_shape,
)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
curr_topk_ids,
config["BLOCK_SIZE_M"],
global_num_experts,
expert_map,
ignore_invalid_experts=True,
# SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k
# activates only a small fraction of total experts
SPARSITY_FACTOR = 4
# block quantized code path is not implemented yet.
naive_block_assignment = (
expert_map is None
and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts
and not (
(use_int8_w8a16 or use_int4_w4a16)
and block_shape is not None
and block_shape[1] > 0
)
)
if not naive_block_assignment:
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
curr_topk_ids,
config["BLOCK_SIZE_M"],
global_num_experts,
expert_map,
ignore_invalid_experts=True,
)
else:
max_num_tokens_padded = topk_ids.numel() * config["BLOCK_SIZE_M"]
expert_ids = curr_topk_ids.view(-1)
num_tokens_post_padded = torch.empty(
(1), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_padded.fill_(max_num_tokens_padded)
sorted_token_ids = None
dispatch_fused_moe_kernel(
qcurr_hidden_states,
w1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment