chat.py 1.62 KB
Newer Older
dengjb's avatar
dengjb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""
References: https://docs.llamaindex.ai/en/stable/use_cases/q_and_a/
"""
import argparse

import gradio as gr
from llama_index.core import Settings

from models.embedding import GLMEmbeddings
from models.synthesizer import CodegeexSynthesizer
from utils.vector import load_vectors


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--vector_path', type=str, help="path to store the vectors", default='vectors')
    parser.add_argument('--model_name_or_path', type=str, default='THUDM/codegeex4-all-9b')
    parser.add_argument('--device', type=str, help="cpu or cuda", default="cpu")
    parser.add_argument('--temperature', type=float, help="model's temperature", default=0.2)
    return parser.parse_args()


def chat(query, history):
    resp = query_engine.query(query)

    ans = "相关文档".center(150, '-') + '\n'
    yield ans
    for i, node in enumerate(resp.source_nodes):
        file_name = node.metadata['filename']
        ext = node.metadata['extension']
        text = node.text
        ans += f"File{i + 1}: {file_name}\n```{ext}\n{text}\n```\n"
        yield ans

    ans += "模型回复".center(150, '-') + '\n'
    ans += resp.response
    yield ans


if __name__ == '__main__':
    args = parse_arguments()
    Settings.embed_model = GLMEmbeddings()
    try:
        query_engine = load_vectors(args.vector_path).as_query_engine(
            response_synthesizer=CodegeexSynthesizer(args)
        )
    except Exception as e:
        print(f"Fail to load vectors, caused by {e}")
        exit()
    demo = gr.ChatInterface(chat).queue()
    demo.launch(server_name="127.0.0.1", server_port=8080)