generation.py 14.5 KB
Newer Older
1
# Copyright (c) 2023, Tri Dao.
Tri Dao's avatar
Tri Dao committed
2
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
3
import gc
Tri Dao's avatar
Tri Dao committed
4
5
import time
from collections import namedtuple
Tri Dao's avatar
Tri Dao committed
6
7
from dataclasses import dataclass, field
from typing import Callable, Optional, Sequence, Union
Tri Dao's avatar
Tri Dao committed
8

Tri Dao's avatar
Tri Dao committed
9
10
import torch
from einops import rearrange
Tri Dao's avatar
Tri Dao committed
11
12
from torch import Tensor
from torch.profiler import ProfilerActivity, profile, record_function
13
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
Tri Dao's avatar
Tri Dao committed
14
15
16
17
18
19


@dataclass
class InferenceParams:
    """Inference parameters that are passed to the main model in order
    to efficienly calculate and store the context during inference."""
Tri Dao's avatar
Tri Dao committed
20

Tri Dao's avatar
Tri Dao committed
21
22
23
24
25
    max_sequence_len: int
    max_batch_size: int
    sequence_len_offset: int = 0
    batch_size_offset: int = 0
    key_value_memory_dict: dict = field(default_factory=dict)
26
27
    fused_ft_kernel: bool = False
    lengths_per_sample: Optional[Tensor] = None
Tri Dao's avatar
Tri Dao committed
28
29


30
31
32
33
34
35
36
37
38
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
def modify_logits_for_top_p_filtering(logits, top_p):
    """Set the logits for none top-p values to -inf."""
    if top_p <= 0.0:
        return
    # First sort and calculate cumulative sum of probabilities.
    sorted_logits, sorted_indices = torch.sort(logits, descending=False)
    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
Tri Dao's avatar
Tri Dao committed
39
    # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
40
41
    sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
    # scatter sorted tensors to original indexing
Tri Dao's avatar
Tri Dao committed
42
43
44
45
    indices_to_remove = sorted_indices_to_remove.scatter(
        1, sorted_indices, sorted_indices_to_remove
    )
    logits = logits.masked_fill(indices_to_remove, float("-inf"))
46
47
48
49
50
51
52
53
54
55
56


def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
    """Sample from top-k logits.
    Arguments:
        logits: Tensor of shape (batch_size, vocab_size)
    """
    if top_k == 1:  # Short-circuit for greedy decoding
        return logits.argmax(dim=-1)
    else:
        if top_p > 0.0:
Tri Dao's avatar
Tri Dao committed
57
            assert top_p <= 1.0, "top-p should be in (0, 1]."
58
59
60
61
62
63
64
        if top_k > 0:
            top_k = min(top_k, logits.size(-1))  # Safety check
            logits_top, indices = torch.topk(logits, top_k, dim=-1)
            logits_top /= temperature
            modify_logits_for_top_p_filtering(logits_top, top_p)
            return indices[
                torch.arange(indices.shape[0], device=indices.device),
Tri Dao's avatar
Tri Dao committed
65
                torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
66
67
68
69
            ]
        else:
            logits_top = logits / temperature
            modify_logits_for_top_p_filtering(logits_top, top_p)
Tri Dao's avatar
Tri Dao committed
70
71
72
            return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
                dim=-1
            )
73
74


Tri Dao's avatar
Tri Dao committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def decode(
    input_ids,
    model,
    max_length,
    top_k=1,
    top_p=0.0,
    temperature=1.0,
    eos_token_id=None,
    teacher_outputs=None,
    vocab_size=None,
    tensor_parallel=1,
    fused_ft_kernel=False,
    cg=False,
    timing=False,
):
90
91
92
93
    """Decoding, either greedy or with top-k or top-p sampling.
    If top-k = 0, don't limit the number of candidates (pure sampling).
    Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
    then top-p.
Tri Dao's avatar
Tri Dao committed
94
    We assume that all sequences in the same batch have the same length.
95

Tri Dao's avatar
Tri Dao committed
96
97
98
    Arguments:
        input_ids: (batch, seq_len)
        max_length: int
99
100
        teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
            logits, the next token is taken from the teacher_outputs. Useful for testing.
101
    Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
Tri Dao's avatar
Tri Dao committed
102
103
104
105
        sequences: (batch, max_length)
        scores: tuples of (batch, vocab_size)
    """
    batch_size, seqlen_og = input_ids.shape
Tri Dao's avatar
Tri Dao committed
106
    teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
107
108
    if cg:
        assert fused_ft_kernel
Tri Dao's avatar
Tri Dao committed
109
        if not hasattr(model, "_decoding_cache"):
110
111
            model._decoding_cache = None
        model._decoding_cache = update_graph_cache(
Tri Dao's avatar
Tri Dao committed
112
113
114
115
116
117
            model,
            model._decoding_cache,
            batch_size,
            seqlen_og,
            max_length,
            tensor_parallel=tensor_parallel,
118
119
120
121
122
123
        )
        inference_params = model._decoding_cache.inference_params
        inference_params.max_sequence_len = max_length
        inference_params.max_batch_size = batch_size
        inference_params.sequence_len_offset = 0
    else:
Tri Dao's avatar
Tri Dao committed
124
125
126
        inference_params = InferenceParams(
            max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel
        )
Tri Dao's avatar
Tri Dao committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

    def logits_forward_fn(input_ids, position_ids, inference_params):
        if not cg:
            return model(
                input_ids,
                position_ids=position_ids,
                inference_params=inference_params,
                last_token_only=True,
            ).logits
        else:
            return model._decoding_cache.run(
                input_ids, position_ids, inference_params.sequence_len_offset
            ).clone()

    logits_postprocess_fn = (
        lambda logits: logits[..., :vocab_size] if vocab_size is not None else logits
    )

Tri Dao's avatar
Tri Dao committed
145
146
    scores = []
    with torch.inference_mode():
Tri Dao's avatar
Tri Dao committed
147
        if timing:
148
149
            if tensor_parallel > 1:
                torch.distributed.barrier()
Tri Dao's avatar
Tri Dao committed
150
151
            torch.cuda.synchronize()
            start = time.time()
152
        logits = model(input_ids, inference_params=inference_params, last_token_only=True).logits
Tri Dao's avatar
Tri Dao committed
153
        logits = logits_postprocess_fn(logits)
154
        scores.append(logits if not cg else logits.clone())
Tri Dao's avatar
Tri Dao committed
155
156
157
158
        if teacher_outputs is None or teacher_output_len <= seqlen_og:
            next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
        else:
            next_token = teacher_outputs[:, seqlen_og]
Tri Dao's avatar
Tri Dao committed
159
160
161
        sequences = [next_token]
        inference_params.sequence_len_offset = seqlen_og
        while True:
Tri Dao's avatar
Tri Dao committed
162
163
164
165
166
167
            position_ids = torch.full(
                (batch_size, 1),
                inference_params.sequence_len_offset,
                dtype=torch.long,
                device=input_ids.device,
            )
Tri Dao's avatar
Tri Dao committed
168
169
170
171
            logits = logits_postprocess_fn(logits_forward_fn(
                rearrange(next_token, "b -> b 1"), position_ids, inference_params
            ))
            scores.append(logits)
Tri Dao's avatar
Tri Dao committed
172
173
174
175
            if (
                teacher_outputs is None
                or teacher_output_len <= inference_params.sequence_len_offset + 1
            ):
176
                next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
Tri Dao's avatar
Tri Dao committed
177
178
            else:
                next_token = teacher_outputs[:, inference_params.sequence_len_offset + 1]
Tri Dao's avatar
Tri Dao committed
179
180
            sequences.append(next_token)
            inference_params.sequence_len_offset += 1
Tri Dao's avatar
Tri Dao committed
181
182
            if eos_token_id is not None and (next_token == eos_token_id).all():
                break
Tri Dao's avatar
Tri Dao committed
183
184
            if inference_params.sequence_len_offset >= max_length - 1:
                break
Tri Dao's avatar
Tri Dao committed
185
        if timing:
186
187
            if tensor_parallel > 1:
                torch.distributed.barrier()
188
            torch.cuda.synchronize()
Tri Dao's avatar
Tri Dao committed
189
            print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
190
191
    output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
    return output_cls(
Tri Dao's avatar
Tri Dao committed
192
        sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1), scores=tuple(scores)
Tri Dao's avatar
Tri Dao committed
193
194
195
196
    )


class GenerationMixin:
197
198
199
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        raise NotImplementedError

Tri Dao's avatar
Tri Dao committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
    def generate(
        self,
        input_ids,
        max_length,
        top_k=1,
        top_p=0.0,
        temperature=1.0,
        return_dict_in_generate=False,
        output_scores=False,
        **kwargs,
    ):
        output = decode(
            input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs
        )
Tri Dao's avatar
Tri Dao committed
214
215
216
        if not output_scores:
            output.scores = None
        return output if return_dict_in_generate else output.sequences
Tri Dao's avatar
Tri Dao committed
217
218


Tri Dao's avatar
Tri Dao committed
219
220
221
222
223
224
225
226
227
def allocate_inference_cache(
    max_batch_size,
    max_seqlen,
    nheads,
    headdim,
    layers: Union[int, Sequence],
    device,
    dtype=torch.float16,
):
228
229
230
231
232
233
234
    assert dtype in [torch.float16, torch.bfloat16, torch.float32]
    packsize = 4 if dtype == torch.float32 else 8
    assert headdim % packsize == 0
    k_cache_shape = (max_batch_size, nheads, headdim // packsize, max_seqlen, packsize)
    v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim)
    if isinstance(layers, int):
        layers = range(layers)
Tri Dao's avatar
Tri Dao committed
235
236
237
238
239
240
241
    return {
        i: (
            torch.empty(k_cache_shape, device=device, dtype=dtype),
            torch.empty(v_cache_shape, device=device, dtype=dtype),
        )
        for i in layers
    }
Tri Dao's avatar
Tri Dao committed
242
243
244
245
246
247
248
249
250
251
252


def seqlen_to_seqlen_type(seqlen: int) -> int:
    """Convert sequence length to a seqlen_type.
    This is used to determine which cuda graph to use.
    Arguments:
        seqlen: int
    """
    return 0 if seqlen < 32 else (1 if seqlen < 2048 else 2)


Tri Dao's avatar
Tri Dao committed
253
def seqlen_type_to_max_seqlen(seqlen_type: int) -> int:
Tri Dao's avatar
Tri Dao committed
254
    assert seqlen_type in [0, 1, 2]
Tri Dao's avatar
Tri Dao committed
255
    return 32 if seqlen_type == 0 else (2048 if seqlen_type == 1 else 2**32)
Tri Dao's avatar
Tri Dao committed
256
257


258
259
260
261
262
263
264
265
266
267
268
269
270
@dataclass
class DecodingCGCache:
    max_batch_size: int = 0
    max_seqlen: int = 0
    device = None
    dtype = None
    callables: dict = field(default_factory=dict)
    mempool = None
    inference_params: Optional[InferenceParams] = None
    run: Optional[Callable] = None


@torch.inference_mode()
Tri Dao's avatar
Tri Dao committed
271
272
273
def update_graph_cache(
    model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1, dtype=None, n_warmups=2
):
274
275
276
277
278
279
    if cache is None:
        cache = DecodingCGCache()
    param_example = next(iter(model.parameters()))
    device = param_example.device
    if dtype is None:
        dtype = param_example.dtype
Tri Dao's avatar
Tri Dao committed
280
281
282
283
284
    if (
        (device, dtype) != (cache.device, cache.dtype)
        or batch_size > cache.max_batch_size
        or max_seqlen > cache.max_seqlen
    ):  # Invalidate the cache
285
286
287
288
289
290
        cache.callables = {}
        cache.mempool = None
        cache.inference_params = None
        gc.collect()
        cache.device, cache.dtype = device, dtype
        cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
Tri Dao's avatar
Tri Dao committed
291
        if hasattr(model, "allocate_inference_cache"):
292
293
            inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
        else:
Tri Dao's avatar
Tri Dao committed
294
295
296
297
298
            headdim = getattr(
                model.config,
                "head_dim",
                model.config.hidden_size // model.config.num_attention_heads,
            )
299
            inf_cache = allocate_inference_cache(
Tri Dao's avatar
Tri Dao committed
300
301
302
303
304
305
306
                batch_size,
                max_seqlen,
                model.config.num_attention_heads // tensor_parallel,
                headdim,
                model.config.num_hidden_layers,
                device,
                dtype,
307
            )
308
309
        lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
        cache.inference_params = InferenceParams(
Tri Dao's avatar
Tri Dao committed
310
311
312
313
314
315
            max_sequence_len=max_seqlen,
            max_batch_size=batch_size,
            sequence_len_offset=seqlen_og,
            key_value_memory_dict=inf_cache,
            fused_ft_kernel=True,
            lengths_per_sample=lengths_per_sample,
316
317
318
        )
        cache.mempool = torch.cuda.graphs.graph_pool_handle()
    for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1):
319
        if (batch_size, s_type) not in cache.callables:
Tri Dao's avatar
Tri Dao committed
320
            max_seqlen_ = min(max(seqlen_og, seqlen_type_to_max_seqlen(s_type)), max_seqlen)
321
            cache.callables[batch_size, s_type] = capture_graph(
Tri Dao's avatar
Tri Dao committed
322
323
324
325
326
327
                model,
                cache.inference_params,
                batch_size,
                max_seqlen_,
                mempool=cache.mempool,
                n_warmups=n_warmups,
328
329
330
            )

    def dispatch(input_ids, position_ids, seqlen):
331
        batch_size = input_ids.shape[0]
Tri Dao's avatar
Tri Dao committed
332
333
334
        return cache.callables[batch_size, seqlen_to_seqlen_type(seqlen)](
            input_ids, position_ids, seqlen
        )
335
336

    cache.run = dispatch
Tri Dao's avatar
Tri Dao committed
337
    cache.inference_params.sequence_len_offset = 0  # Reset so it's not confusing
338
339
340
    return cache


Tri Dao's avatar
Tri Dao committed
341
def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, n_warmups=2):
Tri Dao's avatar
Tri Dao committed
342
343
344
    device = next(iter(model.parameters())).device
    input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
    position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
Tri Dao's avatar
Tri Dao committed
345
346
347
348
349
    sequence_len_offset_og = inference_params.sequence_len_offset
    # TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is
    # used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample.
    inference_params.sequence_len_offset = max_seqlen - 1
    inference_params.lengths_per_sample[:] = max_seqlen - 1
350
351
352
353
354

    # Warmup before capture
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
355
        for _ in range(n_warmups):
Tri Dao's avatar
Tri Dao committed
356
357
358
359
360
361
            logits = model(
                input_ids,
                position_ids=position_ids,
                inference_params=inference_params,
                last_token_only=True,
            ).logits
362
        s.synchronize()
363
364
365
366
367
        # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
        # which requires that graph launch and non-captured launch to not overlap (I think,
        # that's how I interpret the documentation). I'm not sure if this is required.
        if torch.distributed.is_initialized():
            torch.distributed.barrier()
368
369
370
371
372
    torch.cuda.current_stream().wait_stream(s)
    # Captures the graph
    # To allow capture, automatically sets a side stream as the current stream in the context
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph, pool=mempool):
Tri Dao's avatar
Tri Dao committed
373
374
375
376
377
378
        logits = model(
            input_ids,
            position_ids=position_ids,
            inference_params=inference_params,
            last_token_only=True,
        ).logits
Tri Dao's avatar
Tri Dao committed
379
380
381

    def run(new_input_ids, new_position_ids, seqlen):
        inference_params.lengths_per_sample[:] = seqlen
382
383
384
385
        input_ids.copy_(new_input_ids)
        position_ids.copy_(new_position_ids)
        graph.replay()
        return logits
Tri Dao's avatar
Tri Dao committed
386

Tri Dao's avatar
Tri Dao committed
387
    inference_params.sequence_len_offset = sequence_len_offset_og
388
    return run