retrieval_conversation_en.py 4.85 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
Script for English retrieval based conversation system backed by LLaMa2
"""
import argparse
import os

from colossalqa.chain.retrieval_qa.base import RetrievalQA
from colossalqa.data_loader.document_loader import DocumentLoader
from colossalqa.local.llm import ColossalAPI, ColossalLLM
from colossalqa.memory import ConversationBufferWithSummary
from colossalqa.prompt.prompt import (
    EN_RETRIEVAL_QA_REJECTION_ANSWER,
    EN_RETRIEVAL_QA_TRIGGER_KEYWORDS,
    PROMPT_DISAMBIGUATE_EN,
    PROMPT_RETRIEVAL_QA_EN,
)
from colossalqa.retriever import CustomRetriever
from langchain import LLMChain
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter

if __name__ == "__main__":
    # Parse arguments
    parser = argparse.ArgumentParser(description="English retrieval based conversation system backed by LLaMa2")
    parser.add_argument("--model_path", type=str, default=None, help="path to the model")
    parser.add_argument("--model_name", type=str, default=None, help="name of the model")
    parser.add_argument(
        "--sql_file_path", type=str, default=None, help="path to the a empty folder for storing sql files for indexing"
    )

    args = parser.parse_args()
    if not os.path.exists(args.sql_file_path):
        os.makedirs(args.sql_file_path)

    colossal_api = ColossalAPI.get_api(args.model_name, args.model_path)
    llm = ColossalLLM(n=1, api=colossal_api)

    # Define the retriever
    information_retriever = CustomRetriever(k=3, sql_file_path=args.sql_file_path, verbose=True)

    # Setup embedding model locally
    embedding = HuggingFaceEmbeddings(
        model_name="moka-ai/m3e-base", model_kwargs={"device": "cpu"}, encode_kwargs={"normalize_embeddings": False}
    )

    # Define memory with summarization ability
    memory = ConversationBufferWithSummary(
        llm=llm, max_tokens=2000, llm_kwargs={"max_new_tokens": 50, "temperature": 0.6, "do_sample": True}
    )

    # Define the chain to preprocess the input
    # Disambiguate the input. e.g. "What is the capital of that country?" -> "What is the capital of France?"
    llm_chain_disambiguate = LLMChain(
        llm=llm, prompt=PROMPT_DISAMBIGUATE_EN, llm_kwargs={"max_new_tokens": 30, "temperature": 0.6, "do_sample": True}
    )

    def disambiguity(input):
        out = llm_chain_disambiguate.run(input=input, chat_history=memory.buffer, stop=["\n"])
        return out.split("\n")[0]

    # Load data to vector store
    print("Select files for constructing retriever")
    documents = []
    while True:
        file = input("Enter a file path or press Enter directory without input to exit:").strip()
        if file == "":
            break
        data_name = input("Enter a short description of the data:")
        separator = input(
70
            "Enter a separator to force separating text into chunks, if no separator is given, the default separator is '\\n\\n'. Note that"
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
            + "we use neural text spliter to split texts into chunks, the seperator only serves as a delimiter to force split long passage into"
            + " chunks before passing to the neural network. Press ENTER directly to skip:"
        )
        separator = separator if separator != "" else "\n\n"
        retriever_data = DocumentLoader([[file, data_name.replace(" ", "_")]]).all_data

        # Split
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=20)
        splits = text_splitter.split_documents(retriever_data)
        documents.extend(splits)
    # Create retriever
    information_retriever.add_documents(docs=documents, cleanup="incremental", mode="by_source", embedding=embedding)

    # Set document retrieval chain, we need this chain to calculate prompt length
    memory.initiate_document_retrieval_chain(
        llm,
        PROMPT_RETRIEVAL_QA_EN,
        information_retriever,
        chain_type_kwargs={
            "chat_history": "",
        },
    )

    # Define retrieval chain
    retrieval_chain = RetrievalQA.from_chain_type(
        llm=llm,
        verbose=False,
        chain_type="stuff",
        retriever=information_retriever,
        chain_type_kwargs={"prompt": PROMPT_RETRIEVAL_QA_EN, "memory": memory},
        llm_kwargs={"max_new_tokens": 50, "temperature": 0.75, "do_sample": True},
    )
    # Set disambiguity handler
    information_retriever.set_rephrase_handler(disambiguity)

    # Start conversation
    while True:
        user_input = input("User: ")
        if "END" == user_input:
            print("Agent: Happy to chat with you :)")
            break
        agent_response = retrieval_chain.run(
            query=user_input,
            stop=["Human: "],
115
            rejection_trigger_keywords=EN_RETRIEVAL_QA_TRIGGER_KEYWORDS,
116
117
118
119
            rejection_answer=EN_RETRIEVAL_QA_REJECTION_ANSWER,
        )
        agent_response = agent_response.split("\n")[0]
        print(f"Agent: {agent_response}")