base.py 8.57 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
"""
Chain for question-answering against a vector database.

Modified from Original Source

This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license.
"""
from __future__ import annotations

import copy
import inspect
from typing import Any, Dict, List, Optional

from colossalqa.chain.retrieval_qa.load_chain import load_qa_chain
from colossalqa.chain.retrieval_qa.stuff import CustomStuffDocumentsChain
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, Callbacks
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
from langchain.chains.retrieval_qa.base import BaseRetrievalQA
from langchain.prompts import PromptTemplate
from langchain.pydantic_v1 import Field
from langchain.schema import BaseRetriever, Document
from langchain.schema.language_model import BaseLanguageModel

27

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
class CustomBaseRetrievalQA(BaseRetrievalQA):
    """Base class for question-answering chains."""

    @classmethod
    def from_llm(
        cls,
        llm: BaseLanguageModel,
        prompt: Optional[PromptTemplate] = None,
        callbacks: Callbacks = None,
        **kwargs: Any,
    ) -> BaseRetrievalQA:
        """Initialize from LLM."""
        llm_kwargs = kwargs.pop("llm_kwargs", {})
        _prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
        llm_chain = LLMChain(llm=llm, prompt=_prompt, callbacks=callbacks, llm_kwargs=llm_kwargs)
        document_prompt = kwargs.get(
            "document_prompt", PromptTemplate(input_variables=["page_content"], template="Context:\n{page_content}")
        )
        combine_documents_chain = CustomStuffDocumentsChain(
            llm_chain=llm_chain,
            document_variable_name="context",
            document_prompt=document_prompt,
            callbacks=callbacks,
        )

        return cls(
            combine_documents_chain=combine_documents_chain,
            callbacks=callbacks,
            **kwargs,
        )

    @classmethod
    def from_chain_type(
        cls,
        llm: BaseLanguageModel,
        chain_type: str = "stuff",
        chain_type_kwargs: Optional[dict] = None,
        **kwargs: Any,
    ) -> BaseRetrievalQA:
        """Load chain from chain type."""
        llm_kwargs = kwargs.pop("llm_kwargs", {})
        _chain_type_kwargs = chain_type_kwargs or {}
        combine_documents_chain = load_qa_chain(llm, chain_type=chain_type, **_chain_type_kwargs, llm_kwargs=llm_kwargs)
        return cls(combine_documents_chain=combine_documents_chain, **kwargs)

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        """Run get_relevant_text and llm on input query.

        If chain has 'return_source_documents' as 'True', returns
        the retrieved documents as well under the key 'source_documents'.

        Example:
        .. code-block:: python

        res = indexqa({'query': 'This is my query'})
        answer, docs = res['result'], res['source_documents']
        """
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        question = inputs[self.input_key]
        accepts_run_manager = "run_manager" in inspect.signature(self._get_docs).parameters
        if accepts_run_manager:
            docs = self._get_docs(question, run_manager=_run_manager)
        else:
            docs = self._get_docs(question)  # type: ignore[call-arg]

        kwargs = {
            k: v
            for k, v in inputs.items()
            if k in ["stop", "temperature", "top_k", "top_p", "max_new_tokens", "doc_prefix"]
        }
        if self.combine_documents_chain.memory is not None:
            buffered_history_backup, summarized_history_temp_backup = copy.deepcopy(
                self.combine_documents_chain.memory.buffered_history
            ), copy.deepcopy(self.combine_documents_chain.memory.summarized_history_temp)
        else:
            buffered_history_backup = None
            summarized_history_temp_backup = None

        answer = self.combine_documents_chain.run(
            input_documents=docs, question=question, callbacks=_run_manager.get_child(), **kwargs
        )
        if summarized_history_temp_backup is not None and buffered_history_backup is not None:
            (
                self.combine_documents_chain.memory.buffered_history,
                self.combine_documents_chain.memory.summarized_history_temp,
            ) = copy.deepcopy(buffered_history_backup), copy.deepcopy(summarized_history_temp_backup)

        # if rejection_trigger_keywords is not given, return the response from LLM directly
120
        rejection_trigger_keywords = inputs.get("rejection_trigger_keywords", [])
121
        answer = answer if all([rej not in answer for rej in rejection_trigger_keywords]) else None
122
123
        if answer is None:
            answer = inputs.get("rejection_answer", "抱歉,根据提供的信息无法回答该问题。")
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        if self.combine_documents_chain.memory is not None:
            self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})

        if self.return_source_documents:
            return {self.output_key: answer, "source_documents": docs}
        else:
            return {self.output_key: answer}

    async def _acall(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        """Run get_relevant_text and llm on input query.

        If chain has 'return_source_documents' as 'True', returns
        the retrieved documents as well under the key 'source_documents'.

        Example:
        .. code-block:: python

        res = indexqa({'query': 'This is my query'})
        answer, docs = res['result'], res['source_documents']
        """
        _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
        question = inputs[self.input_key]
        accepts_run_manager = "run_manager" in inspect.signature(self._aget_docs).parameters
        if accepts_run_manager:
            docs = await self._aget_docs(question, run_manager=_run_manager)
        else:
            docs = await self._aget_docs(question)  # type: ignore[call-arg]
        kwargs = {
            k: v
            for k, v in inputs.items()
            if k in ["stop", "temperature", "top_k", "top_p", "max_new_tokens", "doc_prefix"]
        }
        answer = await self.combine_documents_chain.arun(
            input_documents=docs, question=question, callbacks=_run_manager.get_child(), **kwargs
        )
        # if rejection_trigger_keywords is not given, return the response from LLM directly
164
165
166
167
168
169
        rejection_trigger_keywords = inputs.get("rejection_trigger_keywords", [])
        answer = (
            answer
            if all([rej not in answer for rej in rejection_trigger_keywords]) or len(rejection_trigger_keywords) == 0
            else None
        )
170
        if answer is None:
171
            answer = inputs.get("rejection_answer", "抱歉,根据提供的信息无法回答该问题。")
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
        self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})

        if self.return_source_documents:
            return {self.output_key: answer, "source_documents": docs}
        else:
            return {self.output_key: answer}


class RetrievalQA(CustomBaseRetrievalQA):
    """Chain for question-answering against an index.

    Example:
        .. code-block:: python

            from langchain.llms import OpenAI
            from langchain.chains import RetrievalQA
            from langchain.faiss import FAISS
            from langchain.vectorstores.base import VectorStoreRetriever
            retriever = VectorStoreRetriever(vectorstore=FAISS(...))
            retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=retriever)

    """

    retriever: BaseRetriever = Field(exclude=True)

    def _get_docs(
        self,
        question: str,
        *,
        run_manager: CallbackManagerForChainRun,
    ) -> List[Document]:
        """Get docs."""
        return self.retriever.get_relevant_documents(question, callbacks=run_manager.get_child())

    async def _aget_docs(
        self,
        question: str,
        *,
        run_manager: AsyncCallbackManagerForChainRun,
    ) -> List[Document]:
        """Get docs."""
        return await self.retriever.aget_relevant_documents(question, callbacks=run_manager.get_child())

    @property
    def _chain_type(self) -> str:
        """Return the chain type."""
        return "retrieval_qa"