test_int8_kernel.py 5.08 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10

# Adapted from https://github.com/sgl-project/sglang/blob/main/test/srt/test_int8_kernel.py
import itertools

import pytest
import torch

from vllm.model_executor.layers.activation import SiluAndMul
11
12
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
13

14
from vllm.platforms import current_platform
15
16
17
18
19
20
21
if not current_platform.is_rocm():
    from vllm.model_executor.layers.quantization.utils.int8_utils import (
        per_token_quant_int8,
    )
else:
    from lmslim.layers.gemm.int8_utils import (
        per_token_quant_int8)
22
23

if current_platform.get_device_capability() < (7, 0):
24
    pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
25
26
27
28
29
30
31
32
33


def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
    """Matrix multiplication function that supports per-token input
    quantization and per-column weight quantization"""
    A = A.to(torch.float32)
    B = B.to(torch.float32)

    assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
34
    assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
35
36
37
38
39

    # Reshape input
    M = A.numel() // A.shape[-1]
    B = B.t()  # Transpose weight matrix
    N, K = B.shape
40
    origin_C_shape = A.shape[:-1] + (K,)
41
42
43
44
45
46
47
48
49
    A = A.reshape(M, N)

    # As is per-token [M, 1], Bs is per-column [1, K]
    C = torch.matmul(A, B)  # [M, K]
    C = As * C * Bs.view(1, -1)  # Broadcast per-column scale

    return C.reshape(origin_C_shape).to(output_dtype)


50
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight, topk_ids):
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    """This function performs fused moe with per-column int8 quantization
    using native torch."""

    B, D = a.shape
    # Perform per-token quantization
    a_q, a_s = per_token_quant_int8(a)
    # Repeat tokens to match topk
    a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
    # Also repeat the scale
    a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1)  # [B*topk, 1]

    out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)

    # Calculate routing
    topk_weight = topk_weight.view(-1)
    topk_ids = topk_ids.view(-1)
    # Process each expert
    for i in range(w1.shape[0]):
        mask = topk_ids == i
        if mask.sum():
            # First MLP layer: note that a_s is now per-token
72
73
74
            inter_out = native_w8a8_per_token_matmul(
                a_q[mask], w1[i], a_s[mask], w1_s[i], output_dtype=a.dtype
            )
75
76
77
78
79
80
            # Activation function
            act_out = SiluAndMul().forward_native(inter_out)
            # Quantize activation output with per-token
            act_out_q, act_out_s = per_token_quant_int8(act_out)

            # Second MLP layer
81
82
83
            out[mask] = native_w8a8_per_token_matmul(
                act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype
            )
84
    # Apply routing weights and sum
85
86
87
    return (
        out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
    ).sum(dim=1)
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104


@pytest.fixture(autouse=True, scope="module")
def setup_cuda():
    """Sets the default CUDA device for all tests in this module."""
    torch.set_default_device("cuda")


DTYPES = [torch.half, torch.bfloat16]
M = [1, 33]
N = [128, 1024]
K = [256, 4096]
E = [8]
TOP_KS = [2, 6]
SEEDS = [0]


105
106
107
108
@pytest.mark.parametrize(
    "M, N, K, E, topk, dtype, seed",
    itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS),
)
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
@torch.inference_mode()
def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
    torch.manual_seed(seed)
    # Initialize int8 quantization parameters
    factor_for_scale = 1e-2
    int8_max = 127
    int8_min = -128

    # Input tensor
    # M * K
    a = torch.randn((M, K), dtype=dtype) / 10

    # Generate int8 weights
    w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2
    w1 = (w1_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)

    w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2
    w2 = (w2_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)

    # Generate scale for each column (per-column quantization)
    w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale
    w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
    score = torch.randn((M, E), dtype=dtype)
132
133
134
    score = torch.softmax(score, dim=-1, dtype=torch.float32)
    topk_weights, topk_ids = torch.topk(score, topk)

135
136
137
    ref_out = torch_w8a8_per_column_moe(
        a, w1, w2, w1_s, w2_s, topk, topk_weights, topk_ids
    )
138

139
140
141
142
143
144
145
146
147
    quant_config = FusedMoEQuantConfig.make(
        torch.int8,
        per_act_token_quant=True,
        block_shape=None,
        w1_scale=w1_s,
        w2_scale=w2_s,
    )

    out = fused_experts(
148
149
150
        a,
        w1,
        w2,
151
152
153
        topk_weights,
        topk_ids,
        quant_config=quant_config,
154
155
156
    )

    # Check results
157
158
159
    rel_diff = torch.mean(
        torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
    ) / torch.mean(torch.abs(ref_out.to(torch.float32)))
160
    assert rel_diff < 0.05