text_generation_utils.py 14.8 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Utilities for generating text."""
17

18
import copy
Mohammad's avatar
Mohammad committed
19
20
21
22
import json
import os
import time

23
24
25
import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
26
27
28
from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
29
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
Jared Casper's avatar
Jared Casper committed
30
from megatron.p2p_communication import recv_forward, send_forward
31

32
33
34
35
36
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module

37
38
39
40
def get_batch(context_tokens):
    """Generate batch from context tokens."""
    args = get_args()
    tokenizer = get_tokenizer()
41

42
    # Move to GPU.
43
44
    tokens = context_tokens.contiguous().cuda()
    
45
46
    # Get the attention mask and postition ids.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
47
        tokens,
48
        tokenizer.eod,
49
        args.reset_position_ids,
50
        args.reset_attention_mask,
51
        args.eod_mask_loss)
52

53
54
    return tokens, attention_mask, position_ids

55

56
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
57
58
59
60
    """ This function has been mostly taken from huggingface conversational
     ai code at
         https://medium.com/huggingface/how-to-build-a-state-of-the-art-
              conversational-ai-with-transfer-learning-2d818ac26313 """
61
62

    if top_k > 0:
63
64
        # Remove all tokens with a probability less than the
        # last token of the top-k
65
66
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
67

68
    if top_p > 0.0:
69
70
71
72
73
        # Cconvert to 1D
        sorted_logits, sorted_indices = torch.sort(
            logits, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
                                        dim=-1)
74
75
76

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
77
78
79
80
        # Shift the indices to the right to keep also the first token
        # above the threshold
        sorted_indices_to_remove[..., 1:] \
            = sorted_indices_to_remove[..., :-1].clone()
81
        sorted_indices_to_remove[..., 0] = 0
82
83
84
        for i in range(sorted_indices.size(0)):
            indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
            logits[i][indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
85

86
87
    return logits

Mohammad's avatar
Mohammad committed
88
def pad_batch(batch, pad_id, args):
89
90
91
92
    context_lengths = []
    for tokens in batch:
        context_length = len(tokens)
        if context_length < args.seq_length:
Neel Kant's avatar
Neel Kant committed
93
            tokens.extend([pad_id] * (args.seq_length - context_length))
94
95
96
        context_lengths.append(context_length)
    return batch, context_lengths

97
98
99
100
101
102
103
104
105
106
def tokenize_batch(sentences):
    args = get_args()
    tokenizer = get_tokenizer()
    context_tokens = [tokenizer.tokenize(s) for s in sentences]
    context_tokens, context_lengths = pad_batch(context_tokens,
                                                tokenizer.eod, args)
    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    context_length_tensor = torch.cuda.LongTensor(context_lengths)
    return context_tokens_tensor, context_length_tensor 

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def send_generate_info(context_tokens_tensor, context_length_tensor, max_len):
    """
    Needs to be synced up with receive_generate_info
    """
    # Send the sizes of the tensors
    input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), max_len]
    input_info_tensor = torch.cuda.LongTensor(input_info)
    torch.distributed.broadcast(input_info_tensor, 0)

    # Send variables to all ranks 
    torch.distributed.broadcast(context_length_tensor, 0)
    torch.distributed.broadcast(context_tokens_tensor, 0)

def receive_generate_info():
    """
    Needs to be synced up with send_generate_info
    """
    input_info_tensor = torch.empty(3, dtype=torch.int64, device=torch.device("cuda"))
    torch.distributed.broadcast(input_info_tensor, 0)
    batch_size = input_info_tensor[0].item()
    seq_len = input_info_tensor[1].item()
    max_len = input_info_tensor[2].item()
    
    context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.device("cuda"))
    context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.device("cuda"))
    
    # Send variables to all ranks 
    torch.distributed.broadcast(context_length_tensor, 0)
    torch.distributed.broadcast(context_tokens_tensor, 0)
    
    return context_length_tensor, context_tokens_tensor, max_len

139
140
141
142
143
144
145
146
def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len):
    context_length = context_length_tensor.min().item()
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)

    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids,
                                                 max_len)
rprenger's avatar
rprenger committed
147
    for tokens, lengths, output_logits in batch_token_iterator:
148
        context_length += 1
rprenger's avatar
rprenger committed
149
150
151
152
153
154
155
156
157
158
159
160
                
    if mpu.is_pipeline_last_stage():
        src = mpu.get_pipeline_model_parallel_last_rank()
        group = mpu.get_embedding_group()
        torch.distributed.broadcast(output_logits, src, group)
    else:
        if mpu.is_pipeline_first_stage():
            src = mpu.get_pipeline_model_parallel_last_rank()
            group = mpu.get_embedding_group()
            output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda"))
            torch.distributed.broadcast(output_logits, src, group)
        
161
    if tokens is not None:
rprenger's avatar
rprenger committed
162
        return tokens[:, :context_length], output_logits 
163
164
165
166

def generate(model, sentences=None, max_len=0):
    if torch.distributed.get_rank() == 0:
        context_tokens_tensor, context_length_tensor = tokenize_batch(sentences)
167
168
169
        c = context_length_tensor[0]
        b = context_tokens_tensor.size(0)
        start = time.time()
170
171
172
173
        send_generate_info(context_tokens_tensor, context_length_tensor, max_len)
    else:
        context_length_tensor, context_tokens_tensor, max_len = receive_generate_info()
    
rprenger's avatar
rprenger committed
174
175
176
177
    output = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len)
    if output is not None:
        decode_tokens, output_logits = output

178
179
180
181
    if torch.distributed.get_rank() == 0:
        args = get_args()
        tokenizer = get_tokenizer()
        resp_sentences = []
rprenger's avatar
rprenger committed
182
        resp_sentences_seg = []
183
184
185
        for i in range(decode_tokens.size(0)):
            decode_token = decode_tokens[i,:].cpu().numpy().tolist()
            resp_sentences.append(tokenizer.detokenize(decode_token))
rprenger's avatar
rprenger committed
186
187
188
189
190
191
192
193
            words = []
            for token in decode_token:
                word = tokenizer.tokenizer.decoder[token]
                word = bytearray([tokenizer.tokenizer.byte_decoder[c] for c in word]).decode('utf-8', errors='replace')
                words.append(word)
            resp_sentences_seg.append(words)

        output_logits = output_logits.cpu().numpy().tolist()
194
195
        end = time.time()
        print(str(b)+","+str(c)+","+str(decode_tokens.size(1))+","+str(end-start), flush=True)
rprenger's avatar
rprenger committed
196
        return resp_sentences, resp_sentences_seg, output_logits 
197
198

def switch(val1, val2, boolean):
199
    boolean = boolean.type_as(val1)
Mohammad's avatar
Mohammad committed
200
    return (1 - boolean) * val1 + boolean * val2
201

202

203
204
205
206
def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
                 layer_past=None, get_key_value=None,
                 forward_method_parallel_output=None):

Jared Casper's avatar
Jared Casper committed
207
208
    # Hidden size changes when not using recompute, need to tell p2p_communicate
    # functions the correct size
209
210
211
    args = get_args()
    orig_seq_length = args.seq_length
    args.seq_length = tokens.shape[1]
212
    args.micro_batch_size = tokens.shape[0]
213

Jared Casper's avatar
Jared Casper committed
214
    input_tensor = recv_forward()
215
216

    # Forward pass through the model.
217
218
219
    unwrapped_model = unwrap_model(
        model, (torchDDP, LocalDDP, Float16Module))
    unwrapped_model.set_input_tensor(input_tensor)
Jared Casper's avatar
Jared Casper committed
220
221
222
223
224
    output_tensor = model(tokens, position_ids, attention_mask,
                          tokentype_ids=tokentype_ids,
                          layer_past=layer_past,
                          get_key_value=get_key_value,
                          forward_method_parallel_output=forward_method_parallel_output)
225
226
227
228

    if get_key_value:
        output_tensor, layer_past = output_tensor

Jared Casper's avatar
Jared Casper committed
229
    send_forward(output_tensor)
230

231
    args.seq_length = orig_seq_length
232
233
234
235
236
    if get_key_value:
        return output_tensor, layer_past
    return output_tensor


237
238
239
240
241
def sample_sequence_batch(model, context_tokens, context_lengths,
                          attention_mask, position_ids,
                          maxlen=None, type_ids=None):
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
242

243
244
245
    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
246

Mostofa Patwary's avatar
Mostofa Patwary committed
247
248
        # added eos_id to support the function generate_samples_eval that passes
        # eos_id as an argument and needs termination when that id id found.
249
250
251
252
        if hasattr(args, 'eos_id'):
            eos_id = args.eos_id
        else:
            eos_id = tokenizer.eod
253
254
255
256
257
258
259
260

        counter = 0
        org_context_length = context_length

        layer_past = None
        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
rprenger's avatar
rprenger committed
261
262
        output_logits = None

263
264
        if maxlen is None:
            maxlen = args.seq_length - 1
265
266
267
268
269
270
        
        maxlen = maxlen + org_context_length
        
        if maxlen > (org_context_length + args.out_seq_length):
            maxlen = org_context_length + args.out_seq_length
        
Neel Kant's avatar
Neel Kant committed
271
        lengths = torch.ones([batch_size]).long().cuda() * maxlen
Mohammad's avatar
Mohammad committed
272

273
        while context_length < maxlen:
274
275
276
277
278
279
            types2use = None
            if counter == 0:
                tokens2use = tokens[:, :context_length]
                positions2use = position_ids[:, :context_length]
                if type_ids is not None:
                    types2use = type_ids[:, :context_length]
280
            else:
281
282
283
284
285
286
                tokens2use = tokens[:, context_length - 1].view(
                    batch_size, -1)
                positions2use = position_ids[:, context_length - 1].view(
                    batch_size, -1)
                if type_ids is not None:
                    types2use = type_ids[:, context_length - 1].view(
287
                        batch_size, -1)
rprenger's avatar
rprenger committed
288
            
289
290
291
292
293
294
295
296
297
298
            output, layer_past = forward_step(model, tokens2use,
                                              positions2use,
                                              attention_mask,
                                              layer_past=layer_past,
                                              get_key_value=True,
                                              tokentype_ids=types2use,
                                              forward_method_parallel_output=False)
            if mpu.is_pipeline_last_stage():
                assert output is not None
                logits = output[:, -1].view(batch_size, -1).contiguous()
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315

            if mpu.is_pipeline_last_stage():
                if args.greedy:
                    prev = torch.argmax(logits, dim=-1).view(-1)
                else:
                    logits = logits.float()
                    logits /= args.temperature
                    logits = top_k_logits(logits, top_k=args.top_k,
                                          top_p=args.top_p)
                    log_probs = F.softmax(logits, dim=-1)
                    prev = torch.multinomial(log_probs, num_samples=1).view(-1)

                started = context_lengths <= context_length

                new_tokens = switch(
                    tokens[:, context_length].view(-1), prev, started)
                tokens[:, context_length] = new_tokens
rprenger's avatar
rprenger committed
316
317
318
                
                if output_logits is None:
                    output_context = F.log_softmax(output[:, :context_length, :], 2)
319
                    indices = torch.unsqueeze(tokens[:, 1:context_length+1],2)
rprenger's avatar
rprenger committed
320
321
322
323
324
325
326
327
328
                    output_logits = torch.gather(output_context, 2, indices).squeeze(2)
                else:
                    indices = torch.unsqueeze(new_tokens,1).unsqueeze(2)
                    new_output_logits = torch.gather(F.log_softmax(output,2), 2, indices).squeeze(2)
                    
                    # TODO(rprenger) we're copying output_logits every time.  Should pre-allocate
                    output_logits = torch.cat([output_logits, new_output_logits],1)
                
                #output_logits = torch.cat([output_logits, output[:,context_length,new_tokens]], 1)
329
330
331
332
333
334
335
336
337
338
339
340
341
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_embedding_group()
                torch.distributed.broadcast(new_tokens, src, group)

                done_token = (prev == eos_id).byte() & started.byte()
                just_finished = (done_token & ~is_done).bool()
                lengths[just_finished.view(-1)] = context_length
                is_done = is_done | done_token

                done = torch.all(is_done)
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
rprenger's avatar
rprenger committed
342
                yield tokens, lengths, output_logits
343

344
            else:
345
346
347
348
349
350
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_last_rank()
                    group = mpu.get_embedding_group()
                    new_tokens = torch.empty_like(tokens[:, context_length])
                    torch.distributed.broadcast(new_tokens, src, group)
                    tokens[:, context_length] = new_tokens
rprenger's avatar
rprenger committed
351
                    yield tokens, None, None
352
                else:
rprenger's avatar
rprenger committed
353
                    yield None, None, None
354

355
356
357
358
                done = torch.cuda.ByteTensor([0])
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
359

360
361
            context_length += 1
            counter += 1
362
363
            if done:
                break