main.py 2.81 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#!/usr/bin/env python3
import argparse
import os
from multiprocessing import Process, Value
from loguru import logger
from llm_service import Worker, llm_inference


def check_envs(args):

    if all(isinstance(item, int) for item in args.DCU_ID):
        os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, args.DCU_ID))
        logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {args.DCU_ID}")
    else:
        logger.error(f"The --DCU_ID argument must be a list of integers, but got {args.DCU_ID}")
        raise ValueError("The --DCU_ID argument must be a list of integers")

def parse_args():
    """Parse args."""
    parser = argparse.ArgumentParser(description='Executor.')
    parser.add_argument(
        '--DCU_ID',
        default=[1,2,6,7],
        help='设置DCU')
    parser.add_argument(
        '--config_path',
        default='/path/to/your/ai/config.ini',
        type=str,
        help='config.ini路径')
    parser.add_argument(
        '--standalone',
        default=False,
        help='部署LLM推理服务.')
    parser.add_argument(
Rayyyyy's avatar
Rayyyyy committed
35
        '--use_vllm',
Rayyyyy's avatar
Rayyyyy committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        default=False,
        type=bool,
        help='LLM推理是否启用加速'
    )
    args = parser.parse_args()
    return args


def build_reply_text(reply: str, references: list):
    if len(references) < 1:
        return reply

    ret = reply
    for ref in references:
        ret += '\n'
        ret += ref
    return ret


def reply_workflow(assistant):

    queries = ['你好,我们公司想要购买几台测试机,请问需要联系贵公司哪位?']
    for query in queries:
        code, reply, references = assistant.produce_response(query=query,
                                                     history=[],
                                                     judgment=False)
        logger.info(f'{code}, {query}, {reply}, {references}')


def run():
    args = parse_args()
    if args.standalone is True:
        import time
        check_envs(args)
        server_ready = Value('i', 0)
        server_process = Process(target=llm_inference,
                                 args=(args.config_path,
                                       len(args.DCU_ID),
Rayyyyy's avatar
Rayyyyy committed
74
                                       args.use_vllm,
Rayyyyy's avatar
Rayyyyy committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
                                       server_ready))

        server_process.daemon = True
        server_process.start()
        while True:
            if server_ready.value == 0:
                logger.info('waiting for server to be ready..')
                time.sleep(15)
            elif server_ready.value == 1:
                break
            else:
                logger.error('start local LLM server failed, quit.')
                raise Exception('local LLM path')
        logger.info('LLM Server start.')

    assistant = Worker(args=args)
    reply_workflow(assistant)


if __name__ == '__main__':
    run()