import json import os import pickle from loguru import logger from .utils import COMMON from .helper import ErrorCode from .http_client import OpenAPIClient, ClassifyModel, CacheRetriever SECURITY_TEMAPLTE = '判断以下句子是否涉及政治、辱骂、色情、恐暴、宗教、网络暴力、种族歧视等违禁内容,结果用 0~10 表示,不要解释直接给出得分。判断标准:涉其中任一问题直接得 10 分;完全不涉及得 0 分。直接给得分不要解释:“{}”' # noqa E501 GENERATE_TEMPLATE = '{} \n 回答要求:\n如果你不清楚答案,你需要澄清。\n避免提及你是从 获取的知识。\n保持答案与 中描述的一致。\n使用 Markdown 语法优化回答格式。\n使用与问题相同的语言回答。问题:"{}"' MARKDOWN_TEMPLATE = '问题:“{}” \n请使用markdown格式回答此问题' def substitution(chunks): # 翻译特殊字符 import re new_chunks = [] for chunk in chunks: matchObj = re.split('.*(<.*>).*', chunk, re.M|re.I) if len(matchObj) > 1: obj = matchObj[1] replace_str = COMMON.get(obj) if replace_str: chunk = chunk.replace(obj, replace_str) logger.info(f"{obj} be replaced {replace_str}, after {chunk}") new_chunks.append(chunk) return new_chunks class Worker: def __init__(self, config): self.work_dir = config['default']['work_dir'] llm_model = config['model']['llm_model'] local_model = config['model']['local_model'] llm_service_address = config['model']['llm_service_address'] cls_model_path = config['model']['cls_model_path'] local_server_address = config['model']['local_service_address'] reject_throttle = float(config['feature_database']['reject_throttle']) self.embedding_model_path = config['feature_database']['embedding_model_path'] self.reranker_model_path = config['feature_database']['reranker_model_path'] if not llm_service_address: raise Exception('llm_service_address is required in config.ini') if not cls_model_path: raise Exception('cls_model_path is required in config.ini') self.max_input_len = int(config['model']['max_input_length']) self.retriever = CacheRetriever( self.embedding_model_path, self.reranker_model_path).get(reject_throttle=reject_throttle, work_dir=self.work_dir) self.openapi_service = OpenAPIClient(llm_service_address, llm_model) self.openapi_local_server = OpenAPIClient(local_server_address, local_model) self.classify_service = ClassifyModel(cls_model_path) self.tasks = {} if os.path.exists(self.work_dir + '/tasks_status.pkl'): with open(self.work_dir + '/tasks_status.pkl', 'rb') as f: self.tasks = pickle.load(f) def generate_prompt(self, history_pair, instruction: str, context: str = ''): if context is not None and len(context) > 0: str_context = str(context) if len(str_context) > self.max_input_len: str_context = str_context[:self.max_input_len] instruction = GENERATE_TEMPLATE.format(str_context, instruction) real_history = [] for pair in history_pair: if pair[0] is None or pair[1] is None: continue if len(pair[0]) < 1 or len(pair[1]) < 1: continue real_history.append(pair) return instruction, real_history async def generater(self, content): for word in content: yield word #await asyncio.sleep(0.1) async def response_by_common(self, query, history, output_format=False, stream=False): if output_format: query = MARKDOWN_TEMPLATE.format(query) logger.info('Prompt is: {}, History is: {}'.format(query, history)) response_direct = await self.openapi_service.chat(query, history, stream=stream) return response_direct def format_rag_result(self, chunks, references, stream=False): result = "针对您的问题,我们找到了如下解决方案:\n%s" content = "" for i, item in enumerate(references): if item.endswith(".json"): content += " - %s.%s\n" % (i + 1, chunks[i]) else: line = chunks[i] if len(line) > 300: line = line[:300] + "..." + '\n' line += "详细内容参见:%s" % item content += " - %s.%s\n" % (i + 1, line) if stream: return self.generater((result % content)) return result % content def response_by_finetune(self, query, history=[]): '''微调模型回答''' logger.info('Prompt is: {}, History is: {}'.format(query, history)) response_direct = self.openapi_local_server.chat(query, history) data = json.loads(response_direct.content.decode("utf-8")) output = data["text"] return output async def produce_response(self, config, query, history, stream=False): response = '' references = [] use_template = config.getboolean('default', 'use_template') output_format = config.getboolean('default', 'output_format') if query is None: return ErrorCode.NOT_A_QUESTION, response, references logger.info('input: %s' % [query, history]) # classify score = self.classify_service.classfication(query) if score > 0.8: logger.debug('Start RAG search') chunks, references = self.retriever.query(query) if len(chunks) == 0: logger.debug('Response by finetune model') response = await self.response_by_finetune(query, history=history) chunks = [response] elif use_template: logger.debug('Response by template') response = self.format_rag_result(chunks, references, stream=stream) return ErrorCode.SUCCESS, response, references logger.debug('Response with common model') new_chunks = substitution(chunks) prompt, history = self.generate_prompt( instruction=query, context=new_chunks, history_pair=history) logger.debug('prompt: {}'.format(prompt)) response = await self.response_by_common(prompt, history=history, output_format=False, stream=stream) return ErrorCode.SUCCESS, response, references else: logger.debug('Response by common model') response = await self.response_by_common(query, history=history, output_format=output_format, stream=stream) return ErrorCode.SUCCESS, response, references