Commit 458d7785 authored by rprenger's avatar rprenger
Browse files

Fixing merge conflict

parents d7bf1ab5 9c5a830f
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prompting the pretrained language model to generate knowledge/response"""
import json
import torch
import requests
from nltk import word_tokenize
from megatron import mpu
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.text_generation import generate_and_post_process
def call_model_api(inputs, tokens_to_generate):
"""Calling the model api to get the output generations"""
args = get_args()
# The following is an example of using the Megatron API
# You can also implement your own API function to place this part
headers = {'Content-Type': 'application/json; charset=UTF-8'}
data = {"prompts": [inputs], "tokens_to_generate": tokens_to_generate, "top_k": 1}
data_json = json.dumps(data)
outputs = requests.put(args.megatron_api_url, headers=headers, data=data_json).json()["text"][0]
input_len = len(inputs)
outputs = outputs[input_len:]
outputs = outputs.split("\n")[0].strip()
return outputs
def read_prompts(prompt_path, prompt_type, n_example):
"""Read prompt data"""
if prompt_type == "knowledge":
# prompts for the knowledge generation
prompt_examples_dict = {}
# read prompt_path
with open(prompt_path, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line_dict = json.loads(line)
key = list(line_dict.keys())[0]
if key not in prompt_examples_dict:
prompt_examples = line_dict[key]
prompt = ""
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
prompt_examples_dict[key] = prompt
return prompt_examples_dict
else:
# prompts for the response generation
# read prompt_path
prompt = ""
with open(prompt_path, "r") as f:
prompt_examples = f.readlines()
prompt_examples = prompt_examples[:n_example]
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
return prompt
def generate_samples_by_calling_api():
""" Generate outputs by calling"""
args = get_args()
assert args.prompt_type in ["knowledge", "response"], \
"Please input a correct prompt type!"
if args.prompt_type == "knowledge":
# read knowledge generation prompts
knwl_gen_prompt_dict = read_prompts(
args.prompt_file, args.prompt_type, args.num_prompt_examples)
else:
resp_gen_prompt = read_prompts(
args.prompt_file, args.prompt_type, args.num_prompt_examples)
# read the test data
fname = open(args.sample_input_file, "r")
test_sample_list = fname.readlines()
# create output file
fname_out = open(args.sample_output_file, "w")
# call the api to get the output generations
for test_sample in test_sample_list:
test_sample = test_sample.strip()
splits = test_sample.split("\t")
topic = splits[0]
# prepare the inputs for the api
if args.prompt_type == "knowledge":
## inputs = prompt + current test
# get the prompt
turns = splits[1].split(" [SEP] ")
last_turn = turns[-1]
key = topic + " " + last_turn
inputs = knwl_gen_prompt_dict[key]
# add current test
inputs += "( " + last_turn + " ) " + topic + " =>"
else:
# inputs = prompt + current test
# get the prompt
inputs = resp_gen_prompt
# add current test
turns = splits[1].split(" [SEP] ")
knowledge = splits[2]
last_turn = turns[-1]
last_turn = " ".join(word_tokenize(last_turn))
knowledge = " ".join(word_tokenize(knowledge))
knowledge = knowledge.strip()
last_turn = last_turn.strip()
inputs += "Topic: " + topic + ". "
inputs += "User says: " + last_turn + " "
inputs += "We know that: " + knowledge + " "
inputs += "System replies:"
# get the output generations from the api,
# and write to the output file
generations = call_model_api(inputs, args.out_seq_length)
fname_out.write(generations)
fname_out.write("\n")
fname.close()
fname_out.close()
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building GPT model ...')
model = GPTModel(
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process
)
return model
def generate_samples_by_prompting_input_from_file(model):
"""Prompt a pretrained language model to generate knowledge/response"""
# get tokenizer
args = get_args()
tokenizer = get_tokenizer()
# Read the sample file and open the output file.
assert args.sample_input_file is not None, \
'sample input file is not provided.'
if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
fname = open(args.sample_input_file, "r")
all_raw_text = fname.readlines()
input_count = len(all_raw_text)
if args.sample_output_file is None:
sample_output_file = args.sample_input_file + ".out"
print('`sample-output-file` not specified, setting '
'it to {}'.format(sample_output_file))
else:
sample_output_file = args.sample_output_file
fname_out = open(sample_output_file, "w")
# only two prompt types (i.e., knowledge and response) are allowed
assert args.prompt_type in ["knowledge", "response"], \
"Please input a correct prompt type!"
# Read the prompt file
if args.prompt_type == "knowledge":
# read the prompts for the knowledge generation
prompt_examples_dict = {}
with open(args.prompt_file, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line_dict = json.loads(line)
key = list(line_dict.keys())[0]
# get the prompt examples based on the key
if key not in prompt_examples_dict:
prompt_examples = line_dict[key]
prompt = ""
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
prompt_examples_dict[key] = prompt
else:
# read the prompts for the response generation
# prompts are fixed for all test samples
with open(args.prompt_file, "r") as f:
prompt_examples = f.readlines()
prompt_examples = prompt_examples[:args.num_prompt_examples]
prompt = ""
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
input_pos = 0
model.eval()
# perform prompting
with torch.no_grad():
while True:
raw_text_len = 0
if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0:
input_str = all_raw_text[input_pos]
input_str = input_str.strip()
splits = input_str.split("\t")
topic = splits[0]
if args.prompt_type == "knowledge":
# first add the prompt into the raw_text
turns = splits[1].split(" [SEP] ")
last_turn = turns[-1]
key = topic + " " + last_turn
raw_text = prompt_examples_dict[key]
# construct inputs for knowledge generation
# then add the constructed inputs into the raw_text
raw_text += "( " + last_turn + " ) " + topic + " =>"
else:
# first add the prompt into the raw_text
raw_text = prompt
# construct inputs for response generation
# then add the constructed inputs into the raw_text
turns = splits[1].split(" [SEP] ")
knowledge = splits[2]
last_turn = turns[-1]
last_turn = " ".join(word_tokenize(last_turn))
knowledge = " ".join(word_tokenize(knowledge))
knowledge = knowledge.strip()
last_turn = last_turn.strip()
raw_text += "Topic: " + topic + ". "
raw_text += "User says: " + last_turn + " "
raw_text += "We know that: " + knowledge + " "
raw_text += "System replies:"
input_pos += 1
raw_text_len = len(raw_text)
else:
raw_text = "EMPTY TEXT"
if input_pos % 100 == 0:
print_rank_0("input_pos: %d" % input_pos)
outputs = generate_and_post_process(
model=model,
prompts=[raw_text],
tokens_to_generate=args.out_seq_length,
top_k_sampling=1)
prompts_plus_generations = outputs[0]
prompts_plus_generations = prompts_plus_generations[0]
# write the generated output to the output file
if mpu.get_tensor_model_parallel_rank() == 0:
if mpu.is_pipeline_first_stage():
generations = prompts_plus_generations[raw_text_len:]
generations = generations.split("\n")[0]
generations = generations.strip()
fname_out.write(generations)
fname_out.write("\n")
raw_text = None
if input_pos == input_count:
return
def main():
args = get_args()
if args.api_prompt:
# obtain the generations by calling the api
generate_samples_by_calling_api()
return
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
# Set up model and load checkpoint.
model = get_model(model_provider, wrap_with_ddp=False)
if args.load is not None:
_ = load_checkpoint(model, None, None)
assert len(model) == 1, "Above condition should have caught this"
model = model[0]
# perform the prompting
generate_samples_by_prompting_input_from_file(model)
...@@ -135,7 +135,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset): ...@@ -135,7 +135,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
def _train( def _train(
model, model,
optimizer, optimizer,
lr_scheduler, opt_param_scheduler,
forward_step, forward_step,
train_dataloader, train_dataloader,
valid_dataloader, valid_dataloader,
...@@ -179,7 +179,7 @@ def _train( ...@@ -179,7 +179,7 @@ def _train(
# Train for one step. # Train for one step.
losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step( losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step(
forward_step, batch, model, optimizer, lr_scheduler forward_step, batch, model, optimizer, opt_param_scheduler
) )
iteration += 1 iteration += 1
...@@ -206,7 +206,7 @@ def _train( ...@@ -206,7 +206,7 @@ def _train(
iteration % args.adlr_autoresume_interval == 0 iteration % args.adlr_autoresume_interval == 0
): ):
check_adlr_autoresume_termination( check_adlr_autoresume_termination(
iteration, model, optimizer, lr_scheduler iteration, model, optimizer, opt_param_scheduler
) )
# Checkpointing # Checkpointing
...@@ -215,7 +215,7 @@ def _train( ...@@ -215,7 +215,7 @@ def _train(
and args.save_interval and args.save_interval
and iteration % args.save_interval == 0 and iteration % args.save_interval == 0
): ):
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
# Evaluation # Evaluation
if args.eval_interval and iteration % args.eval_interval == 0: if args.eval_interval and iteration % args.eval_interval == 0:
...@@ -231,7 +231,7 @@ def _train( ...@@ -231,7 +231,7 @@ def _train(
# Checkpointing at the end of each epoch. # Checkpointing at the end of each epoch.
if args.save: if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
# Callback at the end of each epoch. # Callback at the end of each epoch.
if end_of_epoch_callback is not None: if end_of_epoch_callback is not None:
...@@ -266,7 +266,7 @@ def finetune( ...@@ -266,7 +266,7 @@ def finetune(
# Build model, optimizer and learning rate scheduler. # Build model, optimizer and learning rate scheduler.
timers("model and optimizer").start() timers("model and optimizer").start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider)
timers("model and optimizer").stop() timers("model and optimizer").stop()
# If pretrained checkpoint is provided and we have not trained for # If pretrained checkpoint is provided and we have not trained for
...@@ -300,7 +300,7 @@ def finetune( ...@@ -300,7 +300,7 @@ def finetune(
_train( _train(
model, model,
optimizer, optimizer,
lr_scheduler, opt_param_scheduler,
forward_step, forward_step,
train_dataloader, train_dataloader,
valid_dataloader, valid_dataloader,
......
...@@ -78,4 +78,7 @@ if __name__ == "__main__": ...@@ -78,4 +78,7 @@ if __name__ == "__main__":
choice = torch.cuda.LongTensor(1) choice = torch.cuda.LongTensor(1)
torch.distributed.broadcast(choice, 0) torch.distributed.broadcast(choice, 0)
if choice[0].item() == 0: if choice[0].item() == 0:
generate_and_post_process(model) try:
generate_and_post_process(model)
except ValueError as ve:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment