retro_generation.py 11.2 KB
Newer Older
wangsen's avatar
wangsen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.


"""Generation utilities."""
import torch
import torch.nn.functional as F
from megatron.training import get_args, get_tokenizer
from megatron.training import get_retro_args
from megatron.core import mpu
from megatron.training.utils import get_ltor_masks_and_position_ids, unwrap_model
from megatron.inference.text_generation.communication import (
    copy_from_last_to_first_pipeline_stage,
    broadcast_from_last_pipeline_stage,
    broadcast_from_last_to_first_pipeline_stage, broadcast_int_list, broadcast_tensor)
from megatron.inference.text_generation.generation import _build_attention_mask_and_position_ids
from megatron.inference.text_generation.sampling import sample



def retro_generate_tokens_probs_and_return_on_first_stage(
        model, tokens, lengths, neighbours_array=None,
        return_output_log_probs=False,
        top_k=0, top_p=0.0,
        temperature=1.0,
        use_eod_token_for_early_termination=True,
        stop_on_double_eol=False,
        stop_on_eol=False,
        logits_mask=None):
    """Main token generation function.

    Args:
        model: no interleaving is supported.
        tokens: prompt tokens extended to be of size [b, max-sequence-length]
        lengths: original prompt length, size: [b]
        neighbours_array: neighbours array of size [b, l, k, r]
        return_output_log_probs: flag to calculate the log probability of
            the generated tokens. Note that the log probability is the one
            from the original logit.
        top_k, top_p: top-k and top-p sampling parameters.
            Note that top-k = 1 is gready. Also, these paramters are
            exclusive meaning that:
                if top-k > 0 then we expect top-p=0.
                if top-p > 0 then we check for top-k=0.
        temperature: sampling temperature.
        use_eod_token_for_early_termination: if True, do early termination if
            all the sequences have reached this token.
    Note: Outside of model, other parameters only need to be available on
          rank 0.

    Returns: Note that is size is adjusted to a lower value than
             max-sequence-length if generation is terminated early.
        tokens: prompt and generated tokens. size: [b, :]
        generated_sequence_lengths: total length (including prompt) of
            the generated sequence. size: [b]
        output_log_probs: log probability of the selected tokens. size: [b, s]
    """

    args = get_args()
    retro_args = get_retro_args()

    tokenizer = get_tokenizer()

    batch_size = tokens.size(0)
    min_prompt_length = lengths.min().item()
    max_sequence_length = tokens.size(1)
    print("max_sequence_length", max_sequence_length)
    print("min_prompt_length", min_prompt_length)
    max_sequence_length = min(max_sequence_length, args.max_position_embeddings)

    # If the context is too big, this happens
    if min_prompt_length >= max_sequence_length:
        raise ValueError("context length + tokens_to_generate too large")

    # forward step.
    unwrapped_model = unwrap_model(
        model)
    unwrapped_model.language_model.seq_length = max_sequence_length

    # Added termination_id to support the case that we want to terminate the
    # generation once that id is generated.
    if hasattr(args, 'eos_id'):
        termination_id = args.eos_id
    else:
        termination_id = tokenizer.eod

    # ===================
    # Pre-allocate memory
    # ===================

    # Log probability of the sequence (prompt + generated tokens).
    output_log_probs = None
    output_log_probs_size = (batch_size, max_sequence_length - 1)
    # Lengths of generated seuquence including including prompts.
    generated_sequence_lengths = None
    if mpu.is_pipeline_last_stage():
        if return_output_log_probs:
            output_log_probs = torch.empty(output_log_probs_size,
                                           dtype=torch.float32,
                                           device=torch.cuda.current_device())
        generated_sequence_lengths = torch.ones(
            batch_size, dtype=torch.int64,
            device=torch.cuda.current_device()) * max_sequence_length

    # Whether we have reached a termination id.
    is_generation_done = torch.zeros(batch_size, dtype=torch.uint8,
                                     device=torch.cuda.current_device())

    # =============
    # Run infernece
    # =============

    with torch.no_grad():
        attention_mask, position_ids = _build_attention_mask_and_position_ids(
            tokens)
        for context_length in range(min_prompt_length, max_sequence_length):
            prev_context_length = 0
            sizes_list = None
            neighbor_tokens_cuda_long_tensor = None

            # get the chunks for retrieval
            if torch.distributed.get_rank() == 0:
                neighbor_tokens = neighbours_array
                neighbor_tokens_cuda_long_tensor = torch.cuda.LongTensor(
                    neighbor_tokens.reshape((-1, retro_args.retro_gpt_retrieved_length)))
                sizes_list = [neighbor_tokens_cuda_long_tensor.size(0),  # Batch size
                              neighbor_tokens_cuda_long_tensor.size(1)]  # Sequence lenght
            sizes_tensor = broadcast_int_list(2, int_list=sizes_list)
            sizes = sizes_tensor.tolist()
            neighbor_tokens_cuda_long_tensor = broadcast_tensor(
                sizes, torch.int64, tensor=neighbor_tokens_cuda_long_tensor)

            _, _, neighbor_position_ids = get_ltor_masks_and_position_ids(
                neighbor_tokens_cuda_long_tensor,
                tokenizer.eod,
                args.reset_position_ids,
                args.reset_attention_mask,
                args.eod_mask_loss)
            neighbor_attention_mask = None

            # Pick the slice that we need to pass through the network.
            tokens2use = tokens[:, prev_context_length:4096]
            positions2use = position_ids[:, prev_context_length:4096]
            attention_mask2use = attention_mask[
                                 ..., prev_context_length:4096, :4096]

            logits = model(tokens2use, positions2use, attention_mask2use,
                           retriever_input_ids=neighbor_tokens_cuda_long_tensor,
                           retriever_position_ids=neighbor_position_ids, retriever_attn_mask=neighbor_attention_mask,
                           )

            if mpu.is_pipeline_last_stage():
                # Always the last stage should have an output.
                assert logits is not None

                # Sample.
                last_token_logits = logits[:, context_length - 1, :]
                # last_token_logits = logits[:, -1, :]

                # word banning
                if logits_mask is not None:
                    last_token_logits[:, logits_mask] = float('-Inf')

                new_sample = sample(last_token_logits,
                                    top_k=top_k,
                                    top_p=top_p,
                                    temperature=temperature,
                                    vocab_size=tokenizer.vocab_size)

                # If a prompt length is smaller or equal th current context
                # length, it means we have started generating tokens
                started = lengths <= context_length
                # Update the tokens.
                tokens[started, context_length] = new_sample[started]

                # Calculate the log probabilities.
                if return_output_log_probs:
                    log_probs = F.log_softmax(logits, dim=2)
                    if return_output_log_probs:
                        # Pick the tokens that we need to get the log
                        # probabilities for. Note that next input token is
                        # the token which we selected in the current logits,
                        # so shift by 1.
                        indices = torch.unsqueeze(
                            tokens[
                            :,
                            (prev_context_length + 1):(context_length + 1)],
                            2)
                        output_log_probs[:,
                        prev_context_length:context_length] = \
                            torch.gather(log_probs, 2, indices).squeeze(2)

            # Update the tokens on the first stage so the next input to
            # the network is correct.
            copy_from_last_to_first_pipeline_stage(batch_size, torch.int64,
                                                   tokens[:, context_length])

            # Update the context length for the next token generation.
            prev_context_length = context_length

            # Check if all the sequences have hit the termination_id.
            done = None
            if mpu.is_pipeline_last_stage():
                # TODO(rprenger) These stopping methods are tokenizer dependent
                # instead tokenization should be in the inference loop so stop sequences can be used
                if stop_on_double_eol:
                    hit_double_eol = (new_sample == 628).byte() & started.byte()
                    hit_two_eols = (new_sample == 198).byte() & (
                            tokens[:, context_length - 1] == 198).byte() & started.byte()
                    done_token = hit_double_eol | hit_two_eols
                elif stop_on_eol:
                    hit_double_eol = (new_sample == 628).byte() & started.byte()
                    hit_eol = (new_sample == 198).byte() & started.byte()
                    done_token = hit_double_eol | hit_eol
                elif context_length > min_prompt_length + 64:  # previous retrov1 limitations
                    done_token = 1
                else:
                    done_token = (new_sample == termination_id).byte() & \
                                 started.byte()

                just_finished = (done_token & ~is_generation_done).bool()
                generated_sequence_lengths[just_finished.view(-1)] = \
                    context_length + 1
                is_generation_done = is_generation_done | done_token
                done = torch.all(is_generation_done)
            done = broadcast_from_last_pipeline_stage(1, torch.uint8,
                                                      tensor=done)
            if use_eod_token_for_early_termination and done:
                break

    # ===================================================
    # Update the length of based on max generated length.
    # ===================================================

    tokens = tokens[:, :(context_length + 1)]
    if mpu.is_pipeline_last_stage():
        if return_output_log_probs:
            output_log_probs = output_log_probs[:, :context_length]

    # ======================================
    # Broadcast to the first pipeline stage.
    # ======================================

    generated_sequence_lengths = broadcast_from_last_to_first_pipeline_stage(
        batch_size, torch.int64, generated_sequence_lengths)
    if return_output_log_probs:
        output_log_probs_size = (batch_size, context_length)
        output_log_probs = broadcast_from_last_to_first_pipeline_stage(
            output_log_probs_size, torch.float32, output_log_probs)

    return tokens, generated_sequence_lengths, output_log_probs