test_awq_dequant.py 1.72 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import itertools
from typing import Optional, Tuple

import pytest
import torch
from sgl_kernel import awq_dequantize
from vllm import _custom_ops as ops


def vllm_awq_dequantize(
    qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> torch.Tensor:
    return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)


def sglang_awq_dequantize(
    qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> torch.Tensor:
    return awq_dequantize(qweight, scales, qzeros)


@pytest.mark.parametrize(
    "qweight_row,qweight_col",
    list(
        itertools.product(
            [3584, 18944, 128, 256, 512, 1024], [448, 576, 4736, 16, 32, 64, 128]
        )
    ),
)
def test_awq_dequant_compare_implementations(
    qweight_row: int,
    qweight_col: int,
):
    device = torch.device("cuda")

    qweight = torch.randint(
        0,
        torch.iinfo(torch.int32).max,
        (qweight_row, qweight_col),
        dtype=torch.int32,
        device=device,
    )
    group_size = qweight_row
    scales_row = qweight_row // group_size
    scales_col = qweight_col * 8
    scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
    qzeros = torch.randint(
        0,
        torch.iinfo(torch.int32).max,
        (scales_row, qweight_col),
        dtype=torch.int32,
        device=device,
    )

    # Run both implementations
    vllm_out = vllm_awq_dequantize(qweight, scales, qzeros)
    sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)

    # Compare results
    torch.testing.assert_close(
        vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
    )


if __name__ == "__main__":
    # Run the specific test function directly
    pytest.main([__file__])