from loguru import logger from .helper import ErrorCode, LogManager from .retriever import CacheRetriever from .inferencer import LLMInference from .feature_database import DocumentProcessor, FeatureDataBase class ChatAgent: def __init__(self, config, tensor_parallel_size) -> None: self.work_dir = config['default']['work_dir'] self.embedding_model_path = config['feature_database']['embedding_model_path'] self.reranker_model_path = config['feature_database']['reranker_model_path'] reject_throttle = float(config['feature_database']['reject_throttle']) local_llm_path = config['llm']['local_llm_path'] accelerate = config.getboolean('llm', 'accelerate') self.retriever = CacheRetriever(self.embedding_model_path, self.reranker_model_path).get(reject_throttle=reject_throttle, work_dir=self.work_dir) self.llm_server = LLMInference(local_llm_path, tensor_parallel_size, accelerate=accelerate) def generate_prompt(self, history_pair, instruction: str, template: str, context: str = ''): if context is not None and len(context) > 0: instruction = template.format(context, instruction) real_history = [] for pair in history_pair: if pair[0] is None or pair[1] is None: continue if len(pair[0]) < 1 or len(pair[1]) < 1: continue real_history.append(pair) return instruction, real_history def call_rag_retrieve(self, query): return self.retriever.query(query) def call_llm_response(self, prompt, history=None): text, error = self.llm_server.generate_response(prompt=prompt, history=history) return text def parse_file_and_merge(self, file_dir): file_opr = DocumentProcessor() files = file_opr.scan_directory(repo_dir=file_dir) file_handler = FeatureDataBase(embeddings=self.retriever.embeddings, reranker=self.retriever.reranker) file_handler.preprocess(files=files, work_dir=self.work_dir, file_opr=file_opr) file_handler.merge_db_response(self.retriever.vector_store, files=files, work_dir=self.work_dir, file_opr=file_opr) file_opr.summarize(files) self.retriever = CacheRetriever(self.embedding_model_path, self.reranker_model_path).get(work_dir=self.work_dir) class Worker: def __init__(self, config, tensor_parallel_size): self.agent = ChatAgent(config, tensor_parallel_size) self.TOPIC_TEMPLATE = '告诉我这句话的主题,直接说主题不要解释:“{}”' self.SCORING_RELAVANCE_TEMPLATE = '问题:“{}”\n材料:“{}”\n请仔细阅读以上内容,材料里为一个列表,列表里面有若干子列表,请判断每个子列表的内容和问题的相关度,不要解释直接给出相关度得分列表并以空格分隔,用0~10表示。判断标准:非常相关得 10 分;完全没关联得 0 分。\n' # noqa E501 self.KEYWORDS_TEMPLATE = '谷歌搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。搜索参数类型 string, 内容是短语或关键字,以空格分隔。\n你现在是搜搜小助手,用户提问“{}”,你打算通过谷歌搜索查询相关资料,请提供用于搜索的关键字或短语,不要解释直接给出关键字或短语。' # noqa E501 self.SECURITY_TEMAPLTE = '判断以下句子是否涉及政治、辱骂、色情、恐暴、宗教、网络暴力、种族歧视等违禁内容,结果用 0~10 表示,不要解释直接给出得分。判断标准:涉其中任一问题直接得 10 分;完全不涉及得 0 分。直接给得分不要解释:“{}”' # noqa E501 self.PERPLESITY_TEMPLATE = '“question:{} answer:{}”\n阅读以上对话,answer 是否在表达自己不知道,回答越全面得分越少,用0~10表示,不要解释直接给出得分。\n判断标准:准确回答问题得 0 分;答案详尽得 1 分;知道部分答案但有不确定信息得 8 分;知道小部分答案但推荐求助其他人得 9 分;不知道任何答案直接推荐求助别人得 10 分。直接打分不要解释。' # noqa E501 self.SUMMARIZE_TEMPLATE = '{} \n 仔细阅读以上内容,总结得简短有力点' # noqa E501 self.GENERATE_TEMPLATE = '“{}” \n问题:“{}” \n请仔细阅读上述文字, 并使用markdown格式回答问题,直接给出回答不做任何解释。' # noqa E501 self.MARKDOWN_TEMPLATE = '问题:“{}” \n请使用markdown格式回答此问题' def judgment_results(self, query, chunks, throttle): relation_score = self.agent.call_llm_response( prompt=self.SCORING_RELAVANCE_TEMPLATE.format(query, chunks)) logger.info('score: %s' % [relation_score, throttle]) # 过滤操作 filtered_chunks = [] for chunk, score in zip(chunks, relation_score.split()): if float(score) >= float(throttle): filtered_chunks.append(chunk) return filtered_chunks def extract_topic(self, query): topic = self.agent.call_llm_response(self.TOPIC_TEMPLATE.format(query)) return topic def response_direct_by_llm(self, query): # Compliant check prompt = self.SECURITY_TEMAPLTE.format(query) score = self.agent.call_llm_response(prompt=prompt) logger.debug("score:{}, prompt:{}".format(score, prompt)) if int(score) > 5: return ErrorCode.NON_COMPLIANCE_QUESTION, "您的问题中涉及敏感话题,请重新提问。", None logger.info('LLM direct response and prompt is: {}'.format(query)) prompt = self.MARKDOWN_TEMPLATE.format(query) response_direct = self.agent.call_llm_response(prompt=prompt) return ErrorCode.NOT_FIND_RELATED_DOCS, response_direct, None def produce_response(self, query, history, judgment, topic=False, rag=True): response = '' references = [] if query is None: return ErrorCode.NOT_A_QUESTION, response, references logger.info('input: %s' % [query, history]) if rag: if topic: query = self.extract_topic(query) logger.info('topic: %s' % query) if len(query) <= 0: return ErrorCode.NO_TOPIC, response, references chunks, references = self.agent.call_rag_retrieve(query) if len(chunks) == 0: return self.response_direct_by_llm(query) if judgment: chunks = self.judgment_results( query, chunks, throttle=5, ) # 如果DataBase检索到了,就用检索到的块去回答 if len(chunks) > 0: prompt, history = self.agent.generate_prompt( instruction=query, context=chunks, history_pair=history, template=self.GENERATE_TEMPLATE) logger.debug('prompt: {}'.format(prompt)) response = self.agent.call_llm_response(prompt=prompt, history=history) return ErrorCode.SUCCESS, response, references else: return self.response_direct_by_llm(query)