"tests/test_indoor_augment.py" did not exist on "3b9ade96c53364dcb77512236ef787b079fb742b"
text_generation_utils.py 15.5 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Utilities for generating text."""
17

18
import copy
Mohammad's avatar
Mohammad committed
19
20
21
22
import json
import os
import time

23
24
25
import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
26
27
28
29
from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
from megatron.utils import get_ltor_masks_and_position_ids
30

31
32
33
34
35

def get_batch(context_tokens):
    """Generate batch from context tokens."""
    args = get_args()
    tokenizer = get_tokenizer()
36

37
    # Move to GPU.
Mohammad's avatar
Mohammad committed
38
    tokens = context_tokens.view(args.batch_size, -1).contiguous().cuda()
39
40
    # Get the attention mask and postition ids.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
41
        tokens,
42
        tokenizer.eod,
43
        args.reset_position_ids,
44
        args.reset_attention_mask,
45
        args.eod_mask_loss)
46

47
48
    return tokens, attention_mask, position_ids

49

50
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
51
52
53
54
    """ This function has been mostly taken from huggingface conversational
     ai code at
         https://medium.com/huggingface/how-to-build-a-state-of-the-art-
              conversational-ai-with-transfer-learning-2d818ac26313 """
55
56

    if top_k > 0:
57
58
        # Remove all tokens with a probability less than the
        # last token of the top-k
59
60
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
61

62
    if top_p > 0.0:
63
64
65
66
67
        # Cconvert to 1D
        sorted_logits, sorted_indices = torch.sort(
            logits, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
                                        dim=-1)
68
69
70

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
71
72
73
74
        # Shift the indices to the right to keep also the first token
        # above the threshold
        sorted_indices_to_remove[..., 1:] \
            = sorted_indices_to_remove[..., :-1].clone()
75
        sorted_indices_to_remove[..., 0] = 0
76
77
78
        for i in range(sorted_indices.size(0)):
            indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
            logits[i][indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
79

80
81
82
    return logits


83
def generate_samples_input_from_file(model):
Mohammad's avatar
Mohammad committed
84

85
86
    args = get_args()
    tokenizer = get_tokenizer()
87

88
89
90
    # Read the sample file and open the output file.
    assert args.sample_input_file is not None, \
        'sample input file is not provided.'
91
92
93
94
95
    if mpu.get_model_parallel_rank() == 0:
        fname = open(args.sample_input_file, "r")
        all_raw_text = fname.readlines()
        input_count = len(all_raw_text)
        input_pos = 0
96
97
98
        if args.sample_output_file is None:
            sample_output_file = args.sample_input_file + ".out"
            print('could not find `sample-output-file`, setting '
Mohammad's avatar
Mohammad committed
99
                  'it to {}'.format(sample_output_file))
Mostofa Patwary's avatar
Mostofa Patwary committed
100
101
        else:
            sample_output_file = args.sample_output_file
102
        fname_out = open(sample_output_file, "w+")
103

Mohammad's avatar
Mohammad committed
104
    context_count = 0
105
106
107
108
    model.eval()
    with torch.no_grad():
        while True:
            torch.distributed.barrier(group=mpu.get_model_parallel_group())
Mohammad's avatar
Mohammad committed
109
            terminate_runs = 0
110
111

            if mpu.get_model_parallel_rank() == 0:
112
113
114
115
116
117
118
119
                raw_text = all_raw_text[input_pos]
                input_pos += 1
                if input_pos == input_count:
                    raw_text = "stop"

                if "stop" in raw_text:
                    terminate_runs = 1
                else:
120
                    context_tokens = tokenizer.tokenize(raw_text)
121
122
                    context_length = len(context_tokens)

123
                    if context_length >= (args.seq_length // 2):
Neel Kant's avatar
Neel Kant committed
124
125
                        print("\nContext length", context_length,
                              "\nPlease give smaller context (half of the "
126
                              "sequence length)!", flush=True)
127
128
                        continue
            else:
129
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
130
                context_length = len(context_tokens)
Mohammad's avatar
Mohammad committed
131

132
            terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
133
134
135
            torch.distributed.broadcast(terminate_runs_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
136
137
138
139
140
            terminate_runs = terminate_runs_tensor[0].item()

            if terminate_runs == 1:
                return

141
            token_stream = get_token_stream(model, [context_tokens])
Mohammad's avatar
Mohammad committed
142
            for _, decode_tokens in enumerate(token_stream):
143
144
145
146
147
148
                decode_tokens, _ = decode_tokens
                decode_tokens = decode_tokens[0].cpu().numpy().tolist()

            if mpu.get_model_parallel_rank() == 0:
                os.system('clear')
                print("\nContext:", raw_text, flush=True)
149
150
                trim_decode_tokens = tokenizer.detokenize(
                    decode_tokens)[len(raw_text):]
151
152
153
154
155
156
157
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)

                fname_out.write("\nContext:")
                fname_out.write(raw_text)
                fname_out.write("\n\nMegatron-LM:")
                fname_out.write(trim_decode_tokens)
                fname_out.write("\n")
Mohammad's avatar
Mohammad committed
158

159
160
161
162
            raw_text = None

            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            context_count += 1
Mohammad's avatar
Mohammad committed
163

164

165
def generate_samples_interactive(model, print_frequency=24):
Mohammad's avatar
Mohammad committed
166

167
168
    args = get_args()
    tokenizer = get_tokenizer()
169

Mohammad's avatar
Mohammad committed
170
    context_count = 0
171
172
173
174
    model.eval()
    with torch.no_grad():
        while True:
            torch.distributed.barrier(group=mpu.get_model_parallel_group())
Mohammad's avatar
Mohammad committed
175
            terminate_runs = 0
176
177
178

            if mpu.get_model_parallel_rank() == 0:
                os.system('clear')
179
180
181
182
                raw_text = input("\nContext prompt (stop to exit) >>> ")
                while not raw_text:
                    print('Prompt should not be empty!')
                    raw_text = input("\nContext prompt (stop to exit) >>> ")
Mohammad's avatar
Mohammad committed
183

184
185
186
                if "stop" in raw_text:
                    terminate_runs = 1
                else:
187
                    context_tokens = tokenizer.tokenize(raw_text)
188
189
                    context_length = len(context_tokens)

190
                    if context_length >= (args.seq_length // 2):
Neel Kant's avatar
Neel Kant committed
191
192
                        print("\nContext length", context_length,
                              "\nPlease give smaller context (half of the "
193
                              "sequence length)!", flush=True)
194
195
                        continue
            else:
196
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
197
                context_length = len(context_tokens)
Mohammad's avatar
Mohammad committed
198

199
            terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
200
201
202
            torch.distributed.broadcast(terminate_runs_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
203
204
205
206
207
            terminate_runs = terminate_runs_tensor[0].item()

            if terminate_runs == 1:
                return

208
            token_stream = get_token_stream(model, [context_tokens])
209
210
211
212
            for counter, decode_tokens in enumerate(token_stream):
                decode_tokens, _ = decode_tokens
                decode_tokens = decode_tokens[0].cpu().numpy().tolist()

213
214
                if mpu.get_model_parallel_rank() == 0 and \
                   counter % print_frequency == 0:
215
216
                    os.system('clear')
                    print("\nContext:", raw_text, flush=True)
217
218
                    trim_decode_tokens = tokenizer.detokenize(
                        decode_tokens)[len(raw_text):]
219
                    print("\nMegatron-LM:", trim_decode_tokens, flush=True)
220
221
222
223

            if mpu.get_model_parallel_rank() == 0:
                os.system('clear')
                print("\nContext:", raw_text, flush=True)
224
225
                trim_decode_tokens = tokenizer.detokenize(
                    decode_tokens)[len(raw_text):]
226
227
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)

228
229
230
            raw_text = None
            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            context_count += 1
Mohammad's avatar
Mohammad committed
231

232
233
234
            if mpu.get_model_parallel_rank() == 0:
                input("\nPress any key to continue >>>")

235
236

def generate_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
237

238
239
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
240

241
    num_samples = args.num_samples
242
243
    context_tokens = [[tokenizer.eod]
                      for _ in range(args.batch_size)]
244
245
246
    ctr = 0
    while True:
        start_time = time.time()
247
248
        for token_stream in get_token_stream(model,
                                             copy.deepcopy(context_tokens)):
249
            pass
Neel Kant's avatar
Neel Kant committed
250
        if ctr % args.log_interval == 0:
251
252
            print('Avg s/batch:',
                  (time.time() - start_time) / min(args.log_interval, ctr + 1))
253
254
255
256
257
            start_time = time.time()
        length = len(token_stream)
        token_batch = token_stream[0].cpu().numpy().tolist()
        length_batch = token_stream[1].cpu().numpy().tolist()
        for tokens, length in zip(token_batch, length_batch):
Neel Kant's avatar
Neel Kant committed
258
            tokens = tokens[1:length - 1]
259
            text = tokenizer.detokenize(tokens)
260
            is_finished = length < args.seq_length - 1
Neel Kant's avatar
Neel Kant committed
261
            datum = {'text': text, 'length': length - 1, 'finished': is_finished}
262
263
264
265
266
267
268
            yield datum
            ctr += 1
            if ctr >= num_samples:
                break
        if ctr >= num_samples:
            break

269

Mohammad's avatar
Mohammad committed
270
def generate_and_write_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
271

272
    args = get_args()
273
274
    assert args.genfile is not None
    with open(args.genfile, 'w') as f:
275
        for datum in generate_samples_unconditional(model):
Neel Kant's avatar
Neel Kant committed
276
            f.write(json.dumps(datum) + '\n')
277

278

Mohammad's avatar
Mohammad committed
279
280
def pad_batch(batch, pad_id, args):

281
282
283
284
    context_lengths = []
    for tokens in batch:
        context_length = len(tokens)
        if context_length < args.seq_length:
Neel Kant's avatar
Neel Kant committed
285
            tokens.extend([pad_id] * (args.seq_length - context_length))
286
287
288
        context_lengths.append(context_length)
    return batch, context_lengths

289
290

def get_token_stream(model, context_tokens):
Mohammad's avatar
Mohammad committed
291

292
293
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
294
295
296

    context_tokens, context_lengths = pad_batch(context_tokens,
                                                tokenizer.eod, args)
297
298
299
300

    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    context_length_tensor = torch.cuda.LongTensor(context_lengths)

301
302
303
304
305
306
    torch.distributed.broadcast(context_length_tensor,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    torch.distributed.broadcast(context_tokens_tensor,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
307
308

    context_length = context_length_tensor.min().item()
Mohammad's avatar
Mohammad committed
309
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
310

311
312
313
    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids)
314
315
316
317
318
319
    for tokens, lengths in batch_token_iterator:
        context_length += 1
        yield tokens[:, :context_length], lengths


def switch(val1, val2, boolean):
Mohammad's avatar
Mohammad committed
320

321
    boolean = boolean.type_as(val1)
Mohammad's avatar
Mohammad committed
322
    return (1 - boolean) * val1 + boolean * val2
323

324
325
326
327

def sample_sequence_batch(model, context_tokens, context_lengths,
                          attention_mask, position_ids,
                          maxlen=None, type_ids=None):
Mohammad's avatar
Mohammad committed
328

329
330
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
331

332
333
334
    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
Mohammad's avatar
Mohammad committed
335
        eos_id = tokenizer.eod
336
337
338
339
340
341
342
343
344
345
346
347
348

        counter = 0
        org_context_length = context_length

        layer_past = None
        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
        if maxlen is None:
            maxlen = args.seq_length - 1
            if maxlen > (org_context_length + args.out_seq_length):
                maxlen = org_context_length + args.out_seq_length

Neel Kant's avatar
Neel Kant committed
349
        lengths = torch.ones([batch_size]).long().cuda() * maxlen
Mohammad's avatar
Mohammad committed
350

351
352
353
        while context_length <= (maxlen):

            if args.recompute:
354
355
356
357
358
                logits = model(tokens,
                               position_ids,
                               attention_mask,
                               tokentype_ids=type_ids,
                               forward_method_parallel_output=False)
359
                logits = logits[:, context_length - 1, :]
360
            else:
361
                types2use = None
362
363
364
                if counter == 0:
                    tokens2use = tokens[:, :context_length]
                    positions2use = position_ids[:, :context_length]
365
366
                    if type_ids is not None:
                        types2use = type_ids[:, :context_length]
367
                else:
368
369
370
371
                    tokens2use = tokens[:, context_length - 1].view(
                        batch_size, -1)
                    positions2use = position_ids[:, context_length - 1].view(
                        batch_size, -1)
372
                    if type_ids is not None:
373
374
375
376
377
378
379
380
381
                        types2use = type_ids[:, context_length - 1].view(
                            batch_size, -1)
                logits, layer_past = model(tokens2use,
                                           positions2use,
                                           attention_mask,
                                           layer_past=layer_past,
                                           get_key_value=True,
                                           tokentype_ids=types2use,
                                           forward_method_parallel_output=False)
Mohammad's avatar
Mohammad committed
382
                logits = logits[:, -1].view(batch_size, -1).contiguous()
383
384
385
386

            if args.greedy:
                prev = torch.argmax(logits, dim=-1).view(-1)
            else:
Raul Puri's avatar
Raul Puri committed
387
                logits = logits.float()
388
                logits /= args.temperature
389
390
                logits = top_k_logits(logits, top_k=args.top_k,
                                      top_p=args.top_p)
391
392
393
394
395
                log_probs = F.softmax(logits, dim=-1)
                prev = torch.multinomial(log_probs, num_samples=1).view(-1)

            print_logits = []
            for p in prev:
396
397
                print_logits.append([logits[i, p].item()
                                     for i in range(batch_size)])
398
            started = context_lengths <= context_length
399
400
            tokens[:, context_length] = switch(
                tokens[:, context_length].view(-1), prev, started)
401
402
403
            context_length += 1
            counter += 1

404
            done_token = (prev == eos_id).byte() & started.byte()
405
406
407
408
409
410
411
412
            just_finished = (done_token & ~is_done).bool()
            lengths[just_finished.view(-1)] = context_length
            is_done = is_done | done_token
            done = torch.all(is_done)

            yield tokens, lengths
            if done:
                break