chat_data_with_awel.py 9.09 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import asyncio
import json
import shutil

import pandas as pd

from dbgpt.configs.model_config import PILOT_PATH
from dbgpt.core import (
    ChatPromptTemplate,
    HumanPromptTemplate,
    SQLOutputParser,
    SystemPromptTemplate,
)
from dbgpt.core.awel import (
    DAG,
    BranchOperator,
    InputOperator,
    InputSource,
    JoinOperator,
    MapOperator,
    is_empty_data,
)
from dbgpt.core.operators import PromptBuilderOperator, RequestBuilderOperator
from dbgpt.datasource.operators import DatasourceOperator
from dbgpt.model.operators import LLMOperator
from dbgpt.model.proxy import OpenAILLMClient
from dbgpt.rag.embedding import DefaultEmbeddingFactory
from dbgpt_ext.datasource.rdbms.conn_sqlite import SQLiteTempConnector
from dbgpt_ext.rag import ChunkParameters
from dbgpt_ext.rag.operators import DBSchemaAssemblerOperator
from dbgpt_ext.rag.operators.db_schema import DBSchemaRetrieverOperator
from dbgpt_ext.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig

# Delete old vector store directory(/tmp/awel_with_data_vector_store)
shutil.rmtree("/tmp/awel_with_data_vector_store", ignore_errors=True)

embeddings = DefaultEmbeddingFactory.openai()

# Here we use the openai LLM model, if you want to use other models, you can replace
# it according to the previous example.
llm_client = OpenAILLMClient()

db_conn = SQLiteTempConnector.create_temporary_db()
db_conn.create_temp_tables(
    {
        "user": {
            "columns": {
                "id": "INTEGER PRIMARY KEY",
                "name": "TEXT",
                "age": "INTEGER",
            },
            "data": [
                (1, "Tom", 10),
                (2, "Jerry", 16),
                (3, "Jack", 18),
                (4, "Alice", 20),
                (5, "Bob", 22),
            ],
        }
    }
)

config = ChromaVectorConfig(
    persist_path=PILOT_PATH,
    name="db_schema_vector_store",
    embedding_fn=embeddings,
)
vector_store = ChromaStore(config)

antv_charts = [
    {"response_line_chart": "used to display comparative trend analysis data"},
    {
        "response_pie_chart": "suitable for scenarios such as proportion and distribution statistics"
    },
    {
        "response_table": "suitable for display with many display columns or non-numeric columns"
    },
    # {"response_data_text":" the default display method, suitable for single-line or simple content display"},
    {
        "response_scatter_plot": "Suitable for exploring relationships between variables, detecting outliers, etc."
    },
    {
        "response_bubble_chart": "Suitable for relationships between multiple variables, highlighting outliers or special situations, etc."
    },
    {
        "response_donut_chart": "Suitable for hierarchical structure representation, category proportion display and highlighting key categories, etc."
    },
    {
        "response_area_chart": "Suitable for visualization of time series data, comparison of multiple groups of data, analysis of data change trends, etc."
    },
    {
        "response_heatmap": "Suitable for visual analysis of time series data, large-scale data sets, distribution of classified data, etc."
    },
]
display_type = "\n".join(
    f"{key}:{value}" for dict_item in antv_charts for key, value in dict_item.items()
)

system_prompt = """You are a database expert. Please answer the user's question based on the database selected by the user and some of the available table structure definitions of the database.
Database name:
    {db_name}
Table structure definition:
    {table_info}

Constraint:
1.Please understand the user's intention based on the user's question, and use the given table structure definition to create a grammatically correct {dialect} sql. If sql is not required, answer the user's question directly.. 
2.Always limit the query to a maximum of {top_k} results unless the user specifies in the question the specific number of rows of data he wishes to obtain.
3.You can only use the tables provided in the table structure information to generate sql. If you cannot generate sql based on the provided table structure, please say: "The table structure information provided is not enough to generate sql queries." It is prohibited to fabricate information at will.
4.Please be careful not to mistake the relationship between tables and columns when generating SQL.
5.Please check the correctness of the SQL and ensure that the query performance is optimized under correct conditions.
6.Please choose the best one from the display methods given below for data rendering, and put the type name into the name parameter value that returns the required format. If you cannot find the most suitable one, use 'Table' as the display method.
the available data display methods are as follows: {display_type}

User Question:
    {user_input}
Please think step by step and respond according to the following JSON format:
    {response}
Ensure the response is correct json and can be parsed by Python json.loads.
"""

RESPONSE_FORMAT_SIMPLE = {
    "thoughts": "thoughts summary to say to user",
    "sql": "SQL Query to run",
    "display_type": "Data display method",
}

prompt = ChatPromptTemplate(
    messages=[
        SystemPromptTemplate.from_template(
            system_prompt,
            response_format=json.dumps(
                RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4
            ),
        ),
        HumanPromptTemplate.from_template("{user_input}"),
    ]
)


class TwoSumOperator(MapOperator[pd.DataFrame, int]):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    async def map(self, df: pd.DataFrame) -> int:
        return await self.blocking_func_to_async(self._two_sum, df)

    def _two_sum(self, df: pd.DataFrame) -> int:
        return df["age"].sum()


def branch_even(x: int) -> bool:
    return x % 2 == 0


def branch_odd(x: int) -> bool:
    return not branch_even(x)


class DataDecisionOperator(BranchOperator[int, int]):
    def __init__(self, odd_task_name: str, even_task_name: str, **kwargs):
        super().__init__(**kwargs)
        self.odd_task_name = odd_task_name
        self.even_task_name = even_task_name

    async def branches(self):
        return {branch_even: self.even_task_name, branch_odd: self.odd_task_name}


class OddOperator(MapOperator[int, str]):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    async def map(self, x: int) -> str:
        print(f"{x} is odd")
        return f"{x} is odd"


class EvenOperator(MapOperator[int, str]):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    async def map(self, x: int) -> str:
        print(f"{x} is even")
        return f"{x} is even"


class MergeOperator(JoinOperator[str]):
    def __init__(self, **kwargs):
        super().__init__(combine_function=self.merge_func, **kwargs)

    async def merge_func(self, odd: str, even: str) -> str:
        return odd if not is_empty_data(odd) else even


with DAG("load_schema_dag") as load_schema_dag:
    input_task = InputOperator.dummy_input()
    # Load database schema to vector store
    assembler_task = DBSchemaAssemblerOperator(
        connector=db_conn,
        index_store=vector_store,
        chunk_parameters=ChunkParameters(chunk_strategy="CHUNK_BY_SIZE"),
    )
    input_task >> assembler_task

chunks = asyncio.run(assembler_task.call())
print(chunks)


with DAG("chat_data_dag") as chat_data_dag:
    input_task = InputOperator(input_source=InputSource.from_callable())
    retriever_task = DBSchemaRetrieverOperator(
        top_k=1,
        index_store=vector_store,
    )
    content_task = MapOperator(lambda cks: [c.content for c in cks])
    merge_task = JoinOperator(
        lambda table_info, ext_dict: {"table_info": table_info, **ext_dict}
    )
    prompt_task = PromptBuilderOperator(prompt)
    req_build_task = RequestBuilderOperator(model="gpt-3.5-turbo")
    llm_task = LLMOperator(llm_client=llm_client)
    sql_parse_task = SQLOutputParser()
    db_query_task = DatasourceOperator(connector=db_conn)

    (
        input_task
        >> MapOperator(lambda x: x["user_input"])
        >> retriever_task
        >> content_task
        >> merge_task
    )
    input_task >> merge_task
    merge_task >> prompt_task >> req_build_task >> llm_task >> sql_parse_task
    sql_parse_task >> MapOperator(lambda x: x["sql"]) >> db_query_task

    two_sum_task = TwoSumOperator()
    decision_task = DataDecisionOperator(
        odd_task_name="odd_task", even_task_name="even_task"
    )
    odd_task = OddOperator(task_name="odd_task")
    even_task = EvenOperator(task_name="even_task")
    merge_task = MergeOperator()

    db_query_task >> two_sum_task >> decision_task
    decision_task >> odd_task >> merge_task
    decision_task >> even_task >> merge_task

final_result = asyncio.run(
    merge_task.call(
        {
            "user_input": "Query the name and age of users younger than 18 years old",
            "db_name": "user_management",
            "dialect": "SQLite",
            "top_k": 1,
            "display_type": display_type,
            "response": json.dumps(
                RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4
            ),
        }
    )
)
print("The final result is:")
print(final_result)