test_fp8_quant_group.py 5.83 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for QuantFP8 Group Quantization implementation."""

import pytest
import torch

from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
9
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
10
from vllm.utils.torch_utils import set_random_seed
11
12
13
14
15
16
17
18
19


@pytest.mark.parametrize(
    "batch_size,hidden_dim,group_size",
    [
        (16, 256, 32),  # Small
        (64, 1024, 64),  # Medium
        (128, 2048, 128),  # Large
        (8, 513, 64),  # Non-divisible (native only)
20
21
    ],
)
22
@pytest.mark.parametrize("seed", [42])
23
@pytest.mark.parametrize("use_ue8m0", [True, False])
24
@torch.inference_mode()
25
26
27
def test_quantfp8_group_functionality(
    batch_size: int, hidden_dim: int, group_size: int, seed: int, use_ue8m0: bool
) -> None:
28
29
30
31
32
    """Test QuantFP8 group quantization with various configurations.

    Tests both CUDA and native implementations, column-major scales,
    and verifies consistency between implementations.
    """
33
    set_random_seed(seed)
34

35
    x = torch.randn((batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
36
37
38
39
    expected_num_groups = (hidden_dim + group_size - 1) // group_size
    is_divisible = hidden_dim % group_size == 0

    group_shape = GroupShape(1, group_size)
40
41
42
43
44
45
    quant_op = QuantFP8(
        static=False,
        group_shape=group_shape,
        column_major_scales=False,
        use_ue8m0=use_ue8m0,
    )
46
47
48
49
50
51
52

    # 1. Test native implementation (always available)
    x_quant_native, scales_native = quant_op.forward_native(x.clone())
    assert x_quant_native.shape == x.shape
    assert scales_native.shape == (batch_size, expected_num_groups)

    # 2. Test column-major scales configuration
53
54
55
56
57
58
    quant_op_col = QuantFP8(
        static=False,
        group_shape=group_shape,
        column_major_scales=True,
        use_ue8m0=use_ue8m0,
    )
59
    _, scales_col = quant_op_col.forward_native(x.clone())
60
61
62
63
64
    assert scales_col.shape == (batch_size, expected_num_groups)
    assert scales_col.stride(0) == 1
    assert scales_col.stride(1) == batch_size

    # Test column-major scales consistency
65
    torch.testing.assert_close(scales_col, scales_native, rtol=1e-9, atol=1e-8)
66
67
68
69
70
71
72
73

    # 3. Test CUDA implementation (only for divisible dimensions)
    if is_divisible:
        x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone())
        assert x_quant_cuda.shape == x.shape
        assert scales_cuda.shape == (batch_size, expected_num_groups)

        # Verify CUDA/native consistency
74
        torch.testing.assert_close(scales_cuda, scales_native, rtol=2e-7, atol=2e-8)
75
76
77
78
79
80
81
82

        # Quantized values should mostly match
        diff_count = (x_quant_cuda != x_quant_native).sum().item()
        diff_ratio = diff_count / x_quant_cuda.numel()
        assert diff_ratio < 0.002, f"Too many differences: {diff_ratio:.4%}"


@pytest.mark.parametrize("seed", [42])
83
@pytest.mark.parametrize("use_ue8m0", [True, False])
84
@torch.inference_mode()
85
def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
86
    set_random_seed(seed)
87
88
89
90

    group_size = 64

    # Test with 3D input
91
    batch1, batch2, hidden_dim = 4, 8, 1024
92
93
94
95
    x_3d = (
        torch.randn((batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda")
        * 8
    )
96
97

    group_shape = GroupShape(1, group_size)
98
99
100
101
102
103
    quant_op = QuantFP8(
        static=False,
        group_shape=group_shape,
        column_major_scales=False,
        use_ue8m0=use_ue8m0,
    )
104
105
106
107
108
109

    x_quant, scales = quant_op.forward_native(x_3d.clone())
    assert x_quant.shape == x_3d.shape
    assert scales.shape == (batch1, batch2, hidden_dim // group_size)

    # Test column_major_scales with multi-dim
110
111
112
113
114
115
    quant_op_col = QuantFP8(
        static=False,
        group_shape=group_shape,
        column_major_scales=True,
        use_ue8m0=use_ue8m0,
    )
116
    _, scales_col = quant_op_col.forward_native(x_3d.clone())
117
    assert scales_col.shape == (batch1, batch2, hidden_dim // group_size)
118
119
120

    # Test with 4D input
    batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256
121
122
123
124
125
126
    x_4d = (
        torch.randn(
            (batch1, batch2, batch3, hidden_dim), dtype=torch.bfloat16, device="cuda"
        )
        * 8
    )
127
128
129

    x_quant_4d, scales_4d = quant_op.forward_native(x_4d.clone())
    assert x_quant_4d.shape == x_4d.shape
130
    assert scales_4d.shape == (batch1, batch2, batch3, hidden_dim // group_size)
131
132

    _, scales_4d_col = quant_op_col.forward_native(x_4d.clone())
133
    assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size, batch3)
134
135
136
137
138


@pytest.mark.parametrize("seed", [42])
@torch.inference_mode()
def test_quantfp8_group_edge_cases(seed: int) -> None:
139
    set_random_seed(seed)
140
141
142
143
144

    batch_size = 16
    group_size = 64

    # Test with single group (group_size >= hidden_dim)
145
    x_small = torch.randn((batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8
146
    group_shape = GroupShape(1, group_size)
147
148
149
    quant_op = QuantFP8(
        static=False, group_shape=group_shape, column_major_scales=False
    )
150
151
152
153
154
155

    x_quant_small, scales_small = quant_op.forward_native(x_small.clone())
    assert x_quant_small.shape == x_small.shape
    assert scales_small.shape == (batch_size, 1)

    # Test with zero inputs
156
    x_zero = torch.zeros((batch_size, 256), dtype=torch.bfloat16, device="cuda")
157
158
159
160
161
    x_quant_zero, scales_zero = quant_op.forward_native(x_zero.clone())
    assert x_quant_zero.shape == x_zero.shape
    assert (scales_zero > 0).all(), "Scales should be clamped to minimum"

    # Test very large values
162
    x_large = torch.full((batch_size, 256), 1000.0, dtype=torch.bfloat16, device="cuda")
163
164
165
166
    x_quant_large, scales_large = quant_op.forward_native(x_large.clone())
    assert x_quant_large.shape == x_large.shape
    # FP8 max is typically 448 or 224, so scales should be > 1
    assert (scales_large > 1.0).all(), "Large values should have scales > 1"