test_awq.py 1.76 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import pytest
5
6
7
8
9
10
import torch

from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops  # noqa: F401


11
12
13
14
@pytest.mark.skipif(
    not hasattr(torch.ops._C, "awq_dequantize"),
    reason="AWQ is not supported on this GPU type.",
)
15
16
17
def test_awq_dequantize_opcheck(monkeypatch: pytest.MonkeyPatch):
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_TRITON_AWQ", "0")
18
19
20
21
22
        qweight = torch.randint(
            -2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32
        )
        scales = torch.rand((64, 2048), device="cuda", dtype=torch.float16)
        zeros = torch.empty((64, 256), device="cuda", dtype=torch.int32)
23
24
25
        split_k_iters = 0
        thx = 0
        thy = 0
26
27
28
29
        opcheck(
            torch.ops._C.awq_dequantize,
            (qweight, scales, zeros, split_k_iters, thx, thy),
        )
30
31


32
@pytest.mark.skip(reason="Not working; needs investigation.")
33
34
35
36
@pytest.mark.skipif(
    not hasattr(torch.ops._C, "awq_gemm"),
    reason="AWQ is not supported on this GPU type.",
)
37
38
39
def test_awq_gemm_opcheck(monkeypatch: pytest.MonkeyPatch):
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_TRITON_AWQ", "0")
40
41
42
43
44
45
46
47
        input = torch.rand((2, 8192), device="cuda", dtype=torch.float16)
        qweight = torch.randint(
            -2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32
        )
        scales = torch.randint(
            -2000000000, 2000000000, (64, 256), device="cuda", dtype=torch.int32
        )
        qzeros = torch.empty((64, 2048), device="cuda", dtype=torch.float16)
48
        split_k_iters = 8
49
        opcheck(torch.ops._C.awq_gemm, (input, qweight, qzeros, scales, split_k_iters))