text_generation_utils.py 16.7 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Utilities for generating text."""
17

18
import copy
Mohammad's avatar
Mohammad committed
19
20
21
22
import json
import os
import time

23
24
25
import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
26
27
28
from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
29
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
Jared Casper's avatar
Jared Casper committed
30
from megatron.p2p_communication import recv_forward, send_forward
31

32
33
34
35
36
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module

37
38
39
40
def get_batch(context_tokens):
    """Generate batch from context tokens."""
    args = get_args()
    tokenizer = get_tokenizer()
41

42
    # Move to GPU.
43
44
    tokens = context_tokens.contiguous().cuda()
    
45
46
    # Get the attention mask and postition ids.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
47
        tokens,
48
        tokenizer.eod,
49
        args.reset_position_ids,
50
        args.reset_attention_mask,
51
        args.eod_mask_loss)
52

53
54
    return tokens, attention_mask, position_ids

55

56
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
57
58
59
60
    """ This function has been mostly taken from huggingface conversational
     ai code at
         https://medium.com/huggingface/how-to-build-a-state-of-the-art-
              conversational-ai-with-transfer-learning-2d818ac26313 """
61
62

    if top_k > 0:
63
64
        # Remove all tokens with a probability less than the
        # last token of the top-k
65
66
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
67

68
    if top_p > 0.0:
69
70
71
72
73
        # Cconvert to 1D
        sorted_logits, sorted_indices = torch.sort(
            logits, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
                                        dim=-1)
74
75
76

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
77
78
79
80
        # Shift the indices to the right to keep also the first token
        # above the threshold
        sorted_indices_to_remove[..., 1:] \
            = sorted_indices_to_remove[..., :-1].clone()
81
        sorted_indices_to_remove[..., 0] = 0
82
83
84
        for i in range(sorted_indices.size(0)):
            indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
            logits[i][indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
85

86
87
    return logits

88
def pad_batch(batch, pad_id, max_len):
89
    context_lengths = []
90
    max_context_length = max([len(tokens) for tokens in batch])
91
92
    for tokens in batch:
        context_length = len(tokens)
93
94
        if context_length < max_context_length + max_len:
            tokens.extend([pad_id] * (max_context_length + max_len - context_length))
95
96
97
        context_lengths.append(context_length)
    return batch, context_lengths

98
def tokenize_batch(sentences, max_len, add_BOS):
99
100
    args = get_args()
    tokenizer = get_tokenizer()
101
102
103
104
    if add_BOS:
        context_tokens = [[tokenizer.eod] + tokenizer.tokenize(s) for s in sentences]
    else:
        context_tokens = [tokenizer.tokenize(s) for s in sentences]
105
    context_tokens, context_lengths = pad_batch(context_tokens,
106
                                                tokenizer.eod, max_len)
107
108
109
110
    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    context_length_tensor = torch.cuda.LongTensor(context_lengths)
    return context_tokens_tensor, context_length_tensor 

111
def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs):
112
113
114
115
    """
    Needs to be synced up with receive_generate_info
    """
    # Send the sizes of the tensors
116
    input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, all_probs]
117
118
119
120
121
122
123
124
125
126
127
    input_info_tensor = torch.cuda.LongTensor(input_info)
    torch.distributed.broadcast(input_info_tensor, 0)

    # Send variables to all ranks 
    torch.distributed.broadcast(context_length_tensor, 0)
    torch.distributed.broadcast(context_tokens_tensor, 0)

def receive_generate_info():
    """
    Needs to be synced up with send_generate_info
    """
rprenger's avatar
rprenger committed
128
    input_info_tensor = torch.empty(4, dtype=torch.int64, device=torch.cuda.current_device())
129
130
131
    torch.distributed.broadcast(input_info_tensor, 0)
    batch_size = input_info_tensor[0].item()
    seq_len = input_info_tensor[1].item()
132
    tokens_to_generate = input_info_tensor[2].item()
rprenger's avatar
rprenger committed
133
    all_probs = input_info_tensor[3].item()
134
    
rprenger's avatar
rprenger committed
135
136
    context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device())
    context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.cuda.current_device())
137
138
139
140
141
    
    # Send variables to all ranks 
    torch.distributed.broadcast(context_length_tensor, 0)
    torch.distributed.broadcast(context_tokens_tensor, 0)
    
142
    return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs
143

144
def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature):
145
146
147
148
149
    context_length = context_length_tensor.min().item()
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids,
150
                                                 tokens_to_generate,
151
152
                                                 all_probs,
                                                 temperature=temperature)
rprenger's avatar
rprenger committed
153
    for tokens, lengths, output_logits, full_logits in batch_token_iterator:
154
        context_length += 1
rprenger's avatar
rprenger committed
155
156
157
158
159
                
    if mpu.is_pipeline_last_stage():
        src = mpu.get_pipeline_model_parallel_last_rank()
        group = mpu.get_embedding_group()
        torch.distributed.broadcast(output_logits, src, group)
rprenger's avatar
rprenger committed
160
161
162
163
164
        if all_probs:
            src = mpu.get_pipeline_model_parallel_last_rank()
            group = mpu.get_embedding_group()
            torch.distributed.broadcast(full_logits, src, group)

rprenger's avatar
rprenger committed
165
166
167
168
169
170
    else:
        if mpu.is_pipeline_first_stage():
            src = mpu.get_pipeline_model_parallel_last_rank()
            group = mpu.get_embedding_group()
            output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda"))
            torch.distributed.broadcast(output_logits, src, group)
rprenger's avatar
rprenger committed
171
172
            
            if all_probs:
173
                args = get_args()
rprenger's avatar
rprenger committed
174
175
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_embedding_group()
176
                full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda"))
rprenger's avatar
rprenger committed
177
                torch.distributed.broadcast(full_logits, src, group)
178
    if tokens is not None:
rprenger's avatar
rprenger committed
179
        return tokens[:, :context_length], output_logits, full_logits 
180

rprenger's avatar
rprenger committed
181
def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, temperature=1.0, add_BOS=False):
182
    model.eval()
183
    if torch.distributed.get_rank() == 0:
184
        context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate, add_BOS)
185
        send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs)
186
    else:
187
        context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info()
188
189

    output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature)
rprenger's avatar
rprenger committed
190
    if output is not None:
rprenger's avatar
rprenger committed
191
        decode_tokens, output_logits, full_logits = output
192
        
193
194
195
        args = get_args()
        tokenizer = get_tokenizer()
        resp_sentences = []
rprenger's avatar
rprenger committed
196
        resp_sentences_seg = []
197
198
199
        
        decode_tokens = decode_tokens.cpu().numpy().tolist()
        for decode_token in decode_tokens:
200
            resp_sentences.append(tokenizer.detokenize(decode_token))
rprenger's avatar
rprenger committed
201
202
203
204
205
206
207
208
            words = []
            for token in decode_token:
                word = tokenizer.tokenizer.decoder[token]
                word = bytearray([tokenizer.tokenizer.byte_decoder[c] for c in word]).decode('utf-8', errors='replace')
                words.append(word)
            resp_sentences_seg.append(words)

        output_logits = output_logits.cpu().numpy().tolist()
rprenger's avatar
rprenger committed
209
210
        if all_probs:
            full_logits = full_logits.cpu().numpy().tolist()
211
       
212
        return resp_sentences, resp_sentences_seg, output_logits, full_logits, decode_tokens 
213

214
215
216
217
218
def generate_samples_eval(model, context, max_gen_length, eos_token_id):
    """
    This function is here to provide an a matching API for a legacy task
    This implementation hasn't been tested yet to make sure it matches
    """
219
    #assert False, "Implementation untested"
220
221
222
223
    args = get_args()
    args.eos_id = eos_token_id
    raw_text_len = len(context)
    resp_sentences = generate(model, [context], max_gen_length)
224
225
    if resp_sentences:
        return resp_sentences[0][raw_text_len:]
226
227

def switch(val1, val2, boolean):
228
    boolean = boolean.type_as(val1)
Mohammad's avatar
Mohammad committed
229
    return (1 - boolean) * val1 + boolean * val2
230

231

232
def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
233
234
                 set_inference_key_value_memory=False,
                 inference_max_sequence_len=None):
235

Jared Casper's avatar
Jared Casper committed
236
237
    # Hidden size changes when not using recompute, need to tell p2p_communicate
    # functions the correct size
238
239
240
    args = get_args()
    orig_seq_length = args.seq_length
    args.seq_length = tokens.shape[1]
241
    args.micro_batch_size = tokens.shape[0]
242

Jared Casper's avatar
Jared Casper committed
243
    input_tensor = recv_forward()
244
245

    # Forward pass through the model.
246
247
248
    unwrapped_model = unwrap_model(
        model, (torchDDP, LocalDDP, Float16Module))
    unwrapped_model.set_input_tensor(input_tensor)
249
250
251
252
253
    output_tensor = model(
        tokens, position_ids, attention_mask,
        tokentype_ids=tokentype_ids,
        set_inference_key_value_memory=set_inference_key_value_memory,
        inference_max_sequence_len=inference_max_sequence_len)
254

Jared Casper's avatar
Jared Casper committed
255
    send_forward(output_tensor)
256

257
    args.seq_length = orig_seq_length
258

259
260
261
    return output_tensor


262
263
def sample_sequence_batch(model, context_tokens, context_lengths,
                          attention_mask, position_ids,
264
                          tokens_to_generate, all_probs=False, type_ids=None, temperature=None):
265
266
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
267

268
269
270
    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
271

Mostofa Patwary's avatar
Mostofa Patwary committed
272
273
        # added eos_id to support the function generate_samples_eval that passes
        # eos_id as an argument and needs termination when that id id found.
274
275
276
277
        if hasattr(args, 'eos_id'):
            eos_id = args.eos_id
        else:
            eos_id = tokenizer.eod
278
279
280
281
282
283

        counter = 0

        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
rprenger's avatar
rprenger committed
284
        output_logits = None
285
       
286
287
        # Generate enough tokens for the longest sequence
        maxlen = tokens_to_generate + context_lengths.max().item() 
288
289
290
       
        if maxlen > args.seq_length:
            maxlen = args.seq_length
291
        
Neel Kant's avatar
Neel Kant committed
292
        lengths = torch.ones([batch_size]).long().cuda() * maxlen
Mohammad's avatar
Mohammad committed
293

294
        while context_length < maxlen:
295
296
            types2use = None
            if counter == 0:
297
298
                # Allocate memory for the entire context.
                set_inference_key_value_memory = True
299
300
301
302
                tokens2use = tokens[:, :context_length]
                positions2use = position_ids[:, :context_length]
                if type_ids is not None:
                    types2use = type_ids[:, :context_length]
303
            else:
304
305
                # Set this to false so the memory is not reallocated.
                set_inference_key_value_memory = False
306
307
308
309
310
311
                tokens2use = tokens[:, context_length - 1].view(
                    batch_size, -1)
                positions2use = position_ids[:, context_length - 1].view(
                    batch_size, -1)
                if type_ids is not None:
                    types2use = type_ids[:, context_length - 1].view(
312
                        batch_size, -1)
313
314
315
316
317
318
319
320
321
            
            output = forward_step(
                model, tokens2use,
                positions2use,
                attention_mask,
                set_inference_key_value_memory=set_inference_key_value_memory,
                inference_max_sequence_len=maxlen,
                tokentype_ids=types2use)

322
323
            if mpu.is_pipeline_last_stage():
                assert output is not None
324
                output = output.float()
325
                logits = output[:, -1].view(batch_size, -1).contiguous()
326
327
328
329
330

                if args.greedy:
                    prev = torch.argmax(logits, dim=-1).view(-1)
                else:
                    logits = logits.float()
331
                    logits /= temperature
332
333
334
335
336
337
                    logits = top_k_logits(logits, top_k=args.top_k,
                                          top_p=args.top_p)
                    log_probs = F.softmax(logits, dim=-1)
                    prev = torch.multinomial(log_probs, num_samples=1).view(-1)
                started = context_lengths <= context_length

338
339
340
341
                # Clamp the out of vocabulary tokens.
                tokenizer = get_tokenizer()
                prev = torch.clamp(prev, max=tokenizer.vocab_size - 1)

342
343
344
                new_tokens = switch(
                    tokens[:, context_length].view(-1), prev, started)
                tokens[:, context_length] = new_tokens
rprenger's avatar
rprenger committed
345
346
347
                
                if output_logits is None:
                    output_context = F.log_softmax(output[:, :context_length, :], 2)
348
                    indices = torch.unsqueeze(tokens[:, 1:context_length+1],2)
rprenger's avatar
rprenger committed
349
                    output_logits = torch.gather(output_context, 2, indices).squeeze(2)
rprenger's avatar
rprenger committed
350
351
                    if all_probs:
                        full_logits = output_context
rprenger's avatar
rprenger committed
352
                else:
rprenger's avatar
rprenger committed
353
                    output_context = F.log_softmax(output, 2)
rprenger's avatar
rprenger committed
354
                    indices = torch.unsqueeze(new_tokens,1).unsqueeze(2)
rprenger's avatar
rprenger committed
355
                    new_output_logits = torch.gather(output_context, 2, indices).squeeze(2)
rprenger's avatar
rprenger committed
356
357
358
                    
                    # TODO(rprenger) we're copying output_logits every time.  Should pre-allocate
                    output_logits = torch.cat([output_logits, new_output_logits],1)
rprenger's avatar
rprenger committed
359
360
                    if all_probs:
                        full_logits = torch.cat([full_logits, output_context], 1)
rprenger's avatar
rprenger committed
361
                
362
363
364
365
366
367
368
369
370
371
372
373
374
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_embedding_group()
                torch.distributed.broadcast(new_tokens, src, group)

                done_token = (prev == eos_id).byte() & started.byte()
                just_finished = (done_token & ~is_done).bool()
                lengths[just_finished.view(-1)] = context_length
                is_done = is_done | done_token

                done = torch.all(is_done)
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
rprenger's avatar
rprenger committed
375
376
377
378
                if all_probs:
                    yield tokens, lengths, output_logits, full_logits
                else:
                    yield tokens, lengths, output_logits, None
379

380
            else:
381
382
383
384
385
386
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_last_rank()
                    group = mpu.get_embedding_group()
                    new_tokens = torch.empty_like(tokens[:, context_length])
                    torch.distributed.broadcast(new_tokens, src, group)
                    tokens[:, context_length] = new_tokens
rprenger's avatar
rprenger committed
387
                    yield tokens, None, None, None
388
                else:
rprenger's avatar
rprenger committed
389
                    yield None, None, None, None
390

391
392
393
394
                done = torch.cuda.ByteTensor([0])
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
395

396
397
            context_length += 1
            counter += 1
398
399
            if done:
                break