simple_rag_embedding_example.py 2.94 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""AWEL: Simple rag embedding operator example.

    Examples:
    pre-requirements:
        python examples/awel/simple_rag_embedding_example.py
        ..code-block:: shell
            curl --location --request POST 'http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/embedding' \
            --header 'Content-Type: application/json' \
            --data-raw '{
              "url": "https://docs.dbgpt.site/docs/latest/awel/"
            }'
"""

import os
from typing import Dict, List

from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
from dbgpt.rag.embedding import DefaultEmbeddingFactory
from dbgpt.rag.knowledge import KnowledgeType
from dbgpt_ext.rag.operators import EmbeddingAssemblerOperator, KnowledgeOperator
from dbgpt_ext.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig


def _create_vector_connector():
    """Create vector connector."""
    config = ChromaVectorConfig(
        persist_path=PILOT_PATH,
    )

    return ChromaStore(
        config,
        name="embedding_rag_test",
        embedding_fn=DefaultEmbeddingFactory(
            default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
        ).create(),
    )


class TriggerReqBody(BaseModel):
    url: str = Field(..., description="url")


class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    async def map(self, input_value: TriggerReqBody) -> Dict:
        params = {
            "url": input_value.url,
        }
        print(f"Receive input value: {input_value}")
        return params


class ResultOperator(MapOperator):
    """The Result Operator."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    async def map(self, chunks: List) -> str:
        result = f"embedding success, there are {len(chunks)} chunks."
        print(result)
        return result


with DAG("simple_sdk_rag_embedding_example") as dag:
    trigger = HttpTrigger(
        "/examples/rag/embedding", methods="POST", request_body=TriggerReqBody
    )
    request_handle_task = RequestHandleOperator()
    knowledge_operator = KnowledgeOperator(knowledge_type=KnowledgeType.URL.name)
    vector_store = _create_vector_connector()
    url_parser_operator = MapOperator(map_function=lambda x: x["url"])
    embedding_operator = EmbeddingAssemblerOperator(
        index_store=vector_store,
    )
    output_task = ResultOperator()
    (
        trigger
        >> request_handle_task
        >> url_parser_operator
        >> knowledge_operator
        >> embedding_operator
        >> output_task
    )

if __name__ == "__main__":
    if dag.leaf_nodes[0].dev_mode:
        # Development mode, you can run the dag locally for debugging.
        from dbgpt.core.awel import setup_dev_environment

        setup_dev_environment([dag], port=5555)
    else:
        pass