server_start.py 5.21 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import os
import argparse
import bisect
import configparser
import subprocess
from aiohttp import web
from loguru import logger
from llm_service import Worker
from scipy.stats import rankdata


divisible_by_32 = [1, 2, 4, 8, 16, 32]
recv_file_path = "%s/upload"


Rayyyyy's avatar
update  
Rayyyyy committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def auto_select_dcu(config):
    # Read threshold in config file
    mem_threshold = config.getint('default', 'mem_threshold')
    dcu_threshold = config.getint('default', 'dcu_threshold')
    # Get dcu usage
    process = subprocess.Popen("hy-smi | grep '^[0-9]' | awk '{print $1,$6,$7}' | sed 's/%//g'", shell=True,
                               stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    output, error = process.communicate()
    result = output.decode().strip().split('\n')
    if not result:
        raise Exception("There is no dcu on this node.")

    dcu_map = {}
    for line in result:
        dcu_info = line.split()
        if int(dcu_info[1]) >= mem_threshold or int(dcu_info[2]) >= dcu_threshold:
            logger.debug("filter dcu:%s, which mem usage (%s) and dcu usage (%s) above the threshold." % (dcu_info[0],
                                                                                                          dcu_info[1],
                                                                                                          dcu_info[2]))
            continue
        logger.debug("dcu id:%s, mem usage: %s dcu usage: %s" % (dcu_info[0], dcu_info[1], dcu_info[2]))
        dcu_map[dcu_info[0]] = [int(dcu_info[1]), int(dcu_info[2])]
    # Select dcu count must be divisible by 32.
    # TODO temporary use 40% of available count
    count = round(len(dcu_map.keys()) * 0.4)
    if not count:
        logger.error("There is no available dcu device, can not start the service.")
        raise Exception("There is no available dcu device, can not start the service.")
    insert_index = bisect.bisect_left(divisible_by_32, count)

    if insert_index > 0 and count != divisible_by_32[insert_index]:
        index = insert_index - 1
    elif count == divisible_by_32[insert_index]:
        index = insert_index
    else:
        index = 0
    select_count = divisible_by_32[index]
    # Based on the ranking of memory and dcu usage.
    dcu_mem_use_rank = [item[0] for item in dcu_map.values()]
    dcu_use_rank = [item[1] for item in dcu_map.values()]
    # Calculate the final ranking
    final_rank = [(name, dcu_mem_use_rank[i] + dcu_use_rank[i]) for i, name in enumerate(dcu_map.keys())]
    sorted_rank = sorted(final_rank, key=lambda x: x[1])
    sorted_dcu_ids = [item[0] for item in sorted_rank]
    select_dcu_ids = sorted_dcu_ids[:select_count]
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, select_dcu_ids))
    logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {select_dcu_ids}")
    return select_dcu_ids


Rayyyyy's avatar
Rayyyyy committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def workflow(args):
    config = configparser.ConfigParser()
    config.read(args.config_path)
    bind_port = int(config['default']['bind_port'])
    dcu_ids = auto_select_dcu(config)
    tensor_parallel_size = len(dcu_ids)
    try:
        assistant = Worker(config, tensor_parallel_size)
    except Exception as e:
        raise (e)

    async def work(request):
        input_json = await request.json()
        query = input_json['query']

        code, reply, references = assistant.produce_response(query=query,
                                                             history=[],
                                                             judgment=False)
        return web.json_response({'reply': reply, 'references': references})

    async def handle_upload(request):
        reader = await request.multipart()
        while True:
            field = await reader.next()
            if field is None:
                break
            filename = field.filename

            # Save to server
            save_path = recv_file_path % config['default']['work_dir']
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            file_path = os.path.join(save_path, filename)
            with open(file_path, 'wb') as f:
                while True:
                    chunk = await field.read_chunk()
                    if not chunk:
                        break
                    f.write(chunk)
            logger.debug("成功接收文件:%s" % file_path)
            # Call file parse process
        assistant.agent.parse_file_and_merge(save_path)
        return web.json_response({"reply": "成功接收文件:{filename}\n"})

    app = web.Application()
    app.add_routes([web.post('/work', work),
                    web.post('/upload', handle_upload)])
    web.run_app(app, host='0.0.0.0', port=bind_port)


def parse_args():
    parser = argparse.ArgumentParser(description='Start all services.')
    parser.add_argument('--config_path',
Rayyyyy's avatar
Rayyyyy committed
119
                        default='./config.ini',
Rayyyyy's avatar
Rayyyyy committed
120
121
                        help='Config directory')
    parser.add_argument('--log_path',
Rayyyyy's avatar
Rayyyyy committed
122
                        default='./log/assistant.log',
Rayyyyy's avatar
Rayyyyy committed
123
124
125
                        help='Set log file path')
    return parser.parse_args()

Rayyyyy's avatar
update  
Rayyyyy committed
126

Rayyyyy's avatar
Rayyyyy committed
127
128
def main():
    args = parse_args()
Rayyyyy's avatar
Rayyyyy committed
129
    logger.add(sink=args.log_path, level="DEBUG", rotation="500MB", compression="zip", encoding="utf-8", enqueue=True)
Rayyyyy's avatar
Rayyyyy committed
130
131
132
133
    workflow(args)


if __name__ == '__main__':
Rayyyyy's avatar
update  
Rayyyyy committed
134
    main()