Commit a6ba254f authored by Mohammad's avatar Mohammad
Browse files

generate samples linted

parent a19820b1
......@@ -15,34 +15,27 @@
"""Sample Generate GPT2"""
import os
import random
import json
import copy
import numpy as np
import json
import os
import time
import torch
import torch.nn.functional as F
import argparse
import time
from arguments import get_args
from megatron.utils import Timers
from megatron.utils import initialize_distributed
from megatron.utils import set_random_seed
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import load_checkpoint
from megatron.data_utils import make_tokenizer
from configure_data import configure_data
from megatron import mpu
from megatron.fp16 import FP16_Module
from megatron.model import GPT2Model
from megatron.model import DistributedDataParallel as DDP
from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPT2Model
from megatron.training import get_model
from megatron.utils import get_ltor_masks_and_position_ids
def model_provider():
"""Build the model."""
args = get_args()
print_rank_0('building GPT2 model ...')
model = GPT2Model(num_tokentypes=0, parallel_output=False)
......@@ -56,7 +49,7 @@ def get_batch(context_tokens):
tokenizer = get_tokenizer()
# Move to GPU.
tokens = context_tokens.view(args.batch_size, -1)..contiguous().cuda()
tokens = context_tokens.view(args.batch_size, -1).contiguous().cuda()
# Get the attention mask and postition ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens,
......@@ -103,7 +96,7 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
def generate_samples_input_from_file(model):
"""XXX"""
args = get_args()
tokenizer = get_tokenizer()
......@@ -118,15 +111,15 @@ def generate_samples_input_from_file(model):
if args.sample_output_file is None:
sample_output_file = args.sample_input_file + ".out"
print('could not find `sample-output-file`, setting '
'it to {}'.formatsample_output_file())
'it to {}'.format(sample_output_file))
fname_out = open(sample_output_file, "w+")
context_count=0
context_count = 0
model.eval()
with torch.no_grad():
while True:
torch.distributed.barrier(group=mpu.get_model_parallel_group())
terminate_runs=0
terminate_runs = 0
if mpu.get_model_parallel_rank() == 0:
raw_text = all_raw_text[input_pos]
......@@ -158,9 +151,8 @@ def generate_samples_input_from_file(model):
if terminate_runs == 1:
return
start_time = time.time()
token_stream = get_token_stream(model, [context_tokens])
for counter, decode_tokens in enumerate(token_stream):
for _, decode_tokens in enumerate(token_stream):
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
......@@ -184,16 +176,16 @@ def generate_samples_input_from_file(model):
def generate_samples_interactive(model, print_frequency=24):
"""XXX"""
args = get_args()
tokenizer = get_tokenizer()
context_count=0
context_count = 0
model.eval()
with torch.no_grad():
while True:
torch.distributed.barrier(group=mpu.get_model_parallel_group())
terminate_runs=0
terminate_runs = 0
if mpu.get_model_parallel_rank() == 0:
os.system('clear')
......@@ -226,7 +218,6 @@ def generate_samples_interactive(model, print_frequency=24):
if terminate_runs == 1:
return
start_time = time.time()
token_stream = get_token_stream(model, [context_tokens])
for counter, decode_tokens in enumerate(token_stream):
decode_tokens, _ = decode_tokens
......@@ -256,14 +247,13 @@ def generate_samples_interactive(model, print_frequency=24):
def generate_samples_unconditional(model):
"""XXX"""
args = get_args()
tokenizer = get_tokenizer()
num_samples = args.num_samples
context_tokens = [[tokenizer.eod]
for _ in range(args.batch_size)]
samples = []
ctr = 0
while True:
start_time = time.time()
......@@ -291,6 +281,7 @@ def generate_samples_unconditional(model):
def write_and_generate_samples_unconditional(model):
args = get_args()
assert args.genfile is not None
with open(args.genfile, 'w') as f:
......@@ -298,8 +289,8 @@ def write_and_generate_samples_unconditional(model):
f.write(json.dumps(datum)+'\n')
def pad_batch(batch, tokenizer, args):
pad_id = tokenizer.eod
def pad_batch(batch, pad_id, args):
context_lengths = []
for tokens in batch:
context_length = len(tokens)
......@@ -310,11 +301,12 @@ def pad_batch(batch, tokenizer, args):
def get_token_stream(model, context_tokens):
args = get_args()
tokenizer = get_tokenizer()
pad_id = tokenizer.eod
context_tokens, context_lengths = pad_batch(context_tokens, tokenizer, args)
context_tokens, context_lengths = pad_batch(context_tokens,
tokenizer.eod, args)
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor(context_lengths)
......@@ -327,12 +319,7 @@ def get_token_stream(model, context_tokens):
group=mpu.get_model_parallel_group())
context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids=get_batch(context_tokens_tensor, args)
counter = 0
org_context_length = context_length
layer_past = None
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args)
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
context_length_tensor,
......@@ -343,21 +330,22 @@ def get_token_stream(model, context_tokens):
def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
return (1-boolean)*val1 + boolean*val2
return (1 - boolean) * val1 + boolean * val2
def sample_sequence_batch(model, context_tokens, context_lengths,
attention_mask, position_ids,
maxlen=None, type_ids=None):
"""XXX"""
args = get_args()
tokenizer = get_tokenizer()
model.eval()
with torch.no_grad():
context_length = context_lengths.min().item()
eos_id = tokenizer.get_command('eos').Id
eos_id = tokenizer.eod
counter = 0
org_context_length = context_length
......@@ -404,7 +392,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
get_key_value=True,
tokentype_ids=types2use,
forward_method_parallel_output=False)
logits = logits[:, -1].view(batch_size,-1).contiguous()
logits = logits[:, -1].view(batch_size, -1).contiguous()
if args.greedy:
prev = torch.argmax(logits, dim=-1).view(-1)
......@@ -429,7 +417,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
done_token = (prev == eos_id).byte() & started.byte()
just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length
was_done = is_done
is_done = is_done | done_token
done = torch.all(is_done)
......@@ -438,46 +425,51 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
break
def add_text_generate_args(parser):
"""Text generate arguments."""
group = parser.add_argument_group('Text generation', 'configurations')
group.add_argument("--temperature", type=float, default=1.0)
group.add_argument("--greedy", action='store_true', default=False)
group.add_argument("--top_p", type=float, default=0.0)
group.add_argument("--top_k", type=int, default=0)
group.add_argument("--out-seq-length", type=int, default=1024)
"""Text generation arguments."""
group = parser.add_argument_group(title='text generation')
group.add_argument("--temperature", type=float, default=1.0,
help='Sampling temperature.')
group.add_argument("--greedy", action='store_true', default=False,
help='Use greedy sampling.')
group.add_argument("--top_p", type=float, default=0.0,
help='Top p sampling.')
group.add_argument("--top_k", type=int, default=0,
help='Top k sampling.')
group.add_argument("--out-seq-length", type=int, default=1024,
help='Size of the output generated text.')
group.add_argument("--sample-input-file", type=str, default=None,
help='get input from file instead of interactive mode, '
'each line is an input' )
help='Get input from file instead of interactive mode, '
'each line is an input.')
group.add_argument("--sample-output-file", type=str, default=None,
help='output file got from --sample-input-file')
help='Output file got from --sample-input-file')
group.add_argument("--num-samples", type=int, default=0,
help='number of samples to generate unconditionally, '
help='Number of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling')
group.add_argument("--genfile", type=str,
help='output file when generating unconditionally')
help='Output file when generating unconditionally')
group.add_argument("--recompute", action='store_true',
help='during generation recompute all attention '
help='During generation recompute all attention '
'instead of using previously computed keys/values.')
return parser
def main():
"""Main program."""
print('Generate Samples')
initialize_megatron(extra_args_provider=add_text_generate_args,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
# Set up model and load checkpoint.
model = get_model(model_provider)
args = get_args()
if args.load is not None:
_ = load_checkpoint(model, None, None)
#generate samples
# Generate samples.
if args.num_samples == 0:
args.batch_size = 1
assert args.batch_size == 1
if args.sample_input_file != "":
generate_samples_input_from_file(model)
else:
......@@ -487,7 +479,5 @@ def main():
if __name__ == "__main__":
main()
main()
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment