untest_triton_scaled_mm.py 4.24 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Tests for the triton_scaled_mm kernel

Run `pytest tests/kernels/test_triton_scaled_mm.py`.
"""
zhuwenwen's avatar
zhuwenwen committed
7
import os
8
import importlib
9
from typing import Optional
10
11
12
13
14

import pytest
import torch

from vllm.platforms import current_platform
zhuwenwen's avatar
zhuwenwen committed
15
from ...utils import models_path_prefix
16
17
18

device = "cuda"

19
20
21
22
triton_scaled_mm_module = importlib.import_module(
    "vllm.model_executor.layers.quantization.compressed_tensors."
    "triton_scaled_mm")
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
23

24
25

def torch_scaled_mm(a: torch.Tensor,
26
27
28
                    b: torch.Tensor,
                    scale_a: torch.Tensor,
                    scale_b: torch.Tensor,
29
                    out_dtype: type[torch.dtype],
30
31
32
33
34
35
36
37
38
39
40
41
42
                    bias: Optional[torch.Tensor] = None) -> torch.Tensor:
    out = torch.mm(a.to(torch.float32), b.to(torch.float32))
    out = scale_a * out
    out = scale_b.T * out
    out = out.to(out_dtype)
    if bias is not None:
        out = out + bias

    return out


def get_8bit_types():
    types = [torch.int8]
43
44
    if current_platform.supports_fp8():
        types.append(current_platform.fp8_dtype())
45
46
47
    return types


48
49
# This test is to check regressions for int8 support on ROCm.
@pytest.mark.parametrize("model_path", [
zhuwenwen's avatar
zhuwenwen committed
50
    os.path.join(models_path_prefix, "neuralmagic/Llama-3.2-1B-quantized.w8a8"),
51
52
53
54
55
56
57
58
59
60
61
62
63
64
])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
@pytest.mark.skipif(not current_platform.is_rocm(),
                    reason="Should only run on ROCm")
def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path,
                                      max_tokens, num_logprobs):
    dtype = "bfloat16"

    with vllm_runner(model_path, dtype=dtype) as vllm_model:
        vllm_model.generate_greedy_logprobs(example_prompts, max_tokens,
                                            num_logprobs)


65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
@pytest.mark.parametrize("M", [1, 33, 64, 512])
@pytest.mark.parametrize("N", [256, 971, 20486])
@pytest.mark.parametrize("K", [128, 496, 1024])
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("in_dtype", get_8bit_types())
@pytest.mark.parametrize("use_scalar_scale_a", [True, False])
@pytest.mark.parametrize("use_scalar_scale_b", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a,
                   use_scalar_scale_b, use_bias):
    is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t
                                                    ).is_floating_point()

    current_platform.seed_everything(0)

    # NOTE: There are cases, where if the matrix is large enough, an output
    # like 65504.4 can be produced, and can easily turn into inf when
    # multiplied when using float16/bfloat16.  This means one function, e.g.,
    # testing function, and another function, e.g. golden function, can
    # produce a non-inf value while the other produces an inf value, and
    # will cause assert_close/allclose to fail, even though if overflow
    # wouldn't have occurred, the values would have been "close."
    #
    # So, the values here are kept small enough to avoid this situation.
    if is_floating_point_type(in_dtype):
        a = (0.25 * torch.rand(
            (M, K), dtype=torch.float32, device=device)).to(in_dtype)
        b = (0.25 * torch.rand(
            (K, N), dtype=torch.float32, device=device)).to(in_dtype)
    else:
        a = torch.randint(-32, 32, (M, K), dtype=in_dtype, device=device)
        b = torch.randint(-32, 32, (K, N), dtype=in_dtype, device=device)

    if use_scalar_scale_a:
        scale_a = torch.rand((1, 1), device=device)
    else:
        scale_a = 0.25 * torch.rand((M, 1), device=device)

    if use_scalar_scale_b:
        scale_b = torch.rand((1, 1), device=device)
    else:
        scale_b = 0.25 * torch.rand((N, 1), device=device)

    bias = None
    if use_bias:
        bias = torch.rand((N, ), device=device, dtype=out_dtype)

    c_check = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)

114
    c_actual = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
115

116
    torch.testing.assert_close(c_check, c_actual, rtol=1e-1, atol=1e-1)