test_per_token_group_quant_8bit.py 2.85 KB
Newer Older
1
2
3
4
5
import itertools

import pytest
import torch

6
from sglang.srt.layers.quantization import deep_gemm_wrapper
7
8
9
10
from sglang.srt.layers.quantization.fp8_kernel import (
    per_token_group_quant_8bit as triton_per_token_group_quant_8bit,
)
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
11
12
from sglang.srt.layers.quantization.utils import assert_fp8_all_close
from sglang.srt.utils import is_hip
13

14
15
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
16
17
18


@pytest.mark.parametrize(
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    "num_tokens, hidden_dim, group_size, dst_dtype, flags",
    list(
        itertools.product(
            [127, 128, 512, 1024, 4096, 8192],  # num_tokens
            [256, 512, 1024, 2048, 4096],  # hidden_dim
            [8, 16, 32, 64, 128],  # group_size
            # TODO test int8
            [fp8_type_],  # dtype
            [
                dict(
                    column_major_scales=False,
                    scale_tma_aligned=False,
                    scale_ue8m0=False,
                ),
                dict(
                    column_major_scales=True,
                    scale_tma_aligned=False,
                    scale_ue8m0=False,
                ),
                dict(
                    column_major_scales=True,
                    scale_tma_aligned=True,
                    scale_ue8m0=False,
                ),
                dict(
                    column_major_scales=True,
                    scale_tma_aligned=True,
                    scale_ue8m0=True,
                ),
            ],
        )
    ),
51
)
52
53
54
55
56
def test_per_token_group_quant_with_column_major(
    num_tokens,
    hidden_dim,
    group_size,
    dst_dtype,
57
    flags,
58
):
59
    if flags["scale_ue8m0"] and ((group_size != 128) or (hidden_dim % 512 != 0)):
60
        pytest.skip()
61
        return
62
63
64
    if flags["scale_ue8m0"] and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL:
        pytest.skip("scale_ue8m0 only supported on Blackwell")
        return
65

66
    x = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16)
67

68
69
70
    execute_kwargs = dict(
        x=x,
        group_size=group_size,
71
        eps=1e-10,
72
        dst_dtype=dst_dtype,
73
        **flags,
74
75
    )

76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(**execute_kwargs)
    x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(**execute_kwargs)

    # torch.set_printoptions(profile="full")
    # print(f"{x_q_triton=}")
    # print(f"{x_s_triton=}")
    # print(f"{x_q_sglang=}")
    # print(f"{x_s_sglang=}")
    # torch.set_printoptions(profile="default")

    assert_fp8_all_close(x_q_triton, x_q_sglang)
    torch.testing.assert_close(
        x_s_triton.contiguous(),
        x_s_sglang.contiguous(),
        rtol=1e-3,
        atol=1e-5,
        msg=lambda message: message + f" {x_s_triton=} {x_s_sglang=}",
93
    )
94
95
96
97


if __name__ == "__main__":
    pytest.main([__file__])