Commit a6ba254f authored by Mohammad's avatar Mohammad
Browse files

generate samples linted

parent a19820b1
...@@ -15,34 +15,27 @@ ...@@ -15,34 +15,27 @@
"""Sample Generate GPT2""" """Sample Generate GPT2"""
import os
import random
import json
import copy import copy
import numpy as np import json
import os
import time
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import argparse
import time
from arguments import get_args
from megatron.utils import Timers
from megatron.utils import initialize_distributed
from megatron.utils import set_random_seed
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import load_checkpoint
from megatron.data_utils import make_tokenizer
from configure_data import configure_data
from megatron import mpu
from megatron.fp16 import FP16_Module from megatron import get_args
from megatron.model import GPT2Model from megatron import get_tokenizer
from megatron.model import DistributedDataParallel as DDP from megatron import mpu
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPT2Model
from megatron.training import get_model
from megatron.utils import get_ltor_masks_and_position_ids
def model_provider(): def model_provider():
"""Build the model.""" """Build the model."""
args = get_args()
print_rank_0('building GPT2 model ...') print_rank_0('building GPT2 model ...')
model = GPT2Model(num_tokentypes=0, parallel_output=False) model = GPT2Model(num_tokentypes=0, parallel_output=False)
...@@ -56,7 +49,7 @@ def get_batch(context_tokens): ...@@ -56,7 +49,7 @@ def get_batch(context_tokens):
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
# Move to GPU. # Move to GPU.
tokens = context_tokens.view(args.batch_size, -1)..contiguous().cuda() tokens = context_tokens.view(args.batch_size, -1).contiguous().cuda()
# Get the attention mask and postition ids. # Get the attention mask and postition ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids( attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens, tokens,
...@@ -80,7 +73,7 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): ...@@ -80,7 +73,7 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
# last token of the top-k # last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value logits[indices_to_remove] = filter_value
if top_p > 0.0: if top_p > 0.0:
# Cconvert to 1D # Cconvert to 1D
sorted_logits, sorted_indices = torch.sort( sorted_logits, sorted_indices = torch.sort(
...@@ -98,12 +91,12 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): ...@@ -98,12 +91,12 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
for i in range(sorted_indices.size(0)): for i in range(sorted_indices.size(0)):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value logits[i][indices_to_remove] = filter_value
return logits return logits
def generate_samples_input_from_file(model): def generate_samples_input_from_file(model):
"""XXX"""
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
...@@ -118,15 +111,15 @@ def generate_samples_input_from_file(model): ...@@ -118,15 +111,15 @@ def generate_samples_input_from_file(model):
if args.sample_output_file is None: if args.sample_output_file is None:
sample_output_file = args.sample_input_file + ".out" sample_output_file = args.sample_input_file + ".out"
print('could not find `sample-output-file`, setting ' print('could not find `sample-output-file`, setting '
'it to {}'.formatsample_output_file()) 'it to {}'.format(sample_output_file))
fname_out = open(sample_output_file, "w+") fname_out = open(sample_output_file, "w+")
context_count=0 context_count = 0
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
while True: while True:
torch.distributed.barrier(group=mpu.get_model_parallel_group()) torch.distributed.barrier(group=mpu.get_model_parallel_group())
terminate_runs=0 terminate_runs = 0
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
raw_text = all_raw_text[input_pos] raw_text = all_raw_text[input_pos]
...@@ -148,7 +141,7 @@ def generate_samples_input_from_file(model): ...@@ -148,7 +141,7 @@ def generate_samples_input_from_file(model):
else: else:
context_tokens = tokenizer.tokenize("EMPTY TEXT") context_tokens = tokenizer.tokenize("EMPTY TEXT")
context_length = len(context_tokens) context_length = len(context_tokens)
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor, torch.distributed.broadcast(terminate_runs_tensor,
mpu.get_model_parallel_src_rank(), mpu.get_model_parallel_src_rank(),
...@@ -158,9 +151,8 @@ def generate_samples_input_from_file(model): ...@@ -158,9 +151,8 @@ def generate_samples_input_from_file(model):
if terminate_runs == 1: if terminate_runs == 1:
return return
start_time = time.time()
token_stream = get_token_stream(model, [context_tokens]) token_stream = get_token_stream(model, [context_tokens])
for counter, decode_tokens in enumerate(token_stream): for _, decode_tokens in enumerate(token_stream):
decode_tokens, _ = decode_tokens decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist() decode_tokens = decode_tokens[0].cpu().numpy().tolist()
...@@ -176,24 +168,24 @@ def generate_samples_input_from_file(model): ...@@ -176,24 +168,24 @@ def generate_samples_input_from_file(model):
fname_out.write("\n\nMegatron-LM:") fname_out.write("\n\nMegatron-LM:")
fname_out.write(trim_decode_tokens) fname_out.write(trim_decode_tokens)
fname_out.write("\n") fname_out.write("\n")
raw_text = None raw_text = None
torch.distributed.barrier(group=mpu.get_model_parallel_group()) torch.distributed.barrier(group=mpu.get_model_parallel_group())
context_count += 1 context_count += 1
def generate_samples_interactive(model, print_frequency=24): def generate_samples_interactive(model, print_frequency=24):
"""XXX"""
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
context_count=0 context_count = 0
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
while True: while True:
torch.distributed.barrier(group=mpu.get_model_parallel_group()) torch.distributed.barrier(group=mpu.get_model_parallel_group())
terminate_runs=0 terminate_runs = 0
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
os.system('clear') os.system('clear')
...@@ -201,7 +193,7 @@ def generate_samples_interactive(model, print_frequency=24): ...@@ -201,7 +193,7 @@ def generate_samples_interactive(model, print_frequency=24):
while not raw_text: while not raw_text:
print('Prompt should not be empty!') print('Prompt should not be empty!')
raw_text = input("\nContext prompt (stop to exit) >>> ") raw_text = input("\nContext prompt (stop to exit) >>> ")
if "stop" in raw_text: if "stop" in raw_text:
terminate_runs = 1 terminate_runs = 1
else: else:
...@@ -216,7 +208,7 @@ def generate_samples_interactive(model, print_frequency=24): ...@@ -216,7 +208,7 @@ def generate_samples_interactive(model, print_frequency=24):
else: else:
context_tokens = tokenizer.tokenize("EMPTY TEXT") context_tokens = tokenizer.tokenize("EMPTY TEXT")
context_length = len(context_tokens) context_length = len(context_tokens)
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor, torch.distributed.broadcast(terminate_runs_tensor,
mpu.get_model_parallel_src_rank(), mpu.get_model_parallel_src_rank(),
...@@ -226,7 +218,6 @@ def generate_samples_interactive(model, print_frequency=24): ...@@ -226,7 +218,6 @@ def generate_samples_interactive(model, print_frequency=24):
if terminate_runs == 1: if terminate_runs == 1:
return return
start_time = time.time()
token_stream = get_token_stream(model, [context_tokens]) token_stream = get_token_stream(model, [context_tokens])
for counter, decode_tokens in enumerate(token_stream): for counter, decode_tokens in enumerate(token_stream):
decode_tokens, _ = decode_tokens decode_tokens, _ = decode_tokens
...@@ -250,20 +241,19 @@ def generate_samples_interactive(model, print_frequency=24): ...@@ -250,20 +241,19 @@ def generate_samples_interactive(model, print_frequency=24):
raw_text = None raw_text = None
torch.distributed.barrier(group=mpu.get_model_parallel_group()) torch.distributed.barrier(group=mpu.get_model_parallel_group())
context_count += 1 context_count += 1
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
input("\nPress any key to continue >>>") input("\nPress any key to continue >>>")
def generate_samples_unconditional(model): def generate_samples_unconditional(model):
"""XXX"""
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
num_samples = args.num_samples num_samples = args.num_samples
context_tokens = [[tokenizer.eod] context_tokens = [[tokenizer.eod]
for _ in range(args.batch_size)] for _ in range(args.batch_size)]
samples = []
ctr = 0 ctr = 0
while True: while True:
start_time = time.time() start_time = time.time()
...@@ -291,6 +281,7 @@ def generate_samples_unconditional(model): ...@@ -291,6 +281,7 @@ def generate_samples_unconditional(model):
def write_and_generate_samples_unconditional(model): def write_and_generate_samples_unconditional(model):
args = get_args() args = get_args()
assert args.genfile is not None assert args.genfile is not None
with open(args.genfile, 'w') as f: with open(args.genfile, 'w') as f:
...@@ -298,8 +289,8 @@ def write_and_generate_samples_unconditional(model): ...@@ -298,8 +289,8 @@ def write_and_generate_samples_unconditional(model):
f.write(json.dumps(datum)+'\n') f.write(json.dumps(datum)+'\n')
def pad_batch(batch, tokenizer, args): def pad_batch(batch, pad_id, args):
pad_id = tokenizer.eod
context_lengths = [] context_lengths = []
for tokens in batch: for tokens in batch:
context_length = len(tokens) context_length = len(tokens)
...@@ -310,11 +301,12 @@ def pad_batch(batch, tokenizer, args): ...@@ -310,11 +301,12 @@ def pad_batch(batch, tokenizer, args):
def get_token_stream(model, context_tokens): def get_token_stream(model, context_tokens):
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
pad_id = tokenizer.eod context_tokens, context_lengths = pad_batch(context_tokens,
context_tokens, context_lengths = pad_batch(context_tokens, tokenizer, args) tokenizer.eod, args)
context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor(context_lengths) context_length_tensor = torch.cuda.LongTensor(context_lengths)
...@@ -327,12 +319,7 @@ def get_token_stream(model, context_tokens): ...@@ -327,12 +319,7 @@ def get_token_stream(model, context_tokens):
group=mpu.get_model_parallel_group()) group=mpu.get_model_parallel_group())
context_length = context_length_tensor.min().item() context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids=get_batch(context_tokens_tensor, args) tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args)
counter = 0
org_context_length = context_length
layer_past = None
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
context_length_tensor, context_length_tensor,
...@@ -343,21 +330,22 @@ def get_token_stream(model, context_tokens): ...@@ -343,21 +330,22 @@ def get_token_stream(model, context_tokens):
def switch(val1, val2, boolean): def switch(val1, val2, boolean):
boolean = boolean.type_as(val1) boolean = boolean.type_as(val1)
return (1-boolean)*val1 + boolean*val2 return (1 - boolean) * val1 + boolean * val2
def sample_sequence_batch(model, context_tokens, context_lengths, def sample_sequence_batch(model, context_tokens, context_lengths,
attention_mask, position_ids, attention_mask, position_ids,
maxlen=None, type_ids=None): maxlen=None, type_ids=None):
"""XXX"""
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
context_length = context_lengths.min().item() context_length = context_lengths.min().item()
eos_id = tokenizer.get_command('eos').Id eos_id = tokenizer.eod
counter = 0 counter = 0
org_context_length = context_length org_context_length = context_length
...@@ -372,7 +360,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -372,7 +360,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
maxlen = org_context_length + args.out_seq_length maxlen = org_context_length + args.out_seq_length
lengths = torch.ones([batch_size]).long().cuda()*maxlen lengths = torch.ones([batch_size]).long().cuda()*maxlen
while context_length <= (maxlen): while context_length <= (maxlen):
if args.recompute: if args.recompute:
...@@ -404,7 +392,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -404,7 +392,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
get_key_value=True, get_key_value=True,
tokentype_ids=types2use, tokentype_ids=types2use,
forward_method_parallel_output=False) forward_method_parallel_output=False)
logits = logits[:, -1].view(batch_size,-1).contiguous() logits = logits[:, -1].view(batch_size, -1).contiguous()
if args.greedy: if args.greedy:
prev = torch.argmax(logits, dim=-1).view(-1) prev = torch.argmax(logits, dim=-1).view(-1)
...@@ -429,7 +417,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -429,7 +417,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
done_token = (prev == eos_id).byte() & started.byte() done_token = (prev == eos_id).byte() & started.byte()
just_finished = (done_token & ~is_done).bool() just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length lengths[just_finished.view(-1)] = context_length
was_done = is_done
is_done = is_done | done_token is_done = is_done | done_token
done = torch.all(is_done) done = torch.all(is_done)
...@@ -438,56 +425,59 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -438,56 +425,59 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
break break
def add_text_generate_args(parser): def add_text_generate_args(parser):
"""Text generate arguments.""" """Text generation arguments."""
group = parser.add_argument_group(title='text generation')
group = parser.add_argument_group('Text generation', 'configurations')
group.add_argument("--temperature", type=float, default=1.0) group.add_argument("--temperature", type=float, default=1.0,
group.add_argument("--greedy", action='store_true', default=False) help='Sampling temperature.')
group.add_argument("--top_p", type=float, default=0.0) group.add_argument("--greedy", action='store_true', default=False,
group.add_argument("--top_k", type=int, default=0) help='Use greedy sampling.')
group.add_argument("--out-seq-length", type=int, default=1024) group.add_argument("--top_p", type=float, default=0.0,
help='Top p sampling.')
group.add_argument("--top_k", type=int, default=0,
help='Top k sampling.')
group.add_argument("--out-seq-length", type=int, default=1024,
help='Size of the output generated text.')
group.add_argument("--sample-input-file", type=str, default=None, group.add_argument("--sample-input-file", type=str, default=None,
help='get input from file instead of interactive mode, ' help='Get input from file instead of interactive mode, '
'each line is an input' ) 'each line is an input.')
group.add_argument("--sample-output-file", type=str, default=None, group.add_argument("--sample-output-file", type=str, default=None,
help='output file got from --sample-input-file') help='Output file got from --sample-input-file')
group.add_argument("--num-samples", type=int, default=0, group.add_argument("--num-samples", type=int, default=0,
help='number of samples to generate unconditionally, ' help='Number of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling') 'defaults to 0 and interactive conditional sampling')
group.add_argument("--genfile", type=str, group.add_argument("--genfile", type=str,
help='output file when generating unconditionally') help='Output file when generating unconditionally')
group.add_argument("--recompute", action='store_true', group.add_argument("--recompute", action='store_true',
help='during generation recompute all attention ' help='During generation recompute all attention '
'instead of using previously computed keys/values.') 'instead of using previously computed keys/values.')
return parser return parser
def main(): def main():
"""Main program.""" """Main program."""
print('Generate Samples')
initialize_megatron(extra_args_provider=add_text_generate_args, initialize_megatron(extra_args_provider=add_text_generate_args,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
# Set up model and load checkpoint. # Set up model and load checkpoint.
model = get_model(model_provider) model = get_model(model_provider)
args = get_args()
if args.load is not None: if args.load is not None:
_ = load_checkpoint(model, None, None) _ = load_checkpoint(model, None, None)
#generate samples # Generate samples.
if args.num_samples == 0: if args.num_samples == 0:
args.batch_size = 1 assert args.batch_size == 1
if args.sample_input_file != "": if args.sample_input_file != "":
generate_samples_input_from_file(model) generate_samples_input_from_file(model)
else: else:
generate_samples_interactive(model) generate_samples_interactive(model)
else: else:
write_and_generate_samples_unconditional(model) write_and_generate_samples_unconditional(model)
if __name__ == "__main__":
main()
if __name__ == "__main__":
main()
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment