"official/projects/panoptic/tasks/panoptic_maskrcnn.py" did not exist on "17b3db9f2e414cb61c637c7f97bb7250c7e5791c"
generation.py 31.8 KB
Newer Older
1
# Copyright (c) 2023, Tri Dao.
Tri Dao's avatar
Tri Dao committed
2
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
3
import gc
Tri Dao's avatar
Tri Dao committed
4
5
import time
from collections import namedtuple
Tri Dao's avatar
Tri Dao committed
6
from dataclasses import dataclass, field
7
from functools import partial
Tri Dao's avatar
Tri Dao committed
8
from typing import Callable, Optional, Sequence, Union
Tri Dao's avatar
Tri Dao committed
9

Tri Dao's avatar
Tri Dao committed
10
import torch
11
12
import torch.nn.functional as F
from einops import rearrange, repeat
Tri Dao's avatar
Tri Dao committed
13
14
from torch import Tensor
from torch.profiler import ProfilerActivity, profile, record_function
15
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
Tri Dao's avatar
Tri Dao committed
16
17
18
19
20
21


@dataclass
class InferenceParams:
    """Inference parameters that are passed to the main model in order
    to efficienly calculate and store the context during inference."""
Tri Dao's avatar
Tri Dao committed
22

Tri Dao's avatar
Tri Dao committed
23
24
25
26
27
    max_sequence_len: int
    max_batch_size: int
    sequence_len_offset: int = 0
    batch_size_offset: int = 0
    key_value_memory_dict: dict = field(default_factory=dict)
28
29
    fused_ft_kernel: bool = False
    lengths_per_sample: Optional[Tensor] = None
Tri Dao's avatar
Tri Dao committed
30
31


32
33
34
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
def modify_logits_for_top_k_filtering(logits, top_k):
35
    """Set the logits for none top-k values to -inf. Done in-place."""
36
37
38
39
    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
    logits.masked_fill_(indices_to_remove, float("-Inf"))


40
41
42
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
def modify_logits_for_top_p_filtering(logits, top_p):
43
    """Set the logits for none top-p values to -inf. Done in-place."""
44
    if top_p <= 0.0 or top_p >= 1.0:
45
46
47
48
        return
    # First sort and calculate cumulative sum of probabilities.
    sorted_logits, sorted_indices = torch.sort(logits, descending=False)
    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
Tri Dao's avatar
Tri Dao committed
49
    # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
50
51
    sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
    # scatter sorted tensors to original indexing
Tri Dao's avatar
Tri Dao committed
52
53
54
    indices_to_remove = sorted_indices_to_remove.scatter(
        1, sorted_indices, sorted_indices_to_remove
    )
55
    logits.masked_fill_(indices_to_remove, float("-inf"))
56
57
58
59
60
61
62
63
64
65
66


def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
    """Sample from top-k logits.
    Arguments:
        logits: Tensor of shape (batch_size, vocab_size)
    """
    if top_k == 1:  # Short-circuit for greedy decoding
        return logits.argmax(dim=-1)
    else:
        if top_p > 0.0:
Tri Dao's avatar
Tri Dao committed
67
            assert top_p <= 1.0, "top-p should be in (0, 1]."
68
69
70
        if top_k > 0:
            top_k = min(top_k, logits.size(-1))  # Safety check
            logits_top, indices = torch.topk(logits, top_k, dim=-1)
71
72
            if temperature != 1.0:
                logits_top /= temperature
73
74
75
            modify_logits_for_top_p_filtering(logits_top, top_p)
            return indices[
                torch.arange(indices.shape[0], device=indices.device),
Tri Dao's avatar
Tri Dao committed
76
                torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
77
78
            ]
        else:
79
80
            # Clone so that when we modify for top_p we don't change the original logits
            logits_top = logits / temperature if temperature != 1.0 else logits.clone()
81
            modify_logits_for_top_p_filtering(logits_top, top_p)
Tri Dao's avatar
Tri Dao committed
82
83
84
            return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
                dim=-1
            )
85
86


Tri Dao's avatar
Tri Dao committed
87
@torch.inference_mode()
Tri Dao's avatar
Tri Dao committed
88
89
90
91
92
93
94
95
96
97
98
99
100
def decode(
    input_ids,
    model,
    max_length,
    top_k=1,
    top_p=0.0,
    temperature=1.0,
    eos_token_id=None,
    teacher_outputs=None,
    vocab_size=None,
    tensor_parallel=1,
    fused_ft_kernel=False,
    cg=False,
Tri Dao's avatar
Tri Dao committed
101
    enable_timing=False,
Tri Dao's avatar
Tri Dao committed
102
):
103
104
105
106
    """Decoding, either greedy or with top-k or top-p sampling.
    If top-k = 0, don't limit the number of candidates (pure sampling).
    Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
    then top-p.
Tri Dao's avatar
Tri Dao committed
107
    We assume that all sequences in the same batch have the same length.
108

Tri Dao's avatar
Tri Dao committed
109
110
111
    Arguments:
        input_ids: (batch, seq_len)
        max_length: int
112
113
        teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
            logits, the next token is taken from the teacher_outputs. Useful for testing.
114
    Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
Tri Dao's avatar
Tri Dao committed
115
116
117
118
        sequences: (batch, max_length)
        scores: tuples of (batch, vocab_size)
    """
    batch_size, seqlen_og = input_ids.shape
Tri Dao's avatar
Tri Dao committed
119
    teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
120
    if cg:
Tri Dao's avatar
Tri Dao committed
121
        if not hasattr(model, "_decoding_cache"):
122
123
            model._decoding_cache = None
        model._decoding_cache = update_graph_cache(
Tri Dao's avatar
Tri Dao committed
124
125
126
127
128
129
            model,
            model._decoding_cache,
            batch_size,
            seqlen_og,
            max_length,
            tensor_parallel=tensor_parallel,
130
            fused_ft_kernel=fused_ft_kernel,
131
132
133
134
135
        )
        inference_params = model._decoding_cache.inference_params
        inference_params.max_sequence_len = max_length
        inference_params.max_batch_size = batch_size
        inference_params.sequence_len_offset = 0
136
        inference_params.lengths_per_sample.zero_()
137
    else:
Tri Dao's avatar
Tri Dao committed
138
139
140
        inference_params = InferenceParams(
            max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel
        )
Tri Dao's avatar
Tri Dao committed
141

Tri Dao's avatar
Tri Dao committed
142
143
144
145
146
147
148
149
150
151
152
153
154
    def get_logits(input_ids, inference_params):
        decoding = inference_params.sequence_len_offset > 0
        if decoding:
            position_ids = torch.full(
                (batch_size, 1),
                inference_params.sequence_len_offset,
                dtype=torch.long,
                device=input_ids.device,
            )
        else:
            position_ids = None
        if not cg or not decoding:
            logits = model(
Tri Dao's avatar
Tri Dao committed
155
156
157
                input_ids,
                position_ids=position_ids,
                inference_params=inference_params,
158
159
                num_last_tokens=1,
            ).logits.squeeze(dim=1)
Tri Dao's avatar
Tri Dao committed
160
        else:
Tri Dao's avatar
Tri Dao committed
161
            logits = model._decoding_cache.run(
Tri Dao's avatar
Tri Dao committed
162
163
                input_ids, position_ids, inference_params.sequence_len_offset
            ).clone()
Tri Dao's avatar
Tri Dao committed
164
        return logits[..., :vocab_size] if vocab_size is not None else logits
Tri Dao's avatar
Tri Dao committed
165

Tri Dao's avatar
Tri Dao committed
166
167
168
    def sample_tokens(logits, inference_params):
        if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset:
            token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
Tri Dao's avatar
Tri Dao committed
169
        else:
Tri Dao's avatar
Tri Dao committed
170
            token = teacher_outputs[:, inference_params.sequence_len_offset]
171
172
        # return rearrange(token, "b -> b 1")
        return token.unsqueeze(1)
Tri Dao's avatar
Tri Dao committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200

    def should_stop(current_token, inference_params):
        if inference_params.sequence_len_offset == 0:
            return False
        if eos_token_id is not None and (current_token == eos_token_id).all():
            return True
        if inference_params.sequence_len_offset >= max_length - 1:
            return True
        return False

    start = torch.cuda.Event(enable_timing=enable_timing)
    end = torch.cuda.Event(enable_timing=enable_timing)

    if enable_timing:
        if tensor_parallel > 1:
            torch.distributed.barrier()
        start.record()
    scores, sequences = [], [input_ids]
    while not should_stop(sequences[-1], inference_params):
        scores.append(get_logits(sequences[-1], inference_params))
        inference_params.sequence_len_offset += sequences[-1].shape[1]
        sequences.append(sample_tokens(scores[-1], inference_params))
    if enable_timing:
        end.record()
        if tensor_parallel > 1:
            torch.distributed.barrier()
        torch.cuda.synchronize()
        print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
201
    output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
202
    return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
Tri Dao's avatar
Tri Dao committed
203
204


205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, temperature=1.0):
    """Algorithm 1 from [1]
    [1] Fast Inference from Transformers via Speculative Decoding
    Yaniv Leviathan, Matan Kalman, Yossi Matias
    https://arxiv.org/abs/2211.17192

    Arguments:
        logits: Tensor of shape (batch_size, seqlen + 1, vocab_size)
        logits_draft: Tensor of shape (batch_size, seqlen, vocab_size)
        tokens_draft: Tensor of shape (batch_size, seqlen)
    Return:
        tokens: Tensor of shape (batch_size, seqlen + 1)
        num_generated_tokens: Tensor of shape (batch_size), with value in [1, seqlen + 1].
            For each sequence in the batch, the number of valid tokens that were sampled by
            speculative sampling.
    """
    batch, seqlen_p_1, vocab_size = logits.shape
    seqlen = seqlen_p_1 - 1
    assert logits_draft.shape == (batch, seqlen, vocab_size)
    assert tokens_draft.shape == (batch, seqlen)
    assert tokens_draft.dtype in [torch.int64, torch.int32]
    # TODO: if top_k = 1 we can simplify things and only work with indices
    if top_p > 0.0:
        assert top_p <= 1.0, "top-p should be in (0, 1]."
    # Clone so that when we modify for top_p we don't change the original logits
    logits = logits / temperature if temperature != 1.0 else logits.clone()
    logits_draft = logits_draft / temperature if temperature != 1.0 else logits_draft.clone()
    if top_k > 0:
        top_k = min(top_k, logits.size(-1))  # Safety check
        modify_logits_for_top_k_filtering(logits, top_k)
        modify_logits_for_top_k_filtering(logits_draft, top_k)
    modify_logits_for_top_p_filtering(logits, top_p)
    modify_logits_for_top_p_filtering(logits_draft, top_p)
    probs = torch.softmax(logits, dim=-1)
    probs_draft = torch.softmax(logits_draft, dim=-1)
    gather = lambda probs, tokens: rearrange(
        probs.gather(dim=-1, index=rearrange(tokens, "... -> ... 1")), "... 1 -> ..."
    )
    # (batch, seqlen)
    accepted = torch.rand(batch, seqlen, device=probs.device) * gather(
        probs_draft, tokens_draft
    ) <= gather(probs[:, :-1], tokens_draft)
    accepted_all = accepted.all(dim=-1)
    # (batch,)
    first_rejected_idx = torch.where(accepted_all, seqlen, accepted.int().argmin(dim=-1))
    probs_diff = torch.clamp(probs[:, :-1] - probs_draft, min=0.0)
    # torch.multinomial can deal with unnormalized probabilities
    # probs_diff /= probs_diff.sum(dim=-1, keepdim=True)
    resample_probs = torch.cat([probs_diff, probs[:, -1:]], dim=1)
    resample_probs = rearrange(
        resample_probs.gather(dim=1, index=repeat(first_rejected_idx, "b -> b 1 d", d=vocab_size)),
        "b 1 d -> b d",
    )
    resample = torch.multinomial(resample_probs, num_samples=1).squeeze(dim=-1)  # (batch,)
    tokens = F.pad(tokens_draft, (0, 1))
    tokens[:, first_rejected_idx] = resample
    return tokens, first_rejected_idx + 1


def decode_speculative(
    input_ids,
    model,
    model_draft,
    max_length,
    speculative_lookahead=3,
    top_k=1,
    top_p=0.0,
    temperature=1.0,
    eos_token_id=None,
    vocab_size=None,
    tensor_parallel=1,
    fused_ft_kernel=False,
    cg=False,
Tri Dao's avatar
Tri Dao committed
278
    enable_timing=False,
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
    debug=False,
):
    """
    TD: WIP, for my own understanding, lightly tested. Only support batch_size == 1 for now.

    Speculative decoding, either greedy or with top-k or top-p sampling.
    If top-k = 0, don't limit the number of candidates (pure sampling).
    Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
    then top-p.
    We assume that all sequences in the same batch have the same length.

    Arguments:
        input_ids: (batch, seq_len)
        max_length: int
    Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
        sequences: (batch, max_length)
        scores: tuples of (batch, vocab_size)
    """
    batch_size, seqlen_og = input_ids.shape
    assert batch_size == 1, "Speculative decoding implementation only supports batch_size=1"
    assert eos_token_id is None, "Speculative decoding implementation doesn't support eos_token_id"
    if cg:
        if not hasattr(model_draft, "_decoding_cache"):
            model_draft._decoding_cache = None
        model_draft._decoding_cache = update_graph_cache(
            model_draft,
            model_draft._decoding_cache,
            batch_size,
            seqlen_og,
            max_length,
            tensor_parallel=tensor_parallel,
310
            fused_ft_kernel=fused_ft_kernel,
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
        )
        inference_params_draft = model_draft._decoding_cache.inference_params
        inference_params_draft.max_sequence_len = max_length
        inference_params_draft.max_batch_size = batch_size
        inference_params_draft.sequence_len_offset = 0
        # fused_ft_kernel doesn't support passing in multiple tokens at once
        inference_params = InferenceParams(
            max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=False
        )
    else:
        inference_params_draft = InferenceParams(
            max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel
        )
        inference_params = InferenceParams(
            max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=False
        )

    def logits_forward_fn(model, input_ids, position_ids, inference_params, cg=False):
        if not cg:
            return model(
                input_ids,
                position_ids=position_ids,
                inference_params=inference_params,
                num_last_tokens=1,
            ).logits.squeeze(dim=1)
        else:
            return model._decoding_cache.run(
                input_ids, position_ids, inference_params.sequence_len_offset
            ).clone()

    logits_postprocess_fn = (
        lambda logits: logits[..., :vocab_size] if vocab_size is not None else logits
    )

    def sample_tokens(
Tri Dao's avatar
Tri Dao committed
346
347
348
349
350
351
352
353
        input_ids,
        model,
        inference_params,
        sample_fn,
        num_tokens=1,
        cg=False,
        decoding=True,
        last_token_logits=False,
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
    ):
        """Sample `num_tokens` tokens from the model, given the previous logits.
        Also return the logits of the sampled tokens.
        Arguments:
            input_ids: (batch, seqlen)
            decoding: whether we're in the decoding phase or the prefilling phase. Prefill doesn't
                need special position_ids.
            last_token_logits: whether to return the logits of the last token. Normally we don't need this.
                However, for speculative sampling, if the main model accepts all the draft tokens, plus it
                samples one new token, then by right at the next iteration the draft model need to evaluate
                the logits of the last draft token and the logits of the newly sampled token.
                This makes implementation more complicated. So here we just evaluate the logits of the last
                token in the draft model to simplify the implementation.
        Return:
            tokens: (batch, num_tokens)
            scores: (batch, num_tokens), which contains @previous_logits and the logits of the next
                (num_tokens - 1) tokens. The logits of the last token isn't computed unless last_token_logits=True.
                In which case we have scores of shape (batch, num_tokens + 1)
        """
        batch_size, seqlen = input_ids.shape
        assert num_tokens >= 1
        sequences = []
        if decoding:
            assert seqlen == 1
Tri Dao's avatar
Tri Dao committed
378
379
380
381
382
            position_ids = repeat(
                torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
                + inference_params.sequence_len_offset,
                "s -> b s",
                b=batch_size,
383
            )
Tri Dao's avatar
Tri Dao committed
384
385
386
387
388
389
            # position_ids = torch.full(
            #     (batch_size, 1),
            #     inference_params.sequence_len_offset,
            #     dtype=torch.long,
            #     device=input_ids.device,
            # )
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
        else:
            position_ids = None
        logits = logits_postprocess_fn(
            logits_forward_fn(model, input_ids, position_ids, inference_params, cg=decoding and cg)
        )
        inference_params.sequence_len_offset += input_ids.shape[1]
        scores = [logits]
        next_token = sample_fn(logits)
        sequences.append(next_token)
        for i in range(num_tokens):
            if i < num_tokens - 1 or last_token_logits:
                position_ids = torch.full(
                    (batch_size, 1),
                    inference_params_draft.sequence_len_offset,
                    dtype=torch.long,
                    device=input_ids.device,
                )
                logits = logits_postprocess_fn(
                    logits_forward_fn(
Tri Dao's avatar
Tri Dao committed
409
410
411
412
413
                        model,
                        rearrange(next_token, "b -> b 1"),
                        position_ids,
                        inference_params,
                        cg=cg,
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
                    )
                )
                inference_params.sequence_len_offset += 1
                scores.append(logits)
            if i < num_tokens - 1:
                next_token = sample_fn(logits)
                sequences.append(next_token)
        return torch.stack(sequences, dim=1), torch.stack(scores, dim=1)

    sampling_kwargs = dict(top_k=top_k, top_p=top_p, temperature=temperature)
    sample_fn = partial(sample, **sampling_kwargs)
    sample_tokens_main = partial(
        sample_tokens, model=model, sample_fn=sample_fn, inference_params=inference_params, cg=False
    )  # main model doesn't use CUDA graph
    sample_tokens_draft = partial(
        sample_tokens,
        model=model_draft,
        sample_fn=sample_fn,
        last_token_logits=True,
        inference_params=inference_params_draft,
Tri Dao's avatar
Tri Dao committed
434
        cg=cg,
435
436
437
438
439
440
441
442
443
    )

    if debug:
        from transformers import AutoTokenizer

        tokenizer = AutoTokenizer.from_pretrained("gpt2")
    sequences = [input_ids]
    scores = []
    with torch.inference_mode():
Tri Dao's avatar
Tri Dao committed
444
        if enable_timing:
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
            if tensor_parallel > 1:
                torch.distributed.barrier()
            torch.cuda.synchronize()
            start = time.time()

        if seqlen_og >= max_length - 1:
            # Don't do speculative sampling, just sample 1 token from the model
            tokens, scores_new = sample_tokens_main(input_ids, num_tokens=1, decoding=False)
            sequences.append(tokens)
            scores.append(scores_new)
        else:
            # Sample from draft model, which produces @n_spec_tokens, and @model
            # will then use to produce between 1 and 1 + @n_spec_tokens tokens.
            # We want seqlen_og + 1 + @n_spec_tokens to be <= @max_length.
            n_spec_tokens = min(speculative_lookahead, max_length - seqlen_og - 1)
            tokens_draft, scores_draft = sample_tokens_draft(
                input_ids,
                num_tokens=n_spec_tokens,
                decoding=False,
            )
            if debug:
                scores_draft_ref = model_draft(
                    torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
                ).logits
                print((scores_draft[:, :-1] - scores_draft_ref[:, :-1]).abs().max())

            # Evaluate the draft tokens with the model
            logits = model(
                torch.cat([input_ids, tokens_draft], dim=1),
                inference_params=inference_params,
                num_last_tokens=n_spec_tokens + 1,
            ).logits
            logits = logits_postprocess_fn(logits)
            tokens, num_generated_tokens = sample_speculative(
                logits, scores_draft[:, :-1], tokens_draft, **sampling_kwargs
            )
            if debug:
                print(tokens)
                print(num_generated_tokens)
                # breakpoint()
            # TODO: we're using the fact that batch_size == 1
            # TODO: check eos_token_id
            sequences.append(tokens[:1, : num_generated_tokens[0]])
            scores.append(logits[:1, : num_generated_tokens[0]])
            # Note that @model has not evaluated the last sampled token yet, so we'll need to pass
            # that in the next time we call @model.
            inference_params.sequence_len_offset = seqlen_og + num_generated_tokens[0].item() - 1
            inference_params_draft.sequence_len_offset = inference_params.sequence_len_offset
            if debug:
                cur_ids = torch.cat([input_ids, sequences[-1]], dim=1)
                scores_ref = model(
                    cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1
                ).logits
                print((scores[-1] - scores_ref[:, :-1]).abs().max())

        while True:
            # sequence_len_offset is total length generated - 1
            if inference_params.sequence_len_offset >= max_length - 1:
                break
            if inference_params.sequence_len_offset >= max_length - 2:
                # Don't do speculative sampling, just sample 1 token from the model
                tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1)
                sequences.append(tokens)
                scores.append(scores_new)
                break
            # Sample from draft model
            n_spec_tokens = min(
                speculative_lookahead, max_length - inference_params_draft.sequence_len_offset - 2
            )
            tokens_draft, scores_draft = sample_tokens_draft(
                sequences[-1][:, -1:], num_tokens=n_spec_tokens
            )
            if debug:
                scores_draft_ref = model_draft(
                    torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
                ).logits
                print((scores_draft[:, :-1] - scores_draft_ref[:, :-1]).abs().max())
            # Evaluate the draft tokens with the model
            position_ids = repeat(
                torch.arange(
                    inference_params.sequence_len_offset,
                    # 1 extra token from last time that hasn't been passed through model
                    inference_params.sequence_len_offset + n_spec_tokens + 1,
                    dtype=torch.long,
                    device=input_ids.device,
                ),
                "s -> b s",
                b=batch_size,
            )
            logits = model(
                torch.cat([sequences[-1][:, -1:], tokens_draft], dim=1),
                position_ids=position_ids,
                inference_params=inference_params,
            ).logits  # (batch, n_spec_tokens, vocab_size)
            logits = logits_postprocess_fn(logits)
            inference_params.sequence_len_offset += 1
            if debug:
                logits_ref = model(
                    torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
                ).logits
                print((logits - logits_ref).abs().max())
            tokens, num_generated_tokens = sample_speculative(
                logits, scores_draft[:, :-1], tokens_draft, **sampling_kwargs
            )
            if debug:
                print(tokens)
                print(num_generated_tokens)
            sequences.append(tokens[:1, : num_generated_tokens[0]])
            scores.append(logits[:1, : num_generated_tokens[0]])
            inference_params.sequence_len_offset += num_generated_tokens[0].item() - 1
            inference_params_draft.sequence_len_offset = inference_params.sequence_len_offset
            # breakpoint()
            if debug:
                cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1)
                scores_ref = model(
                    cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1
                ).logits
                print((scores[-1] - scores_ref[:, :-1]).abs().max())

Tri Dao's avatar
Tri Dao committed
564
        if enable_timing:
565
566
567
568
569
570
571
572
573
574
575
576
577
            if tensor_parallel > 1:
                torch.distributed.barrier()
            torch.cuda.synchronize()
            print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
    sequences = torch.cat(sequences, dim=1)
    scores = torch.cat(scores, dim=1)
    if debug:
        scores_ref = model(sequences).logits
        print((scores - scores_ref[:, seqlen_og - 1 : -1]).abs().max())
    output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
    return output_cls(sequences=sequences, scores=scores)


Tri Dao's avatar
Tri Dao committed
578
class GenerationMixin:
579
580
581
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        raise NotImplementedError

Tri Dao's avatar
Tri Dao committed
582
583
584
585
586
587
588
589
590
591
592
593
594
595
    def generate(
        self,
        input_ids,
        max_length,
        top_k=1,
        top_p=0.0,
        temperature=1.0,
        return_dict_in_generate=False,
        output_scores=False,
        **kwargs,
    ):
        output = decode(
            input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs
        )
Tri Dao's avatar
Tri Dao committed
596
597
598
        if not output_scores:
            output.scores = None
        return output if return_dict_in_generate else output.sequences
Tri Dao's avatar
Tri Dao committed
599
600


Tri Dao's avatar
Tri Dao committed
601
602
603
604
605
606
607
608
def allocate_inference_cache(
    max_batch_size,
    max_seqlen,
    nheads,
    headdim,
    layers: Union[int, Sequence],
    device,
    dtype=torch.float16,
609
    fused_ft_kernel=False,
Tri Dao's avatar
Tri Dao committed
610
):
611
612
613
614
615
    assert dtype in [torch.float16, torch.bfloat16, torch.float32]
    packsize = 4 if dtype == torch.float32 else 8
    assert headdim % packsize == 0
    k_cache_shape = (max_batch_size, nheads, headdim // packsize, max_seqlen, packsize)
    v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim)
616
    kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim)
617
618
    if isinstance(layers, int):
        layers = range(layers)
Tri Dao's avatar
Tri Dao committed
619
620
621
622
623
    return {
        i: (
            torch.empty(k_cache_shape, device=device, dtype=dtype),
            torch.empty(v_cache_shape, device=device, dtype=dtype),
        )
624
625
        if fused_ft_kernel
        else torch.empty(kv_cache_sahpe, device=device, dtype=dtype)
Tri Dao's avatar
Tri Dao committed
626
627
        for i in layers
    }
Tri Dao's avatar
Tri Dao committed
628
629
630
631
632
633
634
635
636
637
638


def seqlen_to_seqlen_type(seqlen: int) -> int:
    """Convert sequence length to a seqlen_type.
    This is used to determine which cuda graph to use.
    Arguments:
        seqlen: int
    """
    return 0 if seqlen < 32 else (1 if seqlen < 2048 else 2)


Tri Dao's avatar
Tri Dao committed
639
def seqlen_type_to_max_seqlen(seqlen_type: int) -> int:
Tri Dao's avatar
Tri Dao committed
640
    assert seqlen_type in [0, 1, 2]
Tri Dao's avatar
Tri Dao committed
641
    return 32 if seqlen_type == 0 else (2048 if seqlen_type == 1 else 2**32)
Tri Dao's avatar
Tri Dao committed
642
643


644
645
646
647
648
649
650
651
652
653
654
655
656
@dataclass
class DecodingCGCache:
    max_batch_size: int = 0
    max_seqlen: int = 0
    device = None
    dtype = None
    callables: dict = field(default_factory=dict)
    mempool = None
    inference_params: Optional[InferenceParams] = None
    run: Optional[Callable] = None


@torch.inference_mode()
Tri Dao's avatar
Tri Dao committed
657
def update_graph_cache(
658
659
660
661
662
663
664
665
666
    model,
    cache,
    batch_size,
    seqlen_og,
    max_seqlen,
    tensor_parallel=1,
    dtype=None,
    n_warmups=2,
    fused_ft_kernel=False,
Tri Dao's avatar
Tri Dao committed
667
):
668
669
670
671
672
673
    if cache is None:
        cache = DecodingCGCache()
    param_example = next(iter(model.parameters()))
    device = param_example.device
    if dtype is None:
        dtype = param_example.dtype
Tri Dao's avatar
Tri Dao committed
674
675
676
677
678
    if (
        (device, dtype) != (cache.device, cache.dtype)
        or batch_size > cache.max_batch_size
        or max_seqlen > cache.max_seqlen
    ):  # Invalidate the cache
679
680
681
682
683
684
        cache.callables = {}
        cache.mempool = None
        cache.inference_params = None
        gc.collect()
        cache.device, cache.dtype = device, dtype
        cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
Tri Dao's avatar
Tri Dao committed
685
        if hasattr(model, "allocate_inference_cache"):
686
687
688
            inf_cache = model.allocate_inference_cache(
                batch_size, max_seqlen, dtype, fused_ft_kernel=fused_ft_kernel
            )
689
        else:
Tri Dao's avatar
Tri Dao committed
690
691
692
693
694
            headdim = getattr(
                model.config,
                "head_dim",
                model.config.hidden_size // model.config.num_attention_heads,
            )
695
            inf_cache = allocate_inference_cache(
Tri Dao's avatar
Tri Dao committed
696
697
698
699
700
701
702
                batch_size,
                max_seqlen,
                model.config.num_attention_heads // tensor_parallel,
                headdim,
                model.config.num_hidden_layers,
                device,
                dtype,
703
                fused_ft_kernel=fused_ft_kernel,
704
            )
705
706
        lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
        cache.inference_params = InferenceParams(
Tri Dao's avatar
Tri Dao committed
707
708
709
710
            max_sequence_len=max_seqlen,
            max_batch_size=batch_size,
            sequence_len_offset=seqlen_og,
            key_value_memory_dict=inf_cache,
711
            fused_ft_kernel=fused_ft_kernel,
Tri Dao's avatar
Tri Dao committed
712
            lengths_per_sample=lengths_per_sample,
713
714
715
        )
        cache.mempool = torch.cuda.graphs.graph_pool_handle()
    for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1):
716
        if (batch_size, s_type) not in cache.callables:
Tri Dao's avatar
Tri Dao committed
717
            max_seqlen_ = min(max(seqlen_og, seqlen_type_to_max_seqlen(s_type)), max_seqlen)
718
            cache.callables[batch_size, s_type] = capture_graph(
Tri Dao's avatar
Tri Dao committed
719
720
721
722
723
724
                model,
                cache.inference_params,
                batch_size,
                max_seqlen_,
                mempool=cache.mempool,
                n_warmups=n_warmups,
725
726
727
            )

    def dispatch(input_ids, position_ids, seqlen):
728
        batch_size = input_ids.shape[0]
Tri Dao's avatar
Tri Dao committed
729
730
731
        return cache.callables[batch_size, seqlen_to_seqlen_type(seqlen)](
            input_ids, position_ids, seqlen
        )
732
733

    cache.run = dispatch
Tri Dao's avatar
Tri Dao committed
734
    cache.inference_params.sequence_len_offset = 0  # Reset so it's not confusing
735
736
737
    return cache


Tri Dao's avatar
Tri Dao committed
738
def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, n_warmups=2):
Tri Dao's avatar
Tri Dao committed
739
740
741
    device = next(iter(model.parameters())).device
    input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
    position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
Tri Dao's avatar
Tri Dao committed
742
743
744
745
746
    sequence_len_offset_og = inference_params.sequence_len_offset
    # TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is
    # used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample.
    inference_params.sequence_len_offset = max_seqlen - 1
    inference_params.lengths_per_sample[:] = max_seqlen - 1
747
748
749
750
751

    # Warmup before capture
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
752
        for _ in range(n_warmups):
Tri Dao's avatar
Tri Dao committed
753
754
755
756
            logits = model(
                input_ids,
                position_ids=position_ids,
                inference_params=inference_params,
757
                num_last_tokens=1,
Tri Dao's avatar
Tri Dao committed
758
            ).logits
759
        s.synchronize()
760
761
762
763
764
        # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
        # which requires that graph launch and non-captured launch to not overlap (I think,
        # that's how I interpret the documentation). I'm not sure if this is required.
        if torch.distributed.is_initialized():
            torch.distributed.barrier()
765
766
767
768
769
    torch.cuda.current_stream().wait_stream(s)
    # Captures the graph
    # To allow capture, automatically sets a side stream as the current stream in the context
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph, pool=mempool):
Tri Dao's avatar
Tri Dao committed
770
771
772
773
        logits = model(
            input_ids,
            position_ids=position_ids,
            inference_params=inference_params,
774
775
            num_last_tokens=1,
        ).logits.squeeze(dim=1)
Tri Dao's avatar
Tri Dao committed
776
777
778

    def run(new_input_ids, new_position_ids, seqlen):
        inference_params.lengths_per_sample[:] = seqlen
779
780
781
        input_ids.copy_(new_input_ids)
        position_ids.copy_(new_position_ids)
        graph.replay()
782
        return logits.clone()
Tri Dao's avatar
Tri Dao committed
783

Tri Dao's avatar
Tri Dao committed
784
    inference_params.sequence_len_offset = sequence_len_offset_og
785
    return run