server_start.py 5.53 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
import os
import argparse
chenych's avatar
chenych committed
3
import asyncio
Rayyyyy's avatar
Rayyyyy committed
4
import bisect
chenych's avatar
chenych committed
5
import json
Rayyyyy's avatar
Rayyyyy committed
6
import subprocess
chenych's avatar
chenych committed
7
8
9
10
import configparser

from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
Rayyyyy's avatar
Rayyyyy committed
11
12
13
14
15
16
17
from loguru import logger
from llm_service import Worker


divisible_by_32 = [1, 2, 4, 8, 16, 32]
recv_file_path = "%s/upload"

chenych's avatar
chenych committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
app = FastAPI()

def workflow(args):
    config = configparser.ConfigParser()
    config.read(args.config_path)
    bind_port = int(config['default']['bind_port'])


    try:
        assistant = Worker(config)
    except Exception as e:
        raise (e)

    @app.post("/work")
    async def work(request: Request):
        input_json = await request.json()
        query = input_json['query']
        history = input_json.get('history', [])
        try:
            code, reply, references = await assistant.produce_response(config,
                                                                       query=query,
                                                                       history=history)
        except Exception as e:
            logger.error(e)
            reply = "服务异常"
            references = []
        return JSONResponse({'reply': reply, 'references': references})

    @app.post("/stream")
    async def stream(request: Request):
        input_json = await request.json()
        query = input_json['query']
        history = input_json.get('history', [])

        async def event_generator():
            try:
                code, reply, references = await assistant.produce_response(config,
                                                                           query=query,
                                                                           history=history,
                                                                           stream=True)
            except Exception as e:
                logger.error(e)
                yield "data: 服务异常\n\n"
                yield 'event: end\n data: End of stream\n\n'
                return

            word = 'data: %s\n\n'
            try:
                async for request_output in reply:
                    text = json.dumps(request_output)
                    data = (word % text).encode('utf-8')
                    yield data
                yield 'event: end\n data: End of stream\n\n'
            except (asyncio.CancelledError, ConnectionResetError) as e:
                logger.debug('user interrupt')
                return

        return StreamingResponse(event_generator(), media_type="text/event-stream")


    import uvicorn
    uvicorn.run(app, host='0.0.0.0', port=bind_port)
Rayyyyy's avatar
Rayyyyy committed
80

Rayyyyy's avatar
update  
Rayyyyy committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def auto_select_dcu(config):
    mem_threshold = config.getint('default', 'mem_threshold')
    dcu_threshold = config.getint('default', 'dcu_threshold')
    process = subprocess.Popen("hy-smi | grep '^[0-9]' | awk '{print $1,$6,$7}' | sed 's/%//g'", shell=True,
                               stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    output, error = process.communicate()
    result = output.decode().strip().split('\n')
    if not result:
        raise Exception("There is no dcu on this node.")

    dcu_map = {}
    for line in result:
        dcu_info = line.split()
        if int(dcu_info[1]) >= mem_threshold or int(dcu_info[2]) >= dcu_threshold:
            logger.debug("filter dcu:%s, which mem usage (%s) and dcu usage (%s) above the threshold." % (dcu_info[0],
chenych's avatar
chenych committed
96
97
                                                                                                         dcu_info[1],
                                                                                                         dcu_info[2]))
Rayyyyy's avatar
update  
Rayyyyy committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
            continue
        logger.debug("dcu id:%s, mem usage: %s dcu usage: %s" % (dcu_info[0], dcu_info[1], dcu_info[2]))
        dcu_map[dcu_info[0]] = [int(dcu_info[1]), int(dcu_info[2])]
    count = round(len(dcu_map.keys()) * 0.4)
    if not count:
        logger.error("There is no available dcu device, can not start the service.")
        raise Exception("There is no available dcu device, can not start the service.")
    insert_index = bisect.bisect_left(divisible_by_32, count)
    if insert_index > 0 and count != divisible_by_32[insert_index]:
        index = insert_index - 1
    elif count == divisible_by_32[insert_index]:
        index = insert_index
    else:
        index = 0
    select_count = divisible_by_32[index]
    dcu_mem_use_rank = [item[0] for item in dcu_map.values()]
    dcu_use_rank = [item[1] for item in dcu_map.values()]
    final_rank = [(name, dcu_mem_use_rank[i] + dcu_use_rank[i]) for i, name in enumerate(dcu_map.keys())]
    sorted_rank = sorted(final_rank, key=lambda x: x[1])
    sorted_dcu_ids = [item[0] for item in sorted_rank]
    select_dcu_ids = sorted_dcu_ids[:select_count]
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, select_dcu_ids))
    logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {select_dcu_ids}")
    return select_dcu_ids

Rayyyyy's avatar
Rayyyyy committed
123
124
125
def parse_args():
    parser = argparse.ArgumentParser(description='Start all services.')
    parser.add_argument('--config_path',
chenych's avatar
chenych committed
126
                        default='ai/config.ini',
Rayyyyy's avatar
Rayyyyy committed
127
128
                        help='Config directory')
    parser.add_argument('--log_path',
chenych's avatar
chenych committed
129
                        default='',
Rayyyyy's avatar
Rayyyyy committed
130
131
132
133
134
                        help='Set log file path')
    return parser.parse_args()

def main():
    args = parse_args()
chenych's avatar
chenych committed
135
136
137
138
    log_path = '/var/log/assistant.log'
    if args.log_path:
        log_path = args.log_path
    logger.add(sink=log_path, level="DEBUG", rotation="500MB", compression="zip", encoding="utf-8", enqueue=True)
Rayyyyy's avatar
Rayyyyy committed
139
140
141
    workflow(args)

if __name__ == '__main__':
Rayyyyy's avatar
update  
Rayyyyy committed
142
    main()