simple_rag_retriever_example.py 3.84 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""AWEL: Simple rag embedding operator example

    pre-requirements:
        1. install openai python sdk

        ```
            pip install openai
        ```
         2. set openai key and base
        ```
            export OPENAI_API_KEY={your_openai_key}
            export OPENAI_API_BASE={your_openai_base}
        ```
        3. make sure you have vector store.
        if there are no data in vector store, please run examples/awel/simple_rag_embedding_example.py


    ensure your embedding model in DB-GPT/models/.

    Examples:
        ..code-block:: shell
            DBGPT_SERVER="http://127.0.0.1:5555"
            curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/rag/retrieve \
            -H "Content-Type: application/json" -d '{ \
                "query": "what is awel talk about?"
            }'
"""

import os
from typing import Dict, List

from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt.core import Chunk
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.model.proxy import OpenAILLMClient
from dbgpt.rag.embedding import DefaultEmbeddingFactory
from dbgpt_ext.rag.operators import (
    EmbeddingRetrieverOperator,
    QueryRewriteOperator,
    RerankOperator,
)
from dbgpt_ext.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig


class TriggerReqBody(BaseModel):
    query: str = Field(..., description="User query")


class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    async def map(self, input_value: TriggerReqBody) -> Dict:
        params = {
            "query": input_value.query,
        }
        print(f"Receive input value: {input_value}")
        return params


def _context_join_fn(context_dict: Dict, chunks: List[Chunk]) -> Dict:
    """context Join function for JoinOperator.

    Args:
        context_dict (Dict): context dict
        chunks (List[Chunk]): chunks
    Returns:
        Dict: context dict
    """
    context_dict["context"] = "\n".join([chunk.content for chunk in chunks])
    return context_dict


def _create_vector_connector():
    """Create vector connector."""
    config = ChromaVectorConfig(
        persist_path=PILOT_PATH,
    )

    return ChromaStore(
        config,
        name="embedding_rag_test",
        embedding_fn=DefaultEmbeddingFactory(
            default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
        ).create(),
    )


with DAG("simple_sdk_rag_retriever_example") as dag:
    vector_store = _create_vector_connector()
    trigger = HttpTrigger(
        "/examples/rag/retrieve", methods="POST", request_body=TriggerReqBody
    )
    request_handle_task = RequestHandleOperator()
    query_parser = MapOperator(map_function=lambda x: x["query"])
    context_join_operator = JoinOperator(combine_function=_context_join_fn)
    rewrite_operator = QueryRewriteOperator(llm_client=OpenAILLMClient())
    retriever_context_operator = EmbeddingRetrieverOperator(
        top_k=3,
        index_store=vector_store,
    )
    retriever_operator = EmbeddingRetrieverOperator(
        top_k=3,
        index_store=vector_store,
    )
    rerank_operator = RerankOperator()
    model_parse_task = MapOperator(lambda out: out.to_dict())

    trigger >> request_handle_task >> context_join_operator
    (
        trigger
        >> request_handle_task
        >> query_parser
        >> retriever_context_operator
        >> context_join_operator
    )
    context_join_operator >> rewrite_operator >> retriever_operator >> rerank_operator

if __name__ == "__main__":
    if dag.leaf_nodes[0].dev_mode:
        # Development mode, you can run the dag locally for debugging.
        from dbgpt.core.awel import setup_dev_environment

        setup_dev_environment([dag], port=5555)
    else:
        pass