simple_dbschema_retriever_example.py 4.15 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""AWEL: Simple rag db schema embedding operator example

    if you not set vector_store_connector, it will return all tables schema in database.
    ```
    retriever_task = DBSchemaRetrieverOperator(
        connector=_create_temporary_connection()
    )
    ```
    if you set vector_store_connector, it will recall topk similarity tables schema in database.
    ```
    retriever_task = DBSchemaRetrieverOperator(
        connector=_create_temporary_connection()
        top_k=1,
        index_store=vector_store_connector
    )
    ```

    Examples:
        ..code-block:: shell
            curl --location 'http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/dbschema' \
            --header 'Content-Type: application/json' \
            --data '{"query": "what is user name?"}'
"""

import os
from typing import Dict, List

from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt.core import Chunk
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.rag.embedding import DefaultEmbeddingFactory
from dbgpt_ext.datasource.rdbms.conn_sqlite import SQLiteTempConnector
from dbgpt_ext.rag.operators import DBSchemaAssemblerOperator
from dbgpt_ext.rag.operators.db_schema import DBSchemaRetrieverOperator
from dbgpt_ext.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig


def _create_vector_connector():
    """Create vector connector."""
    config = ChromaVectorConfig(
        persist_path=PILOT_PATH,
    )

    return ChromaStore(
        config,
        name="embedding_rag_test",
        embedding_fn=DefaultEmbeddingFactory(
            default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
        ).create(),
    )


def _create_temporary_connection():
    """Create a temporary database connection for testing."""
    connect = SQLiteTempConnector.create_temporary_db()
    connect.create_temp_tables(
        {
            "user": {
                "columns": {
                    "id": "INTEGER PRIMARY KEY",
                    "name": "TEXT",
                    "age": "INTEGER",
                },
                "data": [
                    (1, "Tom", 10),
                    (2, "Jerry", 16),
                    (3, "Jack", 18),
                    (4, "Alice", 20),
                    (5, "Bob", 22),
                ],
            }
        }
    )
    return connect


def _join_fn(chunks: List[Chunk], query: str) -> str:
    print(f"db schema info is {[chunk.content for chunk in chunks]}")
    return query


class TriggerReqBody(BaseModel):
    query: str = Field(..., description="User query")


class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    async def map(self, input_value: TriggerReqBody) -> Dict:
        params = {
            "query": input_value.query,
        }
        print(f"Receive input value: {input_value}")
        return params


with DAG("simple_rag_db_schema_example") as dag:
    trigger = HttpTrigger(
        "/examples/rag/dbschema", methods="POST", request_body=TriggerReqBody
    )
    request_handle_task = RequestHandleOperator()
    query_operator = MapOperator(lambda request: request["query"])
    index_store = _create_vector_connector()
    connector = _create_temporary_connection()
    assembler_task = DBSchemaAssemblerOperator(
        connector=connector,
        index_store=index_store,
    )
    join_operator = JoinOperator(combine_function=_join_fn)
    retriever_task = DBSchemaRetrieverOperator(
        connector=_create_temporary_connection(),
        top_k=1,
        index_store=index_store,
    )
    result_parse_task = MapOperator(lambda chunks: [chunk.content for chunk in chunks])
    trigger >> assembler_task >> join_operator
    trigger >> request_handle_task >> query_operator >> join_operator
    join_operator >> retriever_task >> result_parse_task


if __name__ == "__main__":
    if dag.leaf_nodes[0].dev_mode:
        # Development mode, you can run the dag locally for debugging.
        from dbgpt.core.awel import setup_dev_environment

        setup_dev_environment([dag], port=5555)
    else:
        pass