Commit 63262827 authored by Raul Puri's avatar Raul Puri
Browse files

Merge branch 'staging_text_generation' into 'staging'

Refactoring text generation

See merge request ADLR/megatron-lm!39
parents 3977b721 fffa0497
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sample Generate GPT2"""
from megatron import get_args
from megatron import get_tokenizer
from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPT2Model
from megatron.training import get_model
from megatron.text_generation_utils import generate_and_write_samples_unconditional
from megatron.text_generation_utils import generate_samples_input_from_file
from megatron.text_generation_utils import generate_samples_interactive
def model_provider():
"""Build the model."""
print_rank_0('building GPT2 model ...')
model = GPT2Model(num_tokentypes=0, parallel_output=False)
return model
def add_text_generate_args(parser):
"""Text generation arguments."""
group = parser.add_argument_group(title='text generation')
group.add_argument("--temperature", type=float, default=1.0,
help='Sampling temperature.')
group.add_argument("--greedy", action='store_true', default=False,
help='Use greedy sampling.')
group.add_argument("--top_p", type=float, default=0.0,
help='Top p sampling.')
group.add_argument("--top_k", type=int, default=0,
help='Top k sampling.')
group.add_argument("--out-seq-length", type=int, default=1024,
help='Size of the output generated text.')
group.add_argument("--sample-input-file", type=str, default=None,
help='Get input from file instead of interactive mode, '
'each line is an input.')
group.add_argument("--sample-output-file", type=str, default=None,
help='Output file got from --sample-input-file')
group.add_argument("--num-samples", type=int, default=0,
help='Number of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling')
group.add_argument("--genfile", type=str,
help='Output file when generating unconditionally')
group.add_argument("--recompute", action='store_true',
help='During generation recompute all attention '
'instead of using previously computed keys/values.')
return parser
def main():
"""Main program."""
initialize_megatron(extra_args_provider=add_text_generate_args,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
# Set up model and load checkpoint.
model = get_model(model_provider)
args = get_args()
if args.load is not None:
_ = load_checkpoint(model, None, None)
# Generate samples.
if args.num_samples == 0:
args.batch_size = 1
if args.sample_input_file != "":
generate_samples_input_from_file(model)
else:
generate_samples_interactive(model)
else:
generate_and_write_samples_unconditional(model)
if __name__ == "__main__":
main()
...@@ -69,8 +69,10 @@ def parse_args(extra_args_provider=None, defaults={}): ...@@ -69,8 +69,10 @@ def parse_args(extra_args_provider=None, defaults={}):
# Checks. # Checks.
assert args.hidden_size % args.num_attention_heads == 0 assert args.hidden_size % args.num_attention_heads == 0
assert args.max_position_embeddings >= args.seq_length if args.seq_length is not None:
assert args.min_lr <= args.lr assert args.max_position_embeddings >= args.seq_length
if args.lr is not None:
assert args.min_lr <= args.lr
if args.save is not None: if args.save is not None:
assert args.save_interval is not None assert args.save_interval is not None
...@@ -134,7 +136,7 @@ def _add_regularization_args(parser): ...@@ -134,7 +136,7 @@ def _add_regularization_args(parser):
def _add_training_args(parser): def _add_training_args(parser):
group = parser.add_argument_group(title='training') group = parser.add_argument_group(title='training')
group.add_argument('--batch-size', type=int, required=True, group.add_argument('--batch-size', type=int, default=None,
help='Batch size per model instance (local batch size). ' help='Batch size per model instance (local batch size). '
'Global batch size is local batch size times data ' 'Global batch size is local batch size times data '
'parallel size.') 'parallel size.')
...@@ -301,7 +303,7 @@ def _add_data_args(parser): ...@@ -301,7 +303,7 @@ def _add_data_args(parser):
help='Path to the vocab file.') help='Path to the vocab file.')
group.add_argument('--merge-file', type=str, default=None, group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file.') help='Path to the BPE merge file.')
group.add_argument('--seq-length', type=int, required=True, group.add_argument('--seq-length', type=int, default=None,
help="Maximum sequence length to process.") help="Maximum sequence length to process.")
group.add_argument('--mask-prob', type=float, default=0.15, group.add_argument('--mask-prob', type=float, default=0.15,
help='Probability of replacing a token with mask.') help='Probability of replacing a token with mask.')
...@@ -356,32 +358,6 @@ def _add_gpt2_args(parser): ...@@ -356,32 +358,6 @@ def _add_gpt2_args(parser):
def add_text_generate_args(parser):
"""Text generate arguments."""
group = parser.add_argument_group('Text generation', 'configurations')
group.add_argument("--temperature", type=float, default=1.0)
group.add_argument("--greedy", action='store_true', default=False)
group.add_argument("--top_p", type=float, default=0.0)
group.add_argument("--top_k", type=int, default=0)
group.add_argument("--out-seq-length", type=int, default=1024)
group.add_argument("--sample-input-file", type=str, default="",
help='get input from file instead of interactive mode, '
'each line is an input' )
group.add_argument("--sample-output-file", type=str, default="",
help='output file got from --sample-input-file')
group.add_argument("--num-samples", type=int, default=0,
help='number of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling')
group.add_argument("--genfile", type=str,
help='output file when generating unconditionally')
group.add_argument("--recompute", action='store_true',
help='during generation recompute all attention '
'instead of using previously computed keys/values.')
return parser
def add_data_args_(parser): def add_data_args_(parser):
"""Train/valid/test data arguments.""" """Train/valid/test data arguments."""
......
...@@ -137,8 +137,7 @@ class BertModel(MegatronModule): ...@@ -137,8 +137,7 @@ class BertModel(MegatronModule):
self._binary_head_key = 'binary_head' self._binary_head_key = 'binary_head'
def forward(self, input_ids, attention_mask, def forward(self, input_ids, attention_mask, tokentype_ids=None):
tokentype_ids=None):
extended_attention_mask = bert_extended_attention_mask( extended_attention_mask = bert_extended_attention_mask(
attention_mask, next(self.language_model.parameters()).dtype) attention_mask, next(self.language_model.parameters()).dtype)
......
...@@ -51,7 +51,8 @@ class GPT2Model(MegatronModule): ...@@ -51,7 +51,8 @@ class GPT2Model(MegatronModule):
def forward(self, input_ids, position_ids, attention_mask, def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False): tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
# Language model. # Language model.
lm_output = self.language_model(input_ids, lm_output = self.language_model(input_ids,
...@@ -65,10 +66,13 @@ class GPT2Model(MegatronModule): ...@@ -65,10 +66,13 @@ class GPT2Model(MegatronModule):
lm_output, presents = lm_output lm_output, presents = lm_output
# Output. # Output.
parallel_output = self.parallel_output
if forward_method_parallel_output is not None:
parallel_output = forward_method_parallel_output
output = parallel_lm_logits( output = parallel_lm_logits(
lm_output, lm_output,
self.language_model.embedding.word_embeddings.weight, self.language_model.embedding.word_embeddings.weight,
self.parallel_output) parallel_output)
if get_key_value: if get_key_value:
output = [output, presents] output = [output, presents]
......
...@@ -13,150 +13,99 @@ ...@@ -13,150 +13,99 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Sample Generate GPT2""" """Utilities for generating text."""
import os
import random
import json
import copy import copy
import numpy as np import json
import torch import os
import torch.nn.functional as F
import argparse
import time import time
from arguments import get_args
from megatron.utils import Timers
from megatron.utils import initialize_distributed
from megatron.utils import set_random_seed
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import load_checkpoint
from megatron.data_utils import make_tokenizer
from configure_data import configure_data
from megatron import mpu
from megatron.fp16 import FP16_Module
from megatron.model import GPT2Model
from megatron.model import DistributedDataParallel as DDP
from megatron import print_rank_0
def get_model(args):
"""Build the model."""
print_rank_0('building GPT2 model ...')
model = GPT2Model(num_layers=args.num_layers,
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout,
max_sequence_length=args.max_position_embeddings,
checkpoint_activations=args.checkpoint_activations,
checkpoint_num_layers=args.checkpoint_num_layers,
parallel_output=False)
if mpu.get_data_parallel_rank() == 0: import torch
print(' > number of parameters on model parallel rank {}: {}'.format( import torch.nn.functional as F
mpu.get_model_parallel_rank(),
sum([p.nelement() for p in model.parameters()])), flush=True)
# GPU allocation.
model.cuda(torch.cuda.current_device())
# Fp16 conversion.
if args.fp16:
model = FP16_Module(model)
# Wrap model for distributed training.
model = DDP(model)
return model
def setup_model(args):
"""Setup model and optimizer."""
model = get_model(args)
if args.load is not None:
_ = load_checkpoint(
model, None, None, args)
return model from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
from megatron.utils import get_ltor_masks_and_position_ids
def get_batch(context_tokens, args): def get_batch(context_tokens):
tokens = context_tokens """Generate batch from context tokens."""
tokens = tokens.view(args.batch_size, -1).contiguous() args = get_args()
device = args.device tokenizer = get_tokenizer()
tokens = tokens.to(device)
# Get the masks and postition ids. # Move to GPU.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens = context_tokens.view(args.batch_size, -1).contiguous().cuda()
# Get the attention mask and postition ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens, tokens,
args.eod_token, tokenizer.eod,
args.reset_position_ids, args.reset_position_ids,
args.reset_attention_mask, args.reset_attention_mask,
False) args.eod_mask_loss,
args.fp16)
# Fp16 conversion.
if args.fp16:
attention_mask = attention_mask.half()
return tokens, attention_mask, position_ids return tokens, attention_mask, position_ids
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
# This function has been mostly taken from huggingface conversational ai code at """ This function has been mostly taken from huggingface conversational
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
conversational-ai-with-transfer-learning-2d818ac26313 """
if top_k > 0: if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k # Remove all tokens with a probability less than the
# last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value logits[indices_to_remove] = filter_value
if top_p > 0.0: if top_p > 0.0:
#convert to 1D # Cconvert to 1D
# logits=logits.view(logits.size()[1]).contiguous() sorted_logits, sorted_indices = torch.sort(
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
dim=-1)
# Remove tokens with cumulative probability above the threshold # Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold # Shift the indices to the right to keep also the first token
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() # above the threshold
sorted_indices_to_remove[..., 1:] \
= sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0 sorted_indices_to_remove[..., 0] = 0
for i in range(sorted_indices.size(0)): for i in range(sorted_indices.size(0)):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value logits[i][indices_to_remove] = filter_value
#going back to 2D
# logits=logits.view(1, -1).contiguous()
return logits return logits
def generate_samples_input_from_file(model, tokenizer, args):
if args.sample_input_file == "": def generate_samples_input_from_file(model):
if mpu.get_model_parallel_rank() == 0:
print("args.sample_input_file CAN NOT BE empty!\n")
return
args = get_args()
tokenizer = get_tokenizer()
# Read the sample file and open the output file.
assert args.sample_input_file is not None, \
'sample input file is not provided.'
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
fname = open(args.sample_input_file, "r") fname = open(args.sample_input_file, "r")
all_raw_text = fname.readlines() all_raw_text = fname.readlines()
input_count = len(all_raw_text) input_count = len(all_raw_text)
input_pos = 0 input_pos = 0
if args.sample_output_file == "": if args.sample_output_file is None:
print("Argument: sample-output-file can't be empty, setting it to\n") sample_output_file = args.sample_input_file + ".out"
print("\t args.sample_input_file.out") print('could not find `sample-output-file`, setting '
args.sample_output_file = args.sample_input_file+".out" 'it to {}'.format(sample_output_file))
fname_out = open(args.sample_output_file, "w+") fname_out = open(sample_output_file, "w+")
context_count=0 context_count = 0
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
while True: while True:
torch.distributed.barrier(group=mpu.get_model_parallel_group()) torch.distributed.barrier(group=mpu.get_model_parallel_group())
terminate_runs=0 terminate_runs = 0
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
raw_text = all_raw_text[input_pos] raw_text = all_raw_text[input_pos]
...@@ -167,63 +116,62 @@ def generate_samples_input_from_file(model, tokenizer, args): ...@@ -167,63 +116,62 @@ def generate_samples_input_from_file(model, tokenizer, args):
if "stop" in raw_text: if "stop" in raw_text:
terminate_runs = 1 terminate_runs = 1
else: else:
context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization context_tokens = tokenizer.tokenize(raw_text)
context_length = len(context_tokens) context_length = len(context_tokens)
if context_length >=args.seq_length//2: if context_length >= (args.seq_length // 2):
print("\nContext length", context_length, \ print("\nContext length", context_length, \
"\nPlease give smaller context (half of the sequence length)!") "\nPlease give smaller context (half of the "
"sequence length)!", flush=True)
continue continue
else: else:
context_tokens = tokenizer.EncodeAsIds("EMPTY TEXT").tokenization context_tokens = tokenizer.tokenize("EMPTY TEXT")
context_length = len(context_tokens) context_length = len(context_tokens)
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(terminate_runs_tensor,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item() terminate_runs = terminate_runs_tensor[0].item()
if terminate_runs == 1: if terminate_runs == 1:
return return
start_time = time.time() token_stream = get_token_stream(model, [context_tokens])
token_stream = get_token_stream(model, [context_tokens], tokenizer, args) for _, decode_tokens in enumerate(token_stream):
for counter, decode_tokens in enumerate(token_stream):
# token_end = decode_tokens.find("<|endoftext|>")
# if token_end > 0:
# break
decode_tokens, _ = decode_tokens decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist() decode_tokens = decode_tokens[0].cpu().numpy().tolist()
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
os.system('clear') os.system('clear')
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True) print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)[len(raw_text):] trim_decode_tokens = tokenizer.detokenize(
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True) decode_tokens)[len(raw_text):]
print("\nMegatron-LM:", trim_decode_tokens, flush=True) print("\nMegatron-LM:", trim_decode_tokens, flush=True)
fname_out.write("\nContext:") fname_out.write("\nContext:")
fname_out.write(raw_text) fname_out.write(raw_text)
fname_out.write("\n\nMegatron-LM:") fname_out.write("\n\nMegatron-LM:")
fname_out.write(trim_decode_tokens) fname_out.write(trim_decode_tokens)
#fname_out.write(trim_decode_tokens.replace("\n", "\n\n"))
fname_out.write("\n") fname_out.write("\n")
raw_text = None raw_text = None
torch.distributed.barrier(group=mpu.get_model_parallel_group()) torch.distributed.barrier(group=mpu.get_model_parallel_group())
context_count += 1 context_count += 1
def generate_samples_interactive(model, tokenizer, args):
print_frequency = 24
context_count=0 def generate_samples_interactive(model, print_frequency=24):
args = get_args()
tokenizer = get_tokenizer()
context_count = 0
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
while True: while True:
torch.distributed.barrier(group=mpu.get_model_parallel_group()) torch.distributed.barrier(group=mpu.get_model_parallel_group())
terminate_runs=0 terminate_runs = 0
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
os.system('clear') os.system('clear')
...@@ -231,83 +179,83 @@ def generate_samples_interactive(model, tokenizer, args): ...@@ -231,83 +179,83 @@ def generate_samples_interactive(model, tokenizer, args):
while not raw_text: while not raw_text:
print('Prompt should not be empty!') print('Prompt should not be empty!')
raw_text = input("\nContext prompt (stop to exit) >>> ") raw_text = input("\nContext prompt (stop to exit) >>> ")
if "stop" in raw_text: if "stop" in raw_text:
terminate_runs = 1 terminate_runs = 1
else: else:
context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization context_tokens = tokenizer.tokenize(raw_text)
context_length = len(context_tokens) context_length = len(context_tokens)
if context_length >=args.seq_length//2: if context_length >= (args.seq_length // 2):
print("\nContext length", context_length, \ print("\nContext length", context_length, \
"\nPlease give smaller context (half of the sequence length)!") "\nPlease give smaller context (half of the "
"sequence length)!", flush=True)
continue continue
else: else:
context_tokens = tokenizer.EncodeAsIds("EMPTY TEXT").tokenization context_tokens = tokenizer.tokenize("EMPTY TEXT")
context_length = len(context_tokens) context_length = len(context_tokens)
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(terminate_runs_tensor,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item() terminate_runs = terminate_runs_tensor[0].item()
if terminate_runs == 1: if terminate_runs == 1:
return return
start_time = time.time() token_stream = get_token_stream(model, [context_tokens])
token_stream = get_token_stream(model, [context_tokens], tokenizer, args)
for counter, decode_tokens in enumerate(token_stream): for counter, decode_tokens in enumerate(token_stream):
# token_end = decode_tokens.find("<|endoftext|>")
# if token_end > 0:
# break
decode_tokens, _ = decode_tokens decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist() decode_tokens = decode_tokens[0].cpu().numpy().tolist()
if mpu.get_model_parallel_rank() == 0 and counter % print_frequency == 0: if mpu.get_model_parallel_rank() == 0 and \
counter % print_frequency == 0:
os.system('clear') os.system('clear')
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True) print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)[len(raw_text):] trim_decode_tokens = tokenizer.detokenize(
#print("\nGPT2:", trim_decode_tokens, flush=True) decode_tokens)[len(raw_text):]
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
print("\nMegatron-LM:", trim_decode_tokens, flush=True) print("\nMegatron-LM:", trim_decode_tokens, flush=True)
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
os.system('clear') os.system('clear')
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True) print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)[len(raw_text):] trim_decode_tokens = tokenizer.detokenize(
#print("\nGPT2:", trim_decode_tokens, flush=True) decode_tokens)[len(raw_text):]
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
print("\nMegatron-LM:", trim_decode_tokens, flush=True) print("\nMegatron-LM:", trim_decode_tokens, flush=True)
raw_text = None raw_text = None
torch.distributed.barrier(group=mpu.get_model_parallel_group()) torch.distributed.barrier(group=mpu.get_model_parallel_group())
context_count += 1 context_count += 1
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
input("\nPress any key to continue >>>") input("\nPress any key to continue >>>")
def generate_samples_unconditional(model, tokenizer, args):
def generate_samples_unconditional(model):
args = get_args()
tokenizer = get_tokenizer()
num_samples = args.num_samples num_samples = args.num_samples
context_tokens = [[tokenizer.get_command('pad').Id] for _ in range(args.batch_size)] context_tokens = [[tokenizer.eod]
samples = [] for _ in range(args.batch_size)]
# with open(args.genfile, 'w') as f:
ctr = 0 ctr = 0
while True: while True:
start_time = time.time() start_time = time.time()
for token_stream in get_token_stream(model, copy.deepcopy(context_tokens), tokenizer, args): for token_stream in get_token_stream(model,
copy.deepcopy(context_tokens)):
pass pass
# token_stream = list(get_token_stream(model, copy.deepcopy(context_tokens), tokenizer, args))
if ctr%args.log_interval == 0: if ctr%args.log_interval == 0:
print('Avg s/batch:', (time.time()-start_time)/min(args.log_interval, ctr+1)) print('Avg s/batch:',
(time.time() - start_time) / min(args.log_interval, ctr + 1))
start_time = time.time() start_time = time.time()
length = len(token_stream) length = len(token_stream)
token_batch = token_stream[0].cpu().numpy().tolist() token_batch = token_stream[0].cpu().numpy().tolist()
length_batch = token_stream[1].cpu().numpy().tolist() length_batch = token_stream[1].cpu().numpy().tolist()
for tokens, length in zip(token_batch, length_batch): for tokens, length in zip(token_batch, length_batch):
tokens = tokens[1:length-1] tokens = tokens[1:length-1]
text = tokenizer.DecodeIds(tokens) text = tokenizer.detokenize(tokens)
is_finished = length < args.seq_length - 1 is_finished = length < args.seq_length - 1
datum = {'text': text, 'length': length-1, 'finished': is_finished} datum = {'text': text, 'length': length-1, 'finished': is_finished}
yield datum yield datum
...@@ -317,66 +265,73 @@ def generate_samples_unconditional(model, tokenizer, args): ...@@ -317,66 +265,73 @@ def generate_samples_unconditional(model, tokenizer, args):
if ctr >= num_samples: if ctr >= num_samples:
break break
def write_and_generate_samples_unconditional(model, tokenizer, args):
def generate_and_write_samples_unconditional(model):
args = get_args()
assert args.genfile is not None assert args.genfile is not None
with open(args.genfile, 'w') as f: with open(args.genfile, 'w') as f:
for datum in generate_samples_unconditional(model, tokenizer, args): for datum in generate_samples_unconditional(model):
f.write(json.dumps(datum)+'\n') f.write(json.dumps(datum)+'\n')
def pad_batch(batch, tokenizer, args):
pad_id = tokenizer.get_command('pad').Id def pad_batch(batch, pad_id, args):
context_lengths = [] context_lengths = []
for tokens in batch: for tokens in batch:
context_length = len(tokens) context_length = len(tokens)
if context_length < args.seq_length: if context_length < args.seq_length:
tokens.extend([pad_id]*(args.seq_length-context_length)) tokens.extend([pad_id]*(args.seq_length - context_length))
context_lengths.append(context_length) context_lengths.append(context_length)
return batch, context_lengths return batch, context_lengths
def get_token_stream(model, context_tokens, tokenizer, args):
pad_id = tokenizer.get_command('pad').Id def get_token_stream(model, context_tokens):
# context_length = len(context_tokens)
# if context_length < args.seq_length: args = get_args()
# context_tokens = context_tokens + [pad_id] * (args.seq_length - context_length) tokenizer = get_tokenizer()
context_tokens, context_lengths = pad_batch(context_tokens, tokenizer, args)
context_tokens, context_lengths = pad_batch(context_tokens,
tokenizer.eod, args)
context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor(context_lengths) context_length_tensor = torch.cuda.LongTensor(context_lengths)
# context_length_tensor = torch.cuda.LongTensor([context_length])
torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(context_length_tensor,
torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
torch.distributed.broadcast(context_tokens_tensor,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
context_length = context_length_tensor.min().item() context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids=get_batch(context_tokens_tensor, args) tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
counter = 0 batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
org_context_length = context_length context_length_tensor,
attention_mask, position_ids)
layer_past = None
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids, tokenizer, args)
for tokens, lengths in batch_token_iterator: for tokens, lengths in batch_token_iterator:
context_length += 1 context_length += 1
yield tokens[:, :context_length], lengths yield tokens[:, :context_length], lengths
def switch(val1, val2, boolean): def switch(val1, val2, boolean):
boolean = boolean.type_as(val1) boolean = boolean.type_as(val1)
return (1-boolean)*val1 + boolean*val2 return (1 - boolean) * val1 + boolean * val2
def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None, type_ids=None):
actual_model = model def sample_sequence_batch(model, context_tokens, context_lengths,
if isinstance(actual_model, DDP): attention_mask, position_ids,
actual_model = actual_model.module maxlen=None, type_ids=None):
if isinstance(actual_model, FP16_Module):
actual_model = actual_model.module args = get_args()
original_output_parallel = actual_model.parallel_output tokenizer = get_tokenizer()
actual_model.parallel_output = False
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
context_length = context_lengths.min().item() context_length = context_lengths.min().item()
eos_id = tokenizer.get_command('eos').Id eos_id = tokenizer.eod
counter = 0 counter = 0
org_context_length = context_length org_context_length = context_length
...@@ -391,11 +346,15 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask ...@@ -391,11 +346,15 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
maxlen = org_context_length + args.out_seq_length maxlen = org_context_length + args.out_seq_length
lengths = torch.ones([batch_size]).long().cuda()*maxlen lengths = torch.ones([batch_size]).long().cuda()*maxlen
while context_length <= (maxlen): while context_length <= (maxlen):
if args.recompute: if args.recompute:
logits = model(tokens, position_ids, attention_mask, tokentype_ids=type_ids) logits = model(tokens,
position_ids,
attention_mask,
tokentype_ids=type_ids,
forward_method_parallel_output=False)
logits = logits[:, context_length - 1, :] logits = logits[:, context_length - 1, :]
else: else:
types2use = None types2use = None
...@@ -405,113 +364,48 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask ...@@ -405,113 +364,48 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
if type_ids is not None: if type_ids is not None:
types2use = type_ids[:, :context_length] types2use = type_ids[:, :context_length]
else: else:
tokens2use = tokens[:, context_length - 1].view(batch_size, -1) tokens2use = tokens[:, context_length - 1].view(
positions2use = position_ids[:, context_length - 1].view(batch_size, -1) batch_size, -1)
positions2use = position_ids[:, context_length - 1].view(
batch_size, -1)
if type_ids is not None: if type_ids is not None:
types2use = type_ids[:, context_length - 1].view(batch_size, -1) types2use = type_ids[:, context_length - 1].view(
logits, layer_past = model(tokens2use, positions2use, attention_mask, layer_past=layer_past, get_key_value=True, tokentype_ids=types2use) batch_size, -1)
logits = logits[:, -1].view(batch_size,-1).contiguous() logits, layer_past = model(tokens2use,
positions2use,
attention_mask,
layer_past=layer_past,
get_key_value=True,
tokentype_ids=types2use,
forward_method_parallel_output=False)
logits = logits[:, -1].view(batch_size, -1).contiguous()
if args.greedy: if args.greedy:
prev = torch.argmax(logits, dim=-1).view(-1) prev = torch.argmax(logits, dim=-1).view(-1)
else: else:
logits = logits.float() logits = logits.float()
logits /= args.temperature logits /= args.temperature
logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) logits = top_k_logits(logits, top_k=args.top_k,
top_p=args.top_p)
log_probs = F.softmax(logits, dim=-1) log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1) prev = torch.multinomial(log_probs, num_samples=1).view(-1)
print_logits = [] print_logits = []
for p in prev: for p in prev:
print_logits.append([logits[i, p].item() for i in range(batch_size)]) print_logits.append([logits[i, p].item()
for i in range(batch_size)])
started = context_lengths <= context_length started = context_lengths <= context_length
tokens[:, context_length] = switch(tokens[:, context_length].view(-1), prev, started) tokens[:, context_length] = switch(
tokens[:, context_length].view(-1), prev, started)
context_length += 1 context_length += 1
counter += 1 counter += 1
done_token = (prev == eos_id).byte() & started.byte() done_token = (prev == eos_id).byte() & started.byte()
just_finished = (done_token & ~is_done).bool() just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length lengths[just_finished.view(-1)] = context_length
was_done = is_done
is_done = is_done | done_token is_done = is_done | done_token
done = torch.all(is_done) done = torch.all(is_done)
yield tokens, lengths yield tokens, lengths
if done: if done:
break break
actual_model.parallel_output = original_output_parallel
def prepare_tokenizer(args):
tokenizer_args = {
'tokenizer_type': args.tokenizer_type,
'corpus': None,
'model_path': args.tokenizer_path,
'vocab_size': args.vocab_size,
'model_type': args.tokenizer_model_type,
'cache_dir': args.cache_dir}
tokenizer = make_tokenizer(**tokenizer_args)
args.tokenizer_num_tokens = tokenizer.num_tokens
args.tokenizer_num_type_tokens = tokenizer.num_type_tokens
args.eod_token = tokenizer.get_command('eos').Id
after = tokenizer.num_tokens
multiple = args.make_vocab_size_divisible_by * \
mpu.get_model_parallel_world_size()
if multiple != 0:
while (after % multiple) != 0:
after += 1
args.vocab_size = after
print("prepare tokenizer done", flush=True)
return tokenizer
def main():
"""Main training program."""
print('Generate Samples')
# Disable CuDNN.
torch.backends.cudnn.enabled = False
# Timer.
timers = Timers()
# Arguments.
args = get_args()
# Pytorch distributed.
initialize_distributed(args)
# Random seeds for reproducability.
set_random_seed(args.seed)
#get the tokenizer
tokenizer = prepare_tokenizer(args)
# Model, optimizer, and learning rate.
model = setup_model(args)
#setting default batch size to 1
# args.batch_size = 1
args.device = torch.cuda.current_device()
#generate samples
if args.num_samples == 0:
args.batch_size = 1
if args.sample_input_file != "":
generate_samples_input_from_file(model, tokenizer, args)
else:
generate_samples_interactive(model, tokenizer, args)
else:
write_and_generate_samples_unconditional(model, tokenizer, args)
if __name__ == "__main__":
main()
...@@ -91,6 +91,10 @@ class AbstractTokenizer(ABC): ...@@ -91,6 +91,10 @@ class AbstractTokenizer(ABC):
def tokenize(self, text): def tokenize(self, text):
pass pass
def detokenize(self, token_ids):
raise NotImplementedError('detokenizer is not implemented for {} '
'tokenizer'.format(self.name))
@property @property
def cls(self): def cls(self):
raise NotImplementedError('CLS is not provided for {} ' raise NotImplementedError('CLS is not provided for {} '
...@@ -190,6 +194,9 @@ class _GPT2BPETokenizer(AbstractTokenizer): ...@@ -190,6 +194,9 @@ class _GPT2BPETokenizer(AbstractTokenizer):
def tokenize(self, text): def tokenize(self, text):
return self.tokenizer.encode(text) return self.tokenizer.encode(text)
def detokenize(self, token_ids):
return self.tokenizer.decode(token_ids)
@property @property
def eod(self): def eod(self):
return self.eod_id return self.eod_id
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment