Commit 5f433016 authored by zihanl's avatar zihanl
Browse files

add api option

parent fec9668d
......@@ -30,6 +30,121 @@ from tasks.knwl_dialo.utils import get_token_stream
# from megatron.text_generation import generate_and_post_process
def call_model_api(inputs):
"""Calling the model api to get the output generations"""
# TODO
# Implement the model api, and get output generations from the inputs
# After that, return the output generations
# outputs = call_model_api(inputs)
# return outputs
pass
def read_prompts(prompt_path, prompt_type, n_example):
"""Read prompt data"""
if prompt_type == "knowledge":
# prompts for the knowledge generation
prompt_examples_dict = {}
# read prompt_path
with open(prompt_path, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line_dict = json.loads(line)
key = list(line_dict.keys())[0]
if key not in prompt_examples_dict:
prompt_examples = line_dict[key]
prompt = ""
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
prompt_examples_dict[key] = prompt
return prompt_examples_dict
else:
# prompts for the response generation
# read prompt_path
prompt = ""
with open(prompt_path, "r") as f:
prompt_examples = f.readlines()
prompt_examples = prompt_examples[:n_example]
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
return prompt
def generate_samples_by_calling_api():
""" Generate outputs by calling"""
args = get_args()
assert args.prompt_type in ["knowledge", "response"], \
"Please input a correct prompt type!"
if args.prompt_type == "knowledge":
# read knowledge generation prompts
knwl_gen_prompt_dict = read_prompts(
args.prompt_file, args.prompt_type, args.num_prompt_examples)
else:
resp_gen_prompt = read_prompts(
args.prompt_file, args.prompt_type, args.num_prompt_examples)
# read the test data
fname = open(args.sample_input_file, "r")
test_sample_list = fname.readlines()
# create output file
fname_out = open(sample_output_file, "w")
# call the api to get the output generations
for test_sample in test_sample_list:
test_sample = test_sample.strip()
splits = input_str.split("\t")
topic = splits[0]
# prepare the inputs for the api
if args.prompt_type == "knowledge":
# inputs = prompt + current test
# get the prompt
turns = splits[1].split(" [SEP] ")
last_turn = turns[-1]
key = topic + " " + last_turn
inputs = knwl_gen_prompt_dict[key]
# add current test
inputs += "( " + last_turn + " ) " + topic + " =>"
else:
# inputs = prompt + current test
# get the prompt
inputs = resp_gen_prompt
# add current test
turns = splits[1].split(" [SEP] ")
knowledge = splits[2]
last_turn = turns[-1]
last_turn = " ".join(word_tokenize(last_turn))
knowledge = " ".join(word_tokenize(knowledge))
knowledge = knowledge.strip()
last_turn = last_turn.strip()
inputs += "Topic: " + topic + ". "
inputs += "User says: " + last_turn + " "
inputs += "We know that: " + knowledge + " "
inputs += "System replies:"
# get the output generations from the api,
# and write to the output file
generations = call_model_api(inputs)
fname_out.write(generations)
fname_out.write("\n")
fname.close()
fname_out.close()
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
......@@ -124,9 +239,7 @@ def generate_samples_by_prompting_input_from_file(model):
# construct inputs for knowledge generation
# then add the constructed inputs into the raw_text
turns = splits[1].split(" [SEP] ")
context = turns[-1]
raw_text += "( " + context + " ) " + topic + " =>"
raw_text += "( " + last_turn + " ) " + topic + " =>"
else:
# first add the prompt into the raw_text
......@@ -196,6 +309,11 @@ def generate_samples_by_prompting_input_from_file(model):
def main():
args = get_args()
if args.api_prompting:
# obtain the generations by calling the api
generate_samples_by_calling_api()
return
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
......
......@@ -102,6 +102,8 @@ def get_tasks_args(parser):
help='datapath for golden sentences')
group.add_argument('--out-seq-length', type=int, default=100,
help='output sequence length')
group.add_argument('--api-prompt', default=False, action="store_true",
help='setup model api for prompting')
return parser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment