text_generation_utils.py 15.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Utilities for generating text."""
17

18
import copy
Mohammad's avatar
Mohammad committed
19
20
21
22
import json
import os
import time

23
24
25
import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
26
27
28
29
from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
from megatron.utils import get_ltor_masks_and_position_ids
30

31
32
33
34
35

def get_batch(context_tokens):
    """Generate batch from context tokens."""
    args = get_args()
    tokenizer = get_tokenizer()
36

37
    # Move to GPU.
Mohammad's avatar
Mohammad committed
38
    tokens = context_tokens.view(args.batch_size, -1).contiguous().cuda()
39
40
    # Get the attention mask and postition ids.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
41
        tokens,
42
        tokenizer.eod,
43
        args.reset_position_ids,
44
        args.reset_attention_mask,
45
        args.eod_mask_loss)
46

47
48
    return tokens, attention_mask, position_ids

49

50
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
51
52
53
54
    """ This function has been mostly taken from huggingface conversational
     ai code at
         https://medium.com/huggingface/how-to-build-a-state-of-the-art-
              conversational-ai-with-transfer-learning-2d818ac26313 """
55
56

    if top_k > 0:
57
58
        # Remove all tokens with a probability less than the
        # last token of the top-k
59
60
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
61

62
    if top_p > 0.0:
63
64
65
66
67
        # Cconvert to 1D
        sorted_logits, sorted_indices = torch.sort(
            logits, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
                                        dim=-1)
68
69
70

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
71
72
73
74
        # Shift the indices to the right to keep also the first token
        # above the threshold
        sorted_indices_to_remove[..., 1:] \
            = sorted_indices_to_remove[..., :-1].clone()
75
        sorted_indices_to_remove[..., 0] = 0
76
77
78
        for i in range(sorted_indices.size(0)):
            indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
            logits[i][indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
79

80
81
82
    return logits


83
def generate_samples_input_from_file(model):
Mohammad's avatar
Mohammad committed
84

85
86
    args = get_args()
    tokenizer = get_tokenizer()
87

88
89
90
    # Read the sample file and open the output file.
    assert args.sample_input_file is not None, \
        'sample input file is not provided.'
91
92
93
94
95
    if mpu.get_model_parallel_rank() == 0:
        fname = open(args.sample_input_file, "r")
        all_raw_text = fname.readlines()
        input_count = len(all_raw_text)
        input_pos = 0
96
97
98
        if args.sample_output_file is None:
            sample_output_file = args.sample_input_file + ".out"
            print('could not find `sample-output-file`, setting '
Mohammad's avatar
Mohammad committed
99
                  'it to {}'.format(sample_output_file))
100
        fname_out = open(sample_output_file, "w+")
101

Mohammad's avatar
Mohammad committed
102
    context_count = 0
103
104
105
106
    model.eval()
    with torch.no_grad():
        while True:
            torch.distributed.barrier(group=mpu.get_model_parallel_group())
Mohammad's avatar
Mohammad committed
107
            terminate_runs = 0
108
109

            if mpu.get_model_parallel_rank() == 0:
110
111
112
113
114
115
116
117
                raw_text = all_raw_text[input_pos]
                input_pos += 1
                if input_pos == input_count:
                    raw_text = "stop"

                if "stop" in raw_text:
                    terminate_runs = 1
                else:
118
                    context_tokens = tokenizer.tokenize(raw_text)
119
120
                    context_length = len(context_tokens)

121
                    if context_length >= (args.seq_length // 2):
Neel Kant's avatar
Neel Kant committed
122
123
                        print("\nContext length", context_length,
                              "\nPlease give smaller context (half of the "
124
                              "sequence length)!", flush=True)
125
126
                        continue
            else:
127
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
128
                context_length = len(context_tokens)
Mohammad's avatar
Mohammad committed
129

130
            terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
131
132
133
            torch.distributed.broadcast(terminate_runs_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
134
135
136
137
138
            terminate_runs = terminate_runs_tensor[0].item()

            if terminate_runs == 1:
                return

139
            token_stream = get_token_stream(model, [context_tokens])
Mohammad's avatar
Mohammad committed
140
            for _, decode_tokens in enumerate(token_stream):
141
142
143
144
145
146
                decode_tokens, _ = decode_tokens
                decode_tokens = decode_tokens[0].cpu().numpy().tolist()

            if mpu.get_model_parallel_rank() == 0:
                os.system('clear')
                print("\nContext:", raw_text, flush=True)
147
148
                trim_decode_tokens = tokenizer.detokenize(
                    decode_tokens)[len(raw_text):]
149
150
151
152
153
154
155
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)

                fname_out.write("\nContext:")
                fname_out.write(raw_text)
                fname_out.write("\n\nMegatron-LM:")
                fname_out.write(trim_decode_tokens)
                fname_out.write("\n")
Mohammad's avatar
Mohammad committed
156

157
158
159
160
            raw_text = None

            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            context_count += 1
Mohammad's avatar
Mohammad committed
161

162

163
def generate_samples_interactive(model, print_frequency=24):
Mohammad's avatar
Mohammad committed
164

165
166
    args = get_args()
    tokenizer = get_tokenizer()
167

Mohammad's avatar
Mohammad committed
168
    context_count = 0
169
170
171
172
    model.eval()
    with torch.no_grad():
        while True:
            torch.distributed.barrier(group=mpu.get_model_parallel_group())
Mohammad's avatar
Mohammad committed
173
            terminate_runs = 0
174
175
176

            if mpu.get_model_parallel_rank() == 0:
                os.system('clear')
177
178
179
180
                raw_text = input("\nContext prompt (stop to exit) >>> ")
                while not raw_text:
                    print('Prompt should not be empty!')
                    raw_text = input("\nContext prompt (stop to exit) >>> ")
Mohammad's avatar
Mohammad committed
181

182
183
184
                if "stop" in raw_text:
                    terminate_runs = 1
                else:
185
                    context_tokens = tokenizer.tokenize(raw_text)
186
187
                    context_length = len(context_tokens)

188
                    if context_length >= (args.seq_length // 2):
Neel Kant's avatar
Neel Kant committed
189
190
                        print("\nContext length", context_length,
                              "\nPlease give smaller context (half of the "
191
                              "sequence length)!", flush=True)
192
193
                        continue
            else:
194
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
195
                context_length = len(context_tokens)
Mohammad's avatar
Mohammad committed
196

197
            terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
198
199
200
            torch.distributed.broadcast(terminate_runs_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
201
202
203
204
205
            terminate_runs = terminate_runs_tensor[0].item()

            if terminate_runs == 1:
                return

206
            token_stream = get_token_stream(model, [context_tokens])
207
208
209
210
            for counter, decode_tokens in enumerate(token_stream):
                decode_tokens, _ = decode_tokens
                decode_tokens = decode_tokens[0].cpu().numpy().tolist()

211
212
                if mpu.get_model_parallel_rank() == 0 and \
                   counter % print_frequency == 0:
213
214
                    os.system('clear')
                    print("\nContext:", raw_text, flush=True)
215
216
                    trim_decode_tokens = tokenizer.detokenize(
                        decode_tokens)[len(raw_text):]
217
                    print("\nMegatron-LM:", trim_decode_tokens, flush=True)
218
219
220
221

            if mpu.get_model_parallel_rank() == 0:
                os.system('clear')
                print("\nContext:", raw_text, flush=True)
222
223
                trim_decode_tokens = tokenizer.detokenize(
                    decode_tokens)[len(raw_text):]
224
225
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)

226
227
228
            raw_text = None
            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            context_count += 1
Mohammad's avatar
Mohammad committed
229

230
231
232
            if mpu.get_model_parallel_rank() == 0:
                input("\nPress any key to continue >>>")

233
234

def generate_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
235

236
237
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
238

239
    num_samples = args.num_samples
240
241
    context_tokens = [[tokenizer.eod]
                      for _ in range(args.batch_size)]
242
243
244
    ctr = 0
    while True:
        start_time = time.time()
245
246
        for token_stream in get_token_stream(model,
                                             copy.deepcopy(context_tokens)):
247
            pass
Neel Kant's avatar
Neel Kant committed
248
        if ctr % args.log_interval == 0:
249
250
            print('Avg s/batch:',
                  (time.time() - start_time) / min(args.log_interval, ctr + 1))
251
252
253
254
255
            start_time = time.time()
        length = len(token_stream)
        token_batch = token_stream[0].cpu().numpy().tolist()
        length_batch = token_stream[1].cpu().numpy().tolist()
        for tokens, length in zip(token_batch, length_batch):
Neel Kant's avatar
Neel Kant committed
256
            tokens = tokens[1:length - 1]
257
            text = tokenizer.detokenize(tokens)
258
            is_finished = length < args.seq_length - 1
Neel Kant's avatar
Neel Kant committed
259
            datum = {'text': text, 'length': length - 1, 'finished': is_finished}
260
261
262
263
264
265
266
            yield datum
            ctr += 1
            if ctr >= num_samples:
                break
        if ctr >= num_samples:
            break

267

Mohammad's avatar
Mohammad committed
268
def generate_and_write_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
269

270
    args = get_args()
271
272
    assert args.genfile is not None
    with open(args.genfile, 'w') as f:
273
        for datum in generate_samples_unconditional(model):
Neel Kant's avatar
Neel Kant committed
274
            f.write(json.dumps(datum) + '\n')
275

276

Mohammad's avatar
Mohammad committed
277
278
def pad_batch(batch, pad_id, args):

279
280
281
282
    context_lengths = []
    for tokens in batch:
        context_length = len(tokens)
        if context_length < args.seq_length:
Neel Kant's avatar
Neel Kant committed
283
            tokens.extend([pad_id] * (args.seq_length - context_length))
284
285
286
        context_lengths.append(context_length)
    return batch, context_lengths

287
288

def get_token_stream(model, context_tokens):
Mohammad's avatar
Mohammad committed
289

290
291
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
292
293
294

    context_tokens, context_lengths = pad_batch(context_tokens,
                                                tokenizer.eod, args)
295
296
297
298

    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    context_length_tensor = torch.cuda.LongTensor(context_lengths)

299
300
301
302
303
304
    torch.distributed.broadcast(context_length_tensor,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    torch.distributed.broadcast(context_tokens_tensor,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
305
306

    context_length = context_length_tensor.min().item()
Mohammad's avatar
Mohammad committed
307
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
308

309
310
311
    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids)
312
313
314
315
316
317
    for tokens, lengths in batch_token_iterator:
        context_length += 1
        yield tokens[:, :context_length], lengths


def switch(val1, val2, boolean):
Mohammad's avatar
Mohammad committed
318

319
    boolean = boolean.type_as(val1)
Mohammad's avatar
Mohammad committed
320
    return (1 - boolean) * val1 + boolean * val2
321

322
323
324
325

def sample_sequence_batch(model, context_tokens, context_lengths,
                          attention_mask, position_ids,
                          maxlen=None, type_ids=None):
Mohammad's avatar
Mohammad committed
326

327
328
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
329

330
331
332
    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
Mohammad's avatar
Mohammad committed
333
        eos_id = tokenizer.eod
334
335
336
337
338
339
340
341
342
343
344
345
346

        counter = 0
        org_context_length = context_length

        layer_past = None
        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
        if maxlen is None:
            maxlen = args.seq_length - 1
            if maxlen > (org_context_length + args.out_seq_length):
                maxlen = org_context_length + args.out_seq_length

Neel Kant's avatar
Neel Kant committed
347
        lengths = torch.ones([batch_size]).long().cuda() * maxlen
Mohammad's avatar
Mohammad committed
348

349
350
351
        while context_length <= (maxlen):

            if args.recompute:
352
353
354
355
356
                logits = model(tokens,
                               position_ids,
                               attention_mask,
                               tokentype_ids=type_ids,
                               forward_method_parallel_output=False)
357
                logits = logits[:, context_length - 1, :]
358
            else:
359
                types2use = None
360
361
362
                if counter == 0:
                    tokens2use = tokens[:, :context_length]
                    positions2use = position_ids[:, :context_length]
363
364
                    if type_ids is not None:
                        types2use = type_ids[:, :context_length]
365
                else:
366
367
368
369
                    tokens2use = tokens[:, context_length - 1].view(
                        batch_size, -1)
                    positions2use = position_ids[:, context_length - 1].view(
                        batch_size, -1)
370
                    if type_ids is not None:
371
372
373
374
375
376
377
378
379
                        types2use = type_ids[:, context_length - 1].view(
                            batch_size, -1)
                logits, layer_past = model(tokens2use,
                                           positions2use,
                                           attention_mask,
                                           layer_past=layer_past,
                                           get_key_value=True,
                                           tokentype_ids=types2use,
                                           forward_method_parallel_output=False)
Mohammad's avatar
Mohammad committed
380
                logits = logits[:, -1].view(batch_size, -1).contiguous()
381
382
383
384

            if args.greedy:
                prev = torch.argmax(logits, dim=-1).view(-1)
            else:
Raul Puri's avatar
Raul Puri committed
385
                logits = logits.float()
386
                logits /= args.temperature
387
388
                logits = top_k_logits(logits, top_k=args.top_k,
                                      top_p=args.top_p)
389
390
391
392
393
                log_probs = F.softmax(logits, dim=-1)
                prev = torch.multinomial(log_probs, num_samples=1).view(-1)

            print_logits = []
            for p in prev:
394
395
                print_logits.append([logits[i, p].item()
                                     for i in range(batch_size)])
396
            started = context_lengths <= context_length
397
398
            tokens[:, context_length] = switch(
                tokens[:, context_length].view(-1), prev, started)
399
400
401
            context_length += 1
            counter += 1

402
            done_token = (prev == eos_id).byte() & started.byte()
403
404
405
406
407
408
409
410
            just_finished = (done_token & ~is_done).bool()
            lengths[just_finished.view(-1)] = context_length
            is_done = is_done | done_token
            done = torch.all(is_done)

            yield tokens, lengths
            if done:
                break