import argparse
import os
parser = argparse.ArgumentParser()
parser.add_argument('--file_path', required=True, type=str)
parser.add_argument('--embedding_path', required=True, type=str)
parser.add_argument('--model_path', required=True, type=str)
parser.add_argument('--gpu_id', default="0", type=str)
parser.add_argument('--chain_type', default="refine", type=str)
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
file_path = args.file_path
embedding_path = args.embedding_path
model_path = args.model_path
import torch
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.document_loaders import TextLoader
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
prompt_template = (
"[INST] <>\n"
"You are a helpful assistant. 你是一个乐于助人的助手。\n"
"<>\n\n"
"{context}\n{question} [/INST]"
)
refine_prompt_template = (
"[INST] <>\n"
"You are a helpful assistant. 你是一个乐于助人的助手。\n"
"<>\n\n"
"这是原始问题: {question}\n"
"已有的回答: {existing_answer}\n"
"现在还有一些文字,(如果有需要)你可以根据它们完善现有的回答。"
"\n\n"
"{context_str}\n"
"\n\n"
"请根据新的文段,进一步完善你的回答。"
" [/INST]"
)
initial_qa_template = (
"[INST] <>\n"
"You are a helpful assistant. 你是一个乐于助人的助手。\n"
"<>\n\n"
"以下为背景知识:\n"
"{context_str}"
"\n"
"请根据以上背景知识, 回答这个问题:{question}。"
" [/INST]"
)
if __name__ == '__main__':
load_type = torch.float16
if not torch.cuda.is_available():
raise RuntimeError("No CUDA GPUs are available.")
loader = TextLoader(file_path)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=600, chunk_overlap=100)
texts = text_splitter.split_documents(documents)
print("Loading the embedding model...")
embeddings = HuggingFaceEmbeddings(model_name=embedding_path)
docsearch = FAISS.from_documents(texts, embeddings)
print("loading LLM...")
model = HuggingFacePipeline.from_model_id(model_id=model_path,
task="text-generation",
device=0,
pipeline_kwargs={
"max_new_tokens": 400,
"do_sample": True,
"temperature": 0.2,
"top_k": 40,
"top_p": 0.9,
"repetition_penalty": 1.1},
model_kwargs={
"torch_dtype": load_type,
"low_cpu_mem_usage": True,
"trust_remote_code": True}
)
if args.chain_type == "stuff":
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
chain_type_kwargs = {"prompt": PROMPT}
qa = RetrievalQA.from_chain_type(
llm=model,
chain_type="stuff",
retriever=docsearch.as_retriever(search_kwargs={"k": 1}),
chain_type_kwargs=chain_type_kwargs)
elif args.chain_type == "refine":
refine_prompt = PromptTemplate(
input_variables=["question", "existing_answer", "context_str"],
template=refine_prompt_template,
)
initial_qa_prompt = PromptTemplate(
input_variables=["context_str", "question"],
template=initial_qa_template,
)
chain_type_kwargs = {"question_prompt": initial_qa_prompt, "refine_prompt": refine_prompt}
qa = RetrievalQA.from_chain_type(
llm=model, chain_type="refine",
retriever=docsearch.as_retriever(search_kwargs={"k": 1}),
chain_type_kwargs=chain_type_kwargs)
while True:
query = input("请输入问题:")
if len(query.strip())==0:
break
print(qa.run(query))