simple_rag_rewrite_example.py 2.34 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""AWEL: Simple rag rewrite example

    pre-requirements:
        1. install openai python sdk
        ```
            pip install openai
        ```
        2. set openai key and base
        ```
            export OPENAI_API_KEY={your_openai_key}
            export OPENAI_API_BASE={your_openai_base}
        ```
        or
        ```
            import os
            os.environ["OPENAI_API_KEY"] = {your_openai_key}
            os.environ["OPENAI_API_BASE"] = {your_openai_base}
        ```
        python examples/awel/simple_rag_rewrite_example.py
    Example:

    .. code-block:: shell

        DBGPT_SERVER="http://127.0.0.1:5555"
        curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/rag/rewrite \
        -H "Content-Type: application/json" -d '{
            "query": "compare curry and james",
            "context":"steve curry and lebron james are nba all-stars"
        }'
"""

import os
from typing import Dict

from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
from dbgpt.model.proxy import OpenAILLMClient
from dbgpt.rag.operators import QueryRewriteOperator


class TriggerReqBody(BaseModel):
    query: str = Field(..., description="User query")
    context: str = Field(..., description="context")


class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    async def map(self, input_value: TriggerReqBody) -> Dict:
        params = {
            "query": input_value.query,
            "context": input_value.context,
        }
        print(f"Receive input value: {input_value}")
        return params


with DAG("dbgpt_awel_simple_rag_rewrite_example") as dag:
    trigger = HttpTrigger(
        "/examples/rag/rewrite", methods="POST", request_body=TriggerReqBody
    )
    request_handle_task = RequestHandleOperator()
    # build query rewrite operator
    rewrite_task = QueryRewriteOperator(
        llm_client=OpenAILLMClient(api_key=os.getenv("OPENAI_API_KEY", "your api key")),
        nums=2,
    )
    trigger >> request_handle_task >> rewrite_task


if __name__ == "__main__":
    if dag.leaf_nodes[0].dev_mode:
        # Development mode, you can run the dag locally for debugging.
        from dbgpt.core.awel import setup_dev_environment

        setup_dev_environment([dag], port=5555)
    else:
        pass