Commit ce29d4d5 authored by Mohammad's avatar Mohammad
Browse files

working on refactoring text generation

parent a0bcee94
...@@ -39,117 +39,87 @@ from megatron.model import GPT2Model ...@@ -39,117 +39,87 @@ from megatron.model import GPT2Model
from megatron.model import DistributedDataParallel as DDP from megatron.model import DistributedDataParallel as DDP
from megatron import print_rank_0 from megatron import print_rank_0
def get_model(args):
def model_provider():
"""Build the model.""" """Build the model."""
args = get_args()
print_rank_0('building GPT2 model ...') print_rank_0('building GPT2 model ...')
model = GPT2Model(num_layers=args.num_layers, model = GPT2Model(num_tokentypes=0, parallel_output=False)
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout,
max_sequence_length=args.max_position_embeddings,
checkpoint_activations=args.checkpoint_activations,
checkpoint_num_layers=args.checkpoint_num_layers,
parallel_output=False)
if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on model parallel rank {}: {}'.format(
mpu.get_model_parallel_rank(),
sum([p.nelement() for p in model.parameters()])), flush=True)
# GPU allocation.
model.cuda(torch.cuda.current_device())
# Fp16 conversion.
if args.fp16:
model = FP16_Module(model)
# Wrap model for distributed training.
model = DDP(model)
return model return model
def setup_model(args):
"""Setup model and optimizer."""
model = get_model(args)
if args.load is not None: def get_batch(context_tokens):
_ = load_checkpoint( """Generate batch from context tokens."""
model, None, None, args) args = get_args()
tokenizer = get_tokenizer()
return model
def get_batch(context_tokens, args):
tokens = context_tokens
tokens = tokens.view(args.batch_size, -1).contiguous()
device = args.device
tokens = tokens.to(device)
# Get the masks and postition ids. # Move to GPU.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens = context_tokens.view(args.batch_size, -1)..contiguous().cuda()
# Get the attention mask and postition ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens, tokens,
args.eod_token, tokenizer.eod,
args.reset_position_ids, args.reset_position_ids,
args.reset_attention_mask, args.reset_attention_mask,
False) args.eod_mask_loss,
args.fp16)
# Fp16 conversion.
if args.fp16:
attention_mask = attention_mask.half()
return tokens, attention_mask, position_ids return tokens, attention_mask, position_ids
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
# This function has been mostly taken from huggingface conversational ai code at """ This function has been mostly taken from huggingface conversational
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
conversational-ai-with-transfer-learning-2d818ac26313 """
if top_k > 0: if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k # Remove all tokens with a probability less than the
# last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value logits[indices_to_remove] = filter_value
if top_p > 0.0: if top_p > 0.0:
#convert to 1D # Cconvert to 1D
# logits=logits.view(logits.size()[1]).contiguous() sorted_logits, sorted_indices = torch.sort(
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
dim=-1)
# Remove tokens with cumulative probability above the threshold # Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold # Shift the indices to the right to keep also the first token
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() # above the threshold
sorted_indices_to_remove[..., 1:] \
= sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0 sorted_indices_to_remove[..., 0] = 0
for i in range(sorted_indices.size(0)): for i in range(sorted_indices.size(0)):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value logits[i][indices_to_remove] = filter_value
#going back to 2D
# logits=logits.view(1, -1).contiguous()
return logits return logits
def generate_samples_input_from_file(model, tokenizer, args):
if args.sample_input_file == "": def generate_samples_input_from_file(model):
if mpu.get_model_parallel_rank() == 0: """XXX"""
print("args.sample_input_file CAN NOT BE empty!\n") args = get_args()
return tokenizer = get_tokenizer()
# Read the sample file and open the output file.
assert args.sample_input_file is not None, \
'sample input file is not provided.'
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
fname = open(args.sample_input_file, "r") fname = open(args.sample_input_file, "r")
all_raw_text = fname.readlines() all_raw_text = fname.readlines()
input_count = len(all_raw_text) input_count = len(all_raw_text)
input_pos = 0 input_pos = 0
if args.sample_output_file == "": if args.sample_output_file is None:
print("Argument: sample-output-file can't be empty, setting it to\n") sample_output_file = args.sample_input_file + ".out"
print("\t args.sample_input_file.out") print('could not find `sample-output-file`, setting '
args.sample_output_file = args.sample_input_file+".out" 'it to {}'.formatsample_output_file())
fname_out = open(args.sample_output_file, "w+") fname_out = open(sample_output_file, "w+")
context_count=0 context_count=0
model.eval() model.eval()
...@@ -167,46 +137,44 @@ def generate_samples_input_from_file(model, tokenizer, args): ...@@ -167,46 +137,44 @@ def generate_samples_input_from_file(model, tokenizer, args):
if "stop" in raw_text: if "stop" in raw_text:
terminate_runs = 1 terminate_runs = 1
else: else:
context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization context_tokens = tokenizer.tokenize(raw_text)
context_length = len(context_tokens) context_length = len(context_tokens)
if context_length >=args.seq_length//2: if context_length >= (args.seq_length // 2):
print("\nContext length", context_length, \ print("\nContext length", context_length, \
"\nPlease give smaller context (half of the sequence length)!") "\nPlease give smaller context (half of the "
"sequence length)!", flush=True)
continue continue
else: else:
context_tokens = tokenizer.EncodeAsIds("EMPTY TEXT").tokenization context_tokens = tokenizer.tokenize("EMPTY TEXT")
context_length = len(context_tokens) context_length = len(context_tokens)
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(terminate_runs_tensor,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item() terminate_runs = terminate_runs_tensor[0].item()
if terminate_runs == 1: if terminate_runs == 1:
return return
start_time = time.time() start_time = time.time()
token_stream = get_token_stream(model, [context_tokens], tokenizer, args) token_stream = get_token_stream(model, [context_tokens])
for counter, decode_tokens in enumerate(token_stream): for counter, decode_tokens in enumerate(token_stream):
# token_end = decode_tokens.find("<|endoftext|>")
# if token_end > 0:
# break
decode_tokens, _ = decode_tokens decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist() decode_tokens = decode_tokens[0].cpu().numpy().tolist()
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
os.system('clear') os.system('clear')
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True) print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)[len(raw_text):] trim_decode_tokens = tokenizer.detokenize(
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True) decode_tokens)[len(raw_text):]
print("\nMegatron-LM:", trim_decode_tokens, flush=True) print("\nMegatron-LM:", trim_decode_tokens, flush=True)
fname_out.write("\nContext:") fname_out.write("\nContext:")
fname_out.write(raw_text) fname_out.write(raw_text)
fname_out.write("\n\nMegatron-LM:") fname_out.write("\n\nMegatron-LM:")
fname_out.write(trim_decode_tokens) fname_out.write(trim_decode_tokens)
#fname_out.write(trim_decode_tokens.replace("\n", "\n\n"))
fname_out.write("\n") fname_out.write("\n")
raw_text = None raw_text = None
...@@ -214,9 +182,11 @@ def generate_samples_input_from_file(model, tokenizer, args): ...@@ -214,9 +182,11 @@ def generate_samples_input_from_file(model, tokenizer, args):
torch.distributed.barrier(group=mpu.get_model_parallel_group()) torch.distributed.barrier(group=mpu.get_model_parallel_group())
context_count += 1 context_count += 1
def generate_samples_interactive(model, tokenizer, args):
print_frequency = 24 def generate_samples_interactive(model, print_frequency=24):
"""XXX"""
args = get_args()
tokenizer = get_tokenizer()
context_count=0 context_count=0
model.eval() model.eval()
...@@ -235,79 +205,81 @@ def generate_samples_interactive(model, tokenizer, args): ...@@ -235,79 +205,81 @@ def generate_samples_interactive(model, tokenizer, args):
if "stop" in raw_text: if "stop" in raw_text:
terminate_runs = 1 terminate_runs = 1
else: else:
context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization context_tokens = tokenizer.tokenize(raw_text)
context_length = len(context_tokens) context_length = len(context_tokens)
if context_length >=args.seq_length//2: if context_length >= (args.seq_length // 2):
print("\nContext length", context_length, \ print("\nContext length", context_length, \
"\nPlease give smaller context (half of the sequence length)!") "\nPlease give smaller context (half of the "
"sequence length)!", flush=True)
continue continue
else: else:
context_tokens = tokenizer.EncodeAsIds("EMPTY TEXT").tokenization context_tokens = tokenizer.tokenize("EMPTY TEXT")
context_length = len(context_tokens) context_length = len(context_tokens)
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(terminate_runs_tensor,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item() terminate_runs = terminate_runs_tensor[0].item()
if terminate_runs == 1: if terminate_runs == 1:
return return
start_time = time.time() start_time = time.time()
token_stream = get_token_stream(model, [context_tokens], tokenizer, args) token_stream = get_token_stream(model, [context_tokens])
for counter, decode_tokens in enumerate(token_stream): for counter, decode_tokens in enumerate(token_stream):
# token_end = decode_tokens.find("<|endoftext|>")
# if token_end > 0:
# break
decode_tokens, _ = decode_tokens decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist() decode_tokens = decode_tokens[0].cpu().numpy().tolist()
if mpu.get_model_parallel_rank() == 0 and counter % print_frequency == 0: if mpu.get_model_parallel_rank() == 0 and \
counter % print_frequency == 0:
os.system('clear') os.system('clear')
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True) print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)[len(raw_text):] trim_decode_tokens = tokenizer.detokenize(
#print("\nGPT2:", trim_decode_tokens, flush=True) decode_tokens)[len(raw_text):]
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
print("\nMegatron-LM:", trim_decode_tokens, flush=True) print("\nMegatron-LM:", trim_decode_tokens, flush=True)
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
os.system('clear') os.system('clear')
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True) print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)[len(raw_text):] trim_decode_tokens = tokenizer.detokenize(
#print("\nGPT2:", trim_decode_tokens, flush=True) decode_tokens)[len(raw_text):]
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
print("\nMegatron-LM:", trim_decode_tokens, flush=True) print("\nMegatron-LM:", trim_decode_tokens, flush=True)
raw_text = None raw_text = None
torch.distributed.barrier(group=mpu.get_model_parallel_group()) torch.distributed.barrier(group=mpu.get_model_parallel_group())
context_count += 1 context_count += 1
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
input("\nPress any key to continue >>>") input("\nPress any key to continue >>>")
def generate_samples_unconditional(model, tokenizer, args):
def generate_samples_unconditional(model):
"""XXX"""
args = get_args()
tokenizer = get_tokenizer()
num_samples = args.num_samples num_samples = args.num_samples
context_tokens = [[tokenizer.get_command('pad').Id] for _ in range(args.batch_size)] context_tokens = [[tokenizer.eod]
for _ in range(args.batch_size)]
samples = [] samples = []
# with open(args.genfile, 'w') as f:
ctr = 0 ctr = 0
while True: while True:
start_time = time.time() start_time = time.time()
for token_stream in get_token_stream(model, copy.deepcopy(context_tokens), tokenizer, args): for token_stream in get_token_stream(model,
copy.deepcopy(context_tokens)):
pass pass
# token_stream = list(get_token_stream(model, copy.deepcopy(context_tokens), tokenizer, args))
if ctr%args.log_interval == 0: if ctr%args.log_interval == 0:
print('Avg s/batch:', (time.time()-start_time)/min(args.log_interval, ctr+1)) print('Avg s/batch:',
(time.time() - start_time) / min(args.log_interval, ctr + 1))
start_time = time.time() start_time = time.time()
length = len(token_stream) length = len(token_stream)
token_batch = token_stream[0].cpu().numpy().tolist() token_batch = token_stream[0].cpu().numpy().tolist()
length_batch = token_stream[1].cpu().numpy().tolist() length_batch = token_stream[1].cpu().numpy().tolist()
for tokens, length in zip(token_batch, length_batch): for tokens, length in zip(token_batch, length_batch):
tokens = tokens[1:length-1] tokens = tokens[1:length-1]
text = tokenizer.DecodeIds(tokens) text = tokenizer.detokenize(tokens)
is_finished = length < args.seq_length - 1 is_finished = length < args.seq_length - 1
datum = {'text': text, 'length': length-1, 'finished': is_finished} datum = {'text': text, 'length': length-1, 'finished': is_finished}
yield datum yield datum
...@@ -317,35 +289,42 @@ def generate_samples_unconditional(model, tokenizer, args): ...@@ -317,35 +289,42 @@ def generate_samples_unconditional(model, tokenizer, args):
if ctr >= num_samples: if ctr >= num_samples:
break break
def write_and_generate_samples_unconditional(model, tokenizer, args):
def write_and_generate_samples_unconditional(model):
args = get_args()
assert args.genfile is not None assert args.genfile is not None
with open(args.genfile, 'w') as f: with open(args.genfile, 'w') as f:
for datum in generate_samples_unconditional(model, tokenizer, args): for datum in generate_samples_unconditional(model):
f.write(json.dumps(datum)+'\n') f.write(json.dumps(datum)+'\n')
def pad_batch(batch, tokenizer, args): def pad_batch(batch, tokenizer, args):
pad_id = tokenizer.get_command('pad').Id pad_id = tokenizer.eod
context_lengths = [] context_lengths = []
for tokens in batch: for tokens in batch:
context_length = len(tokens) context_length = len(tokens)
if context_length < args.seq_length: if context_length < args.seq_length:
tokens.extend([pad_id]*(args.seq_length-context_length)) tokens.extend([pad_id]*(args.seq_length - context_length))
context_lengths.append(context_length) context_lengths.append(context_length)
return batch, context_lengths return batch, context_lengths
def get_token_stream(model, context_tokens, tokenizer, args):
pad_id = tokenizer.get_command('pad').Id def get_token_stream(model, context_tokens):
# context_length = len(context_tokens) args = get_args()
# if context_length < args.seq_length: tokenizer = get_tokenizer()
# context_tokens = context_tokens + [pad_id] * (args.seq_length - context_length)
pad_id = tokenizer.eod
context_tokens, context_lengths = pad_batch(context_tokens, tokenizer, args) context_tokens, context_lengths = pad_batch(context_tokens, tokenizer, args)
context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor(context_lengths) context_length_tensor = torch.cuda.LongTensor(context_lengths)
# context_length_tensor = torch.cuda.LongTensor([context_length])
torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(context_length_tensor,
torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
torch.distributed.broadcast(context_tokens_tensor,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
context_length = context_length_tensor.min().item() context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids=get_batch(context_tokens_tensor, args) tokens, attention_mask, position_ids=get_batch(context_tokens_tensor, args)
...@@ -355,7 +334,9 @@ def get_token_stream(model, context_tokens, tokenizer, args): ...@@ -355,7 +334,9 @@ def get_token_stream(model, context_tokens, tokenizer, args):
layer_past = None layer_past = None
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids, tokenizer, args) batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
context_length_tensor,
attention_mask, position_ids)
for tokens, lengths in batch_token_iterator: for tokens, lengths in batch_token_iterator:
context_length += 1 context_length += 1
yield tokens[:, :context_length], lengths yield tokens[:, :context_length], lengths
...@@ -365,14 +346,14 @@ def switch(val1, val2, boolean): ...@@ -365,14 +346,14 @@ def switch(val1, val2, boolean):
boolean = boolean.type_as(val1) boolean = boolean.type_as(val1)
return (1-boolean)*val1 + boolean*val2 return (1-boolean)*val1 + boolean*val2
def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None, type_ids=None):
actual_model = model def sample_sequence_batch(model, context_tokens, context_lengths,
if isinstance(actual_model, DDP): attention_mask, position_ids,
actual_model = actual_model.module maxlen=None, type_ids=None):
if isinstance(actual_model, FP16_Module): """XXX"""
actual_model = actual_model.module args = get_args()
original_output_parallel = actual_model.parallel_output tokenizer = get_tokenizer()
actual_model.parallel_output = False
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
context_length = context_lengths.min().item() context_length = context_lengths.min().item()
...@@ -395,7 +376,11 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask ...@@ -395,7 +376,11 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
while context_length <= (maxlen): while context_length <= (maxlen):
if args.recompute: if args.recompute:
logits = model(tokens, position_ids, attention_mask, tokentype_ids=type_ids) logits = model(tokens,
position_ids,
attention_mask,
tokentype_ids=type_ids,
forward_method_parallel_output=False)
logits = logits[:, context_length - 1, :] logits = logits[:, context_length - 1, :]
else: else:
types2use = None types2use = None
...@@ -405,11 +390,20 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask ...@@ -405,11 +390,20 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
if type_ids is not None: if type_ids is not None:
types2use = type_ids[:, :context_length] types2use = type_ids[:, :context_length]
else: else:
tokens2use = tokens[:, context_length - 1].view(batch_size, -1) tokens2use = tokens[:, context_length - 1].view(
positions2use = position_ids[:, context_length - 1].view(batch_size, -1) batch_size, -1)
positions2use = position_ids[:, context_length - 1].view(
batch_size, -1)
if type_ids is not None: if type_ids is not None:
types2use = type_ids[:, context_length - 1].view(batch_size, -1) types2use = type_ids[:, context_length - 1].view(
logits, layer_past = model(tokens2use, positions2use, attention_mask, layer_past=layer_past, get_key_value=True, tokentype_ids=types2use) batch_size, -1)
logits, layer_past = model(tokens2use,
positions2use,
attention_mask,
layer_past=layer_past,
get_key_value=True,
tokentype_ids=types2use,
forward_method_parallel_output=False)
logits = logits[:, -1].view(batch_size,-1).contiguous() logits = logits[:, -1].view(batch_size,-1).contiguous()
if args.greedy: if args.greedy:
...@@ -417,15 +411,18 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask ...@@ -417,15 +411,18 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
else: else:
logits = logits.float() logits = logits.float()
logits /= args.temperature logits /= args.temperature
logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) logits = top_k_logits(logits, top_k=args.top_k,
top_p=args.top_p)
log_probs = F.softmax(logits, dim=-1) log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1) prev = torch.multinomial(log_probs, num_samples=1).view(-1)
print_logits = [] print_logits = []
for p in prev: for p in prev:
print_logits.append([logits[i, p].item() for i in range(batch_size)]) print_logits.append([logits[i, p].item()
for i in range(batch_size)])
started = context_lengths <= context_length started = context_lengths <= context_length
tokens[:, context_length] = switch(tokens[:, context_length].view(-1), prev, started) tokens[:, context_length] = switch(
tokens[:, context_length].view(-1), prev, started)
context_length += 1 context_length += 1
counter += 1 counter += 1
...@@ -439,75 +436,54 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask ...@@ -439,75 +436,54 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
yield tokens, lengths yield tokens, lengths
if done: if done:
break break
actual_model.parallel_output = original_output_parallel
def prepare_tokenizer(args): def add_text_generate_args(parser):
"""Text generate arguments."""
group = parser.add_argument_group('Text generation', 'configurations')
group.add_argument("--temperature", type=float, default=1.0)
group.add_argument("--greedy", action='store_true', default=False)
group.add_argument("--top_p", type=float, default=0.0)
group.add_argument("--top_k", type=int, default=0)
group.add_argument("--out-seq-length", type=int, default=1024)
group.add_argument("--sample-input-file", type=str, default=None,
help='get input from file instead of interactive mode, '
'each line is an input' )
group.add_argument("--sample-output-file", type=str, default=None,
help='output file got from --sample-input-file')
group.add_argument("--num-samples", type=int, default=0,
help='number of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling')
group.add_argument("--genfile", type=str,
help='output file when generating unconditionally')
group.add_argument("--recompute", action='store_true',
help='during generation recompute all attention '
'instead of using previously computed keys/values.')
return parser
tokenizer_args = {
'tokenizer_type': args.tokenizer_type,
'corpus': None,
'model_path': args.tokenizer_path,
'vocab_size': args.vocab_size,
'model_type': args.tokenizer_model_type,
'cache_dir': args.cache_dir}
tokenizer = make_tokenizer(**tokenizer_args)
args.tokenizer_num_tokens = tokenizer.num_tokens
args.tokenizer_num_type_tokens = tokenizer.num_type_tokens
args.eod_token = tokenizer.get_command('eos').Id
after = tokenizer.num_tokens
multiple = args.make_vocab_size_divisible_by * \
mpu.get_model_parallel_world_size()
if multiple != 0:
while (after % multiple) != 0:
after += 1
args.vocab_size = after
print("prepare tokenizer done", flush=True)
return tokenizer
def main(): def main():
"""Main training program.""" """Main program."""
print('Generate Samples') print('Generate Samples')
# Disable CuDNN. initialize_megatron(extra_args_provider=add_text_generate_args,
torch.backends.cudnn.enabled = False args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
# Timer.
timers = Timers()
# Arguments.
args = get_args()
# Pytorch distributed.
initialize_distributed(args)
# Random seeds for reproducability.
set_random_seed(args.seed)
#get the tokenizer
tokenizer = prepare_tokenizer(args)
# Model, optimizer, and learning rate.
model = setup_model(args)
#setting default batch size to 1
# args.batch_size = 1
args.device = torch.cuda.current_device()
# Set up model and load checkpoint.
model = get_model(model_provider)
if args.load is not None:
_ = load_checkpoint(model, None, None)
#generate samples #generate samples
if args.num_samples == 0: if args.num_samples == 0:
args.batch_size = 1 args.batch_size = 1
if args.sample_input_file != "": if args.sample_input_file != "":
generate_samples_input_from_file(model, tokenizer, args) generate_samples_input_from_file(model)
else: else:
generate_samples_interactive(model, tokenizer, args) generate_samples_interactive(model)
else: else:
write_and_generate_samples_unconditional(model, tokenizer, args) write_and_generate_samples_unconditional(model)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -357,29 +357,7 @@ def _add_gpt2_args(parser): ...@@ -357,29 +357,7 @@ def _add_gpt2_args(parser):
def add_text_generate_args(parser):
"""Text generate arguments."""
group = parser.add_argument_group('Text generation', 'configurations')
group.add_argument("--temperature", type=float, default=1.0)
group.add_argument("--greedy", action='store_true', default=False)
group.add_argument("--top_p", type=float, default=0.0)
group.add_argument("--top_k", type=int, default=0)
group.add_argument("--out-seq-length", type=int, default=1024)
group.add_argument("--sample-input-file", type=str, default="",
help='get input from file instead of interactive mode, '
'each line is an input' )
group.add_argument("--sample-output-file", type=str, default="",
help='output file got from --sample-input-file')
group.add_argument("--num-samples", type=int, default=0,
help='number of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling')
group.add_argument("--genfile", type=str,
help='output file when generating unconditionally')
group.add_argument("--recompute", action='store_true',
help='during generation recompute all attention '
'instead of using previously computed keys/values.')
return parser
def add_data_args_(parser): def add_data_args_(parser):
......
...@@ -137,8 +137,7 @@ class BertModel(MegatronModule): ...@@ -137,8 +137,7 @@ class BertModel(MegatronModule):
self._binary_head_key = 'binary_head' self._binary_head_key = 'binary_head'
def forward(self, input_ids, attention_mask, def forward(self, input_ids, attention_mask, tokentype_ids=None):
tokentype_ids=None):
extended_attention_mask = bert_extended_attention_mask( extended_attention_mask = bert_extended_attention_mask(
attention_mask, next(self.language_model.parameters()).dtype) attention_mask, next(self.language_model.parameters()).dtype)
......
...@@ -51,7 +51,8 @@ class GPT2Model(MegatronModule): ...@@ -51,7 +51,8 @@ class GPT2Model(MegatronModule):
def forward(self, input_ids, position_ids, attention_mask, def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False): tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
# Language model. # Language model.
lm_output = self.language_model(input_ids, lm_output = self.language_model(input_ids,
...@@ -65,10 +66,13 @@ class GPT2Model(MegatronModule): ...@@ -65,10 +66,13 @@ class GPT2Model(MegatronModule):
lm_output, presents = lm_output lm_output, presents = lm_output
# Output. # Output.
parallel_output = self.parallel_output
if forward_method_parallel_output is not None:
parallel_output = forward_method_parallel_output
output = parallel_lm_logits( output = parallel_lm_logits(
lm_output, lm_output,
self.language_model.embedding.word_embeddings.weight, self.language_model.embedding.word_embeddings.weight,
self.parallel_output) parallel_output)
if get_key_value: if get_key_value:
output = [output, presents] output = [output, presents]
......
...@@ -91,6 +91,10 @@ class AbstractTokenizer(ABC): ...@@ -91,6 +91,10 @@ class AbstractTokenizer(ABC):
def tokenize(self, text): def tokenize(self, text):
pass pass
def detokenize(self, token_ids):
raise NotImplementedError('detokenizer is not implemented for {} '
'tokenizer'.format(self.name))
@property @property
def cls(self): def cls(self):
raise NotImplementedError('CLS is not provided for {} ' raise NotImplementedError('CLS is not provided for {} '
...@@ -190,6 +194,9 @@ class _GPT2BPETokenizer(AbstractTokenizer): ...@@ -190,6 +194,9 @@ class _GPT2BPETokenizer(AbstractTokenizer):
def tokenize(self, text): def tokenize(self, text):
return self.tokenizer.encode(text) return self.tokenizer.encode(text)
def detokenize(self, token_ids):
return self.tokenizer.decode(token_ids)
@property @property
def eod(self): def eod(self):
return self.eod_id return self.eod_id
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment