main.py 2.7 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
#!/usr/bin/env python3
import argparse
import os
from multiprocessing import Process, Value
from loguru import logger
from llm_service import Worker, llm_inference


Rayyyyy's avatar
update  
Rayyyyy committed
9
10
11
12
13
14
15
def set_envs(dcu_ids):
    try:
        os.environ["CUDA_VISIBLE_DEVICES"] = dcu_ids
        logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {dcu_ids}")
    except Exception as e:
        logger.error(f"{e}, but got {dcu_ids}")
        raise ValueError(f"{e}")
Rayyyyy's avatar
Rayyyyy committed
16
17
18
19
20
21
22


def parse_args():
    """Parse args."""
    parser = argparse.ArgumentParser(description='Executor.')
    parser.add_argument(
        '--DCU_ID',
Rayyyyy's avatar
update  
Rayyyyy committed
23
24
25
        type=str,
        default='0',
        help='设置DCU卡号,卡号之间用英文逗号隔开,输入样例:"0,1,2"')
Rayyyyy's avatar
Rayyyyy committed
26
27
    parser.add_argument(
        '--config_path',
Rayyyyy's avatar
update  
Rayyyyy committed
28
        default='/path/of/config.ini',
Rayyyyy's avatar
Rayyyyy committed
29
30
31
32
33
        type=str,
        help='config.ini路径')
    parser.add_argument(
        '--standalone',
        default=False,
Rayyyyy's avatar
update  
Rayyyyy committed
34
        help='部署LLM推理服务')
Rayyyyy's avatar
Rayyyyy committed
35
    parser.add_argument(
Rayyyyy's avatar
Rayyyyy committed
36
        '--use_vllm',
Rayyyyy's avatar
Rayyyyy committed
37
38
        default=False,
        type=bool,
Rayyyyy's avatar
update  
Rayyyyy committed
39
        help='是否启用LLM推理加速'
Rayyyyy's avatar
Rayyyyy committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    )
    args = parser.parse_args()
    return args


def build_reply_text(reply: str, references: list):
    if len(references) < 1:
        return reply

    ret = reply
    for ref in references:
        ret += '\n'
        ret += ref
    return ret


def reply_workflow(assistant):

Rayyyyy's avatar
update  
Rayyyyy committed
58
    queries = ['我们公司想要购买几台测试机,请问需要联系哪位?']
Rayyyyy's avatar
Rayyyyy committed
59
60
61
62
63
64
65
66
67
68
69
    for query in queries:
        code, reply, references = assistant.produce_response(query=query,
                                                     history=[],
                                                     judgment=False)
        logger.info(f'{code}, {query}, {reply}, {references}')


def run():
    args = parse_args()
    if args.standalone is True:
        import time
Rayyyyy's avatar
update  
Rayyyyy committed
70
        set_envs(args)
Rayyyyy's avatar
Rayyyyy committed
71
72
73
74
        server_ready = Value('i', 0)
        server_process = Process(target=llm_inference,
                                 args=(args.config_path,
                                       len(args.DCU_ID),
Rayyyyy's avatar
Rayyyyy committed
75
                                       args.use_vllm,
Rayyyyy's avatar
Rayyyyy committed
76
77
78
79
80
81
                                       server_ready))

        server_process.daemon = True
        server_process.start()
        while True:
            if server_ready.value == 0:
Rayyyyy's avatar
update  
Rayyyyy committed
82
                logger.info('waiting for server to be ready.')
Rayyyyy's avatar
Rayyyyy committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
                time.sleep(15)
            elif server_ready.value == 1:
                break
            else:
                logger.error('start local LLM server failed, quit.')
                raise Exception('local LLM path')
        logger.info('LLM Server start.')

    assistant = Worker(args=args)
    reply_workflow(assistant)


if __name__ == '__main__':
    run()