generation.py 15.3 KB
Newer Older
1
# Copyright (c) 2023, Tri Dao.
Tri Dao's avatar
Tri Dao committed
2
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
3
import gc
Tri Dao's avatar
Tri Dao committed
4
5
import time
from collections import namedtuple
Tri Dao's avatar
Tri Dao committed
6
7
from dataclasses import dataclass, field
from typing import Callable, Optional, Sequence, Union
Tri Dao's avatar
Tri Dao committed
8

Tri Dao's avatar
Tri Dao committed
9
10
import torch
from einops import rearrange
Tri Dao's avatar
Tri Dao committed
11
12
from torch import Tensor
from torch.profiler import ProfilerActivity, profile, record_function
13
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
Tri Dao's avatar
Tri Dao committed
14
15
16
17
18
19


@dataclass
class InferenceParams:
    """Inference parameters that are passed to the main model in order
    to efficienly calculate and store the context during inference."""
Tri Dao's avatar
Tri Dao committed
20

Tri Dao's avatar
Tri Dao committed
21
22
23
24
25
    max_sequence_len: int
    max_batch_size: int
    sequence_len_offset: int = 0
    batch_size_offset: int = 0
    key_value_memory_dict: dict = field(default_factory=dict)
26
27
    fused_ft_kernel: bool = False
    lengths_per_sample: Optional[Tensor] = None
Tri Dao's avatar
Tri Dao committed
28
29


30
31
32
33
34
35
36
37
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
def modify_logits_for_top_k_filtering(logits, top_k):
    """Set the logits for none top-k values to -inf."""
    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
    logits.masked_fill_(indices_to_remove, float("-Inf"))


38
39
40
41
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
def modify_logits_for_top_p_filtering(logits, top_p):
    """Set the logits for none top-p values to -inf."""
42
    if top_p <= 0.0 or top_p >= 1.0:
43
44
45
46
        return
    # First sort and calculate cumulative sum of probabilities.
    sorted_logits, sorted_indices = torch.sort(logits, descending=False)
    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
Tri Dao's avatar
Tri Dao committed
47
    # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
48
49
    sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
    # scatter sorted tensors to original indexing
Tri Dao's avatar
Tri Dao committed
50
51
52
53
    indices_to_remove = sorted_indices_to_remove.scatter(
        1, sorted_indices, sorted_indices_to_remove
    )
    logits = logits.masked_fill(indices_to_remove, float("-inf"))
54
55
56
57
58
59
60
61
62
63
64


def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
    """Sample from top-k logits.
    Arguments:
        logits: Tensor of shape (batch_size, vocab_size)
    """
    if top_k == 1:  # Short-circuit for greedy decoding
        return logits.argmax(dim=-1)
    else:
        if top_p > 0.0:
Tri Dao's avatar
Tri Dao committed
65
            assert top_p <= 1.0, "top-p should be in (0, 1]."
66
67
68
        if top_k > 0:
            top_k = min(top_k, logits.size(-1))  # Safety check
            logits_top, indices = torch.topk(logits, top_k, dim=-1)
69
70
            if temperature != 1.0:
                logits_top /= temperature
71
72
73
            modify_logits_for_top_p_filtering(logits_top, top_p)
            return indices[
                torch.arange(indices.shape[0], device=indices.device),
Tri Dao's avatar
Tri Dao committed
74
                torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
75
76
            ]
        else:
77
78
            # Clone so that when we modify for top_p we don't change the original logits
            logits_top = logits / temperature if temperature != 1.0 else logits.clone()
79
            modify_logits_for_top_p_filtering(logits_top, top_p)
Tri Dao's avatar
Tri Dao committed
80
81
82
            return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
                dim=-1
            )
83
84


Tri Dao's avatar
Tri Dao committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def decode(
    input_ids,
    model,
    max_length,
    top_k=1,
    top_p=0.0,
    temperature=1.0,
    eos_token_id=None,
    teacher_outputs=None,
    vocab_size=None,
    tensor_parallel=1,
    fused_ft_kernel=False,
    cg=False,
    timing=False,
):
100
101
102
103
    """Decoding, either greedy or with top-k or top-p sampling.
    If top-k = 0, don't limit the number of candidates (pure sampling).
    Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
    then top-p.
Tri Dao's avatar
Tri Dao committed
104
    We assume that all sequences in the same batch have the same length.
105

Tri Dao's avatar
Tri Dao committed
106
107
108
    Arguments:
        input_ids: (batch, seq_len)
        max_length: int
109
110
        teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
            logits, the next token is taken from the teacher_outputs. Useful for testing.
111
    Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
Tri Dao's avatar
Tri Dao committed
112
113
114
115
        sequences: (batch, max_length)
        scores: tuples of (batch, vocab_size)
    """
    batch_size, seqlen_og = input_ids.shape
Tri Dao's avatar
Tri Dao committed
116
    teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
117
118
    if cg:
        assert fused_ft_kernel
Tri Dao's avatar
Tri Dao committed
119
        if not hasattr(model, "_decoding_cache"):
120
121
            model._decoding_cache = None
        model._decoding_cache = update_graph_cache(
Tri Dao's avatar
Tri Dao committed
122
123
124
125
126
127
            model,
            model._decoding_cache,
            batch_size,
            seqlen_og,
            max_length,
            tensor_parallel=tensor_parallel,
128
129
130
131
132
133
        )
        inference_params = model._decoding_cache.inference_params
        inference_params.max_sequence_len = max_length
        inference_params.max_batch_size = batch_size
        inference_params.sequence_len_offset = 0
    else:
Tri Dao's avatar
Tri Dao committed
134
135
136
        inference_params = InferenceParams(
            max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel
        )
Tri Dao's avatar
Tri Dao committed
137
138
139
140
141
142
143

    def logits_forward_fn(input_ids, position_ids, inference_params):
        if not cg:
            return model(
                input_ids,
                position_ids=position_ids,
                inference_params=inference_params,
144
145
                num_last_tokens=1,
            ).logits.squeeze(dim=1)
Tri Dao's avatar
Tri Dao committed
146
147
148
149
150
151
152
153
154
        else:
            return model._decoding_cache.run(
                input_ids, position_ids, inference_params.sequence_len_offset
            ).clone()

    logits_postprocess_fn = (
        lambda logits: logits[..., :vocab_size] if vocab_size is not None else logits
    )

Tri Dao's avatar
Tri Dao committed
155
156
    scores = []
    with torch.inference_mode():
Tri Dao's avatar
Tri Dao committed
157
        if timing:
158
159
            if tensor_parallel > 1:
                torch.distributed.barrier()
Tri Dao's avatar
Tri Dao committed
160
161
            torch.cuda.synchronize()
            start = time.time()
162
163
164
        logits = model(
            input_ids, inference_params=inference_params, num_last_tokens=1
        ).logits.squeeze(dim=1)
Tri Dao's avatar
Tri Dao committed
165
        logits = logits_postprocess_fn(logits)
166
        scores.append(logits if not cg else logits.clone())
Tri Dao's avatar
Tri Dao committed
167
168
169
170
        if teacher_outputs is None or teacher_output_len <= seqlen_og:
            next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
        else:
            next_token = teacher_outputs[:, seqlen_og]
Tri Dao's avatar
Tri Dao committed
171
172
173
        sequences = [next_token]
        inference_params.sequence_len_offset = seqlen_og
        while True:
Tri Dao's avatar
Tri Dao committed
174
175
176
177
178
179
            position_ids = torch.full(
                (batch_size, 1),
                inference_params.sequence_len_offset,
                dtype=torch.long,
                device=input_ids.device,
            )
180
181
182
            logits = logits_postprocess_fn(
                logits_forward_fn(rearrange(next_token, "b -> b 1"), position_ids, inference_params)
            )
Tri Dao's avatar
Tri Dao committed
183
            scores.append(logits)
Tri Dao's avatar
Tri Dao committed
184
185
186
187
            if (
                teacher_outputs is None
                or teacher_output_len <= inference_params.sequence_len_offset + 1
            ):
188
                next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
Tri Dao's avatar
Tri Dao committed
189
190
            else:
                next_token = teacher_outputs[:, inference_params.sequence_len_offset + 1]
Tri Dao's avatar
Tri Dao committed
191
192
            sequences.append(next_token)
            inference_params.sequence_len_offset += 1
Tri Dao's avatar
Tri Dao committed
193
194
            if eos_token_id is not None and (next_token == eos_token_id).all():
                break
Tri Dao's avatar
Tri Dao committed
195
196
            if inference_params.sequence_len_offset >= max_length - 1:
                break
Tri Dao's avatar
Tri Dao committed
197
        if timing:
198
199
            if tensor_parallel > 1:
                torch.distributed.barrier()
200
            torch.cuda.synchronize()
Tri Dao's avatar
Tri Dao committed
201
            print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
202
203
    output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
    return output_cls(
Tri Dao's avatar
Tri Dao committed
204
        sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1), scores=tuple(scores)
Tri Dao's avatar
Tri Dao committed
205
206
207
208
    )


class GenerationMixin:
209
210
211
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        raise NotImplementedError

Tri Dao's avatar
Tri Dao committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    def generate(
        self,
        input_ids,
        max_length,
        top_k=1,
        top_p=0.0,
        temperature=1.0,
        return_dict_in_generate=False,
        output_scores=False,
        **kwargs,
    ):
        output = decode(
            input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs
        )
Tri Dao's avatar
Tri Dao committed
226
227
228
        if not output_scores:
            output.scores = None
        return output if return_dict_in_generate else output.sequences
Tri Dao's avatar
Tri Dao committed
229
230


Tri Dao's avatar
Tri Dao committed
231
232
233
234
235
236
237
238
239
def allocate_inference_cache(
    max_batch_size,
    max_seqlen,
    nheads,
    headdim,
    layers: Union[int, Sequence],
    device,
    dtype=torch.float16,
):
240
241
242
243
244
245
246
    assert dtype in [torch.float16, torch.bfloat16, torch.float32]
    packsize = 4 if dtype == torch.float32 else 8
    assert headdim % packsize == 0
    k_cache_shape = (max_batch_size, nheads, headdim // packsize, max_seqlen, packsize)
    v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim)
    if isinstance(layers, int):
        layers = range(layers)
Tri Dao's avatar
Tri Dao committed
247
248
249
250
251
252
253
    return {
        i: (
            torch.empty(k_cache_shape, device=device, dtype=dtype),
            torch.empty(v_cache_shape, device=device, dtype=dtype),
        )
        for i in layers
    }
Tri Dao's avatar
Tri Dao committed
254
255
256
257
258
259
260
261
262
263
264


def seqlen_to_seqlen_type(seqlen: int) -> int:
    """Convert sequence length to a seqlen_type.
    This is used to determine which cuda graph to use.
    Arguments:
        seqlen: int
    """
    return 0 if seqlen < 32 else (1 if seqlen < 2048 else 2)


Tri Dao's avatar
Tri Dao committed
265
def seqlen_type_to_max_seqlen(seqlen_type: int) -> int:
Tri Dao's avatar
Tri Dao committed
266
    assert seqlen_type in [0, 1, 2]
Tri Dao's avatar
Tri Dao committed
267
    return 32 if seqlen_type == 0 else (2048 if seqlen_type == 1 else 2**32)
Tri Dao's avatar
Tri Dao committed
268
269


270
271
272
273
274
275
276
277
278
279
280
281
282
@dataclass
class DecodingCGCache:
    max_batch_size: int = 0
    max_seqlen: int = 0
    device = None
    dtype = None
    callables: dict = field(default_factory=dict)
    mempool = None
    inference_params: Optional[InferenceParams] = None
    run: Optional[Callable] = None


@torch.inference_mode()
Tri Dao's avatar
Tri Dao committed
283
284
285
def update_graph_cache(
    model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1, dtype=None, n_warmups=2
):
286
287
288
289
290
291
    if cache is None:
        cache = DecodingCGCache()
    param_example = next(iter(model.parameters()))
    device = param_example.device
    if dtype is None:
        dtype = param_example.dtype
Tri Dao's avatar
Tri Dao committed
292
293
294
295
296
    if (
        (device, dtype) != (cache.device, cache.dtype)
        or batch_size > cache.max_batch_size
        or max_seqlen > cache.max_seqlen
    ):  # Invalidate the cache
297
298
299
300
301
302
        cache.callables = {}
        cache.mempool = None
        cache.inference_params = None
        gc.collect()
        cache.device, cache.dtype = device, dtype
        cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
Tri Dao's avatar
Tri Dao committed
303
        if hasattr(model, "allocate_inference_cache"):
304
305
            inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
        else:
Tri Dao's avatar
Tri Dao committed
306
307
308
309
310
            headdim = getattr(
                model.config,
                "head_dim",
                model.config.hidden_size // model.config.num_attention_heads,
            )
311
            inf_cache = allocate_inference_cache(
Tri Dao's avatar
Tri Dao committed
312
313
314
315
316
317
318
                batch_size,
                max_seqlen,
                model.config.num_attention_heads // tensor_parallel,
                headdim,
                model.config.num_hidden_layers,
                device,
                dtype,
319
            )
320
321
        lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
        cache.inference_params = InferenceParams(
Tri Dao's avatar
Tri Dao committed
322
323
324
325
326
327
            max_sequence_len=max_seqlen,
            max_batch_size=batch_size,
            sequence_len_offset=seqlen_og,
            key_value_memory_dict=inf_cache,
            fused_ft_kernel=True,
            lengths_per_sample=lengths_per_sample,
328
329
330
        )
        cache.mempool = torch.cuda.graphs.graph_pool_handle()
    for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1):
331
        if (batch_size, s_type) not in cache.callables:
Tri Dao's avatar
Tri Dao committed
332
            max_seqlen_ = min(max(seqlen_og, seqlen_type_to_max_seqlen(s_type)), max_seqlen)
333
            cache.callables[batch_size, s_type] = capture_graph(
Tri Dao's avatar
Tri Dao committed
334
335
336
337
338
339
                model,
                cache.inference_params,
                batch_size,
                max_seqlen_,
                mempool=cache.mempool,
                n_warmups=n_warmups,
340
341
342
            )

    def dispatch(input_ids, position_ids, seqlen):
343
        batch_size = input_ids.shape[0]
Tri Dao's avatar
Tri Dao committed
344
345
346
        return cache.callables[batch_size, seqlen_to_seqlen_type(seqlen)](
            input_ids, position_ids, seqlen
        )
347
348

    cache.run = dispatch
Tri Dao's avatar
Tri Dao committed
349
    cache.inference_params.sequence_len_offset = 0  # Reset so it's not confusing
350
351
352
    return cache


Tri Dao's avatar
Tri Dao committed
353
def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, n_warmups=2):
Tri Dao's avatar
Tri Dao committed
354
355
356
    device = next(iter(model.parameters())).device
    input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
    position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
Tri Dao's avatar
Tri Dao committed
357
358
359
360
361
    sequence_len_offset_og = inference_params.sequence_len_offset
    # TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is
    # used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample.
    inference_params.sequence_len_offset = max_seqlen - 1
    inference_params.lengths_per_sample[:] = max_seqlen - 1
362
363
364
365
366

    # Warmup before capture
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
367
        for _ in range(n_warmups):
Tri Dao's avatar
Tri Dao committed
368
369
370
371
            logits = model(
                input_ids,
                position_ids=position_ids,
                inference_params=inference_params,
372
                num_last_tokens=1,
Tri Dao's avatar
Tri Dao committed
373
            ).logits
374
        s.synchronize()
375
376
377
378
379
        # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
        # which requires that graph launch and non-captured launch to not overlap (I think,
        # that's how I interpret the documentation). I'm not sure if this is required.
        if torch.distributed.is_initialized():
            torch.distributed.barrier()
380
381
382
383
384
    torch.cuda.current_stream().wait_stream(s)
    # Captures the graph
    # To allow capture, automatically sets a side stream as the current stream in the context
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph, pool=mempool):
Tri Dao's avatar
Tri Dao committed
385
386
387
388
        logits = model(
            input_ids,
            position_ids=position_ids,
            inference_params=inference_params,
389
390
            num_last_tokens=1,
        ).logits.squeeze(dim=1)
Tri Dao's avatar
Tri Dao committed
391
392
393

    def run(new_input_ids, new_position_ids, seqlen):
        inference_params.lengths_per_sample[:] = seqlen
394
395
396
397
        input_ids.copy_(new_input_ids)
        position_ids.copy_(new_position_ids)
        graph.replay()
        return logits
Tri Dao's avatar
Tri Dao committed
398

Tri Dao's avatar
Tri Dao committed
399
    inference_params.sequence_len_offset = sequence_len_offset_og
400
    return run