"tests/models/embedding/language/test_embedding.py" did not exist on "9ba093b4f4f914a8557eb7e4bf961d84420671a5"
test_block_fp8.py 7.01 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
# Adapted from https://github.com/sgl-project/sglang/pull/2575
import itertools

import pytest
import torch

10
11
12
13
from tests.kernels.quant_utils import (
    native_per_token_group_quant_fp8,
    native_w8a8_block_matmul,
)
bnellnm's avatar
bnellnm committed
14
from vllm.config import VllmConfig
15
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
16
17
18
19
    cutlass_scaled_mm,
    per_token_group_quant_fp8,
    w8a8_triton_block_scaled_mm,
)
20
from vllm.platforms import current_platform
21
22
23
24
from vllm.utils.deep_gemm import (
    fp8_gemm_nt,
    get_col_major_tma_aligned_tensor,
    per_block_cast_to_fp8,
25
    should_use_deepgemm_for_fp8_linear,
26
)
27
from vllm.utils.import_utils import has_deep_gemm
28

29
if current_platform.get_device_capability() < (9, 0):
30
    pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
31

32
33
vllm_config = VllmConfig()

34
35
# Test configurations
DTYPES = [torch.bfloat16]  # [torch.half, torch.bfloat16, torch.float32]
36
NUM_TOKENS = [7, 2050]
37
D = [512, 4096, 5120, 13824]
38
39
40
41
GROUP_SIZE = [64, 128, 512]
M = [1, 7, 8, 83, 84, 4096]
N = [128, 512, 7168, 7748, 13824]
K = [256, 3884, 4096, 13824, 16384]
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
# and its hidden size is 7168.
BLOCK_SIZE = [[128, 128]]
OUT_DTYPES = [torch.bfloat16]  # [torch.float32, torch.half, torch.bfloat16]
SEEDS = [0]

# Skip all tests if CUDA is not available
pytest.importorskip("torch.cuda")


@pytest.fixture(autouse=True)
def setup_cuda():
    torch.set_default_device("cuda")


57
58
59
60
@pytest.mark.skipif(
    current_platform.is_fp8_fnuz(),
    reason="This platform supports e4m3fnuz, not e4m3fn.",
)
61
62
@pytest.mark.parametrize(
    "num_tokens,d,dtype,group_size,seed",
63
64
    itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS),
)
65
66
67
68
69
70
71
72
@torch.inference_mode()
def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
    torch.manual_seed(seed)
    x = torch.rand(num_tokens, d, dtype=dtype)

    ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size)
    out, scale = per_token_group_quant_fp8(x, group_size)

73
    assert torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15)
74
75
76
    assert torch.allclose(scale, ref_scale)


77
78
@pytest.mark.parametrize(
    "M,N,K,block_size,out_dtype,seed",
79
80
    itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
)
81
82
83
84
@torch.inference_mode()
def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
    torch.manual_seed(seed)
    factor_for_scale = 1e-2
85
    fp8_info = torch.finfo(current_platform.fp8_dtype())
86
87
88
    fp8_max, fp8_min = fp8_info.max, fp8_info.min

    A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
89
    A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(current_platform.fp8_dtype())
90
91

    B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
92
    B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(current_platform.fp8_dtype())
93
94
95
96
97
98
99
100

    block_n, block_k = block_size[0], block_size[1]
    n_tiles = (N + block_n - 1) // block_n
    k_tiles = (K + block_k - 1) // block_k

    As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale
    Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale

101
102
    ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
    out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
103

104
105
106
    rel_diff = torch.mean(
        torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
    ) / torch.mean(torch.abs(ref_out.to(torch.float32)))
107
108
109
    assert rel_diff < 0.001


110
111
112
@pytest.mark.skipif(
    not current_platform.is_cuda(), reason="CUTLASS only supported on CUDA platform."
)
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
@torch.inference_mode()
def test_w8a8_block_fp8_cutlass_matmul():
    # Test simple case where weight.shape % 128 != 0,
    # like in DSV3 kv_a_proj_with_mqa
    M = 32
    N = 576
    K = 7168
    block_size = [128, 128]
    out_dtype = torch.bfloat16
    seed = 0

    torch.manual_seed(seed)
    factor_for_scale = 1e-2
    fp8_info = torch.finfo(torch.float8_e4m3fn)
    fp8_max, fp8_min = fp8_info.max, fp8_info.min

    A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max

    B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
    B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)

    block_n, block_k = block_size[0], block_size[1]
    n_tiles = (N + block_n - 1) // block_n
    k_tiles = (K + block_k - 1) // block_k

    Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
    # Hopper requires row-major format for scales
140
    Bs_cutlass = Bs.T.contiguous() if current_platform.is_device_capability(90) else Bs
141

142
143
144
    A_fp8, As = per_token_group_quant_fp8(
        A_fp32, block_size[1], column_major_scales=False
    )
145
146
    # CUTLASS uses column-major format for scales
    A_fp8_cutlass, As_cutlass = per_token_group_quant_fp8(
147
148
        A_fp32, block_size[1], column_major_scales=True
    )
149

150
151
152
153
    ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
    out = cutlass_scaled_mm(
        A_fp8_cutlass, B_fp8, As_cutlass, Bs_cutlass, block_size, out_dtype
    )
154

155
156
157
    rel_diff = torch.mean(
        torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
    ) / torch.mean(torch.abs(ref_out.to(torch.float32)))
158
159
160
    assert rel_diff < 0.001


161
162
163
164
@pytest.mark.skipif(
    current_platform.is_fp8_fnuz(),
    reason="This platform supports e4m3fnuz, not e4m3fn.",
)
165
166
@pytest.mark.parametrize(
    "M,N,K,block_size,out_dtype,seed",
167
168
169
    itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
)
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGemm kernels not available.")
170
171
172
173
174
175
176
177
178
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
    torch.manual_seed(seed)
    fp8_info = torch.finfo(torch.float8_e4m3fn)
    fp8_max = fp8_info.max

    A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
    B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max

179
180
181
182
183
184
    # only aligned sizes are supported by deepgemm
    if not should_use_deepgemm_for_fp8_linear(
        output_dtype=out_dtype, weight=B_fp32, supports_deep_gemm=True
    ):
        pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")

185
    A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_size[1])
186
    B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32, block_size=block_size)
187
188
189
190

    As = As_fp8.to(torch.float32)
    Bs = Bs_fp8.to(torch.float32)

191
    ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
192

193
    # Transpose earlier so that the testing will not trigger transposing kernels
194
    As_fp8 = get_col_major_tma_aligned_tensor(As_fp8)
195

196
    out = torch.zeros((M, N), device="cuda", dtype=out_dtype)
197

198
199
200
    assert As_fp8.shape == (M, (K + 127) // 128), (
        f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
    )
201

202
    fp8_gemm_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
203

204
205
206
    rel_diff = torch.mean(
        torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
    ) / torch.mean(torch.abs(ref_out.to(torch.float32)))
207
    assert rel_diff < 0.001