test_onednn.py 5.99 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

from tests.kernels.utils import to_int8
from vllm import _custom_ops as ops
from vllm.platforms import current_platform

if not current_platform.is_cpu():
    pytest.skip("skipping CPU-only tests", allow_module_level=True)

NK_FACTORS = [
    (256, 128),
    (4096, 4096),
    (16384, 4096),
    (1023, 491),
    (1001, 15),
]
M_FACTORS = [
    (16, 1, 32, 128, 64),
    (1, 17, 1, 31, 17),
]
CACHE_SIZES = [2]
DTYPE = [torch.bfloat16]


def rand_int8(shape: tuple, device: str = "cpu"):
    return to_int8(torch.rand(shape, device=device) * 255 - 128)


def ref_int8_scaled_mm(
    a: torch.Tensor,
    b: torch.Tensor,
    scale_a: torch.Tensor,
    scale_b: torch.Tensor,
38
39
    azp: torch.Tensor | None,
    bias: torch.Tensor | None,
40
41
42
43
    output_type: torch.dtype,
):
    if azp is not None:
        a = a.to(dtype=torch.float32) - azp.to(dtype=torch.float32)
44
45
46
    output = torch.mm(
        (scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32))
    )
47
48
49
50
51
52
    if bias is not None:
        output += bias.float()

    return output.to(dtype=output_type)


53
54
55
56
57
58
59
60
61
62
63
64
def onednn_int8_gemm_test_helper(
    primitive_cache_size: int,
    m: int,
    n: int,
    k: int,
    per_tensor_a_quant: bool,
    per_tensor_b_quant: bool,
    use_azp: bool,
    use_bias: bool,
    out_dtype: torch.dtype = torch.bfloat16,
    device: str = "cpu",
):
65
66
67
68
69
70
71
72
    # Test for a oneDNN kernel with per-tensor / per-token activation
    # quantization and per-tensor / per-output channel weight quantization.
    a = to_int8(torch.randn((m, k), device=device) * 5)
    b = to_int8(torch.randn((n, k), device=device).t() * 5)

    a_scales_shape = (1, 1) if per_tensor_a_quant else (m, 1)
    b_scales_shape = (1, 1) if per_tensor_b_quant else (1, n)

73
74
    scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32)
    scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32)
75
76
77
78
79
80
81
82
83

    if use_azp:
        azp = torch.rand(a_scales_shape, dtype=torch.float32) * 10 + 1.5
        azp = (azp / scale_a).round().to(dtype=torch.int32)
        azp_adj = scale_b * b.sum(dim=0, keepdim=True, dtype=torch.float32)
    else:
        azp = None
        azp_adj = None

84
    bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

    handler = ops.create_onednn_scaled_mm(
        b,
        scale_b,
        out_dtype,
        not per_tensor_a_quant,
        use_azp,
        primitive_cache_size,
    )

    out = torch.zeros((m, n), dtype=out_dtype)
    ops.onednn_scaled_mm(handler, a, out, scale_a, azp, azp_adj, bias)
    baseline = ref_int8_scaled_mm(a, b, scale_a, scale_b, azp, bias, out_dtype)

    torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)

    if use_bias:
        # To test runtime bias setting
        out = torch.zeros((m, n), dtype=out_dtype)
        ops.onednn_scaled_mm(handler, a, out, scale_a, azp, azp_adj, None)
105
        baseline = ref_int8_scaled_mm(a, b, scale_a, scale_b, azp, None, out_dtype)
106
107
108
109

        torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)


110
111
112
113
114
115
116
117
118
119
def onednn_gemm_test_helper(
    primitive_cache_size: int,
    m: int,
    n: int,
    k: int,
    use_bias: bool,
    use_stride: bool,
    dtype: torch.dtype = torch.bfloat16,
    device: str = "cpu",
):
120
121
122
123
124
125
126
127
128
    if use_stride:
        a = torch.rand((m, 2 * k), dtype=dtype, device=device) * 1.5
        a = a[:, :k]
    else:
        a = torch.rand((m, k), dtype=dtype, device=device) * 1.5

    b = torch.rand((n, k), dtype=dtype, device=device) * 1.5

    if use_bias:
129
        bias = torch.rand((n,), device=device, dtype=dtype) * 5
130
131
132
133
134
135
136
137
138
139
140
        bias_f32 = bias.float()
    else:
        bias = None
        bias_f32 = None

    handler = ops.create_onednn_mm(
        b.t(),
        primitive_cache_size,
    )

    out = ops.onednn_mm(handler, a, bias)
141
142
143
    baseline = torch.nn.functional.linear(a.float(), b.float(), bias_f32).to(
        dtype=a.dtype
    )
144
145
146
147
148
149

    torch.testing.assert_close(out, baseline)

    if use_bias:
        # To test runtime bias setting
        out = ops.onednn_mm(handler, a, None)
150
151
152
        baseline = torch.nn.functional.linear(a.float(), b.float(), None).to(
            dtype=a.dtype
        )
153
154
155
156

        torch.testing.assert_close(out, baseline)


157
158
159
160
161
162
163
164
165
166
167
@pytest.mark.parametrize("n,k", NK_FACTORS)
@pytest.mark.parametrize("m_list", M_FACTORS)
@pytest.mark.parametrize("per_tensor_a_scale", [True, False])
@pytest.mark.parametrize("per_tensor_b_scale", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("use_azp", [True, False])
@pytest.mark.parametrize("output_type", DTYPE)
@pytest.mark.parametrize("primitive_cache_size", CACHE_SIZES)
def test_onednn_int8_scaled_gemm(
    n: int,
    k: int,
168
    m_list: tuple[int, ...],
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    per_tensor_a_scale: bool,
    per_tensor_b_scale: bool,
    use_bias: bool,
    use_azp: bool,
    output_type: torch.dtype,
    primitive_cache_size: int,
):
    for m in m_list:
        onednn_int8_gemm_test_helper(
            primitive_cache_size=primitive_cache_size,
            m=m,
            n=n,
            k=k,
            per_tensor_a_quant=per_tensor_a_scale,
            per_tensor_b_quant=per_tensor_b_scale,
            use_bias=use_bias,
            use_azp=use_azp,
            out_dtype=output_type,
        )
188
189
190
191
192
193
194
195
196
197
198


@pytest.mark.parametrize("n,k", NK_FACTORS)
@pytest.mark.parametrize("m_list", M_FACTORS)
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("use_stride", [True, False])
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("primitive_cache_size", CACHE_SIZES)
def test_onednn_gemm(
    n: int,
    k: int,
199
    m_list: tuple[int, ...],
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    use_bias: bool,
    use_stride: bool,
    dtype: torch.dtype,
    primitive_cache_size: int,
):
    for m in m_list:
        onednn_gemm_test_helper(
            primitive_cache_size=primitive_cache_size,
            m=m,
            n=n,
            k=k,
            use_bias=use_bias,
            use_stride=use_stride,
            dtype=dtype,
        )