generation.py 30 KB
Newer Older
1
# Copyright (c) 2023, Tri Dao.
Tri Dao's avatar
Tri Dao committed
2
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
3
import gc
Tri Dao's avatar
Tri Dao committed
4
5
import time
from collections import namedtuple
Tri Dao's avatar
Tri Dao committed
6
from dataclasses import dataclass, field
7
from functools import partial
Tri Dao's avatar
Tri Dao committed
8
from typing import Callable, Optional, Sequence, Union
Tri Dao's avatar
Tri Dao committed
9

Tri Dao's avatar
Tri Dao committed
10
import torch
11
12
import torch.nn.functional as F
from einops import rearrange, repeat
Tri Dao's avatar
Tri Dao committed
13
14
from torch import Tensor
from torch.profiler import ProfilerActivity, profile, record_function
Tri Dao's avatar
Tri Dao committed
15
16
17
18
19
20

try:
    from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
except ImportError:
    GreedySearchDecoderOnlyOutput = namedtuple("GreedySearchDecoderOnlyOutput", ["sequences", "scores"])
    SampleDecoderOnlyOutput = namedtuple("SampleDecoderOnlyOutput", ["sequences", "scores"])
Tri Dao's avatar
Tri Dao committed
21
22
23
24
25
26


@dataclass
class InferenceParams:
    """Inference parameters that are passed to the main model in order
    to efficienly calculate and store the context during inference."""
Tri Dao's avatar
Tri Dao committed
27

28
    max_seqlen: int
Tri Dao's avatar
Tri Dao committed
29
    max_batch_size: int
30
    seqlen_offset: int = 0
Tri Dao's avatar
Tri Dao committed
31
32
    batch_size_offset: int = 0
    key_value_memory_dict: dict = field(default_factory=dict)
33
    lengths_per_sample: Optional[Tensor] = None
Tri Dao's avatar
Tri Dao committed
34

35
36
37
38
39
40
41
    def reset(self, max_seqlen, max_batch_size):
        self.max_seqlen = max_seqlen
        self.max_batch_size = max_batch_size
        self.seqlen_offset = 0
        if self.lengths_per_sample is not None:
            self.lengths_per_sample.zero_()

Tri Dao's avatar
Tri Dao committed
42

43
44
45
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
def modify_logits_for_top_k_filtering(logits, top_k):
46
    """Set the logits for none top-k values to -inf. Done in-place."""
47
48
49
50
    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
    logits.masked_fill_(indices_to_remove, float("-Inf"))


51
52
53
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
def modify_logits_for_top_p_filtering(logits, top_p):
54
    """Set the logits for none top-p values to -inf. Done in-place."""
55
    if top_p <= 0.0 or top_p >= 1.0:
56
57
58
59
        return
    # First sort and calculate cumulative sum of probabilities.
    sorted_logits, sorted_indices = torch.sort(logits, descending=False)
    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
Tri Dao's avatar
Tri Dao committed
60
    # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
61
62
    sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
    # scatter sorted tensors to original indexing
Tri Dao's avatar
Tri Dao committed
63
64
65
    indices_to_remove = sorted_indices_to_remove.scatter(
        1, sorted_indices, sorted_indices_to_remove
    )
66
    logits.masked_fill_(indices_to_remove, float("-inf"))
67
68
69
70
71
72
73
74
75
76
77


def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
    """Sample from top-k logits.
    Arguments:
        logits: Tensor of shape (batch_size, vocab_size)
    """
    if top_k == 1:  # Short-circuit for greedy decoding
        return logits.argmax(dim=-1)
    else:
        if top_p > 0.0:
Tri Dao's avatar
Tri Dao committed
78
            assert top_p <= 1.0, "top-p should be in (0, 1]."
79
80
81
        if top_k > 0:
            top_k = min(top_k, logits.size(-1))  # Safety check
            logits_top, indices = torch.topk(logits, top_k, dim=-1)
82
83
            if temperature != 1.0:
                logits_top /= temperature
84
85
86
            modify_logits_for_top_p_filtering(logits_top, top_p)
            return indices[
                torch.arange(indices.shape[0], device=indices.device),
Tri Dao's avatar
Tri Dao committed
87
                torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
88
89
            ]
        else:
90
91
            # Clone so that when we modify for top_p we don't change the original logits
            logits_top = logits / temperature if temperature != 1.0 else logits.clone()
92
            modify_logits_for_top_p_filtering(logits_top, top_p)
Tri Dao's avatar
Tri Dao committed
93
94
95
            return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
                dim=-1
            )
96
97


Tri Dao's avatar
Tri Dao committed
98
@torch.inference_mode()
Tri Dao's avatar
Tri Dao committed
99
100
101
102
103
104
105
106
107
108
109
110
def decode(
    input_ids,
    model,
    max_length,
    top_k=1,
    top_p=0.0,
    temperature=1.0,
    eos_token_id=None,
    teacher_outputs=None,
    vocab_size=None,
    tensor_parallel=1,
    cg=False,
Tri Dao's avatar
Tri Dao committed
111
    enable_timing=False,
Tri Dao's avatar
Tri Dao committed
112
):
113
114
115
116
    """Decoding, either greedy or with top-k or top-p sampling.
    If top-k = 0, don't limit the number of candidates (pure sampling).
    Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
    then top-p.
Tri Dao's avatar
Tri Dao committed
117
    We assume that all sequences in the same batch have the same length.
118

Tri Dao's avatar
Tri Dao committed
119
120
121
    Arguments:
        input_ids: (batch, seq_len)
        max_length: int
122
123
        teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
            logits, the next token is taken from the teacher_outputs. Useful for testing.
124
    Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
Tri Dao's avatar
Tri Dao committed
125
126
127
128
        sequences: (batch, max_length)
        scores: tuples of (batch, vocab_size)
    """
    batch_size, seqlen_og = input_ids.shape
Tri Dao's avatar
Tri Dao committed
129
    teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
130
    if cg:
Tri Dao's avatar
Tri Dao committed
131
        if not hasattr(model, "_decoding_cache"):
132
133
            model._decoding_cache = None
        model._decoding_cache = update_graph_cache(
Tri Dao's avatar
Tri Dao committed
134
135
136
137
138
139
            model,
            model._decoding_cache,
            batch_size,
            seqlen_og,
            max_length,
            tensor_parallel=tensor_parallel,
140
141
        )
        inference_params = model._decoding_cache.inference_params
142
        inference_params.reset(max_length, batch_size)
143
    else:
144
        inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
Tri Dao's avatar
Tri Dao committed
145

Tri Dao's avatar
Tri Dao committed
146
    def get_logits(input_ids, inference_params):
147
        decoding = inference_params.seqlen_offset > 0
Tri Dao's avatar
Tri Dao committed
148
149
150
        if decoding:
            position_ids = torch.full(
                (batch_size, 1),
151
                inference_params.seqlen_offset,
Tri Dao's avatar
Tri Dao committed
152
153
154
155
156
157
158
                dtype=torch.long,
                device=input_ids.device,
            )
        else:
            position_ids = None
        if not cg or not decoding:
            logits = model(
Tri Dao's avatar
Tri Dao committed
159
160
161
                input_ids,
                position_ids=position_ids,
                inference_params=inference_params,
162
163
                num_last_tokens=1,
            ).logits.squeeze(dim=1)
Tri Dao's avatar
Tri Dao committed
164
        else:
Tri Dao's avatar
Tri Dao committed
165
            logits = model._decoding_cache.run(
166
                input_ids, position_ids, inference_params.seqlen_offset
Tri Dao's avatar
Tri Dao committed
167
            ).squeeze(dim=1)
Tri Dao's avatar
Tri Dao committed
168
        return logits[..., :vocab_size] if vocab_size is not None else logits
Tri Dao's avatar
Tri Dao committed
169

Tri Dao's avatar
Tri Dao committed
170
    def sample_tokens(logits, inference_params):
171
        if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
Tri Dao's avatar
Tri Dao committed
172
            token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
Tri Dao's avatar
Tri Dao committed
173
        else:
174
            token = teacher_outputs[:, inference_params.seqlen_offset]
175
176
        # return rearrange(token, "b -> b 1")
        return token.unsqueeze(1)
Tri Dao's avatar
Tri Dao committed
177
178

    def should_stop(current_token, inference_params):
179
        if inference_params.seqlen_offset == 0:
Tri Dao's avatar
Tri Dao committed
180
181
182
            return False
        if eos_token_id is not None and (current_token == eos_token_id).all():
            return True
183
        if inference_params.seqlen_offset >= max_length - 1:
Tri Dao's avatar
Tri Dao committed
184
185
186
187
188
189
190
191
192
193
194
195
196
            return True
        return False

    start = torch.cuda.Event(enable_timing=enable_timing)
    end = torch.cuda.Event(enable_timing=enable_timing)

    if enable_timing:
        if tensor_parallel > 1:
            torch.distributed.barrier()
        start.record()
    scores, sequences = [], [input_ids]
    while not should_stop(sequences[-1], inference_params):
        scores.append(get_logits(sequences[-1], inference_params))
197
        inference_params.seqlen_offset += sequences[-1].shape[1]
Tri Dao's avatar
Tri Dao committed
198
199
200
201
202
203
204
        sequences.append(sample_tokens(scores[-1], inference_params))
    if enable_timing:
        end.record()
        if tensor_parallel > 1:
            torch.distributed.barrier()
        torch.cuda.synchronize()
        print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
205
    output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
206
    return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
Tri Dao's avatar
Tri Dao committed
207
208


209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, temperature=1.0):
    """Algorithm 1 from [1]
    [1] Fast Inference from Transformers via Speculative Decoding
    Yaniv Leviathan, Matan Kalman, Yossi Matias
    https://arxiv.org/abs/2211.17192

    Arguments:
        logits: Tensor of shape (batch_size, seqlen + 1, vocab_size)
        logits_draft: Tensor of shape (batch_size, seqlen, vocab_size)
        tokens_draft: Tensor of shape (batch_size, seqlen)
    Return:
        tokens: Tensor of shape (batch_size, seqlen + 1)
        num_generated_tokens: Tensor of shape (batch_size), with value in [1, seqlen + 1].
            For each sequence in the batch, the number of valid tokens that were sampled by
            speculative sampling.
    """
    batch, seqlen_p_1, vocab_size = logits.shape
    seqlen = seqlen_p_1 - 1
    assert logits_draft.shape == (batch, seqlen, vocab_size)
    assert tokens_draft.shape == (batch, seqlen)
    assert tokens_draft.dtype in [torch.int64, torch.int32]
    # TODO: if top_k = 1 we can simplify things and only work with indices
    if top_p > 0.0:
        assert top_p <= 1.0, "top-p should be in (0, 1]."
    # Clone so that when we modify for top_p we don't change the original logits
    logits = logits / temperature if temperature != 1.0 else logits.clone()
    logits_draft = logits_draft / temperature if temperature != 1.0 else logits_draft.clone()
    if top_k > 0:
        top_k = min(top_k, logits.size(-1))  # Safety check
        modify_logits_for_top_k_filtering(logits, top_k)
        modify_logits_for_top_k_filtering(logits_draft, top_k)
    modify_logits_for_top_p_filtering(logits, top_p)
    modify_logits_for_top_p_filtering(logits_draft, top_p)
    probs = torch.softmax(logits, dim=-1)
    probs_draft = torch.softmax(logits_draft, dim=-1)
    gather = lambda probs, tokens: rearrange(
        probs.gather(dim=-1, index=rearrange(tokens, "... -> ... 1")), "... 1 -> ..."
    )
    # (batch, seqlen)
    accepted = torch.rand(batch, seqlen, device=probs.device) * gather(
        probs_draft, tokens_draft
    ) <= gather(probs[:, :-1], tokens_draft)
    accepted_all = accepted.all(dim=-1)
    # (batch,)
    first_rejected_idx = torch.where(accepted_all, seqlen, accepted.int().argmin(dim=-1))
    probs_diff = torch.clamp(probs[:, :-1] - probs_draft, min=0.0)
    # torch.multinomial can deal with unnormalized probabilities
    # probs_diff /= probs_diff.sum(dim=-1, keepdim=True)
    resample_probs = torch.cat([probs_diff, probs[:, -1:]], dim=1)
    resample_probs = rearrange(
        resample_probs.gather(dim=1, index=repeat(first_rejected_idx, "b -> b 1 d", d=vocab_size)),
        "b 1 d -> b d",
    )
    resample = torch.multinomial(resample_probs, num_samples=1).squeeze(dim=-1)  # (batch,)
    tokens = F.pad(tokens_draft, (0, 1))
    tokens[:, first_rejected_idx] = resample
    return tokens, first_rejected_idx + 1


268
@torch.inference_mode()
269
270
271
272
273
274
275
276
277
278
279
280
281
def decode_speculative(
    input_ids,
    model,
    model_draft,
    max_length,
    speculative_lookahead=3,
    top_k=1,
    top_p=0.0,
    temperature=1.0,
    eos_token_id=None,
    vocab_size=None,
    tensor_parallel=1,
    cg=False,
Tri Dao's avatar
Tri Dao committed
282
    enable_timing=False,
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
    debug=False,
):
    """
    TD: WIP, for my own understanding, lightly tested. Only support batch_size == 1 for now.

    Speculative decoding, either greedy or with top-k or top-p sampling.
    If top-k = 0, don't limit the number of candidates (pure sampling).
    Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
    then top-p.
    We assume that all sequences in the same batch have the same length.

    Arguments:
        input_ids: (batch, seq_len)
        max_length: int
    Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
        sequences: (batch, max_length)
        scores: tuples of (batch, vocab_size)
    """
    batch_size, seqlen_og = input_ids.shape
    assert batch_size == 1, "Speculative decoding implementation only supports batch_size=1"
    assert eos_token_id is None, "Speculative decoding implementation doesn't support eos_token_id"
    if cg:
        if not hasattr(model_draft, "_decoding_cache"):
            model_draft._decoding_cache = None
        model_draft._decoding_cache = update_graph_cache(
            model_draft,
            model_draft._decoding_cache,
            batch_size,
            seqlen_og,
            max_length,
Tri Dao's avatar
Tri Dao committed
313
314
            # draft model needs to process either 1 or 2 tokens at a time
            decoding_seqlens=(1, 2),
315
316
317
            tensor_parallel=tensor_parallel,
        )
        inference_params_draft = model_draft._decoding_cache.inference_params
318
        inference_params_draft.reset(max_length, batch_size)
Tri Dao's avatar
Tri Dao committed
319
320
321
322
323
324
325
326
327
328
329
330
331
        if not hasattr(model, "_decoding_cache"):
            model._decoding_cache = None
        model._decoding_cache = update_graph_cache(
            model,
            model._decoding_cache,
            batch_size,
            seqlen_og,
            max_length,
            decoding_seqlens=range(1, speculative_lookahead + 2),
            tensor_parallel=tensor_parallel,
        )
        inference_params = model._decoding_cache.inference_params
        inference_params.reset(max_length, batch_size)
332
    else:
333
334
        inference_params_draft = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
        inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
335

Tri Dao's avatar
Tri Dao committed
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
    def get_logits(input_ids, inference_params, model, num_last_tokens=1, cg=False):
        decoding = inference_params.seqlen_offset > 0
        if decoding:
            seqlen = input_ids.shape[1]
            # if inference_params.lengths_per_sample is None:
            # TODO: in the case of batched decoding where each sequence has a different length,
            # we need to compute the position_ids for each sequence using lengths_per_sample
            if True:
                cache_seqlens = torch.full(
                    (input_ids.shape[0],),
                    inference_params.seqlen_offset,
                    dtype=torch.int32,
                    device=input_ids.device,
                )
            else:
                cache_seqlens = inference_params.lengths_per_sample
            position_ids = cache_seqlens[:, None] + torch.arange(
                seqlen, dtype=torch.long, device=input_ids.device
            )
        else:
            position_ids = None
        if not cg or not decoding:
            logits = model(
359
360
361
                input_ids,
                position_ids=position_ids,
                inference_params=inference_params,
Tri Dao's avatar
Tri Dao committed
362
363
                num_last_tokens=num_last_tokens,
            ).logits
364
        else:
Tri Dao's avatar
Tri Dao committed
365
366
367
368
            # NOTE: careful, CUDA graph is set to have num_last_tokens=input_ids.shape[1].
            # This might not be compatible the num_last_tokens used here.
            assert num_last_tokens <= input_ids.shape[1]
            logits = model._decoding_cache.run(
369
                input_ids, position_ids, inference_params.seqlen_offset
Tri Dao's avatar
Tri Dao committed
370
371
            )[:, -num_last_tokens:]
        return logits[..., :vocab_size] if vocab_size is not None else logits
372

Tri Dao's avatar
Tri Dao committed
373
    def sample_tokens(input_ids, get_logits_fn, inference_params, sample_fn, num_tokens=1):
374
375
376
377
378
379
380
        """Sample `num_tokens` tokens from the model, given the previous logits.
        Also return the logits of the sampled tokens.
        Arguments:
            input_ids: (batch, seqlen)
        Return:
            tokens: (batch, num_tokens)
            scores: (batch, num_tokens), which contains @previous_logits and the logits of the next
Tri Dao's avatar
Tri Dao committed
381
                (num_tokens - 1) tokens. The logits of the last token isn't computed.
382
383
        """
        assert num_tokens >= 1
Tri Dao's avatar
Tri Dao committed
384
        sequences, scores = [input_ids], []
385
        for i in range(num_tokens):
Tri Dao's avatar
Tri Dao committed
386
387
388
389
            scores.append(get_logits_fn(sequences[-1], inference_params)[:, -1])
            inference_params.seqlen_offset += sequences[-1].shape[1]
            sequences.append(sample_fn(scores[-1]).unsqueeze(1))
        return torch.cat(sequences[1:], dim=1), torch.stack(scores, dim=1)
390
391
392

    sampling_kwargs = dict(top_k=top_k, top_p=top_p, temperature=temperature)
    sample_fn = partial(sample, **sampling_kwargs)
Tri Dao's avatar
Tri Dao committed
393
394
    get_logits_main = partial(get_logits, model=model, cg=cg)
    get_logits_draft = partial(get_logits, model=model_draft, cg=cg)
395
    sample_tokens_main = partial(
Tri Dao's avatar
Tri Dao committed
396
397
398
399
400
        sample_tokens,
        get_logits_fn=get_logits_main,
        sample_fn=sample_fn,
        inference_params=inference_params,
    )
401
402
    sample_tokens_draft = partial(
        sample_tokens,
Tri Dao's avatar
Tri Dao committed
403
        get_logits_fn=get_logits_draft,
404
405
406
407
408
409
410
411
        sample_fn=sample_fn,
        inference_params=inference_params_draft,
    )

    if debug:
        from transformers import AutoTokenizer

        tokenizer = AutoTokenizer.from_pretrained("gpt2")
Tri Dao's avatar
Tri Dao committed
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
    if enable_timing:
        if tensor_parallel > 1:
            torch.distributed.barrier()
        torch.cuda.synchronize()
        start = time.time()

    sequences, scores = [input_ids], []
    num_main_model_calls = 0
    num_draft_tokens = 0
    num_accepted_tokens_history = []
    if seqlen_og >= max_length - 1:
        # Don't do speculative sampling, just sample 1 token from the model
        tokens, scores_new = sample_tokens_main(input_ids, num_tokens=1)
        sequences.append(tokens)
        scores.append(scores_new)
    else:
        # Sample from draft model, which produces @n_spec_tokens, and @model
        # will then use to produce between 1 and 1 + @n_spec_tokens tokens.
        # We want seqlen_og + 1 + @n_spec_tokens to be <= @max_length.
        n_spec_tokens = min(speculative_lookahead, max_length - seqlen_og - 1)
        tokens_draft, scores_draft = sample_tokens_draft(input_ids, num_tokens=n_spec_tokens)
        num_draft_tokens += n_spec_tokens
        if debug:
            scores_draft_ref = model_draft(
                torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
            ).logits
            print((scores_draft - scores_draft_ref[:, :-1]).abs().max())

        # Evaluate the draft tokens with the model
        logits = get_logits_main(
            torch.cat([input_ids, tokens_draft], dim=1),
            inference_params,
            num_last_tokens=n_spec_tokens + 1,
        )
        num_main_model_calls += 1
        if debug:
            logits_ref = model(
                torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
            ).logits
            print((logits - logits_ref).abs().max())
            # breakpoint()
        tokens, num_generated_tokens = sample_speculative(
            logits, scores_draft, tokens_draft, **sampling_kwargs
        )
        num_accepted_tokens_history.append(num_generated_tokens - 1)
        if debug:
            print(tokens)
            print(num_generated_tokens)
            # breakpoint()
        # TODO: we're using the fact that batch_size == 1
        # TODO: check eos_token_id
        sequences.append(tokens[:1, : num_generated_tokens[0]])
        scores.append(logits[:1, : num_generated_tokens[0]])
        # Note that @model has not evaluated the last sampled token yet, so we'll need to pass
        # that in the next time we call @model.
        num_generated = num_generated_tokens[0].item()
        inference_params.seqlen_offset = seqlen_og + num_generated - 1
        inference_params_draft.seqlen_offset = (
            inference_params.seqlen_offset - 1
            if num_generated > 1
            else inference_params.seqlen_offset
        )
        if debug:
            cur_ids = torch.cat([input_ids, sequences[-1]], dim=1)
            scores_ref = model(cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1).logits
            print((scores[-1] - scores_ref[:, :-1]).abs().max())
            # breakpoint()

    while True:
        # seqlen_offset is total length generated - 1
        if inference_params.seqlen_offset >= max_length - 1:
            break
        if inference_params.seqlen_offset >= max_length - 2:
485
            # Don't do speculative sampling, just sample 1 token from the model
Tri Dao's avatar
Tri Dao committed
486
            tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1)
487
488
            sequences.append(tokens)
            scores.append(scores_new)
Tri Dao's avatar
Tri Dao committed
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
            break
        # Sample from draft model
        n_spec_tokens = min(
            speculative_lookahead, max_length - inference_params_draft.seqlen_offset - 2
        )
        # If the main model accepts all the draft tokens, plus it samples one new token,
        # then at the next iteration the draft model need to evaluate the logits of the last draft
        # token and the logits of the newly sampled token. So here we pass in the last 2 tokens
        # of sequences[-1].
        # This exception is when the main model rejects all the draft tokens, in which case we
        # will only have 1 token to pass in.
        tokens_draft, scores_draft = sample_tokens_draft(
            sequences[-1][:, -2:], num_tokens=n_spec_tokens
        )
        num_draft_tokens += n_spec_tokens
        if debug:
            scores_draft_ref = model_draft(
                torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
507
            ).logits
Tri Dao's avatar
Tri Dao committed
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
            print((scores_draft - scores_draft_ref[:, :-1]).abs().max())
            # breakpoint()
        # Evaluate the draft tokens with the model
        logits = get_logits_main(
            torch.cat([sequences[-1][:, -1:], tokens_draft], dim=1),
            inference_params,
            num_last_tokens=n_spec_tokens + 1,
        )  # (batch, n_spec_tokens + 1, vocab_size)
        num_main_model_calls += 1
        if debug:
            logits_ref = model(
                torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
            ).logits
            print((logits - logits_ref).abs().max())
            # breakpoint()
        tokens, num_generated_tokens = sample_speculative(
            logits, scores_draft, tokens_draft, **sampling_kwargs
        )
        num_accepted_tokens_history.append(num_generated_tokens - 1)
        if debug:
            print(tokens)
            print(num_generated_tokens)
530
            # breakpoint()
Tri Dao's avatar
Tri Dao committed
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
        sequences.append(tokens[:1, : num_generated_tokens[0]])
        scores.append(logits[:1, : num_generated_tokens[0]])
        # We've evaluated 1 token from sequences[-1][:, -1:] above, plus
        # num_generated_tokens[0].item() - 1 tokens from the draft model.
        num_generated = num_generated_tokens[0].item()
        inference_params.seqlen_offset += num_generated
        inference_params_draft.seqlen_offset = (
            inference_params.seqlen_offset - 1
            if num_generated > 1
            else inference_params.seqlen_offset
        )
        if debug:
            cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1)
            scores_ref = model(cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1).logits
            print((scores[-1] - scores_ref[:, :-1]).abs().max())
            # breakpoint()

    if enable_timing:
        if tensor_parallel > 1:
            torch.distributed.barrier()
        torch.cuda.synchronize()
        print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
        print(f"Number of calls to main model: {num_main_model_calls}")
        print(
            f"Acceptance rate: {torch.cat(num_accepted_tokens_history).sum().item() / num_draft_tokens * 100:.2f}%"
        )
557
558
559
560
561
562
563
564
565
    sequences = torch.cat(sequences, dim=1)
    scores = torch.cat(scores, dim=1)
    if debug:
        scores_ref = model(sequences).logits
        print((scores - scores_ref[:, seqlen_og - 1 : -1]).abs().max())
    output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
    return output_cls(sequences=sequences, scores=scores)


Tri Dao's avatar
Tri Dao committed
566
class GenerationMixin:
567
568
569
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        raise NotImplementedError

Tri Dao's avatar
Tri Dao committed
570
571
572
573
574
575
576
577
578
579
580
581
582
583
    def generate(
        self,
        input_ids,
        max_length,
        top_k=1,
        top_p=0.0,
        temperature=1.0,
        return_dict_in_generate=False,
        output_scores=False,
        **kwargs,
    ):
        output = decode(
            input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs
        )
Tri Dao's avatar
Tri Dao committed
584
585
586
        if not output_scores:
            output.scores = None
        return output if return_dict_in_generate else output.sequences
Tri Dao's avatar
Tri Dao committed
587
588


Tri Dao's avatar
Tri Dao committed
589
590
591
592
593
594
595
596
597
def allocate_inference_cache(
    max_batch_size,
    max_seqlen,
    nheads,
    headdim,
    layers: Union[int, Sequence],
    device,
    dtype=torch.float16,
):
598
    assert dtype in [torch.float16, torch.bfloat16, torch.float32]
599
    kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim)
600
601
    if isinstance(layers, int):
        layers = range(layers)
602
    return {i: torch.empty(kv_cache_shape, device=device, dtype=dtype) for i in layers}
Tri Dao's avatar
Tri Dao committed
603
604


605
606
607
608
609
610
611
612
613
614
615
616
617
@dataclass
class DecodingCGCache:
    max_batch_size: int = 0
    max_seqlen: int = 0
    device = None
    dtype = None
    callables: dict = field(default_factory=dict)
    mempool = None
    inference_params: Optional[InferenceParams] = None
    run: Optional[Callable] = None


@torch.inference_mode()
Tri Dao's avatar
Tri Dao committed
618
def update_graph_cache(
619
620
621
622
623
    model,
    cache,
    batch_size,
    seqlen_og,
    max_seqlen,
Tri Dao's avatar
Tri Dao committed
624
    decoding_seqlens=(1,),
625
626
627
    tensor_parallel=1,
    dtype=None,
    n_warmups=2,
Tri Dao's avatar
Tri Dao committed
628
):
629
630
631
632
633
634
    if cache is None:
        cache = DecodingCGCache()
    param_example = next(iter(model.parameters()))
    device = param_example.device
    if dtype is None:
        dtype = param_example.dtype
Tri Dao's avatar
Tri Dao committed
635
636
637
638
639
    if (
        (device, dtype) != (cache.device, cache.dtype)
        or batch_size > cache.max_batch_size
        or max_seqlen > cache.max_seqlen
    ):  # Invalidate the cache
640
641
642
643
644
645
        cache.callables = {}
        cache.mempool = None
        cache.inference_params = None
        gc.collect()
        cache.device, cache.dtype = device, dtype
        cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
Tri Dao's avatar
Tri Dao committed
646
        if hasattr(model, "allocate_inference_cache"):
647
            inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
648
        else:
Tri Dao's avatar
Tri Dao committed
649
650
651
652
653
            headdim = getattr(
                model.config,
                "head_dim",
                model.config.hidden_size // model.config.num_attention_heads,
            )
654
            inf_cache = allocate_inference_cache(
Tri Dao's avatar
Tri Dao committed
655
656
657
658
659
660
661
                batch_size,
                max_seqlen,
                model.config.num_attention_heads // tensor_parallel,
                headdim,
                model.config.num_hidden_layers,
                device,
                dtype,
662
            )
663
664
        lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
        cache.inference_params = InferenceParams(
665
            max_seqlen=max_seqlen,
Tri Dao's avatar
Tri Dao committed
666
            max_batch_size=batch_size,
667
            seqlen_offset=seqlen_og,
Tri Dao's avatar
Tri Dao committed
668
669
            key_value_memory_dict=inf_cache,
            lengths_per_sample=lengths_per_sample,
670
671
        )
        cache.mempool = torch.cuda.graphs.graph_pool_handle()
Tri Dao's avatar
Tri Dao committed
672
673
674
    for decoding_seqlen in decoding_seqlens:
        if (batch_size, decoding_seqlen) not in cache.callables:
            cache.callables[batch_size, decoding_seqlen] = capture_graph(
Tri Dao's avatar
Tri Dao committed
675
676
677
                model,
                cache.inference_params,
                batch_size,
Tri Dao's avatar
Tri Dao committed
678
679
                max_seqlen,
                decoding_seqlen=decoding_seqlen,
Tri Dao's avatar
Tri Dao committed
680
681
                mempool=cache.mempool,
                n_warmups=n_warmups,
682
683
684
            )

    def dispatch(input_ids, position_ids, seqlen):
Tri Dao's avatar
Tri Dao committed
685
686
        batch_size, decoding_seqlen = input_ids.shape[:2]
        return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
687
688

    cache.run = dispatch
689
    cache.inference_params.seqlen_offset = 0  # Reset so it's not confusing
690
691
692
    return cache


Tri Dao's avatar
Tri Dao committed
693
694
695
def capture_graph(
    model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
):
Tri Dao's avatar
Tri Dao committed
696
    device = next(iter(model.parameters())).device
Tri Dao's avatar
Tri Dao committed
697
698
    input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
    position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
699
    seqlen_offset_og = inference_params.seqlen_offset
Tri Dao's avatar
Tri Dao committed
700
701
    inference_params.seqlen_offset = max_seqlen - decoding_seqlen
    inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
702
703
704
705
706

    # Warmup before capture
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
707
        for _ in range(n_warmups):
Tri Dao's avatar
Tri Dao committed
708
709
710
711
            logits = model(
                input_ids,
                position_ids=position_ids,
                inference_params=inference_params,
Tri Dao's avatar
Tri Dao committed
712
                num_last_tokens=decoding_seqlen,
Tri Dao's avatar
Tri Dao committed
713
            ).logits
714
        s.synchronize()
715
716
717
718
719
        # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
        # which requires that graph launch and non-captured launch to not overlap (I think,
        # that's how I interpret the documentation). I'm not sure if this is required.
        if torch.distributed.is_initialized():
            torch.distributed.barrier()
720
721
722
723
724
    torch.cuda.current_stream().wait_stream(s)
    # Captures the graph
    # To allow capture, automatically sets a side stream as the current stream in the context
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph, pool=mempool):
Tri Dao's avatar
Tri Dao committed
725
726
727
728
        logits = model(
            input_ids,
            position_ids=position_ids,
            inference_params=inference_params,
Tri Dao's avatar
Tri Dao committed
729
730
            num_last_tokens=decoding_seqlen,
        ).logits
Tri Dao's avatar
Tri Dao committed
731
732
733

    def run(new_input_ids, new_position_ids, seqlen):
        inference_params.lengths_per_sample[:] = seqlen
734
735
736
        input_ids.copy_(new_input_ids)
        position_ids.copy_(new_position_ids)
        graph.replay()
737
        return logits.clone()
Tri Dao's avatar
Tri Dao committed
738

739
    inference_params.seqlen_offset = seqlen_offset_og
740
    return run