test_top_k_per_row.py 8.44 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import numpy as np
import pytest
import torch

from vllm.platforms import current_platform
9
from vllm.utils.torch_utils import set_random_seed
10
11
12

# Test parameters
NUM_ROWS = [1, 32, 2050]
13
14
15
16
TOP_K_VALUES = [2048, 3000]
BATCH_SIZE = [1, 2, 2048]
NEXT_N = [1, 8]
DATA_GENERATION = ["random", "10LSBits"]
17
18
19
20
21
22
23


def create_random_logits(
    row_starts: torch.Tensor,
    row_ends: torch.Tensor,
    dtype: torch.dtype,
    seed: int,
24
    clean_logits: bool,
25
    data_generation: str,
26
27
28
29
30
) -> torch.Tensor:
    """Create random logits tensor for testing."""
    torch.manual_seed(seed)
    np.random.seed(seed)
    # Generate logits with some structure to make testing more meaningful
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    if data_generation == "random":
        logits = torch.randn(
            row_starts.shape[0], max(row_ends), dtype=dtype, device="cuda"
        )
    elif data_generation == "10LSBits":
        top_22_bits_mask = 0xFFFFFC00
        last_10_bits_mask = 0x000003FF
        fixed_top_22_bits = 0x3F900000
        # Generate random bits for the last 10 bits
        random_bottom_bits = torch.randint(
            0,
            2**10,
            (row_starts.shape[0], max(row_ends)),
            dtype=torch.int32,
            device="cuda",
        )
        # Combine: fixed top 22 bits with random last 10 bits
        logits_bits = (fixed_top_22_bits & top_22_bits_mask) | (
            random_bottom_bits & last_10_bits_mask
        )
        logits = logits_bits.view(dtype)

53
54
55
    if clean_logits:
        for i, end in enumerate(row_ends):
            logits[i, end:] = float("-inf")
56
57
58
59
60
61
62
63
64
65
66
67
68
    return logits


def create_row_boundaries(
    seq_len: int, vocab_size: int
) -> tuple[torch.Tensor, torch.Tensor]:
    """Create row start and end indices for testing."""
    row_starts = torch.zeros(seq_len, dtype=torch.int32, device="cuda")
    row_ends = torch.arange(1, seq_len + 1, device="cuda", dtype=torch.int32)
    return row_starts, row_ends


def compare_top_k_results(
69
    logits: torch.Tensor,
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    cuda_indices: torch.Tensor,
    torch_indices: torch.Tensor,
    row_starts: torch.Tensor,
    row_ends: torch.Tensor,
    top_k: int,
    tolerance: float = 1e-5,
) -> bool:
    """
    Compare results from CUDA top_k_per_row with torch.topk.
    Both results should be sorted and contain the same top-k elements.
    """
    num_rows = cuda_indices.shape[0]

    for row_idx in range(num_rows):
        # Get valid elements using row boundaries
        row_start = row_starts[row_idx].item()
        row_end = row_ends[row_idx].item()
        row_length = row_end - row_start
        num_valid = min(top_k, row_length)
        cuda_row_indices = cuda_indices[row_idx][:num_valid].cpu()
        torch_row_indices = torch_indices[row_idx][:num_valid].cpu()

        # Compare the sets of indices first
        cuda_set = set(cuda_row_indices.tolist())
        torch_set = set(torch_row_indices.tolist())
        if cuda_set == torch_set:
            continue

        # Any difference in elements, compare the values
99
100
101
        logits_row = logits[row_idx]
        cuda_row_values = [logits_row[i] for i in cuda_row_indices]
        torch_row_values = [logits_row[i] for i in torch_row_indices]
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

        cuda_only_values, torch_only_values = [], []
        for idx in cuda_set - torch_set:
            cuda_pos = (cuda_row_indices == idx).nonzero(as_tuple=True)[0]
            cuda_only_values.append(cuda_row_values[cuda_pos[0]])

        for idx in torch_set - cuda_set:
            torch_pos = (torch_row_indices == idx).nonzero(as_tuple=True)[0]
            torch_only_values.append(torch_row_values[torch_pos[0]])

        if len(cuda_only_values) != len(torch_only_values):
            return False
        if not torch.allclose(
            torch.tensor(cuda_only_values),
            torch.tensor(torch_only_values),
            rtol=tolerance,
            atol=tolerance,
        ):
            return False

    return True


@pytest.mark.parametrize("num_rows", NUM_ROWS)
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
127
@pytest.mark.parametrize("clean_logits", [True, False])
128
129
130
131
132
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@torch.inference_mode()
def test_top_k_per_row(
    num_rows: int,
    top_k: int,
133
    clean_logits: bool,
134
135
136
137
) -> None:
    """
    Test top_k_per_row.
    """
138
    set_random_seed(0)
139
140
141
142
143
    torch.set_default_device("cuda:0")

    # Create test data
    vocab_size = 20000
    row_starts, row_ends = create_row_boundaries(num_rows, vocab_size)
144
145
146
    logits = create_random_logits(
        row_starts, row_ends, torch.float32, 42, clean_logits, "random"
    )
147
148

    # Create output tensors
149
    indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
150
151

    # Run CUDA implementation
152
    torch.ops._C.top_k_per_row_prefill(
153
154
155
156
157
158
159
        logits,
        row_starts,
        row_ends,
        indices,
        num_rows,
        logits.stride(0),
        logits.stride(1),
160
        top_k,
161
162
163
    )

    # Run reference implementation
164
165
166
167
168
169
    torch_indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
    for i in range(num_rows):
        row_end = int(row_ends[i])
        k_i = min(top_k, row_end)
        idx = logits[i, :row_end].topk(k_i, dim=-1)[1]
        torch_indices[i, :k_i] = idx
170
171
172

    # Compare results
    assert compare_top_k_results(
173
        logits, indices, torch_indices, row_starts, row_ends, top_k
174
    ), "CUDA top_k_per_row_prefill results don't match torch.topk"
175
176


177
def _run_top_k_per_row_decode_test(
178
179
180
    top_k: int,
    batch_size: int,
    next_n: int,
181
    vocab_size: int,
182
    clean_logits: bool,
183
    data_generation: str,
184
185
) -> None:
    """
186
    Helper function to run top_k_per_row_decode test with given parameters.
187
188
189
190
191
192
    """
    torch.set_default_device("cuda:0")

    # Create test data
    num_rows = batch_size * next_n
    seq_lens = torch.randint(
193
194
195
196
197
        low=next_n,
        high=vocab_size,
        size=(batch_size,),
        dtype=torch.int32,
        device="cuda",
198
199
200
201
202
    )
    row_starts = torch.zeros(num_rows, dtype=torch.int32, device="cuda")
    row_indices = torch.arange(num_rows, device="cuda") // next_n
    next_n_offset = torch.arange(num_rows, device="cuda") % next_n
    row_ends = seq_lens[row_indices] - next_n + next_n_offset + 1
203
    logits = create_random_logits(
204
        row_starts, row_ends, torch.float32, 42, clean_logits, data_generation
205
    )
206
207
208
209
210
211
212
213
214
215
216
217
218

    # Create output tensors
    indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")

    # Run CUDA implementation
    torch.ops._C.top_k_per_row_decode(
        logits,
        next_n,
        seq_lens,
        indices,
        num_rows,
        logits.stride(0),
        logits.stride(1),
219
        top_k,
220
221
222
223
224
    )

    torch.cuda.synchronize()

    # Run reference implementation
225
226
227
228
229
230
    torch_indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
    for i in range(num_rows):
        row_end = int(row_ends[i])
        k_i = min(top_k, row_end)
        idx = logits[i, :row_end].topk(k_i, dim=-1)[1]
        torch_indices[i, :k_i] = idx
231
232
233
234

    # Compare results
    assert compare_top_k_results(
        logits, indices, torch_indices, row_starts, row_ends, top_k
235
236
237
238
239
240
    ), "CUDA top_k_per_row_decode results don't match torch.topk"


@pytest.mark.parametrize("top_k", TOP_K_VALUES)
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
@pytest.mark.parametrize("next_n", NEXT_N)
241
@pytest.mark.parametrize("clean_logits", [True, False])
242
243
244
245
246
247
248
@pytest.mark.parametrize("data_generation", DATA_GENERATION)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@torch.inference_mode()
def test_top_k_per_row_decode(
    top_k: int,
    batch_size: int,
    next_n: int,
249
    clean_logits: bool,
250
251
252
253
254
    data_generation: str,
) -> None:
    """
    Test top_k_per_row with seq_lens tensor.
    """
255
    set_random_seed(0)
256
257
    vocab_size = 20000
    _run_top_k_per_row_decode_test(
258
        top_k, batch_size, next_n, vocab_size, clean_logits, data_generation
259
260
261
262
    )


@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
263
@pytest.mark.parametrize("clean_logits", [True, False])
264
@torch.inference_mode()
265
def test_top_k_per_row_decode_large_vocab_size(clean_logits: bool) -> None:
266
267
268
    """
    Test top_k_per_row_decode with large vocabulary size.
    """
269
    set_random_seed(0)
270
271
272
273
274
275
    top_k = 2048
    batch_size = 2
    next_n = 2
    vocab_size = 300000
    data_generation = "random"
    _run_top_k_per_row_decode_test(
276
        top_k, batch_size, next_n, vocab_size, clean_logits, data_generation
277
    )