rag_embedding_api_example.py 2.84 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""A RAG example using the OpenAPIEmbeddings.

Example:

    Test with `OpenAI embeddings
    <https://platform.openai.com/docs/api-reference/embeddings/create>`_.

    .. code-block:: shell

        export API_SERVER_BASE_URL=${OPENAI_API_BASE:-"https://api.openai.com/v1"}
        export API_SERVER_API_KEY="${OPENAI_API_KEY}"
        export API_SERVER_EMBEDDINGS_MODEL="text-embedding-ada-002"
        python examples/rag/rag_embedding_api_example.py

    Test with DB-GPT `API Server
    <https://docs.dbgpt.site/docs/installation/advanced_usage/OpenAI_SDK_call#start-apiserver>`_.

    .. code-block:: shell
        export API_SERVER_BASE_URL="http://localhost:8100/api/v1"
        export API_SERVER_API_KEY="your_api_key"
        export API_SERVER_EMBEDDINGS_MODEL="text2vec"
        python examples/rag/rag_embedding_api_example.py

"""

import asyncio
import os
from typing import Optional

from dbgpt.configs.model_config import PILOT_PATH, ROOT_PATH
from dbgpt.rag.embedding import OpenAPIEmbeddings
from dbgpt_ext.rag import ChunkParameters
from dbgpt_ext.rag.assembler import EmbeddingAssembler
from dbgpt_ext.rag.knowledge import KnowledgeFactory
from dbgpt_ext.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig


def _create_embeddings(
    api_url: str = None, api_key: Optional[str] = None, model_name: Optional[str] = None
) -> OpenAPIEmbeddings:
    if not api_url:
        api_server_base_url = os.getenv(
            "API_SERVER_BASE_URL", "http://localhost:8100/api/v1/"
        )
        api_url = f"{api_server_base_url}/embeddings"
    if not api_key:
        api_key = os.getenv("API_SERVER_API_KEY")

    if not model_name:
        model_name = os.getenv("API_SERVER_EMBEDDINGS_MODEL", "text2vec")

    return OpenAPIEmbeddings(api_url=api_url, api_key=api_key, model_name=model_name)


def _create_vector_connector():
    """Create vector connector."""
    config = ChromaVectorConfig(
        persist_path=PILOT_PATH,
    )

    return ChromaStore(
        config,
        name="embedding_rag_test",
        embedding_fn=_create_embeddings(),
    )


async def main():
    file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
    knowledge = KnowledgeFactory.from_file_path(file_path)
    vector_store = _create_vector_connector()
    chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
    # get embedding assembler
    assembler = EmbeddingAssembler.load_from_knowledge(
        knowledge=knowledge,
        chunk_parameters=chunk_parameters,
        index_store=vector_store,
    )
    assembler.persist()
    # get embeddings retriever
    retriever = assembler.as_retriever(3)
    chunks = await retriever.aretrieve_with_scores("what is awel talk about", 0.3)
    print(f"embedding rag example results:{chunks}")
    vector_store.delete_vector_name("embedding_api_rag_test")


if __name__ == "__main__":
    asyncio.run(main())