generate_samples.py 18.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Sample Generate GPT2"""

18
import copy
Mohammad's avatar
Mohammad committed
19
20
21
22
import json
import os
import time

23
24
25
import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
26
27
28
from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
29
from megatron import print_rank_0
Mohammad's avatar
Mohammad committed
30
31
32
33
34
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPT2Model
from megatron.training import get_model
from megatron.utils import get_ltor_masks_and_position_ids
35

36
37

def model_provider():
38
39
40
    """Build the model."""

    print_rank_0('building GPT2 model ...')
41
    model = GPT2Model(num_tokentypes=0, parallel_output=False)
42
43
44
45

    return model


46
47
48
49
def get_batch(context_tokens):
    """Generate batch from context tokens."""
    args = get_args()
    tokenizer = get_tokenizer()
50

51
    # Move to GPU.
Mohammad's avatar
Mohammad committed
52
    tokens = context_tokens.view(args.batch_size, -1).contiguous().cuda()
53
54
    # Get the attention mask and postition ids.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
55
        tokens,
56
        tokenizer.eod,
57
        args.reset_position_ids,
58
        args.reset_attention_mask,
59
60
        args.eod_mask_loss,
        args.fp16)
61

62
63
    return tokens, attention_mask, position_ids

64

65
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
66
67
68
69
    """ This function has been mostly taken from huggingface conversational
     ai code at
         https://medium.com/huggingface/how-to-build-a-state-of-the-art-
              conversational-ai-with-transfer-learning-2d818ac26313 """
70
71

    if top_k > 0:
72
73
        # Remove all tokens with a probability less than the
        # last token of the top-k
74
75
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
76

77
    if top_p > 0.0:
78
79
80
81
82
        # Cconvert to 1D
        sorted_logits, sorted_indices = torch.sort(
            logits, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
                                        dim=-1)
83
84
85

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
86
87
88
89
        # Shift the indices to the right to keep also the first token
        # above the threshold
        sorted_indices_to_remove[..., 1:] \
            = sorted_indices_to_remove[..., :-1].clone()
90
        sorted_indices_to_remove[..., 0] = 0
91
92
93
        for i in range(sorted_indices.size(0)):
            indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
            logits[i][indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
94

95
96
97
    return logits


98
def generate_samples_input_from_file(model):
Mohammad's avatar
Mohammad committed
99

100
101
    args = get_args()
    tokenizer = get_tokenizer()
102

103
104
105
    # Read the sample file and open the output file.
    assert args.sample_input_file is not None, \
        'sample input file is not provided.'
106
107
108
109
110
    if mpu.get_model_parallel_rank() == 0:
        fname = open(args.sample_input_file, "r")
        all_raw_text = fname.readlines()
        input_count = len(all_raw_text)
        input_pos = 0
111
112
113
        if args.sample_output_file is None:
            sample_output_file = args.sample_input_file + ".out"
            print('could not find `sample-output-file`, setting '
Mohammad's avatar
Mohammad committed
114
                  'it to {}'.format(sample_output_file))
115
        fname_out = open(sample_output_file, "w+")
116

Mohammad's avatar
Mohammad committed
117
    context_count = 0
118
119
120
121
    model.eval()
    with torch.no_grad():
        while True:
            torch.distributed.barrier(group=mpu.get_model_parallel_group())
Mohammad's avatar
Mohammad committed
122
            terminate_runs = 0
123
124

            if mpu.get_model_parallel_rank() == 0:
125
126
127
128
129
130
131
132
                raw_text = all_raw_text[input_pos]
                input_pos += 1
                if input_pos == input_count:
                    raw_text = "stop"

                if "stop" in raw_text:
                    terminate_runs = 1
                else:
133
                    context_tokens = tokenizer.tokenize(raw_text)
134
135
                    context_length = len(context_tokens)

136
                    if context_length >= (args.seq_length // 2):
137
                        print("\nContext length", context_length, \
138
139
                            "\nPlease give smaller context (half of the "
                              "sequence length)!", flush=True)
140
141
                        continue
            else:
142
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
143
                context_length = len(context_tokens)
Mohammad's avatar
Mohammad committed
144

145
            terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
146
147
148
            torch.distributed.broadcast(terminate_runs_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
149
150
151
152
153
            terminate_runs = terminate_runs_tensor[0].item()

            if terminate_runs == 1:
                return

154
            token_stream = get_token_stream(model, [context_tokens])
Mohammad's avatar
Mohammad committed
155
            for _, decode_tokens in enumerate(token_stream):
156
157
158
159
160
161
                decode_tokens, _ = decode_tokens
                decode_tokens = decode_tokens[0].cpu().numpy().tolist()

            if mpu.get_model_parallel_rank() == 0:
                os.system('clear')
                print("\nContext:", raw_text, flush=True)
162
163
                trim_decode_tokens = tokenizer.detokenize(
                    decode_tokens)[len(raw_text):]
164
165
166
167
168
169
170
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)

                fname_out.write("\nContext:")
                fname_out.write(raw_text)
                fname_out.write("\n\nMegatron-LM:")
                fname_out.write(trim_decode_tokens)
                fname_out.write("\n")
Mohammad's avatar
Mohammad committed
171

172
173
174
175
            raw_text = None

            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            context_count += 1
Mohammad's avatar
Mohammad committed
176

177

178
def generate_samples_interactive(model, print_frequency=24):
Mohammad's avatar
Mohammad committed
179

180
181
    args = get_args()
    tokenizer = get_tokenizer()
182

Mohammad's avatar
Mohammad committed
183
    context_count = 0
184
185
186
187
    model.eval()
    with torch.no_grad():
        while True:
            torch.distributed.barrier(group=mpu.get_model_parallel_group())
Mohammad's avatar
Mohammad committed
188
            terminate_runs = 0
189
190
191

            if mpu.get_model_parallel_rank() == 0:
                os.system('clear')
192
193
194
195
                raw_text = input("\nContext prompt (stop to exit) >>> ")
                while not raw_text:
                    print('Prompt should not be empty!')
                    raw_text = input("\nContext prompt (stop to exit) >>> ")
Mohammad's avatar
Mohammad committed
196

197
198
199
                if "stop" in raw_text:
                    terminate_runs = 1
                else:
200
                    context_tokens = tokenizer.tokenize(raw_text)
201
202
                    context_length = len(context_tokens)

203
                    if context_length >= (args.seq_length // 2):
204
                        print("\nContext length", context_length, \
205
206
                            "\nPlease give smaller context (half of the "
                              "sequence length)!", flush=True)
207
208
                        continue
            else:
209
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
210
                context_length = len(context_tokens)
Mohammad's avatar
Mohammad committed
211

212
            terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
213
214
215
            torch.distributed.broadcast(terminate_runs_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
216
217
218
219
220
            terminate_runs = terminate_runs_tensor[0].item()

            if terminate_runs == 1:
                return

221
            token_stream = get_token_stream(model, [context_tokens])
222
223
224
225
            for counter, decode_tokens in enumerate(token_stream):
                decode_tokens, _ = decode_tokens
                decode_tokens = decode_tokens[0].cpu().numpy().tolist()

226
227
                if mpu.get_model_parallel_rank() == 0 and \
                   counter % print_frequency == 0:
228
229
                    os.system('clear')
                    print("\nContext:", raw_text, flush=True)
230
231
                    trim_decode_tokens = tokenizer.detokenize(
                        decode_tokens)[len(raw_text):]
232
                    print("\nMegatron-LM:", trim_decode_tokens, flush=True)
233
234
235
236

            if mpu.get_model_parallel_rank() == 0:
                os.system('clear')
                print("\nContext:", raw_text, flush=True)
237
238
                trim_decode_tokens = tokenizer.detokenize(
                    decode_tokens)[len(raw_text):]
239
240
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)

241
242
243
            raw_text = None
            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            context_count += 1
Mohammad's avatar
Mohammad committed
244

245
246
247
            if mpu.get_model_parallel_rank() == 0:
                input("\nPress any key to continue >>>")

248
249

def generate_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
250

251
252
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
253

254
    num_samples = args.num_samples
255
256
    context_tokens = [[tokenizer.eod]
                      for _ in range(args.batch_size)]
257
258
259
    ctr = 0
    while True:
        start_time = time.time()
260
261
        for token_stream in get_token_stream(model,
                                             copy.deepcopy(context_tokens)):
262
263
            pass
        if ctr%args.log_interval == 0:
264
265
            print('Avg s/batch:',
                  (time.time() - start_time) / min(args.log_interval, ctr + 1))
266
267
268
269
270
271
            start_time = time.time()
        length = len(token_stream)
        token_batch = token_stream[0].cpu().numpy().tolist()
        length_batch = token_stream[1].cpu().numpy().tolist()
        for tokens, length in zip(token_batch, length_batch):
            tokens = tokens[1:length-1]
272
            text = tokenizer.detokenize(tokens)
273
274
275
276
277
278
279
280
281
            is_finished = length < args.seq_length - 1
            datum = {'text': text, 'length': length-1, 'finished': is_finished}
            yield datum
            ctr += 1
            if ctr >= num_samples:
                break
        if ctr >= num_samples:
            break

282
283

def write_and_generate_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
284

285
    args = get_args()
286
287
    assert args.genfile is not None
    with open(args.genfile, 'w') as f:
288
        for datum in generate_samples_unconditional(model):
289
290
            f.write(json.dumps(datum)+'\n')

291

Mohammad's avatar
Mohammad committed
292
293
def pad_batch(batch, pad_id, args):

294
295
296
297
    context_lengths = []
    for tokens in batch:
        context_length = len(tokens)
        if context_length < args.seq_length:
298
            tokens.extend([pad_id]*(args.seq_length - context_length))
299
300
301
        context_lengths.append(context_length)
    return batch, context_lengths

302
303

def get_token_stream(model, context_tokens):
Mohammad's avatar
Mohammad committed
304

305
306
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
307
308
309

    context_tokens, context_lengths = pad_batch(context_tokens,
                                                tokenizer.eod, args)
310
311
312
313

    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    context_length_tensor = torch.cuda.LongTensor(context_lengths)

314
315
316
317
318
319
    torch.distributed.broadcast(context_length_tensor,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    torch.distributed.broadcast(context_tokens_tensor,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
320
321

    context_length = context_length_tensor.min().item()
Mohammad's avatar
Mohammad committed
322
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
323

324
325
326
    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids)
327
328
329
330
331
332
    for tokens, lengths in batch_token_iterator:
        context_length += 1
        yield tokens[:, :context_length], lengths


def switch(val1, val2, boolean):
Mohammad's avatar
Mohammad committed
333

334
    boolean = boolean.type_as(val1)
Mohammad's avatar
Mohammad committed
335
    return (1 - boolean) * val1 + boolean * val2
336

337
338
339
340

def sample_sequence_batch(model, context_tokens, context_lengths,
                          attention_mask, position_ids,
                          maxlen=None, type_ids=None):
Mohammad's avatar
Mohammad committed
341

342
343
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
344

345
346
347
    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
Mohammad's avatar
Mohammad committed
348
        eos_id = tokenizer.eod
349
350
351
352
353
354
355
356
357
358
359
360
361
362

        counter = 0
        org_context_length = context_length

        layer_past = None
        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
        if maxlen is None:
            maxlen = args.seq_length - 1
            if maxlen > (org_context_length + args.out_seq_length):
                maxlen = org_context_length + args.out_seq_length

        lengths = torch.ones([batch_size]).long().cuda()*maxlen
Mohammad's avatar
Mohammad committed
363

364
365
366
        while context_length <= (maxlen):

            if args.recompute:
367
368
369
370
371
                logits = model(tokens,
                               position_ids,
                               attention_mask,
                               tokentype_ids=type_ids,
                               forward_method_parallel_output=False)
372
                logits = logits[:, context_length - 1, :]
373
            else:
374
                types2use = None
375
376
377
                if counter == 0:
                    tokens2use = tokens[:, :context_length]
                    positions2use = position_ids[:, :context_length]
378
379
                    if type_ids is not None:
                        types2use = type_ids[:, :context_length]
380
                else:
381
382
383
384
                    tokens2use = tokens[:, context_length - 1].view(
                        batch_size, -1)
                    positions2use = position_ids[:, context_length - 1].view(
                        batch_size, -1)
385
                    if type_ids is not None:
386
387
388
389
390
391
392
393
394
                        types2use = type_ids[:, context_length - 1].view(
                            batch_size, -1)
                logits, layer_past = model(tokens2use,
                                           positions2use,
                                           attention_mask,
                                           layer_past=layer_past,
                                           get_key_value=True,
                                           tokentype_ids=types2use,
                                           forward_method_parallel_output=False)
Mohammad's avatar
Mohammad committed
395
                logits = logits[:, -1].view(batch_size, -1).contiguous()
396
397
398
399

            if args.greedy:
                prev = torch.argmax(logits, dim=-1).view(-1)
            else:
Raul Puri's avatar
Raul Puri committed
400
                logits = logits.float()
401
                logits /= args.temperature
402
403
                logits = top_k_logits(logits, top_k=args.top_k,
                                      top_p=args.top_p)
404
405
406
407
408
                log_probs = F.softmax(logits, dim=-1)
                prev = torch.multinomial(log_probs, num_samples=1).view(-1)

            print_logits = []
            for p in prev:
409
410
                print_logits.append([logits[i, p].item()
                                     for i in range(batch_size)])
411
            started = context_lengths <= context_length
412
413
            tokens[:, context_length] = switch(
                tokens[:, context_length].view(-1), prev, started)
414
415
416
            context_length += 1
            counter += 1

417
            done_token = (prev == eos_id).byte() & started.byte()
418
419
420
421
422
423
424
425
            just_finished = (done_token & ~is_done).bool()
            lengths[just_finished.view(-1)] = context_length
            is_done = is_done | done_token
            done = torch.all(is_done)

            yield tokens, lengths
            if done:
                break
426

427
def add_text_generate_args(parser):
Mohammad's avatar
Mohammad committed
428
429
430
431
432
433
434
435
436
437
438
439
440
    """Text generation arguments."""
    group = parser.add_argument_group(title='text generation')

    group.add_argument("--temperature", type=float, default=1.0,
                       help='Sampling temperature.')
    group.add_argument("--greedy", action='store_true', default=False,
                       help='Use greedy sampling.')
    group.add_argument("--top_p", type=float, default=0.0,
                       help='Top p sampling.')
    group.add_argument("--top_k", type=int, default=0,
                       help='Top k sampling.')
    group.add_argument("--out-seq-length", type=int, default=1024,
                       help='Size of the output generated text.')
441
    group.add_argument("--sample-input-file", type=str, default=None,
Mohammad's avatar
Mohammad committed
442
443
                       help='Get input from file instead of interactive mode, '
                       'each line is an input.')
444
    group.add_argument("--sample-output-file", type=str, default=None,
Mohammad's avatar
Mohammad committed
445
                       help='Output file got from --sample-input-file')
446
    group.add_argument("--num-samples", type=int, default=0,
Mohammad's avatar
Mohammad committed
447
                       help='Number of samples to generate unconditionally, '
448
449
                       'defaults to 0 and interactive conditional sampling')
    group.add_argument("--genfile", type=str,
Mohammad's avatar
Mohammad committed
450
                       help='Output file when generating unconditionally')
451
    group.add_argument("--recompute", action='store_true',
Mohammad's avatar
Mohammad committed
452
                       help='During generation recompute all attention '
453
                       'instead of using previously computed keys/values.')
Mohammad's avatar
Mohammad committed
454

455
    return parser
456
457
458


def main():
459
    """Main program."""
460

461
462
    initialize_megatron(extra_args_provider=add_text_generate_args,
                        args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
463

464
465
    # Set up model and load checkpoint.
    model = get_model(model_provider)
Mohammad's avatar
Mohammad committed
466
    args = get_args()
467
468
    if args.load is not None:
        _ = load_checkpoint(model, None, None)
Mohammad's avatar
Mohammad committed
469
470

    # Generate samples.
471
    if args.num_samples == 0:
Mohammad's avatar
Mohammad committed
472
        args.batch_size = 1
473
        if args.sample_input_file != "":
474
            generate_samples_input_from_file(model)
475
        else:
476
            generate_samples_interactive(model)
477
    else:
478
        write_and_generate_samples_unconditional(model)
479
480


Mohammad's avatar
Mohammad committed
481
if __name__ == "__main__":
482

Mohammad's avatar
Mohammad committed
483
    main()