graph_rag_example.py 4.01 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os

import pytest

from dbgpt.configs.model_config import ROOT_PATH
from dbgpt.core import Chunk, HumanPromptTemplate, ModelMessage, ModelRequest
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
from dbgpt.rag.embedding import DefaultEmbeddingFactory
from dbgpt.rag.retriever import RetrieverStrategy
from dbgpt_ext.rag import ChunkParameters
from dbgpt_ext.rag.assembler import EmbeddingAssembler
from dbgpt_ext.rag.knowledge import KnowledgeFactory
from dbgpt_ext.storage.graph_store.tugraph_store import TuGraphStoreConfig
from dbgpt_ext.storage.knowledge_graph.community_summary import (
    CommunitySummaryKnowledgeGraph,
)
from dbgpt_ext.storage.knowledge_graph.knowledge_graph import (
    BuiltinKnowledgeGraph,
)

"""GraphRAG example.
    ```
    # Set LLM config (url/sk) in `.env`.
    # Install pytest utils: `pip install pytest pytest-asyncio`
    GRAPH_STORE_TYPE=TuGraph
    TUGRAPH_HOST=127.0.0.1
    TUGRAPH_PORT=7687
    TUGRAPH_USERNAME=admin
    TUGRAPH_PASSWORD=73@TuGraph
    ```
    Examples:
        ..code-block:: shell
            pytest -s examples/rag/graph_rag_example.py
"""

llm_client = OpenAILLMClient()
model_name = "gpt-4o-mini"


@pytest.mark.asyncio
async def test_naive_graph_rag():
    await __run_graph_rag(
        knowledge_file="examples/test_files/graphrag-mini.md",
        chunk_strategy="CHUNK_BY_SIZE",
        knowledge_graph=__create_naive_kg_connector(),
        question="What's the relationship between TuGraph and DB-GPT ?",
    )


@pytest.mark.asyncio
async def test_community_graph_rag():
    await __run_graph_rag(
        knowledge_file="examples/test_files/graphrag-mini.md",
        chunk_strategy="CHUNK_BY_MARKDOWN_HEADER",
        knowledge_graph=__create_community_kg_connector(),
        question="What's the relationship between TuGraph and DB-GPT ?",
    )


def __create_naive_kg_connector():
    """Create knowledge graph connector."""
    return BuiltinKnowledgeGraph(
        config=TuGraphStoreConfig(),
        name="naive_graph_rag_test",
        embedding_fn=None,
        llm_client=llm_client,
        llm_model=model_name,
    )


def __create_community_kg_connector():
    """Create community knowledge graph connector."""
    return CommunitySummaryKnowledgeGraph(
        config=TuGraphStoreConfig(),
        name="community_graph_rag_test",
        embedding_fn=DefaultEmbeddingFactory.openai(),
        llm_client=llm_client,
        llm_model=model_name,
    )


async def ask_chunk(chunk: Chunk, question) -> str:
    rag_template = (
        "Based on the following [Context] {context}, answer [Question] {question}."
    )
    template = HumanPromptTemplate.from_template(rag_template)
    messages = template.format_messages(context=chunk.content, question=question)
    model_messages = ModelMessage.from_base_messages(messages)
    request = ModelRequest(model=model_name, messages=model_messages)
    response = await llm_client.generate(request=request)

    if not response.success:
        code = str(response.error_code)
        reason = response.text
        raise Exception(f"request llm failed ({code}) {reason}")

    return response.text


async def __run_graph_rag(knowledge_file, chunk_strategy, knowledge_graph, question):
    file_path = os.path.join(ROOT_PATH, knowledge_file).format()
    knowledge = KnowledgeFactory.from_file_path(file_path)
    try:
        chunk_parameters = ChunkParameters(chunk_strategy=chunk_strategy)

        # get embedding assembler
        assembler = await EmbeddingAssembler.aload_from_knowledge(
            knowledge=knowledge,
            chunk_parameters=chunk_parameters,
            index_store=knowledge_graph,
            retrieve_strategy=RetrieverStrategy.GRAPH,
        )
        await assembler.apersist()

        # get embeddings retriever
        retriever = assembler.as_retriever(1)
        chunks = await retriever.aretrieve_with_scores(question, score_threshold=0.3)

        # chat
        print(f"{await ask_chunk(chunks[0], question)}")

    finally:
        knowledge_graph.delete_vector_name(knowledge_graph.get_config().name)