test_awq_dequant.py 3.17 KB
Newer Older
1
2
3
4
5
6
7
8
import itertools
from typing import Optional, Tuple

import pytest
import torch
from sgl_kernel import awq_dequantize


AniZpZ's avatar
AniZpZ committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def reverse_awq_order(t: torch.Tensor):
    bits = 4
    AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
    reverse_order_tensor = torch.arange(
        t.shape[-1],
        dtype=torch.int32,
        device=t.device,
    )
    reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
    reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
    reverse_order_tensor = reverse_order_tensor.view(-1)

    t = t[:, reverse_order_tensor] & 0xF
    return t


# qweights - [R     , C // 8], int32
# scales   - [R // G, C     ], float16
# zeros    - [R // G, C // 8], int32
def awq_dequantize_torch(
    qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, group_size: int
) -> torch.Tensor:

    if group_size == -1:
        group_size = qweight.shape[0]

    bits = 4
    shifts = torch.arange(0, 32, bits, device=qzeros.device)

    iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
        torch.int8
    )

    iweights = iweights.view(iweights.shape[0], -1)

    zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(
        torch.int8
    )
    zeros = zeros.view(qzeros.shape[0], -1)
    zeros = reverse_awq_order(zeros)

    iweights = reverse_awq_order(iweights)

    iweights = torch.bitwise_and(iweights, (2**bits) - 1)
    zeros = torch.bitwise_and(zeros, (2**bits) - 1)

    scales = scales.repeat_interleave(group_size, dim=0)
    zeros = zeros.repeat_interleave(group_size, dim=0)
    return (iweights - zeros) * scales


60
61
62
63
64
65
66
def sglang_awq_dequantize(
    qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> torch.Tensor:
    return awq_dequantize(qweight, scales, qzeros)


@pytest.mark.parametrize(
AniZpZ's avatar
AniZpZ committed
67
    "qweight_row,qweight_col,is_bf16_act",
68
69
    list(
        itertools.product(
70
71
            [3584, 18944, 128, 256, 512, 1024, 1536],
            [448, 576, 4736, 16, 32, 64, 128, 72],
AniZpZ's avatar
AniZpZ committed
72
            [True, False],
73
74
75
76
        )
    ),
)
def test_awq_dequant_compare_implementations(
AniZpZ's avatar
AniZpZ committed
77
    qweight_row: int, qweight_col: int, is_bf16_act: bool
78
79
80
81
82
83
84
85
86
87
88
89
):
    device = torch.device("cuda")
    qweight = torch.randint(
        0,
        torch.iinfo(torch.int32).max,
        (qweight_row, qweight_col),
        dtype=torch.int32,
        device=device,
    )
    group_size = qweight_row
    scales_row = qweight_row // group_size
    scales_col = qweight_col * 8
AniZpZ's avatar
AniZpZ committed
90
91
92
93
94
95

    if is_bf16_act:
        scales = torch.rand(scales_row, scales_col, dtype=torch.bfloat16, device=device)
    else:
        scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)

96
97
98
99
100
101
102
103
104
    qzeros = torch.randint(
        0,
        torch.iinfo(torch.int32).max,
        (scales_row, qweight_col),
        dtype=torch.int32,
        device=device,
    )

    # Run both implementations
AniZpZ's avatar
AniZpZ committed
105
    torch_out = awq_dequantize_torch(qweight, scales, qzeros, group_size)
106
107
108
109
    sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)

    # Compare results
    torch.testing.assert_close(
AniZpZ's avatar
AniZpZ committed
110
        torch_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
111
112
113
114
115
    )


if __name__ == "__main__":
    pytest.main([__file__])