sample.py 16.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
import math
from typing import Tuple, Optional

import torch
import triton
import triton.language as tl

from vllm.model_executor.layers.ops.rand import seeded_uniform

_EPS = 1e-6

# This is a hardcoded limit in Triton (max block size).
MAX_TRITON_N_COLS = 131072


def get_num_triton_sampler_splits(n_cols: int) -> int:
    """Get the number of splits to use for Triton sampling.

    Triton has a limit on the number of columns it can handle, so we need to
    split the tensor and call the kernel multiple times if it's too large.
    """
    return math.ceil(n_cols / MAX_TRITON_N_COLS)


def _multi_split_sample(
    probs: torch.Tensor,
    seeds: torch.Tensor,
    n_splits: int,
    sampled_tokens_size: Tuple[int, int],
    sampled_logprobs_size: Tuple[int, int],
    sample_indices: torch.Tensor,
    *,
    logprobs: Optional[torch.Tensor] = None,
    modify_greedy_probs: bool = False,
    save_logprobs: bool = False,
):
    """Sample tokens where vocab size is split into multiple parts
    (too large for Triton otherwise)."""
    assert seeds.ndim == 2 and seeds.shape[0] == n_splits
    split_probs = probs.tensor_split(n_splits, 1)
    split_logprobs = logprobs.tensor_split(n_splits, 1)
    sampled_tokens_tmp = [
        torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device)
        for _ in range(n_splits)
    ]
    sampled_logprobs_tmp = [
        torch.empty(sampled_logprobs_size,
                    dtype=probs.dtype,
                    device=probs.device) for _ in range(n_splits)
    ]
    # We are purposefuly using sampled_tokens_size as we need to always
    # save modified probs in this case.
    sampled_modified_probs_tmp = [
        torch.empty(sampled_tokens_size,
                    dtype=probs.dtype,
                    device=probs.device) for _ in range(n_splits)
    ]
    for i in range(n_splits):
        n_samples = sample_indices.shape[0]
        n_cols = split_probs[i].shape[1]
        n_best = sampled_tokens_tmp[i].shape[1]
        uniform_noise = seeded_uniform(n_samples,
                                       n_best,
                                       n_cols,
                                       seeds=seeds[i].flatten(),
                                       device=split_probs[i].device,
                                       dtype=split_probs[i].dtype)
        # TODO(yard1): See if we can remove the contiguous() calls.
        # Will need kernel support.
        _sample(
            split_probs[i].contiguous(),
            split_logprobs[i].contiguous(),
            sample_indices,
            sampled_tokens_tmp[i],
            sampled_logprobs_tmp[i],
            sampled_modified_probs_tmp[i],
            seeds[i],
            uniform_noise,
            modify_greedy_probs=False,
            save_logprobs=save_logprobs,
            save_modified_probs=True,
        )
        if i > 0:
            # Add offset to sampled tokens
            sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1])
    sampled_tokens = torch.stack(sampled_tokens_tmp)
    sampled_modified_probs = torch.stack(sampled_modified_probs_tmp)
    # Reduce the results from the splits.
    sampled_modified_probs, indices = torch.max(sampled_modified_probs,
                                                dim=0,
                                                keepdim=True)
    sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0)
    if save_logprobs:
        sampled_logprobs = torch.stack(sampled_logprobs_tmp)
        sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0)
    else:
        sampled_logprobs = None
    sampled_modified_probs = sampled_modified_probs.squeeze(0)

    if modify_greedy_probs:
        # We need to modify the greedy probs for the sampled tokens.
        # We can't do this in the kernel as we need to know the
        # sampled tokens.
        probs.fill_(0.0)
        probs.scatter_(1, sampled_tokens, 1.0)

    return (sampled_tokens, sampled_logprobs, sampled_modified_probs)


def sample(
    probs: torch.Tensor,
    seeds: torch.Tensor,
    *,
    max_best_of: int = 1,
    sample_indices: Optional[torch.Tensor] = None,
    logprobs: Optional[torch.Tensor] = None,
    modify_greedy_probs: bool = False,
    save_logprobs: bool = False,
    _save_modified_probs: bool = False,  # pylint: disable=invalid-name
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
    """Sample tokens from probs. with per-sequence seeds.

    Can sample from a subset of sequences through sample_indices.

    Args:
        probs: Probabilities to sample from.
            shape = [batch_size, vocab_size]
        seeds: Per-sequence seed values.
            shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)]
        max_best_of: Number of samples to generate per sequence.
            Sequence seed will be incremented by 1 each time.
        sample_indices: Indices of sequences to sample from.
            If not provided, will sample from all sequences.
            shape = [n]
        logprobs: Log-probabilities of the sampled tokens.
            Only used for saving the logprobs if save_logprobs is True.
            shape = [batch_size, vocab_size]
        modify_greedy_probs: Whether to modify the greedy probabilities
            for speculative sampling (sampled token = 1.0,
            everything else = 0.0).
        save_logprobs: Whether to save the log-probabilities of the
            sampled tokens to a tensor.
        _save_modified_probs: Whether to save the modified probabilities
            (including gumbel noise) of the sampled tokens to a tensor.
            DOES NOT include the modification done by modify_greedy_probs
            (because we want to use the unmodified probs to pick the best
            split in case of multi-split sampling).
            This is exposed only for testing.

    Returns:
        sampled_tokens: shape = [n, max_best_of]
        sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None
        sampled_modified_probs: shape = [n, max_best_of]
            if save_modified_probs else None
    """
    if sample_indices is None:
        sample_indices = torch.arange(0, probs.shape[0], device=probs.device)

    sampled_tokens_size = (sample_indices.size(0), max_best_of)
    if save_logprobs:
        if logprobs is None:
            raise ValueError(
                "logprobs tensor must be provided if save_logprobs is True")
        sampled_logprobs_size = sampled_tokens_size
    else:
        # Empty tensors to invoke the kernel
        sampled_logprobs_size = (0, 0)
        logprobs = probs

    if _save_modified_probs:
        sampled_modified_probs_size = sampled_tokens_size
    else:
        # Empty tensors to invoke the kernel
        sampled_modified_probs_size = (0, 0)

    # If the number of columns in probs is too large for Triton to handle,
    # we split the tensor and sample from each split separately, and then
    # do an argmax+gather to combine the results.
    n_splits = get_num_triton_sampler_splits(probs.shape[1])
    if n_splits > 1:
        (sampled_tokens, sampled_logprobs,
         sampled_modified_probs) = _multi_split_sample(
             probs,
             seeds,
             n_splits,
             sampled_tokens_size,
             sampled_logprobs_size,
             sample_indices,
             logprobs=logprobs,
             modify_greedy_probs=modify_greedy_probs,
             save_logprobs=save_logprobs)
    else:
        sampled_tokens = torch.empty(sampled_tokens_size,
                                     dtype=torch.long,
                                     device=probs.device)
        sampled_logprobs = torch.empty(sampled_logprobs_size,
                                       dtype=probs.dtype,
                                       device=probs.device)
        sampled_modified_probs = torch.empty(sampled_modified_probs_size,
                                             dtype=probs.dtype,
                                             device=probs.device)
        n_samples = sample_indices.shape[0]
        n_cols = probs.shape[1]
        uniform_noise = seeded_uniform(n_samples,
                                       max_best_of,
                                       n_cols,
                                       seeds=seeds.flatten(),
                                       device=probs.device,
                                       dtype=probs.dtype)

        _sample(
            probs,
            logprobs,
            sample_indices,
            sampled_tokens,
            sampled_logprobs,
            sampled_modified_probs,
            seeds,
            uniform_noise,
            modify_greedy_probs=modify_greedy_probs,
            save_logprobs=save_logprobs,
            save_modified_probs=_save_modified_probs,
        )
    return (sampled_tokens, sampled_logprobs if save_logprobs else None,
            sampled_modified_probs if _save_modified_probs else None)


def _sample(probs: torch.Tensor,
            logprobs: torch.Tensor,
            sample_indices: torch.Tensor,
            output_samples: torch.Tensor,
            output_logprobs: torch.Tensor,
            output_modified_probs: torch.Tensor,
            seeds: torch.Tensor,
            uniform_noise: torch.Tensor,
            *,
            modify_greedy_probs: bool = False,
            save_logprobs: bool = True,
            save_modified_probs: bool = False) -> torch.Tensor:
    """Sample tokens from probs.

    Args:
        probs [batch_size, vocab_size]: probs to sample from.
        logprobs [batch_size, vocab_size]: logprobs (used when
            save_logprobsis True).
        sample_indices [n]: Indices of the samples to use for each row of probs.
        output_samples [n, n_best]: Output tensor to store samples in.
        output_logprobs [n, n_best]: Output tensor to store logprobs in.
        output_modified_probs [n, n_best]: Output tensor to store
            probs of chosen tokens in (modified with noise).
        seeds [n]: Seeds to use for sampling. If the seed is 0, we use
            greedy sampling. Note this is ONLY used for determining
            whether to use random sampling or not. The actual random
            noise should be passed as uniform_noise.
        uniform_noise [batch_size, n_best, vocab_size]: Uniform
            noise to use for random sampling (will be converted
            to exponential gumbel noise by the kernel).
        modify_greedy_probs: If True, we modify the probs tensor in-place
            to encode the sampling method used for each row. This is used
            in speculative decoding. Only applies in greedy decoding.
        save_logprobs: If True, we save the logprobs of the sampled tokens
            in the output_logprobs tensor.
        save_modified_probs: If True, we save the modified probs (with noise)
            of the sampled tokens in the output_modified_probs tensor.
            DOES NOT include the modification done by modify_greedy_probs
            (because we want to use the unmodified probs to pick the best
            split in case of multi-split sampling).
    """
    n_samples = sample_indices.shape[0]
    n_cols = probs.shape[1]
    n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1

    # The block size is the smallest power of two greater than the number of
    # columns in probs
    block_size = triton.next_power_of_2(n_cols)
    num_warps = 4
    # Manual tuning. This seems to give best performance on A100 for
    # simple kernels like this.
    if block_size >= 8192:
        num_warps = 32
    elif block_size >= 4096:
        num_warps = 16
    elif block_size >= 2048:
        num_warps = 8

    # Enqueue kernel. The 1D launch grid is simple: we have one kernel
    # instance per row of the probs matrix
    _sample_triton[(n_samples, n_best)](
        sample_indices,
        output_samples,
        output_logprobs,
        output_modified_probs,
        probs,
        logprobs,
        seeds,
        uniform_noise,
        output_samples.stride(0),
        probs.stride(0),
        uniform_noise.stride(0),
        uniform_noise.stride(1) if n_best > 1 else 1,
        n_samples,
        n_cols,
        n_best,
        num_warps=num_warps,
        block_size=block_size,
        modify_greedy_probs=modify_greedy_probs,
        save_logprobs=save_logprobs,
        save_modified_probs=save_modified_probs,
    )
    return output_samples, output_logprobs, output_modified_probs


@triton.jit
def _uniform_to_exponential(uniform_noise):
    """Convert uniform samples to exponential samples."""
    # tl.rand returns values in [0, 1), so we clamp lower bound
    # to _EPS to avoid log(0) and thus division by 0 later
    lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype)
    uniform_noise = tl.maximum(uniform_noise, lb)
    # Use the inversion method to turn uniform samples
    # into exponential samples
    exponential_noise = -tl.log(uniform_noise)
    return exponential_noise


@triton.jit
def _sample_triton(
        sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor,
        output_logprobs_ptr: torch.Tensor,
        output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor,
        logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor,
        uniform_noise_ptr: torch.Tensor, output_row_stride: int,
        probs_row_stride: int, uniform_noise_row_stride: int,
        uniform_noise_best_stride: int, n_samples: int, n_cols: int,
        n_best: int, block_size: tl.constexpr,
        modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr,
        save_modified_probs: tl.constexpr):
    # The rows are independent, so we parallelize across those
    sample_idx = tl.program_id(0)
    best_idx = tl.program_id(1)

    # Load the row index from DRAM
    row_idx = tl.load(sample_indices_ptr + sample_idx)
    seed = tl.load(seeds_ptr + sample_idx)
    uses_random_sampling = seed != 0

    # The stride represents how much we need to increase the
    # pointer to advance 1 row
    row_start_ptr = probs_ptr + row_idx * probs_row_stride

    # The block size is the next power of two greater than n_cols,
    # so we can fit each row in a single block
    col_offsets = tl.arange(0, block_size)

    # Load the row into SRAM, using a mask since block_size may be > than n_cols
    row = tl.load(row_start_ptr + col_offsets,
                  mask=col_offsets < n_cols,
                  other=float("-inf"))

    if uses_random_sampling:
        uniform_noise_start_ptr = (uniform_noise_ptr +
                                   sample_idx * uniform_noise_row_stride +
                                   best_idx * uniform_noise_best_stride)
        uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets,
                                mask=col_offsets < n_cols,
                                other=0.5)
        exponential_noise = _uniform_to_exponential(uniform_noise)
        row /= exponential_noise

    sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True)
    # clamp sampled token to n_cols - 1
    # this should not be necessary, but we do it
    # just in case
    if sampled_token >= n_cols:
        sampled_token = n_cols - 1
    # Write back output to DRAM
    output_row_start_ptr = (output_ptr + sample_idx * output_row_stride +
                            best_idx)
    tl.store(output_row_start_ptr, sampled_token)

    if modify_greedy_probs:  # noqa
        if not uses_random_sampling:
            # Set the probability of the sampled token to 1, all other
            # tokens to zero. This is used in speculative decoding where
            # the sampling method must be encoded within the sampled
            # probability distributions.
            row = tl.where(col_offsets == sampled_token, 1.0, 0.0)
            tl.store(row_start_ptr + col_offsets,
                     row,
                     mask=col_offsets < n_cols)

    if save_modified_probs:
        output_row_start_ptr = (output_modified_probs_ptr +
                                sample_idx * output_row_stride + best_idx)
        tl.store(output_row_start_ptr, sampled_value)

    if save_logprobs:
        # Load the row into SRAM, using a mask since block_size
        # may be > than n_cols
        sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride +
                                  sampled_token)
        # Write back output to DRAM
        output_row_start_ptr = (output_logprobs_ptr +
                                sample_idx * output_row_stride + best_idx)
        tl.store(output_row_start_ptr, sampled_logprob)