Unverified Commit 4b8312c0 authored by Camille Zhong's avatar Camille Zhong Committed by GitHub
Browse files

fix sft single turn inference example (#5416)

parent a1c6cdb1
...@@ -15,7 +15,7 @@ def load_model(model_path, device="cuda", **kwargs): ...@@ -15,7 +15,7 @@ def load_model(model_path, device="cuda", **kwargs):
model.to(device) model.to(device)
try: try:
tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side='left')
except OSError: except OSError:
raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.") raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.")
...@@ -29,6 +29,7 @@ def generate(args): ...@@ -29,6 +29,7 @@ def generate(args):
if args.prompt_style == "sft": if args.prompt_style == "sft":
conversation = default_conversation.copy() conversation = default_conversation.copy()
conversation.append_message("Human", args.input_txt) conversation.append_message("Human", args.input_txt)
conversation.append_message("Assistant", None)
input_txt = conversation.get_prompt() input_txt = conversation.get_prompt()
else: else:
BASE_INFERENCE_SUFFIX = "\n\n->\n\n" BASE_INFERENCE_SUFFIX = "\n\n->\n\n"
...@@ -46,7 +47,7 @@ def generate(args): ...@@ -46,7 +47,7 @@ def generate(args):
num_return_sequences=1, num_return_sequences=1,
) )
response = tokenizer.decode(output.cpu()[0, num_input_tokens:], skip_special_tokens=True) response = tokenizer.decode(output.cpu()[0, num_input_tokens:], skip_special_tokens=True)
logger.info(f"Question: {input_txt} \n\n Answer: \n{response}") logger.info(f"\nHuman: {args.input_txt} \n\nAssistant: \n{response}")
return response return response
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment