server_start.py 5.36 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import argparse
import bisect
import configparser
import subprocess
from aiohttp import web
from loguru import logger
from llm_service import Worker
from scipy.stats import rankdata


divisible_by_32 = [1, 2, 4, 8, 16, 32]
recv_file_path = "%s/upload"


def workflow(args):
    config = configparser.ConfigParser()
    config.read(args.config_path)
    bind_port = int(config['default']['bind_port'])
    dcu_ids = auto_select_dcu(config)
    tensor_parallel_size = len(dcu_ids)
    try:
        assistant = Worker(config, tensor_parallel_size)
    except Exception as e:
        raise (e)

    async def work(request):
        input_json = await request.json()
        query = input_json['query']

        code, reply, references = assistant.produce_response(query=query,
                                                             history=[],
                                                             judgment=False)
        return web.json_response({'reply': reply, 'references': references})

    async def handle_upload(request):
        reader = await request.multipart()
        while True:
            field = await reader.next()
            if field is None:
                break
            filename = field.filename

            # Save to server
            save_path = recv_file_path % config['default']['work_dir']
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            file_path = os.path.join(save_path, filename)
            with open(file_path, 'wb') as f:
                while True:
                    chunk = await field.read_chunk()
                    if not chunk:
                        break
                    f.write(chunk)
            logger.debug("成功接收文件:%s" % file_path)
            # Call file parse process
        assistant.agent.parse_file_and_merge(save_path)
        return web.json_response({"reply": "成功接收文件:{filename}\n"})

    app = web.Application()
    app.add_routes([web.post('/work', work),
                    web.post('/upload', handle_upload)])
    web.run_app(app, host='0.0.0.0', port=bind_port)


def auto_select_dcu(config):
    # Read threshold in config file
    mem_threshold = config.getint('default', 'mem_threshold')
    dcu_threshold = config.getint('default', 'dcu_threshold')
    # Get dcu usage
    process = subprocess.Popen("hy-smi | grep '^[0-9]' | awk '{print $1,$6,$7}' | sed 's/%//g'", shell=True,
                               stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    output, error = process.communicate()
    result = output.decode().strip().split('\n')
    if not result:
        raise Exception("There is no dcu on this node.")

    dcu_map = {}
    for line in result:
        dcu_info = line.split()
        if int(dcu_info[1]) >= mem_threshold or int(dcu_info[2]) >= dcu_threshold:
            logger.debug("filter dcu:%s, which mem usage (%s) and dcu usage (%s) above the threshold." % (dcu_info[0],
                                                                                                          dcu_info[1],
                                                                                                          dcu_info[2]))
            continue
        logger.debug("dcu id:%s, mem usage: %s dcu usage: %s" % (dcu_info[0], dcu_info[1], dcu_info[2]))
        dcu_map[dcu_info[0]] = [int(dcu_info[1]), int(dcu_info[2])]
    # Select dcu count must be divisible by 32.
    # TODO temporary use 40% of available count
    count = round(len(dcu_map.keys()) * 0.4)
    if not count:
        logger.error("There is no available dcu device, can not start the service.")
        raise Exception("There is no available dcu device, can not start the service.")
    insert_index = bisect.bisect_left(divisible_by_32, count)
    #insert_index = bisect.bisect_left(divisible_by_32, len(dcu_map.keys()))
    if insert_index > 0 and count != divisible_by_32[insert_index]:
        index = insert_index - 1
    elif count == divisible_by_32[insert_index]:
        index = insert_index
    else:
        index = 0
    select_count = divisible_by_32[index]
    # Based on the ranking of memory and dcu usage.
    dcu_mem_use_rank = [item[0] for item in dcu_map.values()]
    dcu_use_rank = [item[1] for item in dcu_map.values()]
    # Calculate the final ranking
    final_rank = [(name, dcu_mem_use_rank[i] + dcu_use_rank[i]) for i, name in enumerate(dcu_map.keys())]
    sorted_rank = sorted(final_rank, key=lambda x: x[1])
    sorted_dcu_ids = [item[0] for item in sorted_rank]
    select_dcu_ids = sorted_dcu_ids[:select_count]
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, select_dcu_ids))
    logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {select_dcu_ids}")
    return select_dcu_ids


def parse_args():
    parser = argparse.ArgumentParser(description='Start all services.')
    parser.add_argument('--config_path',
Rayyyyy's avatar
update  
Rayyyyy committed
119
                        default='/path/of/config.ini',
Rayyyyy's avatar
Rayyyyy committed
120
121
122
123
124
125
                        help='Config directory')
    parser.add_argument('--log_path',
                        default='',
                        help='Set log file path')
    return parser.parse_args()

Rayyyyy's avatar
update  
Rayyyyy committed
126

Rayyyyy's avatar
Rayyyyy committed
127
128
def main():
    args = parse_args()
Rayyyyy's avatar
update  
Rayyyyy committed
129
    log_path = './log/assistant.log'
Rayyyyy's avatar
Rayyyyy committed
130
131
132
133
134
135
136
    if args.log_path:
        log_path = args.log_path
    logger.add(sink=log_path, level="DEBUG", rotation="500MB", compression="zip", encoding="utf-8", enqueue=True)
    workflow(args)


if __name__ == '__main__':
Rayyyyy's avatar
update  
Rayyyyy committed
137
    main()