text_generation_utils.py 16.5 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Utilities for generating text."""
17

18
import copy
Mohammad's avatar
Mohammad committed
19
20
21
22
import json
import os
import time

23
24
25
import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
26
27
28
from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
29
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
Jared Casper's avatar
Jared Casper committed
30
from megatron.p2p_communication import recv_forward, send_forward
31

32
33
34
35
36
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module

37
38
39
40
def get_batch(context_tokens):
    """Generate batch from context tokens."""
    args = get_args()
    tokenizer = get_tokenizer()
41

42
    # Move to GPU.
43
44
    tokens = context_tokens.contiguous().cuda()
    
45
46
    # Get the attention mask and postition ids.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
47
        tokens,
48
        tokenizer.eod,
49
        args.reset_position_ids,
50
        args.reset_attention_mask,
51
        args.eod_mask_loss)
52

53
54
    return tokens, attention_mask, position_ids

55

56
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
57
58
59
60
    """ This function has been mostly taken from huggingface conversational
     ai code at
         https://medium.com/huggingface/how-to-build-a-state-of-the-art-
              conversational-ai-with-transfer-learning-2d818ac26313 """
61
62

    if top_k > 0:
63
64
        # Remove all tokens with a probability less than the
        # last token of the top-k
65
66
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
67

68
    if top_p > 0.0:
69
70
71
72
73
        # Cconvert to 1D
        sorted_logits, sorted_indices = torch.sort(
            logits, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
                                        dim=-1)
74
75
76

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
77
78
79
80
        # Shift the indices to the right to keep also the first token
        # above the threshold
        sorted_indices_to_remove[..., 1:] \
            = sorted_indices_to_remove[..., :-1].clone()
81
        sorted_indices_to_remove[..., 0] = 0
82
83
84
        for i in range(sorted_indices.size(0)):
            indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
            logits[i][indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
85

86
87
    return logits

88
def pad_batch(batch, pad_id, max_len):
89
    context_lengths = []
90
    max_context_length = max([len(tokens) for tokens in batch])
91
92
    for tokens in batch:
        context_length = len(tokens)
93
94
        if context_length < max_context_length + max_len:
            tokens.extend([pad_id] * (max_context_length + max_len - context_length))
95
96
97
        context_lengths.append(context_length)
    return batch, context_lengths

98
def tokenize_batch(sentences, max_len, add_BOS):
99
100
    args = get_args()
    tokenizer = get_tokenizer()
101
102
103
104
    if add_BOS:
        context_tokens = [[tokenizer.eod] + tokenizer.tokenize(s) for s in sentences]
    else:
        context_tokens = [tokenizer.tokenize(s) for s in sentences]
105
    context_tokens, context_lengths = pad_batch(context_tokens,
106
                                                tokenizer.eod, max_len)
107
108
109
110
    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    context_length_tensor = torch.cuda.LongTensor(context_lengths)
    return context_tokens_tensor, context_length_tensor 

111
def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p):
112
113
114
115
    """
    Needs to be synced up with receive_generate_info
    """
    # Send the sizes of the tensors
116
117
    input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, logprobs, temperature, top_k, top_p]
    input_info_tensor = torch.cuda.FloatTensor(input_info)
118
119
120
121
122
123
124
125
126
127
    torch.distributed.broadcast(input_info_tensor, 0)

    # Send variables to all ranks 
    torch.distributed.broadcast(context_length_tensor, 0)
    torch.distributed.broadcast(context_tokens_tensor, 0)

def receive_generate_info():
    """
    Needs to be synced up with send_generate_info
    """
128
    input_info_tensor = torch.empty(7, dtype=torch.float32, device=torch.cuda.current_device())
129
    torch.distributed.broadcast(input_info_tensor, 0)
130
131
132
133
134
135
136
    batch_size = int(input_info_tensor[0].item())
    seq_len = int(input_info_tensor[1].item())
    tokens_to_generate = int(input_info_tensor[2].item())
    logprobs = bool(input_info_tensor[3].item())
    temperature = float(input_info_tensor[4].item())
    top_k = int(input_info_tensor[5].item())
    top_p = float(input_info_tensor[6].item())
137
    
rprenger's avatar
rprenger committed
138
139
    context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device())
    context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.cuda.current_device())
140
141
142
143
144
    
    # Send variables to all ranks 
    torch.distributed.broadcast(context_length_tensor, 0)
    torch.distributed.broadcast(context_tokens_tensor, 0)
    
145
    return context_length_tensor, context_tokens_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p
146

147
def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p):
148
149
    context_length = context_length_tensor.min().item()
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
150
151
    batch_token_iterator = sample_sequence_batch(model,
                                                 context_tokens_tensor,
152
                                                 context_length_tensor,
153
154
                                                 attention_mask,
                                                 position_ids,
155
                                                 tokens_to_generate,
156
157
158
159
160
161
                                                 logprobs,
                                                 temperature,
                                                 top_k,
                                                 top_p)

    for tokens, lengths, output_logits in batch_token_iterator:
162
        context_length += 1
163
   
rprenger's avatar
rprenger committed
164

165
166
    if logprobs:
        if mpu.is_pipeline_last_stage():
rprenger's avatar
rprenger committed
167
168
169
            src = mpu.get_pipeline_model_parallel_last_rank()
            group = mpu.get_embedding_group()
            torch.distributed.broadcast(output_logits, src, group)
170
171
172

        else:
            if mpu.is_pipeline_first_stage():
rprenger's avatar
rprenger committed
173
174
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_embedding_group()
175
176
177
                output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda"))
                torch.distributed.broadcast(output_logits, src, group)
            
178
    if tokens is not None:
179
        return tokens[:, :context_length], output_logits 
180

181
def generate(model, sentences=None, tokens_to_generate=0, logprobs=False, temperature=1.0, top_k=0, top_p=0.0, add_BOS=False):
182
    model.eval()
183
    if torch.distributed.get_rank() == 0:
184
        context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate, add_BOS)
185
        send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p)
186
    else:
187
        context_length_tensor, context_tokens_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p = receive_generate_info()
188

189
    output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p)
rprenger's avatar
rprenger committed
190
    if output is not None:
191
        decode_tokens, output_logits = output
192
        
193
194
195
        args = get_args()
        tokenizer = get_tokenizer()
        resp_sentences = []
rprenger's avatar
rprenger committed
196
        resp_sentences_seg = []
197
198
        
        decode_tokens = decode_tokens.cpu().numpy().tolist()
199
200
        
        for i, decode_token in enumerate(decode_tokens):
201
            resp_sentences.append(tokenizer.detokenize(decode_token))
rprenger's avatar
rprenger committed
202
203
204
205
206
207
            words = []
            for token in decode_token:
                word = tokenizer.tokenizer.decoder[token]
                word = bytearray([tokenizer.tokenizer.byte_decoder[c] for c in word]).decode('utf-8', errors='replace')
                words.append(word)
            resp_sentences_seg.append(words)
208
209
210
211
        
        if logprobs:
            output_logits = output_logits.cpu().numpy().tolist()
        return resp_sentences, resp_sentences_seg, output_logits
212

213
214
215
216
217
def generate_samples_eval(model, context, max_gen_length, eos_token_id):
    """
    This function is here to provide an a matching API for a legacy task
    This implementation hasn't been tested yet to make sure it matches
    """
218
    #assert False, "Implementation untested"
219
220
221
222
    args = get_args()
    args.eos_id = eos_token_id
    raw_text_len = len(context)
    resp_sentences = generate(model, [context], max_gen_length)
223
224
    if resp_sentences:
        return resp_sentences[0][raw_text_len:]
225
226

def switch(val1, val2, boolean):
227
    boolean = boolean.type_as(val1)
Mohammad's avatar
Mohammad committed
228
    return (1 - boolean) * val1 + boolean * val2
229

230

231
def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
232
233
                 set_inference_key_value_memory=False,
                 inference_max_sequence_len=None):
234

Jared Casper's avatar
Jared Casper committed
235
236
    # Hidden size changes when not using recompute, need to tell p2p_communicate
    # functions the correct size
237
238
239
    args = get_args()
    orig_seq_length = args.seq_length
    args.seq_length = tokens.shape[1]
240
    args.micro_batch_size = tokens.shape[0]
241

Jared Casper's avatar
Jared Casper committed
242
    input_tensor = recv_forward()
243
244

    # Forward pass through the model.
245
246
247
    unwrapped_model = unwrap_model(
        model, (torchDDP, LocalDDP, Float16Module))
    unwrapped_model.set_input_tensor(input_tensor)
248
249
250
251
252
    output_tensor = model(
        tokens, position_ids, attention_mask,
        tokentype_ids=tokentype_ids,
        set_inference_key_value_memory=set_inference_key_value_memory,
        inference_max_sequence_len=inference_max_sequence_len)
253

Jared Casper's avatar
Jared Casper committed
254
    send_forward(output_tensor)
255

256
    args.seq_length = orig_seq_length
257

258
259
260
    return output_tensor


261
262
263
264
265
266
267
268
269
270
271
def sample_sequence_batch(model,
                          context_tokens,
                          context_lengths,
                          attention_mask,
                          position_ids,
                          tokens_to_generate,
                          logprobs,
                          temperature,
                          top_k,
                          top_p,
                          type_ids=None):
272
273
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
274

275
276
277
    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
278

Mostofa Patwary's avatar
Mostofa Patwary committed
279
280
        # added eos_id to support the function generate_samples_eval that passes
        # eos_id as an argument and needs termination when that id id found.
281
282
283
284
        if hasattr(args, 'eos_id'):
            eos_id = args.eos_id
        else:
            eos_id = tokenizer.eod
285
286
287
288
289
290

        counter = 0

        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
rprenger's avatar
rprenger committed
291
        output_logits = None
292
       
293
294
        # Generate enough tokens for the longest sequence
        maxlen = tokens_to_generate + context_lengths.max().item() 
295
296
297
       
        if maxlen > args.seq_length:
            maxlen = args.seq_length
298
        
Neel Kant's avatar
Neel Kant committed
299
        lengths = torch.ones([batch_size]).long().cuda() * maxlen
Mohammad's avatar
Mohammad committed
300

301
        while context_length < maxlen:
302
303
            types2use = None
            if counter == 0:
304
305
                # Allocate memory for the entire context.
                set_inference_key_value_memory = True
306
307
308
309
                tokens2use = tokens[:, :context_length]
                positions2use = position_ids[:, :context_length]
                if type_ids is not None:
                    types2use = type_ids[:, :context_length]
310
            else:
311
312
                # Set this to false so the memory is not reallocated.
                set_inference_key_value_memory = False
313
314
315
316
317
318
                tokens2use = tokens[:, context_length - 1].view(
                    batch_size, -1)
                positions2use = position_ids[:, context_length - 1].view(
                    batch_size, -1)
                if type_ids is not None:
                    types2use = type_ids[:, context_length - 1].view(
319
                        batch_size, -1)
320
321
322
323
324
325
326
327
328
            
            output = forward_step(
                model, tokens2use,
                positions2use,
                attention_mask,
                set_inference_key_value_memory=set_inference_key_value_memory,
                inference_max_sequence_len=maxlen,
                tokentype_ids=types2use)

329
330
            if mpu.is_pipeline_last_stage():
                assert output is not None
331
                output = output.float()
332
                logits = output[:, -1].view(batch_size, -1).contiguous()
333
334
335
336
337

                if args.greedy:
                    prev = torch.argmax(logits, dim=-1).view(-1)
                else:
                    logits = logits.float()
338
                    logits /= temperature
339
340
                    logits = top_k_logits(logits, top_k=top_k,
                                          top_p=top_p)
341
342
343
344
                    log_probs = F.softmax(logits, dim=-1)
                    prev = torch.multinomial(log_probs, num_samples=1).view(-1)
                started = context_lengths <= context_length

345
346
347
348
                # Clamp the out of vocabulary tokens.
                tokenizer = get_tokenizer()
                prev = torch.clamp(prev, max=tokenizer.vocab_size - 1)

349
350
351
                new_tokens = switch(
                    tokens[:, context_length].view(-1), prev, started)
                tokens[:, context_length] = new_tokens
352
353
354
355
356
357
358
359
360
361
362
363
364
               
                if logprobs:
                    if output_logits is None:
                        output_context = F.log_softmax(output[:, :context_length, :], 2)
                        indices = torch.unsqueeze(tokens[:, 1:context_length+1],2)
                        output_logits = torch.gather(output_context, 2, indices).squeeze(2)
                    else:
                        output_context = F.log_softmax(output, 2)
                        indices = torch.unsqueeze(new_tokens,1).unsqueeze(2)
                        new_output_logits = torch.gather(output_context, 2, indices).squeeze(2)
                        
                        # TODO(rprenger) we're copying output_logits every time.  Should pre-allocate
                        output_logits = torch.cat([output_logits, new_output_logits],1)
rprenger's avatar
rprenger committed
365
                
366
367
368
369
370
371
372
373
374
375
376
377
378
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_embedding_group()
                torch.distributed.broadcast(new_tokens, src, group)

                done_token = (prev == eos_id).byte() & started.byte()
                just_finished = (done_token & ~is_done).bool()
                lengths[just_finished.view(-1)] = context_length
                is_done = is_done | done_token

                done = torch.all(is_done)
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
379
                yield tokens, lengths, output_logits
380

381
            else:
382
383
384
385
386
387
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_last_rank()
                    group = mpu.get_embedding_group()
                    new_tokens = torch.empty_like(tokens[:, context_length])
                    torch.distributed.broadcast(new_tokens, src, group)
                    tokens[:, context_length] = new_tokens
388
                    yield tokens, None, None
389
                else:
390
                    yield None, None, None
391

392
393
394
395
                done = torch.cuda.ByteTensor([0])
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
396

397
398
            context_length += 1
            counter += 1
399
400
            if done:
                break