import argparse
import os
parser = argparse.ArgumentParser()
parser.add_argument('--file_path', required=True, type=str)
parser.add_argument('--model_path', required=True, type=str)
parser.add_argument('--gpu_id', default="0", type=str)
parser.add_argument('--chain_type', default="refine", type=str)
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
file_path = args.file_path
model_path = args.model_path
import torch
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import PromptTemplate
from langchain.chains.summarize import load_summarize_chain
prompt_template = (
"[INST] <>\n"
"You are a helpful assistant. 你是一个乐于助人的助手。\n"
"<>\n\n"
"请为以下文字写一段摘要:\n{text} [/INST]"
)
refine_template = (
"[INST] <>\n"
"You are a helpful assistant. 你是一个乐于助人的助手。\n"
"<>\n\n"
"已有一段摘要:{existing_answer}\n"
"现在还有一些文字,(如果有需要)你可以根据它们完善现有的摘要。"
"\n"
"{text}\n"
"\n"
"如果这段文字没有用,返回原来的摘要即可。请你生成一个最终的摘要。"
" [/INST]"
)
if __name__ == '__main__':
load_type = torch.float16
if not torch.cuda.is_available():
raise RuntimeError("No CUDA GPUs are available.")
text_splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=100, length_function=len)
with open(file_path) as f:
text = f.read()
docs = text_splitter.create_documents([text])
print("loading LLM...")
model = HuggingFacePipeline.from_model_id(model_id=model_path,
task="text-generation",
device=0,
pipeline_kwargs={
"max_new_tokens": 400,
"do_sample": True,
"temperature": 0.2,
"top_k": 40,
"top_p": 0.9,
"repetition_penalty": 1.1},
model_kwargs={
"torch_dtype" : load_type,
"low_cpu_mem_usage" : True,
"trust_remote_code": True}
)
PROMPT = PromptTemplate(template=prompt_template, input_variables=["text"])
REFINE_PROMPT = PromptTemplate(
template=refine_template,input_variables=["existing_answer", "text"],
)
if args.chain_type == "stuff":
chain = load_summarize_chain(model, chain_type="stuff", prompt=PROMPT)
elif args.chain_type == "refine":
chain = load_summarize_chain(model, chain_type="refine", question_prompt=PROMPT, refine_prompt=REFINE_PROMPT)
print(chain.run(docs))