Commit ccf0eae2 authored by zihanl's avatar zihanl
Browse files

use new text generation

parent b3cd8a47
...@@ -26,19 +26,26 @@ from megatron.model import GPTModel ...@@ -26,19 +26,26 @@ from megatron.model import GPTModel
from megatron.training import get_model from megatron.training import get_model
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from tasks.knwl_dialo.utils import get_token_stream from megatron.text_generation import generate_and_post_process
# from megatron.text_generation import generate_and_post_process
def call_model_api(inputs): def call_model_api(inputs, tokens_to_generate):
"""Calling the model api to get the output generations""" """Calling the model api to get the output generations"""
# TODO
# Implement the model api, and get output generations from the inputs args = get_args()
# After that, return the output generations
# The following is an example of using the Megatron API
# You can also implement your own API function to place this part
headers = {'Content-Type': 'application/json; charset=UTF-8'}
data = {"prompts": [inputs], "tokens_to_generate": tokens_to_generate, "top_k": 1}
data_json = json.dumps(data)
outputs = requests.put(args.megatron_api_url, headers=headers, data=data_json).json()["text"][0]
# outputs = call_model_api(inputs) input_len = len(inputs)
# return outputs outputs = outputs[input_len:]
pass outputs = outputs.split("\n")[0].strip()
return outputs
def read_prompts(prompt_path, prompt_type, n_example): def read_prompts(prompt_path, prompt_type, n_example):
...@@ -107,7 +114,7 @@ def generate_samples_by_calling_api(): ...@@ -107,7 +114,7 @@ def generate_samples_by_calling_api():
# prepare the inputs for the api # prepare the inputs for the api
if args.prompt_type == "knowledge": if args.prompt_type == "knowledge":
# inputs = prompt + current test ## inputs = prompt + current test
# get the prompt # get the prompt
turns = splits[1].split(" [SEP] ") turns = splits[1].split(" [SEP] ")
last_turn = turns[-1] last_turn = turns[-1]
...@@ -216,7 +223,6 @@ def generate_samples_by_prompting_input_from_file(model): ...@@ -216,7 +223,6 @@ def generate_samples_by_prompting_input_from_file(model):
instance = instance.strip() instance = instance.strip()
prompt += instance + " \n" prompt += instance + " \n"
context_count = 0
input_pos = 0 input_pos = 0
model.eval() model.eval()
# perform prompting # perform prompting
...@@ -261,47 +267,32 @@ def generate_samples_by_prompting_input_from_file(model): ...@@ -261,47 +267,32 @@ def generate_samples_by_prompting_input_from_file(model):
input_pos += 1 input_pos += 1
raw_text_len = len(raw_text) raw_text_len = len(raw_text)
context_tokens = tokenizer.tokenize(raw_text)
else: else:
context_tokens = tokenizer.tokenize("EMPTY TEXT") raw_text = "EMPTY TEXT"
# raw_text = "EMPTY TEXT"
if input_pos % 100 == 0: if input_pos % 100 == 0:
print_rank_0("input_pos: %d" % input_pos) print_rank_0("input_pos: %d" % input_pos)
# get the generation outputs (in decode_tokens) outputs = generate_and_post_process(
token_stream = get_token_stream(model, [context_tokens]) model=model,
for _, decode_tokens in enumerate(token_stream): prompts=[raw_text],
pass tokens_to_generate=args.out_seq_length,
# outputs = generate_and_post_process( top_k_sampling=1)
# model=model, prompts_plus_generations = outputs[0]
# prompts=[raw_text], prompts_plus_generations = prompts_plus_generations[0]
# tokens_to_generate=args.out_seq_length,
# top_k_sampling=1)
# prompts_plus_generations = outputs[0]
# write the generated output to the output file # write the generated output to the output file
if mpu.get_tensor_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist() generations = prompts_plus_generations[raw_text_len:]
trim_decode_tokens = tokenizer.detokenize( generations = generations.split("\n")[0]
decode_tokens)[raw_text_len:] generations = generations.strip()
fname_out.write(generations)
generated_output = trim_decode_tokens.split("\n")[0]
generated_output = generated_output.strip()
fname_out.write(generated_output)
fname_out.write("\n") fname_out.write("\n")
# generations = prompts_plus_generations[raw_text_len:]
# generations = generations.split("\n")[0]
# generations = generations.strip()
# fname_out.write(generations)
# fname_out.write("\n")
raw_text = None raw_text = None
context_count += 1
if input_pos == input_count: if input_pos == input_count:
return return
...@@ -309,7 +300,7 @@ def generate_samples_by_prompting_input_from_file(model): ...@@ -309,7 +300,7 @@ def generate_samples_by_prompting_input_from_file(model):
def main(): def main():
args = get_args() args = get_args()
if args.api_prompting: if args.api_prompt:
# obtain the generations by calling the api # obtain the generations by calling the api
generate_samples_by_calling_api() generate_samples_by_calling_api()
return return
...@@ -319,7 +310,7 @@ def main(): ...@@ -319,7 +310,7 @@ def main():
exit() exit()
# Set up model and load checkpoint. # Set up model and load checkpoint.
model = get_model(model_provider) model = get_model(model_provider, wrap_with_ddp=False)
if args.load is not None: if args.load is not None:
_ = load_checkpoint(model, None, None) _ = load_checkpoint(model, None, None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment