test_fp8_quant_group.py 5.85 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for QuantFP8 Group Quantization implementation."""

import pytest
import torch

from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    GroupShape)
from vllm.platforms import current_platform


zhuwenwen's avatar
zhuwenwen committed
14
15
16
17
18
19
20
21
22
23
24
25
26
# @pytest.mark.parametrize(
#     "batch_size,hidden_dim,group_size",
#     [
#         (16, 256, 32),  # Small
#         (64, 1024, 64),  # Medium
#         (128, 2048, 128),  # Large
#         (8, 513, 64),  # Non-divisible (native only)
#     ])
# @pytest.mark.parametrize("seed", [42])
# @torch.inference_mode()
# def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
#                                       group_size: int, seed: int) -> None:
#     """Test QuantFP8 group quantization with various configurations.
27

zhuwenwen's avatar
zhuwenwen committed
28
29
30
31
#     Tests both CUDA and native implementations, column-major scales,
#     and verifies consistency between implementations.
#     """
#     current_platform.seed_everything(seed)
32

zhuwenwen's avatar
zhuwenwen committed
33
34
35
36
#     x = torch.randn(
#         (batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
#     expected_num_groups = (hidden_dim + group_size - 1) // group_size
#     is_divisible = hidden_dim % group_size == 0
37

zhuwenwen's avatar
zhuwenwen committed
38
39
40
41
#     group_shape = GroupShape(1, group_size)
#     quant_op = QuantFP8(static=False,
#                         group_shape=group_shape,
#                         column_major_scales=False)
42

zhuwenwen's avatar
zhuwenwen committed
43
44
45
46
#     # 1. Test native implementation (always available)
#     x_quant_native, scales_native = quant_op.forward_native(x.clone())
#     assert x_quant_native.shape == x.shape
#     assert scales_native.shape == (batch_size, expected_num_groups)
47

zhuwenwen's avatar
zhuwenwen committed
48
49
50
51
52
53
#     # 2. Test column-major scales configuration
#     quant_op_col = QuantFP8(static=False,
#                             group_shape=group_shape,
#                             column_major_scales=True)
#     _, scales_col = quant_op_col.forward_native(x.clone())
#     assert scales_col.shape == (expected_num_groups, batch_size)
54

zhuwenwen's avatar
zhuwenwen committed
55
56
57
58
59
#     # 3. Test CUDA implementation (only for divisible dimensions)
#     if is_divisible:
#         x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone())
#         assert x_quant_cuda.shape == x.shape
#         assert scales_cuda.shape == (batch_size, expected_num_groups)
60

zhuwenwen's avatar
zhuwenwen committed
61
62
#         # Verify CUDA/native consistency
#         assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8)
63

zhuwenwen's avatar
zhuwenwen committed
64
65
66
67
#         # Quantized values should mostly match
#         diff_count = (x_quant_cuda != x_quant_native).sum().item()
#         diff_ratio = diff_count / x_quant_cuda.numel()
#         assert diff_ratio < 0.002, f"Too many differences: {diff_ratio:.4%}"
68
69
70
71


@pytest.mark.parametrize("seed", [42])
@torch.inference_mode()
72
def test_quantfp8_group_multidimensional(seed: int) -> None:
73
74
75
76
77
78
79
80
81
82
83
84
    current_platform.seed_everything(seed)

    group_size = 64

    # Test with 3D input
    batch1, batch2, hidden_dim = 4, 8, 512
    x_3d = torch.randn(
        (batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8

    group_shape = GroupShape(1, group_size)
    quant_op = QuantFP8(static=False,
                        group_shape=group_shape,
85
                        column_major_scales=False)
86
87
88
89
90
91
92
93

    x_quant, scales = quant_op.forward_native(x_3d.clone())
    assert x_quant.shape == x_3d.shape
    assert scales.shape == (batch1, batch2, hidden_dim // group_size)

    # Test column_major_scales with multi-dim
    quant_op_col = QuantFP8(static=False,
                            group_shape=group_shape,
94
                            column_major_scales=True)
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    _, scales_col = quant_op_col.forward_native(x_3d.clone())
    assert scales_col.shape == (batch1, hidden_dim // group_size, batch2)

    # Test with 4D input
    batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256
    x_4d = torch.randn((batch1, batch2, batch3, hidden_dim),
                       dtype=torch.bfloat16,
                       device="cuda") * 8

    x_quant_4d, scales_4d = quant_op.forward_native(x_4d.clone())
    assert x_quant_4d.shape == x_4d.shape
    assert scales_4d.shape == (batch1, batch2, batch3,
                               hidden_dim // group_size)

    _, scales_4d_col = quant_op_col.forward_native(x_4d.clone())
    assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size,
                                   batch3)


@pytest.mark.parametrize("seed", [42])
@torch.inference_mode()
def test_quantfp8_group_edge_cases(seed: int) -> None:
    current_platform.seed_everything(seed)

    batch_size = 16
    group_size = 64

    # Test with single group (group_size >= hidden_dim)
    x_small = torch.randn(
        (batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8
    group_shape = GroupShape(1, group_size)
    quant_op = QuantFP8(static=False,
                        group_shape=group_shape,
                        column_major_scales=False)

    x_quant_small, scales_small = quant_op.forward_native(x_small.clone())
    assert x_quant_small.shape == x_small.shape
    assert scales_small.shape == (batch_size, 1)

    # Test with zero inputs
    x_zero = torch.zeros((batch_size, 256),
                         dtype=torch.bfloat16,
                         device="cuda")
    x_quant_zero, scales_zero = quant_op.forward_native(x_zero.clone())
    assert x_quant_zero.shape == x_zero.shape
    assert (scales_zero > 0).all(), "Scales should be clamped to minimum"

    # Test very large values
    x_large = torch.full((batch_size, 256),
                         1000.0,
                         dtype=torch.bfloat16,
                         device="cuda")
    x_quant_large, scales_large = quant_op.forward_native(x_large.clone())
    assert x_quant_large.shape == x_large.shape
    # FP8 max is typically 448 or 224, so scales should be > 1
    assert (scales_large > 1.0).all(), "Large values should have scales > 1"