import os import argparse import asyncio import bisect import json import subprocess import configparser from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, StreamingResponse from loguru import logger from llm_service import Worker divisible_by_32 = [1, 2, 4, 8, 16, 32] recv_file_path = "%s/upload" app = FastAPI() def workflow(args): config = configparser.ConfigParser() config.read(args.config_path) bind_port = int(config['default']['bind_port']) try: assistant = Worker(config) except Exception as e: raise (e) @app.post("/work") async def work(request: Request): input_json = await request.json() query = input_json['query'] history = input_json.get('history', []) try: _, reply, references = await assistant.produce_response(config, query=query, history=history) except Exception as e: logger.error(e) reply = "服务异常" references = [] return JSONResponse({'reply': reply, 'references': references}) @app.post("/stream") async def stream(request: Request): input_json = await request.json() query = input_json['query'] history = input_json.get('history', []) async def event_generator(): try: _, reply, references = await assistant.produce_response(config, query=query, history=history, stream=True) except Exception as e: logger.error(e) yield "data: 服务异常\n\n" yield 'event: end\n data: End of stream\n\n' return word = 'data: %s\n\n' try: async for request_output in reply: text = json.dumps(request_output) data = (word % text).encode('utf-8') yield data yield 'event: end\n data: End of stream\n\n' except (asyncio.CancelledError, ConnectionResetError) as e: logger.debug('user interrupt') return return StreamingResponse(event_generator(), media_type="text/event-stream") import uvicorn uvicorn.run(app, host='0.0.0.0', port=bind_port) def auto_select_dcu(config): mem_threshold = config.getint('default', 'mem_threshold') dcu_threshold = config.getint('default', 'dcu_threshold') process = subprocess.Popen("hy-smi | grep '^[0-9]' | awk '{print $1,$6,$7}' | sed 's/%//g'", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) output, error = process.communicate() result = output.decode().strip().split('\n') if not result: raise Exception("There is no dcu on this node.") dcu_map = {} for line in result: dcu_info = line.split() if int(dcu_info[1]) >= mem_threshold or int(dcu_info[2]) >= dcu_threshold: logger.debug("filter dcu:%s, which mem usage (%s) and dcu usage (%s) above the threshold." % (dcu_info[0], dcu_info[1], dcu_info[2])) continue logger.debug("dcu id:%s, mem usage: %s dcu usage: %s" % (dcu_info[0], dcu_info[1], dcu_info[2])) dcu_map[dcu_info[0]] = [int(dcu_info[1]), int(dcu_info[2])] count = round(len(dcu_map.keys()) * 0.4) if not count: logger.error("There is no available dcu device, can not start the service.") raise Exception("There is no available dcu device, can not start the service.") insert_index = bisect.bisect_left(divisible_by_32, count) if insert_index > 0 and count != divisible_by_32[insert_index]: index = insert_index - 1 elif count == divisible_by_32[insert_index]: index = insert_index else: index = 0 select_count = divisible_by_32[index] dcu_mem_use_rank = [item[0] for item in dcu_map.values()] dcu_use_rank = [item[1] for item in dcu_map.values()] final_rank = [(name, dcu_mem_use_rank[i] + dcu_use_rank[i]) for i, name in enumerate(dcu_map.keys())] sorted_rank = sorted(final_rank, key=lambda x: x[1]) sorted_dcu_ids = [item[0] for item in sorted_rank] select_dcu_ids = sorted_dcu_ids[:select_count] os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, select_dcu_ids)) logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {select_dcu_ids}") return select_dcu_ids def parse_args(): parser = argparse.ArgumentParser(description='Start all services.') parser.add_argument('--config_path', default='ai/config.ini', help='Config directory') parser.add_argument('--log_path', default='', help='Set log file path') return parser.parse_args() def main(): args = parse_args() log_path = '/var/log/assistant.log' if args.log_path: log_path = args.log_path logger.add(sink=log_path, level="DEBUG", rotation="500MB", compression="zip", encoding="utf-8", enqueue=True) workflow(args) if __name__ == '__main__': main()