import os import argparse import bisect import configparser import subprocess from aiohttp import web from loguru import logger from llm_service import Worker from scipy.stats import rankdata divisible_by_32 = [1, 2, 4, 8, 16, 32] recv_file_path = "%s/upload" def workflow(args): config = configparser.ConfigParser() config.read(args.config_path) bind_port = int(config['default']['bind_port']) dcu_ids = auto_select_dcu(config) tensor_parallel_size = len(dcu_ids) try: assistant = Worker(config, tensor_parallel_size) except Exception as e: raise (e) async def work(request): input_json = await request.json() query = input_json['query'] code, reply, references = assistant.produce_response(query=query, history=[], judgment=False) return web.json_response({'reply': reply, 'references': references}) async def handle_upload(request): reader = await request.multipart() while True: field = await reader.next() if field is None: break filename = field.filename # Save to server save_path = recv_file_path % config['default']['work_dir'] if not os.path.exists(save_path): os.makedirs(save_path) file_path = os.path.join(save_path, filename) with open(file_path, 'wb') as f: while True: chunk = await field.read_chunk() if not chunk: break f.write(chunk) logger.debug("成功接收文件:%s" % file_path) # Call file parse process assistant.agent.parse_file_and_merge(save_path) return web.json_response({"reply": "成功接收文件:{filename}\n"}) app = web.Application() app.add_routes([web.post('/work', work), web.post('/upload', handle_upload)]) web.run_app(app, host='0.0.0.0', port=bind_port) def auto_select_dcu(config): # Read threshold in config file mem_threshold = config.getint('default', 'mem_threshold') dcu_threshold = config.getint('default', 'dcu_threshold') # Get dcu usage process = subprocess.Popen("hy-smi | grep '^[0-9]' | awk '{print $1,$6,$7}' | sed 's/%//g'", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) output, error = process.communicate() result = output.decode().strip().split('\n') if not result: raise Exception("There is no dcu on this node.") dcu_map = {} for line in result: dcu_info = line.split() if int(dcu_info[1]) >= mem_threshold or int(dcu_info[2]) >= dcu_threshold: logger.debug("filter dcu:%s, which mem usage (%s) and dcu usage (%s) above the threshold." % (dcu_info[0], dcu_info[1], dcu_info[2])) continue logger.debug("dcu id:%s, mem usage: %s dcu usage: %s" % (dcu_info[0], dcu_info[1], dcu_info[2])) dcu_map[dcu_info[0]] = [int(dcu_info[1]), int(dcu_info[2])] # Select dcu count must be divisible by 32. # TODO temporary use 40% of available count count = round(len(dcu_map.keys()) * 0.4) if not count: logger.error("There is no available dcu device, can not start the service.") raise Exception("There is no available dcu device, can not start the service.") insert_index = bisect.bisect_left(divisible_by_32, count) #insert_index = bisect.bisect_left(divisible_by_32, len(dcu_map.keys())) if insert_index > 0 and count != divisible_by_32[insert_index]: index = insert_index - 1 elif count == divisible_by_32[insert_index]: index = insert_index else: index = 0 select_count = divisible_by_32[index] # Based on the ranking of memory and dcu usage. dcu_mem_use_rank = [item[0] for item in dcu_map.values()] dcu_use_rank = [item[1] for item in dcu_map.values()] # Calculate the final ranking final_rank = [(name, dcu_mem_use_rank[i] + dcu_use_rank[i]) for i, name in enumerate(dcu_map.keys())] sorted_rank = sorted(final_rank, key=lambda x: x[1]) sorted_dcu_ids = [item[0] for item in sorted_rank] select_dcu_ids = sorted_dcu_ids[:select_count] os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, select_dcu_ids)) logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {select_dcu_ids}") return select_dcu_ids def parse_args(): parser = argparse.ArgumentParser(description='Start all services.') parser.add_argument('--config_path', default='ai/config.ini', help='Config directory') parser.add_argument('--log_path', default='', help='Set log file path') return parser.parse_args() def main(): args = parse_args() log_path = '/var/log/assistant.log' if args.log_path: log_path = args.log_path logger.add(sink=log_path, level="DEBUG", rotation="500MB", compression="zip", encoding="utf-8", enqueue=True) workflow(args) if __name__ == '__main__': main()