generate_samples.py 18.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Sample Generate GPT2"""

import os
import random
20
21
import json
import copy
22
23
24
25
26
27
import numpy as np
import torch
import torch.nn.functional as F
import argparse
import time
from arguments import get_args
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
28
from megatron.utils import Timers
29
30
from megatron.utils import initialize_distributed
from megatron.utils import set_random_seed
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
31
from megatron.utils import get_ltor_masks_and_position_ids
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
32
33
from megatron.utils import load_checkpoint
from megatron.data_utils import make_tokenizer
34
from configure_data import configure_data
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
35
from megatron import mpu
36

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
37
38
39
from megatron.fp16 import FP16_Module
from megatron.model import GPT2Model
from megatron.model import DistributedDataParallel as DDP
40
from megatron import print_rank_0
41

42
43

def model_provider():
44
    """Build the model."""
45
    args = get_args()
46
47

    print_rank_0('building GPT2 model ...')
48
    model = GPT2Model(num_tokentypes=0, parallel_output=False)
49
50
51
52

    return model


53
54
55
56
def get_batch(context_tokens):
    """Generate batch from context tokens."""
    args = get_args()
    tokenizer = get_tokenizer()
57

58
59
60
61
    # Move to GPU.
    tokens = context_tokens.view(args.batch_size, -1)..contiguous().cuda()
    # Get the attention mask and postition ids.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
62
        tokens,
63
        tokenizer.eod,
64
        args.reset_position_ids,
65
        args.reset_attention_mask,
66
67
        args.eod_mask_loss,
        args.fp16)
68

69
70
    return tokens, attention_mask, position_ids

71

72
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
73
74
75
76
    """ This function has been mostly taken from huggingface conversational
     ai code at
         https://medium.com/huggingface/how-to-build-a-state-of-the-art-
              conversational-ai-with-transfer-learning-2d818ac26313 """
77
78

    if top_k > 0:
79
80
        # Remove all tokens with a probability less than the
        # last token of the top-k
81
82
83
84
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value
        
    if top_p > 0.0:
85
86
87
88
89
        # Cconvert to 1D
        sorted_logits, sorted_indices = torch.sort(
            logits, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
                                        dim=-1)
90
91
92

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
93
94
95
96
        # Shift the indices to the right to keep also the first token
        # above the threshold
        sorted_indices_to_remove[..., 1:] \
            = sorted_indices_to_remove[..., :-1].clone()
97
        sorted_indices_to_remove[..., 0] = 0
98
99
100
101
        for i in range(sorted_indices.size(0)):
            indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
            logits[i][indices_to_remove] = filter_value
    
102
103
104
    return logits


105
106
107
108
def generate_samples_input_from_file(model):
    """XXX"""
    args = get_args()
    tokenizer = get_tokenizer()
109

110
111
112
    # Read the sample file and open the output file.
    assert args.sample_input_file is not None, \
        'sample input file is not provided.'
113
114
115
116
117
    if mpu.get_model_parallel_rank() == 0:
        fname = open(args.sample_input_file, "r")
        all_raw_text = fname.readlines()
        input_count = len(all_raw_text)
        input_pos = 0
118
119
120
121
122
        if args.sample_output_file is None:
            sample_output_file = args.sample_input_file + ".out"
            print('could not find `sample-output-file`, setting '
                  'it to {}'.formatsample_output_file())
        fname_out = open(sample_output_file, "w+")
123

124
125
126
127
128
129
130
131
    context_count=0
    model.eval()
    with torch.no_grad():
        while True:
            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            terminate_runs=0

            if mpu.get_model_parallel_rank() == 0:
132
133
134
135
136
137
138
139
                raw_text = all_raw_text[input_pos]
                input_pos += 1
                if input_pos == input_count:
                    raw_text = "stop"

                if "stop" in raw_text:
                    terminate_runs = 1
                else:
140
                    context_tokens = tokenizer.tokenize(raw_text)
141
142
                    context_length = len(context_tokens)

143
                    if context_length >= (args.seq_length // 2):
144
                        print("\nContext length", context_length, \
145
146
                            "\nPlease give smaller context (half of the "
                              "sequence length)!", flush=True)
147
148
                        continue
            else:
149
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
150
151
152
                context_length = len(context_tokens)
            
            terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
153
154
155
            torch.distributed.broadcast(terminate_runs_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
156
157
158
159
160
161
            terminate_runs = terminate_runs_tensor[0].item()

            if terminate_runs == 1:
                return

            start_time = time.time()
162
            token_stream = get_token_stream(model, [context_tokens])
163
164
165
166
167
168
169
            for counter, decode_tokens in enumerate(token_stream):
                decode_tokens, _ = decode_tokens
                decode_tokens = decode_tokens[0].cpu().numpy().tolist()

            if mpu.get_model_parallel_rank() == 0:
                os.system('clear')
                print("\nContext:", raw_text, flush=True)
170
171
                trim_decode_tokens = tokenizer.detokenize(
                    decode_tokens)[len(raw_text):]
172
173
174
175
176
177
178
179
180
181
182
183
184
185
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)

                fname_out.write("\nContext:")
                fname_out.write(raw_text)
                fname_out.write("\n\nMegatron-LM:")
                fname_out.write(trim_decode_tokens)
                fname_out.write("\n")
 
            raw_text = None

            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            context_count += 1
            

186
187
188
189
def generate_samples_interactive(model, print_frequency=24):
    """XXX"""
    args = get_args()
    tokenizer = get_tokenizer()
190
191
192
193
194
195
196
197
198
199

    context_count=0
    model.eval()
    with torch.no_grad():
        while True:
            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            terminate_runs=0

            if mpu.get_model_parallel_rank() == 0:
                os.system('clear')
200
201
202
203
204
205
206
207
                raw_text = input("\nContext prompt (stop to exit) >>> ")
                while not raw_text:
                    print('Prompt should not be empty!')
                    raw_text = input("\nContext prompt (stop to exit) >>> ")
           
                if "stop" in raw_text:
                    terminate_runs = 1
                else:
208
                    context_tokens = tokenizer.tokenize(raw_text)
209
210
                    context_length = len(context_tokens)

211
                    if context_length >= (args.seq_length // 2):
212
                        print("\nContext length", context_length, \
213
214
                            "\nPlease give smaller context (half of the "
                              "sequence length)!", flush=True)
215
216
                        continue
            else:
217
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
218
219
220
                context_length = len(context_tokens)
            
            terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
221
222
223
            torch.distributed.broadcast(terminate_runs_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
224
225
226
227
228
229
            terminate_runs = terminate_runs_tensor[0].item()

            if terminate_runs == 1:
                return

            start_time = time.time()
230
            token_stream = get_token_stream(model, [context_tokens])
231
232
233
234
            for counter, decode_tokens in enumerate(token_stream):
                decode_tokens, _ = decode_tokens
                decode_tokens = decode_tokens[0].cpu().numpy().tolist()

235
236
                if mpu.get_model_parallel_rank() == 0 and \
                   counter % print_frequency == 0:
237
238
                    os.system('clear')
                    print("\nContext:", raw_text, flush=True)
239
240
                    trim_decode_tokens = tokenizer.detokenize(
                        decode_tokens)[len(raw_text):]
241
                    print("\nMegatron-LM:", trim_decode_tokens, flush=True)
242
243
244
245

            if mpu.get_model_parallel_rank() == 0:
                os.system('clear')
                print("\nContext:", raw_text, flush=True)
246
247
                trim_decode_tokens = tokenizer.detokenize(
                    decode_tokens)[len(raw_text):]
248
249
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)

250
251
252
            raw_text = None
            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            context_count += 1
253
254
255
256
            
            if mpu.get_model_parallel_rank() == 0:
                input("\nPress any key to continue >>>")

257
258
259
260
261
262

def generate_samples_unconditional(model):
    """XXX"""
    args = get_args()
    tokenizer = get_tokenizer()
    
263
    num_samples = args.num_samples
264
265
    context_tokens = [[tokenizer.eod]
                      for _ in range(args.batch_size)]
266
267
268
269
    samples = []
    ctr = 0
    while True:
        start_time = time.time()
270
271
        for token_stream in get_token_stream(model,
                                             copy.deepcopy(context_tokens)):
272
273
            pass
        if ctr%args.log_interval == 0:
274
275
            print('Avg s/batch:',
                  (time.time() - start_time) / min(args.log_interval, ctr + 1))
276
277
278
279
280
281
            start_time = time.time()
        length = len(token_stream)
        token_batch = token_stream[0].cpu().numpy().tolist()
        length_batch = token_stream[1].cpu().numpy().tolist()
        for tokens, length in zip(token_batch, length_batch):
            tokens = tokens[1:length-1]
282
            text = tokenizer.detokenize(tokens)
283
284
285
286
287
288
289
290
291
            is_finished = length < args.seq_length - 1
            datum = {'text': text, 'length': length-1, 'finished': is_finished}
            yield datum
            ctr += 1
            if ctr >= num_samples:
                break
        if ctr >= num_samples:
            break

292
293
294

def write_and_generate_samples_unconditional(model):
    args = get_args()
295
296
    assert args.genfile is not None
    with open(args.genfile, 'w') as f:
297
        for datum in generate_samples_unconditional(model):
298
299
            f.write(json.dumps(datum)+'\n')

300

301
def pad_batch(batch, tokenizer, args):
302
    pad_id = tokenizer.eod
303
304
305
306
    context_lengths = []
    for tokens in batch:
        context_length = len(tokens)
        if context_length < args.seq_length:
307
            tokens.extend([pad_id]*(args.seq_length - context_length))
308
309
310
        context_lengths.append(context_length)
    return batch, context_lengths

311
312
313
314
315
316

def get_token_stream(model, context_tokens):
    args = get_args()
    tokenizer = get_tokenizer()
    
    pad_id = tokenizer.eod
317
318
319
320
321
    context_tokens, context_lengths = pad_batch(context_tokens, tokenizer, args)

    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    context_length_tensor = torch.cuda.LongTensor(context_lengths)

322
323
324
325
326
327
    torch.distributed.broadcast(context_length_tensor,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    torch.distributed.broadcast(context_tokens_tensor,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
328
329
330
331
332
333
334
335
336

    context_length = context_length_tensor.min().item()
    tokens, attention_mask, position_ids=get_batch(context_tokens_tensor, args)

    counter = 0
    org_context_length = context_length

    layer_past = None

337
338
339
    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids)
340
341
342
343
344
345
346
347
348
    for tokens, lengths in batch_token_iterator:
        context_length += 1
        yield tokens[:, :context_length], lengths


def switch(val1, val2, boolean):
    boolean = boolean.type_as(val1)
    return (1-boolean)*val1 + boolean*val2

349
350
351
352
353
354
355
356

def sample_sequence_batch(model, context_tokens, context_lengths,
                          attention_mask, position_ids,
                          maxlen=None, type_ids=None):
    """XXX"""
    args = get_args()
    tokenizer = get_tokenizer()
    
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
        eos_id = tokenizer.get_command('eos').Id

        counter = 0
        org_context_length = context_length

        layer_past = None
        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
        if maxlen is None:
            maxlen = args.seq_length - 1
            if maxlen > (org_context_length + args.out_seq_length):
                maxlen = org_context_length + args.out_seq_length

        lengths = torch.ones([batch_size]).long().cuda()*maxlen
        
        while context_length <= (maxlen):

            if args.recompute:
379
380
381
382
383
                logits = model(tokens,
                               position_ids,
                               attention_mask,
                               tokentype_ids=type_ids,
                               forward_method_parallel_output=False)
384
                logits = logits[:, context_length - 1, :]
385
            else:
386
                types2use = None
387
388
389
                if counter == 0:
                    tokens2use = tokens[:, :context_length]
                    positions2use = position_ids[:, :context_length]
390
391
                    if type_ids is not None:
                        types2use = type_ids[:, :context_length]
392
                else:
393
394
395
396
                    tokens2use = tokens[:, context_length - 1].view(
                        batch_size, -1)
                    positions2use = position_ids[:, context_length - 1].view(
                        batch_size, -1)
397
                    if type_ids is not None:
398
399
400
401
402
403
404
405
406
                        types2use = type_ids[:, context_length - 1].view(
                            batch_size, -1)
                logits, layer_past = model(tokens2use,
                                           positions2use,
                                           attention_mask,
                                           layer_past=layer_past,
                                           get_key_value=True,
                                           tokentype_ids=types2use,
                                           forward_method_parallel_output=False)
407
408
409
410
411
                logits = logits[:, -1].view(batch_size,-1).contiguous()

            if args.greedy:
                prev = torch.argmax(logits, dim=-1).view(-1)
            else:
Raul Puri's avatar
Raul Puri committed
412
                logits = logits.float()
413
                logits /= args.temperature
414
415
                logits = top_k_logits(logits, top_k=args.top_k,
                                      top_p=args.top_p)
416
417
418
419
420
                log_probs = F.softmax(logits, dim=-1)
                prev = torch.multinomial(log_probs, num_samples=1).view(-1)

            print_logits = []
            for p in prev:
421
422
                print_logits.append([logits[i, p].item()
                                     for i in range(batch_size)])
423
            started = context_lengths <= context_length
424
425
            tokens[:, context_length] = switch(
                tokens[:, context_length].view(-1), prev, started)
426
427
428
            context_length += 1
            counter += 1

429
            done_token = (prev == eos_id).byte() & started.byte()
430
431
432
433
434
435
436
437
438
            just_finished = (done_token & ~is_done).bool()
            lengths[just_finished.view(-1)] = context_length
            was_done = is_done
            is_done = is_done | done_token
            done = torch.all(is_done)

            yield tokens, lengths
            if done:
                break
439

440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
def add_text_generate_args(parser):
    """Text generate arguments."""

    group = parser.add_argument_group('Text generation', 'configurations')
    group.add_argument("--temperature", type=float, default=1.0)
    group.add_argument("--greedy", action='store_true', default=False)
    group.add_argument("--top_p", type=float, default=0.0)
    group.add_argument("--top_k", type=int, default=0)
    group.add_argument("--out-seq-length", type=int, default=1024)
    group.add_argument("--sample-input-file", type=str, default=None,
                      help='get input from file instead of interactive mode, '
                           'each line is an input' )
    group.add_argument("--sample-output-file", type=str, default=None,
                      help='output file got from --sample-input-file')
    group.add_argument("--num-samples", type=int, default=0,
                       help='number of samples to generate unconditionally, '
                       'defaults to 0 and interactive conditional sampling')
    group.add_argument("--genfile", type=str,
                       help='output file when generating unconditionally')
    group.add_argument("--recompute", action='store_true',
                       help='during generation recompute all attention '
                       'instead of using previously computed keys/values.')
    return parser
463
464
465


def main():
466
    """Main program."""
467
468
469

    print('Generate Samples')

470
471
    initialize_megatron(extra_args_provider=add_text_generate_args,
                        args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
472

473
474
475
476
477
    # Set up model and load checkpoint.
    model = get_model(model_provider)
    if args.load is not None:
        _ = load_checkpoint(model, None, None)
    
478
    #generate samples
479
480
481
    if args.num_samples == 0:
        args.batch_size = 1
        if args.sample_input_file != "":
482
            generate_samples_input_from_file(model)
483
        else:
484
            generate_samples_interactive(model)
485
    else:
486
        write_and_generate_samples_unconditional(model)
487
488
489
490
491
492
493
    

if __name__ == "__main__":
    main()