generation.py 9.5 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
3
from typing import Optional
Tri Dao's avatar
Tri Dao committed
4
import time
5

Tri Dao's avatar
Tri Dao committed
6
from dataclasses import dataclass, field
Tri Dao's avatar
Tri Dao committed
7
8
from collections import namedtuple

Tri Dao's avatar
Tri Dao committed
9
import torch
10
from torch import Tensor
Tri Dao's avatar
Tri Dao committed
11
from torch.profiler import profile, record_function, ProfilerActivity
Tri Dao's avatar
Tri Dao committed
12
13
14

from einops import rearrange

15
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
Tri Dao's avatar
Tri Dao committed
16
17
18
19
20
21
22
23
24
25
26


@dataclass
class InferenceParams:
    """Inference parameters that are passed to the main model in order
    to efficienly calculate and store the context during inference."""
    max_sequence_len: int
    max_batch_size: int
    sequence_len_offset: int = 0
    batch_size_offset: int = 0
    key_value_memory_dict: dict = field(default_factory=dict)
27
28
    fused_ft_kernel: bool = False
    lengths_per_sample: Optional[Tensor] = None
Tri Dao's avatar
Tri Dao committed
29
30


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
def modify_logits_for_top_p_filtering(logits, top_p):
    """Set the logits for none top-p values to -inf."""
    if top_p <= 0.0:
        return
    # First sort and calculate cumulative sum of probabilities.
    sorted_logits, sorted_indices = torch.sort(logits, descending=False)
    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
     # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
    sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
    # scatter sorted tensors to original indexing
    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
    logits = logits.masked_fill(indices_to_remove, float('-inf'))


def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
    """Sample from top-k logits.
    Arguments:
        logits: Tensor of shape (batch_size, vocab_size)
    """
    if top_k == 1:  # Short-circuit for greedy decoding
        return logits.argmax(dim=-1)
    else:
        if top_p > 0.0:
            assert top_p <= 1.0, 'top-p should be in (0, 1].'
        if top_k > 0:
            top_k = min(top_k, logits.size(-1))  # Safety check
            logits_top, indices = torch.topk(logits, top_k, dim=-1)
            logits_top /= temperature
            modify_logits_for_top_p_filtering(logits_top, top_p)
            return indices[
                torch.arange(indices.shape[0], device=indices.device),
                torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
            ]
        else:
            logits_top = logits / temperature
            modify_logits_for_top_p_filtering(logits_top, top_p)
            return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)


Tri Dao's avatar
Tri Dao committed
72
73
def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
           fused_ft_kernel=False, cg=False, timing=False):
74
75
76
77
    """Decoding, either greedy or with top-k or top-p sampling.
    If top-k = 0, don't limit the number of candidates (pure sampling).
    Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
    then top-p.
Tri Dao's avatar
Tri Dao committed
78
    We assume that all sequences in the same batch have the same length.
79

Tri Dao's avatar
Tri Dao committed
80
81
82
    Arguments:
        input_ids: (batch, seq_len)
        max_length: int
83
    Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
Tri Dao's avatar
Tri Dao committed
84
85
86
87
        sequences: (batch, max_length)
        scores: tuples of (batch, vocab_size)
    """
    batch_size, seqlen_og = input_ids.shape
88
89
    inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size,
                                       fused_ft_kernel=fused_ft_kernel)
Tri Dao's avatar
Tri Dao committed
90
91
92
93
    scores = []
    with torch.inference_mode():
        logits = model(input_ids, inference_params=inference_params).logits[:, -1]
        scores.append(logits)
94
        next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
Tri Dao's avatar
Tri Dao committed
95
96
        sequences = [next_token]
        inference_params.sequence_len_offset = seqlen_og
Tri Dao's avatar
Tri Dao committed
97
98
99
100
101
        if cg:
            assert fused_ft_kernel
            run, cg_cache = capture_cg(model, inference_params, batch_size, seqlen_og, max_length)
        if timing:
            start = time.time()
Tri Dao's avatar
Tri Dao committed
102
103
        while True:
            position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset,
Tri Dao's avatar
Tri Dao committed
104
105
106
107
108
109
110
                                    dtype=torch.long, device=input_ids.device)
            if not cg:
                logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
                               inference_params=inference_params).logits[:, -1]
            else:
                logits = run(rearrange(next_token, 'b -> b 1'), position_ids,
                             inference_params.sequence_len_offset)
Tri Dao's avatar
Tri Dao committed
111
            scores.append(logits)
112
            next_token = sample(logits, top_k=top_k, temperature=temperature)
Tri Dao's avatar
Tri Dao committed
113
114
115
116
            sequences.append(next_token)
            inference_params.sequence_len_offset += 1
            if inference_params.sequence_len_offset >= max_length - 1:
                break
Tri Dao's avatar
Tri Dao committed
117
118
        if timing:
            print(f'Decoding time: {time.time() - start}')
119
120
    output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
    return output_cls(
Tri Dao's avatar
Tri Dao committed
121
122
123
124
125
126
127
        sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
        scores=tuple(scores)
    )


class GenerationMixin:

128
129
130
131
    def generate(self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0,
                 return_dict_in_generate=False, output_scores=False, **kwargs):
        output = decode(input_ids, self, max_length, top_k=top_k, top_p=top_p,
                        temperature=temperature, **kwargs)
Tri Dao's avatar
Tri Dao committed
132
133
134
        if not output_scores:
            output.scores = None
        return output if return_dict_in_generate else output.sequences
Tri Dao's avatar
Tri Dao committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214


CgKey = namedtuple('CgKey', ['batch_size', 'seqlen_type', 'max_length'])
CgVal = namedtuple('CgVal', ['graph', 'input_ids', 'position_ids', 'lengths', 'logits'])


def seqlen_to_seqlen_type(seqlen: int) -> int:
    """Convert sequence length to a seqlen_type.
    This is used to determine which cuda graph to use.
    Arguments:
        seqlen: int
    """
    return 0 if seqlen < 32 else (1 if seqlen < 2048 else 2)


def seqlen_type_to_seqlen(seqlen_type: int) -> int:
    assert seqlen_type in [0, 1, 2]
    return 1 if seqlen_type == 0 else (32 if seqlen_type == 1 else 2048)


def capture_cg(model, inference_params, batch_size, seqlen_og, max_length, copy_output=False):
    """Build a cache of cuda graphs for decoding.
    Arguments:
        model: a GPTLMHeadModel
        batch_size: int
        seqlen_og: int. Length of the prompt.
        max_length: int
    TODO: how do we deal with the k_cache and v_cache memory? I think the CUDA graph also
    has to own the k_cache and v_cache?
    Here we assume that the model already has inference_params from the prompt processing.
    """
    assert max_length > seqlen_og
    cg_cache: dict[CgKey, CgVal] = {}
    device = next(iter(model.parameters())).device
    sequence_length_offset_og = inference_params.sequence_len_offset
    input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
    position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
    inference_params.lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32,
                                                        device=device)

    memory_pool = None
    for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_length) + 1):
        seqlen = max(seqlen_og, seqlen_type_to_seqlen(s_type))
        input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
        position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
        inference_params.lengths_per_sample[:] = seqlen
        inference_params.sequence_len_offset = seqlen
        g = torch.cuda.CUDAGraph()
        # Warmup before capture
        s = torch.cuda.Stream()
        s.wait_stream(torch.cuda.current_stream())
        with torch.cuda.stream(s):
            for _ in range(2):
                logits = model(input_ids, position_ids=position_ids,
                               inference_params=inference_params).logits[:, -1]
        torch.cuda.current_stream().wait_stream(s)
        # Captures the graph
        # To allow capture, automatically sets a side stream as the current stream in the context
        with torch.cuda.graph(g, pool=memory_pool):
            logits = model(input_ids, position_ids=position_ids,
                            inference_params=inference_params).logits[:, -1]
        if memory_pool is None:
            memory_pool = g.pool()
        cg_cache[CgKey(batch_size, s_type, max_length)] = CgVal(
            g, input_ids, position_ids, inference_params.lengths_per_sample, logits
        )

    def run(new_input_ids, new_position_ids, seqlen):
        cg_val = cg_cache[CgKey(batch_size, seqlen_to_seqlen_type(seqlen), max_length)]
        inference_params.lengths_per_sample = cg_val.lengths
        inference_params.lengths_per_sample[:] = seqlen
        cg_val.input_ids.copy_(new_input_ids)
        cg_val.position_ids.copy_(new_position_ids)
        cg_val.graph.replay()
        output = cg_val.logits
        return output.clone() if copy_output else output

    inference_params.sequence_len_offset = sequence_length_offset_og

    return run, cg_cache