generation.py 14.5 KB
Newer Older
1
# Copyright (c) 2023, Tri Dao.
Tri Dao's avatar
Tri Dao committed
2
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
3
import gc
Tri Dao's avatar
Tri Dao committed
4
5
import time
from collections import namedtuple
Tri Dao's avatar
Tri Dao committed
6
7
from dataclasses import dataclass, field
from typing import Callable, Optional, Sequence, Union
Tri Dao's avatar
Tri Dao committed
8

Tri Dao's avatar
Tri Dao committed
9
10
import torch
from einops import rearrange
Tri Dao's avatar
Tri Dao committed
11
12
from torch import Tensor
from torch.profiler import ProfilerActivity, profile, record_function
13
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
Tri Dao's avatar
Tri Dao committed
14
15
16
17
18
19


@dataclass
class InferenceParams:
    """Inference parameters that are passed to the main model in order
    to efficienly calculate and store the context during inference."""
Tri Dao's avatar
Tri Dao committed
20

Tri Dao's avatar
Tri Dao committed
21
22
23
24
25
    max_sequence_len: int
    max_batch_size: int
    sequence_len_offset: int = 0
    batch_size_offset: int = 0
    key_value_memory_dict: dict = field(default_factory=dict)
26
27
    fused_ft_kernel: bool = False
    lengths_per_sample: Optional[Tensor] = None
Tri Dao's avatar
Tri Dao committed
28
29


30
31
32
33
34
35
36
37
38
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
def modify_logits_for_top_p_filtering(logits, top_p):
    """Set the logits for none top-p values to -inf."""
    if top_p <= 0.0:
        return
    # First sort and calculate cumulative sum of probabilities.
    sorted_logits, sorted_indices = torch.sort(logits, descending=False)
    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
Tri Dao's avatar
Tri Dao committed
39
    # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
40
41
    sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
    # scatter sorted tensors to original indexing
Tri Dao's avatar
Tri Dao committed
42
43
44
45
    indices_to_remove = sorted_indices_to_remove.scatter(
        1, sorted_indices, sorted_indices_to_remove
    )
    logits = logits.masked_fill(indices_to_remove, float("-inf"))
46
47
48
49
50
51
52
53
54
55
56


def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
    """Sample from top-k logits.
    Arguments:
        logits: Tensor of shape (batch_size, vocab_size)
    """
    if top_k == 1:  # Short-circuit for greedy decoding
        return logits.argmax(dim=-1)
    else:
        if top_p > 0.0:
Tri Dao's avatar
Tri Dao committed
57
            assert top_p <= 1.0, "top-p should be in (0, 1]."
58
59
60
61
62
63
64
        if top_k > 0:
            top_k = min(top_k, logits.size(-1))  # Safety check
            logits_top, indices = torch.topk(logits, top_k, dim=-1)
            logits_top /= temperature
            modify_logits_for_top_p_filtering(logits_top, top_p)
            return indices[
                torch.arange(indices.shape[0], device=indices.device),
Tri Dao's avatar
Tri Dao committed
65
                torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
66
67
68
69
            ]
        else:
            logits_top = logits / temperature
            modify_logits_for_top_p_filtering(logits_top, top_p)
Tri Dao's avatar
Tri Dao committed
70
71
72
            return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
                dim=-1
            )
73
74


Tri Dao's avatar
Tri Dao committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def decode(
    input_ids,
    model,
    max_length,
    top_k=1,
    top_p=0.0,
    temperature=1.0,
    eos_token_id=None,
    teacher_outputs=None,
    vocab_size=None,
    tensor_parallel=1,
    fused_ft_kernel=False,
    cg=False,
    timing=False,
):
90
91
92
93
    """Decoding, either greedy or with top-k or top-p sampling.
    If top-k = 0, don't limit the number of candidates (pure sampling).
    Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
    then top-p.
Tri Dao's avatar
Tri Dao committed
94
    We assume that all sequences in the same batch have the same length.
95

Tri Dao's avatar
Tri Dao committed
96
97
98
    Arguments:
        input_ids: (batch, seq_len)
        max_length: int
99
100
        teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
            logits, the next token is taken from the teacher_outputs. Useful for testing.
101
    Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
Tri Dao's avatar
Tri Dao committed
102
103
104
105
        sequences: (batch, max_length)
        scores: tuples of (batch, vocab_size)
    """
    batch_size, seqlen_og = input_ids.shape
Tri Dao's avatar
Tri Dao committed
106
    teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
107
108
    if cg:
        assert fused_ft_kernel
Tri Dao's avatar
Tri Dao committed
109
        if not hasattr(model, "_decoding_cache"):
110
111
            model._decoding_cache = None
        model._decoding_cache = update_graph_cache(
Tri Dao's avatar
Tri Dao committed
112
113
114
115
116
117
            model,
            model._decoding_cache,
            batch_size,
            seqlen_og,
            max_length,
            tensor_parallel=tensor_parallel,
118
119
120
121
122
123
        )
        inference_params = model._decoding_cache.inference_params
        inference_params.max_sequence_len = max_length
        inference_params.max_batch_size = batch_size
        inference_params.sequence_len_offset = 0
    else:
Tri Dao's avatar
Tri Dao committed
124
125
126
        inference_params = InferenceParams(
            max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel
        )
Tri Dao's avatar
Tri Dao committed
127
128
    scores = []
    with torch.inference_mode():
Tri Dao's avatar
Tri Dao committed
129
        if timing:
130
131
            if tensor_parallel > 1:
                torch.distributed.barrier()
Tri Dao's avatar
Tri Dao committed
132
133
            torch.cuda.synchronize()
            start = time.time()
134
        logits = model(input_ids, inference_params=inference_params, last_token_only=True).logits
135
136
        if vocab_size is not None:
            logits = logits[..., :vocab_size]
137
        scores.append(logits if not cg else logits.clone())
Tri Dao's avatar
Tri Dao committed
138
139
140
141
        if teacher_outputs is None or teacher_output_len <= seqlen_og:
            next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
        else:
            next_token = teacher_outputs[:, seqlen_og]
Tri Dao's avatar
Tri Dao committed
142
143
144
        sequences = [next_token]
        inference_params.sequence_len_offset = seqlen_og
        while True:
Tri Dao's avatar
Tri Dao committed
145
146
147
148
149
150
            position_ids = torch.full(
                (batch_size, 1),
                inference_params.sequence_len_offset,
                dtype=torch.long,
                device=input_ids.device,
            )
Tri Dao's avatar
Tri Dao committed
151
            if not cg:
Tri Dao's avatar
Tri Dao committed
152
153
154
155
156
157
                logits = model(
                    rearrange(next_token, "b -> b 1"),
                    position_ids=position_ids,
                    inference_params=inference_params,
                    last_token_only=True,
                ).logits
Tri Dao's avatar
Tri Dao committed
158
            else:
Tri Dao's avatar
Tri Dao committed
159
160
161
162
163
                logits = model._decoding_cache.run(
                    rearrange(next_token, "b -> b 1"),
                    position_ids,
                    inference_params.sequence_len_offset,
                )
164
165
            if vocab_size is not None:
                logits = logits[..., :vocab_size]
166
            scores.append(logits if not cg else logits.clone())
Tri Dao's avatar
Tri Dao committed
167
168
169
170
            if (
                teacher_outputs is None
                or teacher_output_len <= inference_params.sequence_len_offset + 1
            ):
Tri Dao's avatar
Tri Dao committed
171
172
173
                next_token = sample(logits, top_k=top_k, temperature=temperature)
            else:
                next_token = teacher_outputs[:, inference_params.sequence_len_offset + 1]
Tri Dao's avatar
Tri Dao committed
174
175
            sequences.append(next_token)
            inference_params.sequence_len_offset += 1
Tri Dao's avatar
Tri Dao committed
176
177
            if eos_token_id is not None and (next_token == eos_token_id).all():
                break
Tri Dao's avatar
Tri Dao committed
178
179
            if inference_params.sequence_len_offset >= max_length - 1:
                break
Tri Dao's avatar
Tri Dao committed
180
        if timing:
181
182
            if tensor_parallel > 1:
                torch.distributed.barrier()
183
            torch.cuda.synchronize()
Tri Dao's avatar
Tri Dao committed
184
            print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
185
186
    output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
    return output_cls(
Tri Dao's avatar
Tri Dao committed
187
        sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1), scores=tuple(scores)
Tri Dao's avatar
Tri Dao committed
188
189
190
191
    )


class GenerationMixin:
192
193
194
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        raise NotImplementedError

Tri Dao's avatar
Tri Dao committed
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    def generate(
        self,
        input_ids,
        max_length,
        top_k=1,
        top_p=0.0,
        temperature=1.0,
        return_dict_in_generate=False,
        output_scores=False,
        **kwargs,
    ):
        output = decode(
            input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs
        )
Tri Dao's avatar
Tri Dao committed
209
210
211
        if not output_scores:
            output.scores = None
        return output if return_dict_in_generate else output.sequences
Tri Dao's avatar
Tri Dao committed
212
213


Tri Dao's avatar
Tri Dao committed
214
215
216
217
218
219
220
221
222
def allocate_inference_cache(
    max_batch_size,
    max_seqlen,
    nheads,
    headdim,
    layers: Union[int, Sequence],
    device,
    dtype=torch.float16,
):
223
224
225
226
227
228
229
    assert dtype in [torch.float16, torch.bfloat16, torch.float32]
    packsize = 4 if dtype == torch.float32 else 8
    assert headdim % packsize == 0
    k_cache_shape = (max_batch_size, nheads, headdim // packsize, max_seqlen, packsize)
    v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim)
    if isinstance(layers, int):
        layers = range(layers)
Tri Dao's avatar
Tri Dao committed
230
231
232
233
234
235
236
    return {
        i: (
            torch.empty(k_cache_shape, device=device, dtype=dtype),
            torch.empty(v_cache_shape, device=device, dtype=dtype),
        )
        for i in layers
    }
Tri Dao's avatar
Tri Dao committed
237
238
239
240
241
242
243
244
245
246
247


def seqlen_to_seqlen_type(seqlen: int) -> int:
    """Convert sequence length to a seqlen_type.
    This is used to determine which cuda graph to use.
    Arguments:
        seqlen: int
    """
    return 0 if seqlen < 32 else (1 if seqlen < 2048 else 2)


Tri Dao's avatar
Tri Dao committed
248
def seqlen_type_to_max_seqlen(seqlen_type: int) -> int:
Tri Dao's avatar
Tri Dao committed
249
    assert seqlen_type in [0, 1, 2]
Tri Dao's avatar
Tri Dao committed
250
    return 32 if seqlen_type == 0 else (2048 if seqlen_type == 1 else 2**32)
Tri Dao's avatar
Tri Dao committed
251
252


253
254
255
256
257
258
259
260
261
262
263
264
265
@dataclass
class DecodingCGCache:
    max_batch_size: int = 0
    max_seqlen: int = 0
    device = None
    dtype = None
    callables: dict = field(default_factory=dict)
    mempool = None
    inference_params: Optional[InferenceParams] = None
    run: Optional[Callable] = None


@torch.inference_mode()
Tri Dao's avatar
Tri Dao committed
266
267
268
def update_graph_cache(
    model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1, dtype=None, n_warmups=2
):
269
270
271
272
273
274
    if cache is None:
        cache = DecodingCGCache()
    param_example = next(iter(model.parameters()))
    device = param_example.device
    if dtype is None:
        dtype = param_example.dtype
Tri Dao's avatar
Tri Dao committed
275
276
277
278
279
    if (
        (device, dtype) != (cache.device, cache.dtype)
        or batch_size > cache.max_batch_size
        or max_seqlen > cache.max_seqlen
    ):  # Invalidate the cache
280
281
282
283
284
285
        cache.callables = {}
        cache.mempool = None
        cache.inference_params = None
        gc.collect()
        cache.device, cache.dtype = device, dtype
        cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
Tri Dao's avatar
Tri Dao committed
286
        if hasattr(model, "allocate_inference_cache"):
287
288
            inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
        else:
Tri Dao's avatar
Tri Dao committed
289
290
291
292
293
            headdim = getattr(
                model.config,
                "head_dim",
                model.config.hidden_size // model.config.num_attention_heads,
            )
294
            inf_cache = allocate_inference_cache(
Tri Dao's avatar
Tri Dao committed
295
296
297
298
299
300
301
                batch_size,
                max_seqlen,
                model.config.num_attention_heads // tensor_parallel,
                headdim,
                model.config.num_hidden_layers,
                device,
                dtype,
302
            )
303
304
        lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
        cache.inference_params = InferenceParams(
Tri Dao's avatar
Tri Dao committed
305
306
307
308
309
310
            max_sequence_len=max_seqlen,
            max_batch_size=batch_size,
            sequence_len_offset=seqlen_og,
            key_value_memory_dict=inf_cache,
            fused_ft_kernel=True,
            lengths_per_sample=lengths_per_sample,
311
312
313
        )
        cache.mempool = torch.cuda.graphs.graph_pool_handle()
    for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1):
314
        if (batch_size, s_type) not in cache.callables:
Tri Dao's avatar
Tri Dao committed
315
            max_seqlen_ = min(max(seqlen_og, seqlen_type_to_max_seqlen(s_type)), max_seqlen)
316
            cache.callables[batch_size, s_type] = capture_graph(
Tri Dao's avatar
Tri Dao committed
317
318
319
320
321
322
                model,
                cache.inference_params,
                batch_size,
                max_seqlen_,
                mempool=cache.mempool,
                n_warmups=n_warmups,
323
324
325
            )

    def dispatch(input_ids, position_ids, seqlen):
326
        batch_size = input_ids.shape[0]
Tri Dao's avatar
Tri Dao committed
327
328
329
        return cache.callables[batch_size, seqlen_to_seqlen_type(seqlen)](
            input_ids, position_ids, seqlen
        )
330
331

    cache.run = dispatch
Tri Dao's avatar
Tri Dao committed
332
    cache.inference_params.sequence_len_offset = 0  # Reset so it's not confusing
333
334
335
    return cache


Tri Dao's avatar
Tri Dao committed
336
def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, n_warmups=2):
Tri Dao's avatar
Tri Dao committed
337
338
339
    device = next(iter(model.parameters())).device
    input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
    position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
Tri Dao's avatar
Tri Dao committed
340
341
342
343
344
    sequence_len_offset_og = inference_params.sequence_len_offset
    # TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is
    # used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample.
    inference_params.sequence_len_offset = max_seqlen - 1
    inference_params.lengths_per_sample[:] = max_seqlen - 1
345
346
347
348
349

    # Warmup before capture
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
350
        for _ in range(n_warmups):
Tri Dao's avatar
Tri Dao committed
351
352
353
354
355
356
            logits = model(
                input_ids,
                position_ids=position_ids,
                inference_params=inference_params,
                last_token_only=True,
            ).logits
357
        s.synchronize()
358
359
360
361
362
        # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
        # which requires that graph launch and non-captured launch to not overlap (I think,
        # that's how I interpret the documentation). I'm not sure if this is required.
        if torch.distributed.is_initialized():
            torch.distributed.barrier()
363
364
365
366
367
    torch.cuda.current_stream().wait_stream(s)
    # Captures the graph
    # To allow capture, automatically sets a side stream as the current stream in the context
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph, pool=mempool):
Tri Dao's avatar
Tri Dao committed
368
369
370
371
372
373
        logits = model(
            input_ids,
            position_ids=position_ids,
            inference_params=inference_params,
            last_token_only=True,
        ).logits
Tri Dao's avatar
Tri Dao committed
374
375
376

    def run(new_input_ids, new_position_ids, seqlen):
        inference_params.lengths_per_sample[:] = seqlen
377
378
379
380
        input_ids.copy_(new_input_ids)
        position_ids.copy_(new_position_ids)
        graph.replay()
        return logits
Tri Dao's avatar
Tri Dao committed
381

Tri Dao's avatar
Tri Dao committed
382
    inference_params.sequence_len_offset = sequence_len_offset_og
383
    return run