RAG_ChatBot.py 7.08 KB
Newer Older
1
import os
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from typing import Dict, Tuple

from colossalqa.chain.retrieval_qa.base import RetrievalQA
from colossalqa.data_loader.document_loader import DocumentLoader
from colossalqa.memory import ConversationBufferWithSummary
from colossalqa.mylogging import get_logger
from colossalqa.prompt.prompt import (
    PROMPT_DISAMBIGUATE_ZH,
    PROMPT_RETRIEVAL_QA_ZH,
    SUMMARY_PROMPT_ZH,
    ZH_RETRIEVAL_QA_REJECTION_ANSWER,
    ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
)
from colossalqa.retriever import CustomRetriever
from langchain import LLMChain
from langchain.embeddings import HuggingFaceEmbeddings

logger = get_logger()


class RAG_ChatBot:
    def __init__(
        self,
        llm,
        rag_config,
    ) -> None:
        self.llm = llm
        self.rag_config = rag_config
30
31
32
33
34
35
36
37
38
39
        self.set_embed_model(**self.rag_config["embed"])
        self.set_text_splitter(**self.rag_config["splitter"])
        self.set_memory(**self.rag_config["chain"])
        self.set_info_retriever(**self.rag_config["retrieval"])
        self.set_rag_chain(**self.rag_config["chain"])
        if self.rag_config["chain"].get("disambig_prompt", None):
            self.set_disambig_retriv(**self.rag_config["chain"])

        self.documents = []
        self.docs_names = []
40
41
42
43
44
45
46
47
48
49

    def set_embed_model(self, **kwargs):
        self.embed_model = HuggingFaceEmbeddings(
            model_name=kwargs["embed_model_name_or_path"],
            model_kwargs=kwargs["embed_model_device"],
            encode_kwargs={"normalize_embeddings": False},
        )

    def set_text_splitter(self, **kwargs):
        # Initialize text_splitter
50
        self.text_splitter = kwargs["name"]()
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

    def set_memory(self, **kwargs):
        params = {"llm_kwargs": kwargs["mem_llm_kwargs"]} if kwargs.get("mem_llm_kwargs", None) else {}
        # Initialize memory with summarization ability
        self.memory = ConversationBufferWithSummary(
            llm=self.llm,
            prompt=kwargs["mem_summary_prompt"],
            human_prefix=kwargs["mem_human_prefix"],
            ai_prefix=kwargs["mem_ai_prefix"],
            max_tokens=kwargs["mem_max_tokens"],
            **params,
        )

    def set_info_retriever(self, **kwargs):
        self.info_retriever = CustomRetriever(
            k=kwargs["retri_top_k"], sql_file_path=kwargs["retri_kb_file_path"], verbose=kwargs["verbose"]
        )

    def set_rag_chain(self, **kwargs):
        params = {"llm_kwargs": kwargs["gen_llm_kwargs"]} if kwargs.get("gen_llm_kwargs", None) else {}
        self.rag_chain = RetrievalQA.from_chain_type(
            llm=self.llm,
            verbose=kwargs["verbose"],
            chain_type="stuff",
            retriever=self.info_retriever,
            chain_type_kwargs={"prompt": kwargs["gen_qa_prompt"], "memory": self.memory},
            **params,
        )

    def set_disambig_retriv(self, **kwargs):
        params = {"llm_kwargs": kwargs["disambig_llm_kwargs"]} if kwargs.get("disambig_llm_kwargs", None) else {}
        self.llm_chain_disambiguate = LLMChain(llm=self.llm, prompt=kwargs["disambig_prompt"], **params)

        def disambiguity(input: str):
            out = self.llm_chain_disambiguate.run(input=input, chat_history=self.memory.buffer, stop=["\n"])
            return out.split("\n")[0]

        self.info_retriever.set_rephrase_handler(disambiguity)

    def load_doc_from_console(self, json_parse_args: Dict = {}):
91
        print("Select files for constructing the retriever")
92
93
94
95
96
97
        while True:
            file = input("Enter a file path or press Enter directly without input to exit:").strip()
            if file == "":
                break
            data_name = input("Enter a short description of the data:")
            docs = DocumentLoader([[file, data_name.replace(" ", "_")]], **json_parse_args).all_data
98
99
100
            self.documents.extend(docs)
            self.docs_names.append(data_name)
        self.split_docs_and_add_to_mem(**self.rag_config["chain"])
101
102
103
104

    def load_doc_from_files(self, files, data_name="default_kb", json_parse_args: Dict = {}):
        for file in files:
            docs = DocumentLoader([[file, data_name.replace(" ", "_")]], **json_parse_args).all_data
105
106
107
            self.documents.extend(docs)
            self.docs_names.append(os.path.basename(file))
        self.split_docs_and_add_to_mem(**self.rag_config["chain"])
108
109

    def split_docs_and_add_to_mem(self, **kwargs):
110
        doc_splits = self.split_docs(self.documents)
111
        self.info_retriever.add_documents(
112
            docs=doc_splits, cleanup="incremental", mode="by_source", embedding=self.embed_model
113
114
115
        )
        self.memory.initiate_document_retrieval_chain(self.llm, kwargs["gen_qa_prompt"], self.info_retriever)

116
117
118
119
120
121
122
123
124
125
    def split_docs(self, documents):
        doc_splits = self.text_splitter.split_documents(documents)
        return doc_splits
    
    def clear_docs(self, **kwargs):
        self.documents = []
        self.docs_names = []
        self.info_retriever.clear_documents()
        self.memory.initiate_document_retrieval_chain(self.llm, kwargs["gen_qa_prompt"], self.info_retriever)
        
126
127
    def reset_config(self, rag_config):
        self.rag_config = rag_config
128
129
130
131
132
133
134
        self.set_embed_model(**self.rag_config["embed"])
        self.set_text_splitter(**self.rag_config["splitter"])
        self.set_memory(**self.rag_config["chain"])
        self.set_info_retriever(**self.rag_config["retrieval"])
        self.set_rag_chain(**self.rag_config["chain"])
        if self.rag_config["chain"].get("disambig_prompt", None):
            self.set_disambig_retriv(**self.rag_config["chain"])
135
136
137
138
139
140
141
142
143
144
145

    def run(self, user_input: str, memory: ConversationBufferWithSummary) -> Tuple[str, ConversationBufferWithSummary]:
        if memory:
            memory.buffered_history.messages = memory.buffered_history.messages
            memory.summarized_history_temp.messages = memory.summarized_history_temp.messages
        result = self.rag_chain.run(
            query=user_input,
            stop=[memory.human_prefix + ": "],
            rejection_trigger_keywrods=ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
            rejection_answer=ZH_RETRIEVAL_QA_REJECTION_ANSWER,
        )
146
        return result, memory
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162

    def start_test_session(self):
        """
        Simple session for testing purpose
        """
        while True:
            user_input = input("User: ")
            if "END" == user_input:
                print("Agent: Happy to chat with you :)")
                break
            agent_response, self.memory = self.run(user_input, self.memory)
            print(f"Agent: {agent_response}")


if __name__ == "__main__":
    # Initialize an Langchain LLM(here we use ChatGPT as an example)
163
    import config
164
165
    from langchain.llms import OpenAI

166
167
    # you need to: export OPENAI_API_KEY="YOUR_OPENAI_API_KEY"
    llm = OpenAI(openai_api_key=os.getenv("OPENAI_API_KEY"))
168
169

    # chatgpt cannot control temperature, do_sample, etc.
170
171
172
173
    all_config = config.ALL_CONFIG
    all_config["chain"]["mem_llm_kwargs"] = None
    all_config["chain"]["disambig_llm_kwargs"] = None
    all_config["chain"]["gen_llm_kwargs"] = None
174

175
    rag = RAG_ChatBot(llm, all_config)
176
177
    rag.load_doc_from_console()
    rag.start_test_session()