tool.py 4.77 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# encoding=gbk
import faiss
import numpy as np
import tiktoken
import torch
from datasets import load_from_disk
from langchain import PromptTemplate, LLMChain
from langchain.chat_models import ChatOpenAI
from pyserini.search.lucene import LuceneSearcher
from transformers import AutoTokenizer, AutoModel


class LocalDatasetLoader:
    data_path: str = ""
    doc_emb: np.ndarray = None

    def __init__(self,
                 data_path,
                 embedding_path):
        dataset = load_from_disk(data_path)
        title = dataset['title']
        text = dataset['text']
        self.data = [title[i] + ' -- ' + text[i] for i in range(len(title))]
        self.doc_emb = np.load(embedding_path)


class QueryGenerator:
    def __init__(self):
        prompt_template = """请基于历史,对新问题进行修改,如果新问题与历史相关,你必须结合语境将代词替换为相应的指代内容,让它的提问更加明确;否则保留原始的新问题。只修改新问题,不作任何回答:\n历史:{history}\n新问题:{question}\n修改后的新问题:"""
        llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0613")
        prompt = PromptTemplate(
            input_variables=["history", "question"],
            template=prompt_template,
        )

        self.llm_chain = LLMChain(llm=llm, prompt=prompt)

    def run(self, history, question):
        return self.llm_chain.predict(history=history, question=question).strip()


class AnswerGenerator:
    def __init__(self):
        prompt_template = """请基于参考,回答问题,不需要标注任何引用:\n历史:{history}\n问题:{question}\n参考:{references}\n答案:"""
        llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0613")
        prompt = PromptTemplate(
            input_variables=["history", "question", "references"],
            template=prompt_template,
        )

        self.llm_chain = LLMChain(llm=llm, prompt=prompt)

    def run(self, history, question, references):
        return self.llm_chain.predict(history=history, question=question, references=references).strip()


class BMVectorIndex:
    def __init__(self,
                 model_path,
                 bm_index_path,
                 data_loader):
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModel.from_pretrained(model_path)
        self.bm_searcher = LuceneSearcher(bm_index_path)
        self.loader = data_loader

        if '-en' in model_path:
            raise NotImplementedError("only support chinese currently")

        if 'noinstruct' in model_path:
            self.instruction = None
        else:
            self.instruction = "为这个句子生成表示以用于检索相关文章:"

    def search_for_doc(self, query: str, RANKING: int = 1000, TOP_N: int = 5):
        hits = self.bm_searcher.search(query, RANKING)
        ids = [int(e.docid) for e in hits]
        use_docs = self.loader.doc_emb[ids]

        if self.instruction is not None:
            query = self.instruction + query
        encoded_input = self.tokenizer(query, max_length=512, padding=True, truncation=True, return_tensors='pt')
        with torch.no_grad():
            model_output = self.model(**encoded_input)
            query_emb = model_output.last_hidden_state[:, 0]
            query_emb = torch.nn.functional.normalize(query_emb, p=2, dim=-1).detach().cpu().numpy()

        temp_index = faiss.IndexFlatIP(use_docs[0].shape[0])
        temp_index.add(use_docs)

        _, I = temp_index.search(query_emb, TOP_N)

        return "\n".join([self.loader.data[ids[I[0][i]]] for i in range(TOP_N)])


class Agent:
    def __init__(self, index):
        self.memory = ""
        self.index = index
        self.query_generator = QueryGenerator()
        self.answer_generator = AnswerGenerator()

    def empty_memory(self):
        self.memory = ""

    def update_memory(self, question, answer):
        self.memory += f"问:{question}\n答:{answer}\n"
        encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
        while len(encoding.encode(self.memory)) > 3500:
            pos = self.memory[2:].rfind("问")
            self.memory = self.memory[pos:]

    def generate_query(self, question):
        if self.memory == "":
            return question
        else:
            return self.query_generator.run(self.memory, question)

    def generate_answer(self, query, references):
        return self.answer_generator.run(self.memory, query, references)

    def answer(self, question, RANKING=1000, TOP_N=5, verbose=True):
        query = self.generate_query(question)
        references = self.index.search_for_doc(query, RANKING=RANKING, TOP_N=TOP_N)
        answer = self.generate_answer(question, references)
        self.update_memory(question, answer)
        if verbose:
            print('\033[96m' + "查:" + query + '\033[0m')
            print('\033[96m' + "参:" + references + '\033[0m')
        print("答:" + answer)