generation.py 13.8 KB
Newer Older
1
# Copyright (c) 2023, Tri Dao.
Tri Dao's avatar
Tri Dao committed
2
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
3
4
from typing import Optional, Union, Sequence, Callable
import gc
Tri Dao's avatar
Tri Dao committed
5
import time
6

Tri Dao's avatar
Tri Dao committed
7
from dataclasses import dataclass, field
Tri Dao's avatar
Tri Dao committed
8
9
from collections import namedtuple

Tri Dao's avatar
Tri Dao committed
10
import torch
11
from torch import Tensor
Tri Dao's avatar
Tri Dao committed
12
from torch.profiler import profile, record_function, ProfilerActivity
Tri Dao's avatar
Tri Dao committed
13
14
15

from einops import rearrange

16
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
Tri Dao's avatar
Tri Dao committed
17
18
19
20
21
22
23
24
25
26
27


@dataclass
class InferenceParams:
    """Inference parameters that are passed to the main model in order
    to efficienly calculate and store the context during inference."""
    max_sequence_len: int
    max_batch_size: int
    sequence_len_offset: int = 0
    batch_size_offset: int = 0
    key_value_memory_dict: dict = field(default_factory=dict)
28
29
    fused_ft_kernel: bool = False
    lengths_per_sample: Optional[Tensor] = None
Tri Dao's avatar
Tri Dao committed
30
31


32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
def modify_logits_for_top_p_filtering(logits, top_p):
    """Set the logits for none top-p values to -inf."""
    if top_p <= 0.0:
        return
    # First sort and calculate cumulative sum of probabilities.
    sorted_logits, sorted_indices = torch.sort(logits, descending=False)
    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
     # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
    sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
    # scatter sorted tensors to original indexing
    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
    logits = logits.masked_fill(indices_to_remove, float('-inf'))


def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
    """Sample from top-k logits.
    Arguments:
        logits: Tensor of shape (batch_size, vocab_size)
    """
    if top_k == 1:  # Short-circuit for greedy decoding
        return logits.argmax(dim=-1)
    else:
        if top_p > 0.0:
            assert top_p <= 1.0, 'top-p should be in (0, 1].'
        if top_k > 0:
            top_k = min(top_k, logits.size(-1))  # Safety check
            logits_top, indices = torch.topk(logits, top_k, dim=-1)
            logits_top /= temperature
            modify_logits_for_top_p_filtering(logits_top, top_p)
            return indices[
                torch.arange(indices.shape[0], device=indices.device),
                torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
            ]
        else:
            logits_top = logits / temperature
            modify_logits_for_top_p_filtering(logits_top, top_p)
            return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)


Tri Dao's avatar
Tri Dao committed
73
def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
Tri Dao's avatar
Tri Dao committed
74
75
           eos_token_id=None, teacher_outputs=None, vocab_size=None, tensor_parallel=1,
           fused_ft_kernel=False, cg=False, timing=False):
76
77
78
79
    """Decoding, either greedy or with top-k or top-p sampling.
    If top-k = 0, don't limit the number of candidates (pure sampling).
    Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
    then top-p.
Tri Dao's avatar
Tri Dao committed
80
    We assume that all sequences in the same batch have the same length.
81

Tri Dao's avatar
Tri Dao committed
82
83
84
    Arguments:
        input_ids: (batch, seq_len)
        max_length: int
85
86
        teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
            logits, the next token is taken from the teacher_outputs. Useful for testing.
87
    Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
Tri Dao's avatar
Tri Dao committed
88
89
90
91
        sequences: (batch, max_length)
        scores: tuples of (batch, vocab_size)
    """
    batch_size, seqlen_og = input_ids.shape
Tri Dao's avatar
Tri Dao committed
92
    teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    if cg:
        assert fused_ft_kernel
        if not hasattr(model, '_decoding_cache'):
            model._decoding_cache = None
        model._decoding_cache = update_graph_cache(
            model, model._decoding_cache, batch_size, seqlen_og, max_length,
            tensor_parallel=tensor_parallel
        )
        inference_params = model._decoding_cache.inference_params
        inference_params.max_sequence_len = max_length
        inference_params.max_batch_size = batch_size
        inference_params.sequence_len_offset = 0
    else:
        inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size,
                                           fused_ft_kernel=fused_ft_kernel)
Tri Dao's avatar
Tri Dao committed
108
109
    scores = []
    with torch.inference_mode():
Tri Dao's avatar
Tri Dao committed
110
        if timing:
111
112
            if tensor_parallel > 1:
                torch.distributed.barrier()
Tri Dao's avatar
Tri Dao committed
113
114
            torch.cuda.synchronize()
            start = time.time()
115
        logits = model(input_ids, inference_params=inference_params, last_token_only=True).logits
116
117
        if vocab_size is not None:
            logits = logits[..., :vocab_size]
118
        scores.append(logits if not cg else logits.clone())
Tri Dao's avatar
Tri Dao committed
119
120
121
122
        if teacher_outputs is None or teacher_output_len <= seqlen_og:
            next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
        else:
            next_token = teacher_outputs[:, seqlen_og]
Tri Dao's avatar
Tri Dao committed
123
124
125
126
        sequences = [next_token]
        inference_params.sequence_len_offset = seqlen_og
        while True:
            position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset,
Tri Dao's avatar
Tri Dao committed
127
128
129
                                    dtype=torch.long, device=input_ids.device)
            if not cg:
                logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
130
                               inference_params=inference_params, last_token_only=True).logits
Tri Dao's avatar
Tri Dao committed
131
            else:
132
133
134
135
                logits = model._decoding_cache.run(rearrange(next_token, 'b -> b 1'), position_ids,
                                                   inference_params.sequence_len_offset)
            if vocab_size is not None:
                logits = logits[..., :vocab_size]
136
            scores.append(logits if not cg else logits.clone())
Tri Dao's avatar
Tri Dao committed
137
138
139
140
            if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset + 1:
                next_token = sample(logits, top_k=top_k, temperature=temperature)
            else:
                next_token = teacher_outputs[:, inference_params.sequence_len_offset + 1]
Tri Dao's avatar
Tri Dao committed
141
142
            sequences.append(next_token)
            inference_params.sequence_len_offset += 1
Tri Dao's avatar
Tri Dao committed
143
144
            if eos_token_id is not None and (next_token == eos_token_id).all():
                break
Tri Dao's avatar
Tri Dao committed
145
146
            if inference_params.sequence_len_offset >= max_length - 1:
                break
Tri Dao's avatar
Tri Dao committed
147
        if timing:
148
149
            if tensor_parallel > 1:
                torch.distributed.barrier()
150
            torch.cuda.synchronize()
151
            print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
152
153
    output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
    return output_cls(
Tri Dao's avatar
Tri Dao committed
154
155
156
157
158
159
160
        sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
        scores=tuple(scores)
    )


class GenerationMixin:

161
162
163
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        raise NotImplementedError

164
165
166
167
    def generate(self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0,
                 return_dict_in_generate=False, output_scores=False, **kwargs):
        output = decode(input_ids, self, max_length, top_k=top_k, top_p=top_p,
                        temperature=temperature, **kwargs)
Tri Dao's avatar
Tri Dao committed
168
169
170
        if not output_scores:
            output.scores = None
        return output if return_dict_in_generate else output.sequences
Tri Dao's avatar
Tri Dao committed
171
172


173
174
def allocate_inference_cache(max_batch_size, max_seqlen, nheads, headdim, layers: Union[int, Sequence],
                             device, dtype=torch.float16):
175
176
177
178
179
180
181
182
183
184
    assert dtype in [torch.float16, torch.bfloat16, torch.float32]
    packsize = 4 if dtype == torch.float32 else 8
    assert headdim % packsize == 0
    k_cache_shape = (max_batch_size, nheads, headdim // packsize, max_seqlen, packsize)
    v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim)
    if isinstance(layers, int):
        layers = range(layers)
    return {i: (torch.empty(k_cache_shape, device=device, dtype=dtype),
                torch.empty(v_cache_shape, device=device, dtype=dtype))
            for i in layers}
Tri Dao's avatar
Tri Dao committed
185
186
187
188
189
190
191
192
193
194
195


def seqlen_to_seqlen_type(seqlen: int) -> int:
    """Convert sequence length to a seqlen_type.
    This is used to determine which cuda graph to use.
    Arguments:
        seqlen: int
    """
    return 0 if seqlen < 32 else (1 if seqlen < 2048 else 2)


Tri Dao's avatar
Tri Dao committed
196
def seqlen_type_to_max_seqlen(seqlen_type: int) -> int:
Tri Dao's avatar
Tri Dao committed
197
    assert seqlen_type in [0, 1, 2]
Tri Dao's avatar
Tri Dao committed
198
    return 32 if seqlen_type == 0 else (2048 if seqlen_type == 1 else 2**32)
Tri Dao's avatar
Tri Dao committed
199
200


201
202
203
204
205
206
207
208
209
210
211
212
213
214
@dataclass
class DecodingCGCache:
    max_batch_size: int = 0
    max_seqlen: int = 0
    device = None
    dtype = None
    callables: dict = field(default_factory=dict)
    mempool = None
    inference_params: Optional[InferenceParams] = None
    run: Optional[Callable] = None


@torch.inference_mode()
def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1,
215
                       dtype=None, n_warmups=2):
216
217
218
219
220
221
222
223
224
225
226
227
228
229
    if cache is None:
        cache = DecodingCGCache()
    param_example = next(iter(model.parameters()))
    device = param_example.device
    if dtype is None:
        dtype = param_example.dtype
    if ((device, dtype) != (cache.device, cache.dtype) or batch_size > cache.max_batch_size
        or max_seqlen > cache.max_seqlen):  # Invalidate the cache
        cache.callables = {}
        cache.mempool = None
        cache.inference_params = None
        gc.collect()
        cache.device, cache.dtype = device, dtype
        cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
230
231
232
        if hasattr(model, 'allocate_inference_cache'):
            inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
        else:
233
234
            headdim = getattr(model.config, 'head_dim',
                              model.config.hidden_size // model.config.num_attention_heads)
235
236
237
238
            inf_cache = allocate_inference_cache(
                batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim,
                model.config.num_hidden_layers, device, dtype
            )
239
240
241
        lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
        cache.inference_params = InferenceParams(
            max_sequence_len=max_seqlen, max_batch_size=batch_size,
242
            sequence_len_offset=seqlen_og, key_value_memory_dict=inf_cache, fused_ft_kernel=True,
243
244
245
246
            lengths_per_sample=lengths_per_sample
        )
        cache.mempool = torch.cuda.graphs.graph_pool_handle()
    for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1):
247
        if (batch_size, s_type) not in cache.callables:
Tri Dao's avatar
Tri Dao committed
248
            max_seqlen_ = min(max(seqlen_og, seqlen_type_to_max_seqlen(s_type)), max_seqlen)
249
            cache.callables[batch_size, s_type] = capture_graph(
Tri Dao's avatar
Tri Dao committed
250
                model, cache.inference_params, batch_size, max_seqlen_, mempool=cache.mempool,
251
                n_warmups=n_warmups
252
253
254
            )

    def dispatch(input_ids, position_ids, seqlen):
255
256
        batch_size = input_ids.shape[0]
        return cache.callables[batch_size, seqlen_to_seqlen_type(seqlen)](input_ids, position_ids, seqlen)
257
258

    cache.run = dispatch
Tri Dao's avatar
Tri Dao committed
259
    cache.inference_params.sequence_len_offset = 0  # Reset so it's not confusing
260
261
262
    return cache


Tri Dao's avatar
Tri Dao committed
263
def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, n_warmups=2):
Tri Dao's avatar
Tri Dao committed
264
265
266
    device = next(iter(model.parameters())).device
    input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
    position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
Tri Dao's avatar
Tri Dao committed
267
268
269
270
271
    sequence_len_offset_og = inference_params.sequence_len_offset
    # TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is
    # used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample.
    inference_params.sequence_len_offset = max_seqlen - 1
    inference_params.lengths_per_sample[:] = max_seqlen - 1
272
273
274
275
276

    # Warmup before capture
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
277
        for _ in range(n_warmups):
278
279
            logits = model(input_ids, position_ids=position_ids, inference_params=inference_params,
                           last_token_only=True).logits
280
        s.synchronize()
281
282
283
284
285
        # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
        # which requires that graph launch and non-captured launch to not overlap (I think,
        # that's how I interpret the documentation). I'm not sure if this is required.
        if torch.distributed.is_initialized():
            torch.distributed.barrier()
286
287
288
289
290
    torch.cuda.current_stream().wait_stream(s)
    # Captures the graph
    # To allow capture, automatically sets a side stream as the current stream in the context
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph, pool=mempool):
291
292
        logits = model(input_ids, position_ids=position_ids, inference_params=inference_params,
                       last_token_only=True).logits
Tri Dao's avatar
Tri Dao committed
293
294
295

    def run(new_input_ids, new_position_ids, seqlen):
        inference_params.lengths_per_sample[:] = seqlen
296
297
298
299
        input_ids.copy_(new_input_ids)
        position_ids.copy_(new_position_ids)
        graph.replay()
        return logits
Tri Dao's avatar
Tri Dao committed
300

Tri Dao's avatar
Tri Dao committed
301
    inference_params.sequence_len_offset = sequence_len_offset_og
302
    return run