Commit ccf0eae2 authored by zihanl's avatar zihanl
Browse files

use new text generation

parent b3cd8a47
......@@ -26,19 +26,26 @@ from megatron.model import GPTModel
from megatron.training import get_model
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from tasks.knwl_dialo.utils import get_token_stream
# from megatron.text_generation import generate_and_post_process
from megatron.text_generation import generate_and_post_process
def call_model_api(inputs):
def call_model_api(inputs, tokens_to_generate):
"""Calling the model api to get the output generations"""
# TODO
# Implement the model api, and get output generations from the inputs
# After that, return the output generations
args = get_args()
# The following is an example of using the Megatron API
# You can also implement your own API function to place this part
headers = {'Content-Type': 'application/json; charset=UTF-8'}
data = {"prompts": [inputs], "tokens_to_generate": tokens_to_generate, "top_k": 1}
data_json = json.dumps(data)
outputs = requests.put(args.megatron_api_url, headers=headers, data=data_json).json()["text"][0]
# outputs = call_model_api(inputs)
# return outputs
pass
input_len = len(inputs)
outputs = outputs[input_len:]
outputs = outputs.split("\n")[0].strip()
return outputs
def read_prompts(prompt_path, prompt_type, n_example):
......@@ -107,7 +114,7 @@ def generate_samples_by_calling_api():
# prepare the inputs for the api
if args.prompt_type == "knowledge":
# inputs = prompt + current test
## inputs = prompt + current test
# get the prompt
turns = splits[1].split(" [SEP] ")
last_turn = turns[-1]
......@@ -216,7 +223,6 @@ def generate_samples_by_prompting_input_from_file(model):
instance = instance.strip()
prompt += instance + " \n"
context_count = 0
input_pos = 0
model.eval()
# perform prompting
......@@ -261,47 +267,32 @@ def generate_samples_by_prompting_input_from_file(model):
input_pos += 1
raw_text_len = len(raw_text)
context_tokens = tokenizer.tokenize(raw_text)
else:
context_tokens = tokenizer.tokenize("EMPTY TEXT")
# raw_text = "EMPTY TEXT"
raw_text = "EMPTY TEXT"
if input_pos % 100 == 0:
print_rank_0("input_pos: %d" % input_pos)
# get the generation outputs (in decode_tokens)
token_stream = get_token_stream(model, [context_tokens])
for _, decode_tokens in enumerate(token_stream):
pass
# outputs = generate_and_post_process(
# model=model,
# prompts=[raw_text],
# tokens_to_generate=args.out_seq_length,
# top_k_sampling=1)
# prompts_plus_generations = outputs[0]
outputs = generate_and_post_process(
model=model,
prompts=[raw_text],
tokens_to_generate=args.out_seq_length,
top_k_sampling=1)
prompts_plus_generations = outputs[0]
prompts_plus_generations = prompts_plus_generations[0]
# write the generated output to the output file
if mpu.get_tensor_model_parallel_rank() == 0:
if mpu.is_pipeline_first_stage():
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[raw_text_len:]
generated_output = trim_decode_tokens.split("\n")[0]
generated_output = generated_output.strip()
fname_out.write(generated_output)
generations = prompts_plus_generations[raw_text_len:]
generations = generations.split("\n")[0]
generations = generations.strip()
fname_out.write(generations)
fname_out.write("\n")
# generations = prompts_plus_generations[raw_text_len:]
# generations = generations.split("\n")[0]
# generations = generations.strip()
# fname_out.write(generations)
# fname_out.write("\n")
raw_text = None
context_count += 1
if input_pos == input_count:
return
......@@ -309,7 +300,7 @@ def generate_samples_by_prompting_input_from_file(model):
def main():
args = get_args()
if args.api_prompting:
if args.api_prompt:
# obtain the generations by calling the api
generate_samples_by_calling_api()
return
......@@ -319,7 +310,7 @@ def main():
exit()
# Set up model and load checkpoint.
model = get_model(model_provider)
model = get_model(model_provider, wrap_with_ddp=False)
if args.load is not None:
_ = load_checkpoint(model, None, None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment