test_allspark_gemm.py 3.52 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
import pytest
import torch

from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
9
10
11
12
13
    ALLSPARK_AMPERE_K_ALIGN,
    ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
    ALLSPARK_AMPERE_N_ALIGN,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights
14
15
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
16
from vllm.utils.platform_utils import num_compute_units
17
18


19
def is_gptq_allspark_supported(min_capability: int, max_capability: int) -> bool:
20
21
22
23
24
25
    if not current_platform.is_cuda():
        return False

    capability = current_platform.get_device_capability()
    assert capability is not None

26
27
28
    return (
        capability.to_int() >= min_capability and capability.to_int() <= max_capability
    )
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47


MNK_FACTORS = [
    (1, 4, 8),
    (13, 17, 67),
    (26, 37, 13),
    (48, 16, 24),
    (67, 13, 88),
    (257, 13, 11),
    (658, 13, 11),
    (1033, 9, 17),
]

DTYPES = [torch.float16, torch.bfloat16]
HAS_ZP_OPTS = [False, True]


def compute_max_diff(output, output_ref):
    return torch.mean(torch.abs(output - output_ref)) / torch.mean(
48
49
        torch.abs(output_ref)
    )
50
51
52
53
54
55
56
57


def rand_data(shape, dtype=torch.float16):
    return torch.randn(shape, dtype=dtype, device="cuda")


@pytest.mark.skipif(
    not is_gptq_allspark_supported(80, 89),
58
59
    reason="AllSpark Ampere kernel is not supported on this GPU type.",
)
60
61
62
63
64
65
66
67
68
69
70
71
72
73
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("group_size", [-1])
@pytest.mark.parametrize("has_zp", HAS_ZP_OPTS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype):
    m_factor, n_factor, k_factor = mnk_factors
    m = m_factor
    n = n_factor * ALLSPARK_AMPERE_N_ALIGN
    k = k_factor * ALLSPARK_AMPERE_K_ALIGN

    input = rand_data((m, k), dtype=dtype)
    weight = rand_data((k, n), dtype=dtype)

    # Quantize (and apply act_order if provided)
74
75
76
    w_ref, qw, s, zp = quantize_weights(
        weight, scalar_types.uint8b128, group_size, has_zp
    )
77
78
79
80
81

    qw = qw.to(torch.uint8)
    if has_zp:
        zp = zp.to(dtype)
    properties = torch.cuda.get_device_properties(qw.device.index)
82
    sm_count = num_compute_units(qw.device.index)
83
84
85
86
    sm_version = properties.major * 10 + properties.minor

    n_32align = (n + 32 - 1) // 32 * 32

87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(qw, s, zp, has_zp)
    opcheck(
        torch.ops._C.rearrange_kn_weight_as_n32k16_order,
        (qw, s, zp, has_zp, qw_reorder, s_reorder, zp_reorder, k, n, n_32align),
    )

    opcheck(
        torch.ops._C.allspark_w8a16_gemm,
        (
            input,
            qw_reorder,
            s_reorder,
            zp_reorder,
            n,
            group_size,
            sm_count,
            sm_version,
            ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
            has_zp,
            True,
        ),
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
    )
    output = ops.allspark_w8a16_gemm(
        input,
        qw_reorder,
        s_reorder,
        zp_reorder,
        n,
        group_size,
        sm_count,
        sm_version,
        ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
        has_zp,
        True,
    )
123
124
125
126
127
128

    output_ref = torch.matmul(input, w_ref)
    torch.cuda.synchronize()
    max_diff = compute_max_diff(output, output_ref)

    assert max_diff < 0.04