"extensions/csrc/cuda/multi_tensor_sgd_kernel.cu" did not exist on "0772828fba9b9dfc07e5e319b324642ecb0455e9"
memory.py 7.56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
Implement a memory class for storing conversation history
Support long term and short term memory
"""
from typing import Any, Dict, List

from colossalqa.chain.memory.summary import ConversationSummaryMemory
from colossalqa.chain.retrieval_qa.load_chain import load_qa_chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
from langchain.schema import BaseChatMessageHistory
from langchain.schema.messages import BaseMessage
from langchain.schema.retriever import BaseRetriever
from pydantic import Field


class ConversationBufferWithSummary(ConversationSummaryMemory):
    """Memory class for storing information about entities."""

    # Define dictionary to store information about entities.
    # Store the most recent conversation history
    buffered_history: BaseChatMessageHistory = Field(default_factory=ChatMessageHistory)
    # Temp buffer
    summarized_history_temp: BaseChatMessageHistory = Field(default_factory=ChatMessageHistory)
    human_prefix: str = "Human"
    ai_prefix: str = "Assistant"
    buffer: str = ""  # Formated conversation in str
    existing_summary: str = ""  # Summarization of stale converstion in str
    # Define key to pass information about entities into prompt.
    memory_key: str = "chat_history"
    input_key: str = "question"
    retriever: BaseRetriever = None
    max_tokens: int = 2000
    chain: BaseCombineDocumentsChain = None
    input_chain_type_kwargs: List = {}

    @property
    def buffer(self) -> Any:
        """String buffer of memory."""
        return self.buffer_as_messages if self.return_messages else self.buffer_as_str

    @property
    def buffer_as_str(self) -> str:
        """Exposes the buffer as a string in case return_messages is True."""
        self.buffer = self.format_dialogue()
        return self.buffer

    @property
    def buffer_as_messages(self) -> List[BaseMessage]:
        """Exposes the buffer as a list of messages in case return_messages is False."""
        return self.buffered_history.messages

    def clear(self):
        """Clear all the memory"""
        self.buffered_history.clear()
        self.summarized_history_temp.clear()

    def initiate_document_retrieval_chain(
        self, llm: Any, prompt_template: Any, retriever: Any, chain_type_kwargs: Dict[str, Any] = {}
    ) -> None:
        """
        Since we need to calculate the length of the prompt, we need to initiate a retrieval chain
        to calculate the length of the prompt.
        Args:
            llm: the language model for the retrieval chain (we won't actually return the output)
            prompt_template: the prompt template for constructing the retrieval chain
            retriever: the retriever for the retrieval chain
            max_tokens: the max length of the prompt (not include the output)
            chain_type_kwargs: the kwargs for the retrieval chain
            memory_key: the key for the chat history
            input_key: the key for the input query
        """
        self.retriever = retriever
        input_chain_type_kwargs = {k: v for k, v in chain_type_kwargs.items() if k not in [self.memory_key]}
        self.input_chain_type_kwargs = input_chain_type_kwargs
        self.chain = load_qa_chain(llm, chain_type="stuff", prompt=prompt_template, **self.input_chain_type_kwargs)

    @property
    def memory_variables(self) -> List[str]:
        """Define the variables we are providing to the prompt."""
        return [self.memory_key]

    def format_dialogue(self, lang: str = "en") -> str:
        """Format memory into two parts--- summarization of historical conversation and most recent conversation"""
        if len(self.summarized_history_temp.messages) != 0:
            for i in range(int(len(self.summarized_history_temp.messages) / 2)):
                self.existing_summary = (
                    self.predict_new_summary(
                        self.summarized_history_temp.messages[i * 2 : i * 2 + 2], self.existing_summary, stop=["\n\n"]
                    )
                    .strip()
                    .split("\n")[0]
                    .strip()
                )
            for i in range(int(len(self.summarized_history_temp.messages) / 2)):
                self.summarized_history_temp.messages.pop(0)
                self.summarized_history_temp.messages.pop(0)
        conversation_buffer = []
        for t in self.buffered_history.messages:
            if t.type == "human":
                prefix = self.human_prefix
            else:
                prefix = self.ai_prefix
            conversation_buffer.append(prefix + ": " + t.content)
        conversation_buffer = "\n".join(conversation_buffer)
        if len(self.existing_summary) > 0:
            if lang == "en":
                message = f"A summarization of historical conversation:\n{self.existing_summary}\nMost recent conversation:\n{conversation_buffer}"
            elif lang == "zh":
                message = f"历史对话概要:\n{self.existing_summary}\n最近的对话:\n{conversation_buffer}"
            else:
                raise ValueError("Unsupported language")
            return message
        else:
            message = conversation_buffer
            return message

    def get_conversation_length(self):
        """Get the length of the formatted conversation"""
        prompt = self.format_dialogue()
        length = self.llm.get_num_tokens(prompt)
        return length

    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
        """Load the memory variables.
        Summarize oversize conversation to fit into the length constraint defined by max_tokene
        Args:
            inputs: the kwargs of the chain of your definition
        Returns:
            a dict that maps from memory key to the formated dialogue
            the formated dialogue has the following format
            if conversation is too long:
                A summarization of historical conversation:
                {summarization}
                Most recent conversation:
                Human: XXX
                Assistant: XXX
                ...
            otherwise
                Human: XXX
                Assistant: XXX
                ...
        """
        # Calculate remain length
        if "input_documents" in inputs:
            # Run in a retrieval qa chain
            docs = inputs["input_documents"]
        else:
            # For test
            docs = self.retriever.get_relevant_documents(inputs[self.input_key])
        inputs[self.memory_key] = ""
        inputs = {k: v for k, v in inputs.items() if k in [self.chain.input_key, self.input_key, self.memory_key]}
        prompt_length = self.chain.prompt_length(docs, **inputs)
        remain = self.max_tokens - prompt_length
        while self.get_conversation_length() > remain:
            if len(self.buffered_history.messages) <= 2:
Michelle's avatar
Michelle committed
157
                raise RuntimeError("Exceed max_tokens, trunk size of retrieved documents is too large")
158
159
160
161
162
163
164
165
166
167
168
            temp = self.buffered_history.messages.pop(0)
            self.summarized_history_temp.messages.append(temp)
            temp = self.buffered_history.messages.pop(0)
            self.summarized_history_temp.messages.append(temp)
        return {self.memory_key: self.format_dialogue()}

    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
        """Save context from this conversation to buffer."""
        input_str, output_str = self._get_input_output(inputs, outputs)
        self.buffered_history.add_user_message(input_str.strip())
        self.buffered_history.add_ai_message(output_str.strip())