test_fp8_quant_group.py 5.92 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for QuantFP8 Group Quantization implementation."""

import pytest
import torch

from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
9
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
10
from vllm.utils.torch_utils import set_random_seed
11
12
13
14
15
16
17
18
19


@pytest.mark.parametrize(
    "batch_size,hidden_dim,group_size",
    [
        (16, 256, 32),  # Small
        (64, 1024, 64),  # Medium
        (128, 2048, 128),  # Large
        (8, 513, 64),  # Non-divisible (native only)
20
21
    ],
)
22
@pytest.mark.parametrize("seed", [42])
23
@pytest.mark.parametrize("use_ue8m0", [True, False])
24
@torch.inference_mode()
25
def test_quantfp8_group_functionality(
26
27
28
29
30
31
    default_vllm_config,
    batch_size: int,
    hidden_dim: int,
    group_size: int,
    seed: int,
    use_ue8m0: bool,
32
) -> None:
33
34
35
36
37
    """Test QuantFP8 group quantization with various configurations.

    Tests both CUDA and native implementations, column-major scales,
    and verifies consistency between implementations.
    """
38
    set_random_seed(seed)
39

40
    x = torch.randn((batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
41
42
43
44
    expected_num_groups = (hidden_dim + group_size - 1) // group_size
    is_divisible = hidden_dim % group_size == 0

    group_shape = GroupShape(1, group_size)
45
46
47
48
49
50
    quant_op = QuantFP8(
        static=False,
        group_shape=group_shape,
        column_major_scales=False,
        use_ue8m0=use_ue8m0,
    )
51
52
53
54
55
56
57

    # 1. Test native implementation (always available)
    x_quant_native, scales_native = quant_op.forward_native(x.clone())
    assert x_quant_native.shape == x.shape
    assert scales_native.shape == (batch_size, expected_num_groups)

    # 2. Test column-major scales configuration
58
59
60
61
62
63
    quant_op_col = QuantFP8(
        static=False,
        group_shape=group_shape,
        column_major_scales=True,
        use_ue8m0=use_ue8m0,
    )
64
    _, scales_col = quant_op_col.forward_native(x.clone())
65
66
67
68
69
    assert scales_col.shape == (batch_size, expected_num_groups)
    assert scales_col.stride(0) == 1
    assert scales_col.stride(1) == batch_size

    # Test column-major scales consistency
70
    torch.testing.assert_close(scales_col, scales_native, rtol=1e-9, atol=1e-8)
71
72
73
74
75
76
77
78

    # 3. Test CUDA implementation (only for divisible dimensions)
    if is_divisible:
        x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone())
        assert x_quant_cuda.shape == x.shape
        assert scales_cuda.shape == (batch_size, expected_num_groups)

        # Verify CUDA/native consistency
79
        torch.testing.assert_close(scales_cuda, scales_native, rtol=2e-7, atol=2e-8)
80
81
82
83
84
85
86
87

        # Quantized values should mostly match
        diff_count = (x_quant_cuda != x_quant_native).sum().item()
        diff_ratio = diff_count / x_quant_cuda.numel()
        assert diff_ratio < 0.002, f"Too many differences: {diff_ratio:.4%}"


@pytest.mark.parametrize("seed", [42])
88
@pytest.mark.parametrize("use_ue8m0", [True, False])
89
@torch.inference_mode()
90
91
92
def test_quantfp8_group_multidimensional(
    default_vllm_config, seed: int, use_ue8m0: bool
) -> None:
93
    set_random_seed(seed)
94
95
96
97

    group_size = 64

    # Test with 3D input
98
    batch1, batch2, hidden_dim = 4, 8, 1024
99
100
101
102
    x_3d = (
        torch.randn((batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda")
        * 8
    )
103
104

    group_shape = GroupShape(1, group_size)
105
106
107
108
109
110
    quant_op = QuantFP8(
        static=False,
        group_shape=group_shape,
        column_major_scales=False,
        use_ue8m0=use_ue8m0,
    )
111
112
113
114
115
116

    x_quant, scales = quant_op.forward_native(x_3d.clone())
    assert x_quant.shape == x_3d.shape
    assert scales.shape == (batch1, batch2, hidden_dim // group_size)

    # Test column_major_scales with multi-dim
117
118
119
120
121
122
    quant_op_col = QuantFP8(
        static=False,
        group_shape=group_shape,
        column_major_scales=True,
        use_ue8m0=use_ue8m0,
    )
123
    _, scales_col = quant_op_col.forward_native(x_3d.clone())
124
    assert scales_col.shape == (batch1, batch2, hidden_dim // group_size)
125
126
127

    # Test with 4D input
    batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256
128
129
130
131
132
133
    x_4d = (
        torch.randn(
            (batch1, batch2, batch3, hidden_dim), dtype=torch.bfloat16, device="cuda"
        )
        * 8
    )
134
135
136

    x_quant_4d, scales_4d = quant_op.forward_native(x_4d.clone())
    assert x_quant_4d.shape == x_4d.shape
137
    assert scales_4d.shape == (batch1, batch2, batch3, hidden_dim // group_size)
138
139

    _, scales_4d_col = quant_op_col.forward_native(x_4d.clone())
140
    assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size, batch3)
141
142
143
144


@pytest.mark.parametrize("seed", [42])
@torch.inference_mode()
145
def test_quantfp8_group_edge_cases(default_vllm_config, seed: int) -> None:
146
    set_random_seed(seed)
147
148
149
150
151

    batch_size = 16
    group_size = 64

    # Test with single group (group_size >= hidden_dim)
152
    x_small = torch.randn((batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8
153
    group_shape = GroupShape(1, group_size)
154
155
156
    quant_op = QuantFP8(
        static=False, group_shape=group_shape, column_major_scales=False
    )
157
158
159
160
161
162

    x_quant_small, scales_small = quant_op.forward_native(x_small.clone())
    assert x_quant_small.shape == x_small.shape
    assert scales_small.shape == (batch_size, 1)

    # Test with zero inputs
163
    x_zero = torch.zeros((batch_size, 256), dtype=torch.bfloat16, device="cuda")
164
165
166
167
168
    x_quant_zero, scales_zero = quant_op.forward_native(x_zero.clone())
    assert x_quant_zero.shape == x_zero.shape
    assert (scales_zero > 0).all(), "Scales should be clamped to minimum"

    # Test very large values
169
    x_large = torch.full((batch_size, 256), 1000.0, dtype=torch.bfloat16, device="cuda")
170
171
172
173
    x_quant_large, scales_large = quant_op.forward_native(x_large.clone())
    assert x_quant_large.shape == x_large.shape
    # FP8 max is typically 448 or 224, so scales should be > 1
    assert (scales_large > 1.0).all(), "Large values should have scales > 1"