test_batched_deepgemm.py 3.87 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

7
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
8
9
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.experts.batched_deep_gemm_moe import (
10
11
    BatchedDeepGemmExperts,
)
12
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
13
14
    BatchedTritonExperts,
)
15
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
16
17
18
from vllm.model_executor.layers.fused_moe.prepare_finalize.batched import (
    BatchedPrepareAndFinalize,
)
19
20
21
from vllm.utils.deep_gemm import calc_diff, is_deep_gemm_supported

from .test_deepgemm import make_block_quant_fp8_weights
22
from .utils import make_dummy_moe_config
23
24
25
26

BLOCK_SIZE = [128, 128]


27
@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels")
28
29
30
31
32
@pytest.mark.parametrize("E", [16, 32])  # number of experts
@pytest.mark.parametrize("T", [256, 512])  # tokens per expert
@pytest.mark.parametrize("K", [128, 256])  # hidden dim
@pytest.mark.parametrize("N", [512, 1024])  # intermediate dim per expert
@pytest.mark.parametrize("topk", [2, 4])
33
def test_batched_deepgemm_vs_triton(
34
    E: int, T: int, K: int, N: int, topk: int, monkeypatch, workspace_init
35
):
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    """Compare BatchedDeepGemmExperts to BatchedTritonExperts."""

    monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1")

    device = "cuda"
    w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(E, N, K, BLOCK_SIZE)

    M = E * T  # total tokens
    a = torch.randn(M, K, device=device, dtype=torch.bfloat16) / 10.0
    fp8_info = torch.finfo(torch.float8_e4m3fn)
    a.clamp_(fp8_info.min, fp8_info.max)

    # random router outputs → top-k indices / weights
    router_logits = torch.randn(M, E, device=device, dtype=torch.float32)
    topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
    topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)

    # token number for each expert
    cnt = torch.bincount(topk_ids.flatten(), minlength=E)
    max_cnt = int(cnt.max().item())
    # next power of 2 for max token number
    max_num_tokens = 1 << (max_cnt - 1).bit_length()

    prep_finalize = BatchedPrepareAndFinalize(
        max_num_tokens=max_num_tokens,
        num_local_experts=E,
        num_dispatchers=1,
        rank=0,
    )

66
67
68
69
70
71
72
    quant_config = fp8_w8a8_moe_quant_config(
        w1_scale=w1_s,
        w2_scale=w2_s,
        per_act_token_quant=False,
        block_shape=BLOCK_SIZE,
    )

73
74
75
76
    # triton (reference)
    triton_experts = BatchedTritonExperts(
        max_num_tokens=max_num_tokens,
        num_dispatchers=1,
77
        quant_config=quant_config,
78
        moe_config=make_dummy_moe_config(),
79
    )
80
    mk_triton = FusedMoEKernel(
81
82
83
84
        prep_finalize,
        triton_experts,
        inplace=False,
    )
85

86
    out_triton = mk_triton.apply(
87
88
89
90
91
        hidden_states=a,
        w1=w1,
        w2=w2,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
92
        activation=MoEActivation.SILU,
93
        global_num_experts=E,
94
95
        expert_map=None,
        apply_router_weight_on_input=False,
96
97
98
99
100
101
    )

    # deepgemm
    deepgemm_experts = BatchedDeepGemmExperts(
        max_num_tokens=max_num_tokens,
        num_dispatchers=1,
102
        quant_config=quant_config,
103
        moe_config=make_dummy_moe_config(),
104
    )
105
    mk_deepgemm = FusedMoEKernel(
106
107
108
109
        prep_finalize,
        deepgemm_experts,
        inplace=False,
    )
110

111
    out_deepgemm = mk_deepgemm.apply(
112
113
114
115
116
        hidden_states=a,
        w1=w1,
        w2=w2,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
117
        activation=MoEActivation.SILU,
118
        global_num_experts=E,
119
120
        expert_map=None,
        apply_router_weight_on_input=False,
121
122
123
124
    )

    diff = calc_diff(out_deepgemm, out_triton)
    assert diff < 1e-3, f"Output diff too large: {diff}"