import json
import os
import pickle
from loguru import logger
from .utils import COMMON
from .helper import ErrorCode
from .http_client import OpenAPIClient, ClassifyModel, CacheRetriever
SECURITY_TEMAPLTE = '判断以下句子是否涉及政治、辱骂、色情、恐暴、宗教、网络暴力、种族歧视等违禁内容,结果用 0~10 表示,不要解释直接给出得分。判断标准:涉其中任一问题直接得 10 分;完全不涉及得 0 分。直接给得分不要解释:“{}”' # noqa E501
GENERATE_TEMPLATE = '{} \n 回答要求:\n如果你不清楚答案,你需要澄清。\n避免提及你是从 获取的知识。\n保持答案与 中描述的一致。\n使用 Markdown 语法优化回答格式。\n使用与问题相同的语言回答。问题:"{}"'
MARKDOWN_TEMPLATE = '问题:“{}” \n请使用markdown格式回答此问题'
def substitution(chunks):
# 翻译特殊字符
import re
new_chunks = []
for chunk in chunks:
matchObj = re.split('.*(<.*>).*', chunk, re.M|re.I)
if len(matchObj) > 1:
obj = matchObj[1]
replace_str = COMMON.get(obj)
if replace_str:
chunk = chunk.replace(obj, replace_str)
logger.info(f"{obj} be replaced {replace_str}, after {chunk}")
new_chunks.append(chunk)
return new_chunks
class Worker:
def __init__(self, config):
self.work_dir = config['default']['work_dir']
llm_model = config['model']['llm_model']
local_model = config['model']['local_model']
llm_service_address = config['model']['llm_service_address']
cls_model_path = config['model']['cls_model_path']
local_server_address = config['model']['local_service_address']
reject_throttle = float(config['feature_database']['reject_throttle'])
self.embedding_model_path = config['feature_database']['embedding_model_path']
self.reranker_model_path = config['feature_database']['reranker_model_path']
if not llm_service_address:
raise Exception('llm_service_address is required in config.ini')
if not cls_model_path:
raise Exception('cls_model_path is required in config.ini')
self.max_input_len = int(config['model']['max_input_length'])
self.retriever = CacheRetriever(
self.embedding_model_path,
self.reranker_model_path).get(reject_throttle=reject_throttle,
work_dir=self.work_dir)
self.openapi_service = OpenAPIClient(llm_service_address, llm_model)
self.openapi_local_server = OpenAPIClient(local_server_address, local_model)
self.classify_service = ClassifyModel(cls_model_path)
self.tasks = {}
if os.path.exists(self.work_dir + '/tasks_status.pkl'):
with open(self.work_dir + '/tasks_status.pkl', 'rb') as f:
self.tasks = pickle.load(f)
def generate_prompt(self,
history_pair,
instruction: str,
context: str = ''):
if context is not None and len(context) > 0:
str_context = str(context)
if len(str_context) > self.max_input_len:
str_context = str_context[:self.max_input_len]
instruction = GENERATE_TEMPLATE.format(str_context, instruction)
real_history = []
for pair in history_pair:
if pair[0] is None or pair[1] is None:
continue
if len(pair[0]) < 1 or len(pair[1]) < 1:
continue
real_history.append(pair)
return instruction, real_history
async def generater(self, content):
for word in content:
yield word
#await asyncio.sleep(0.1)
async def response_by_common(self, query, history, output_format=False, stream=False):
if output_format:
query = MARKDOWN_TEMPLATE.format(query)
logger.info('Prompt is: {}, History is: {}'.format(query, history))
response_direct = await self.openapi_service.chat(query, history, stream=stream)
return response_direct
def format_rag_result(self, chunks, references, stream=False):
result = "针对您的问题,我们找到了如下解决方案:\n%s"
content = ""
for i, item in enumerate(references):
if item.endswith(".json"):
content += " - %s.%s\n" % (i + 1, chunks[i])
else:
line = chunks[i]
if len(line) > 300:
line = line[:300] + "..." + '\n'
line += "详细内容参见:%s" % item
content += " - %s.%s\n" % (i + 1, line)
if stream:
return self.generater((result % content))
return result % content
def response_by_finetune(self, query, history=[]):
'''微调模型回答'''
logger.info('Prompt is: {}, History is: {}'.format(query, history))
response_direct = self.openapi_local_server.chat(query, history)
data = json.loads(response_direct.content.decode("utf-8"))
output = data["text"]
return output
async def produce_response(self, config, query, history, stream=False):
response = ''
references = []
use_template = config.getboolean('default', 'use_template')
output_format = config.getboolean('default', 'output_format')
if query is None:
return ErrorCode.NOT_A_QUESTION, response, references
logger.info('input: %s' % [query, history])
# classify
score = self.classify_service.classfication(query)
if score > 0.8:
logger.debug('Start RAG search')
chunks, references = self.retriever.query(query)
if len(chunks) == 0:
logger.debug('Response by finetune model')
response = await self.response_by_finetune(query, history=history)
chunks = [response]
elif use_template:
logger.debug('Response by template')
response = self.format_rag_result(chunks, references, stream=stream)
return ErrorCode.SUCCESS, response, references
logger.debug('Response with common model')
new_chunks = substitution(chunks)
prompt, history = self.generate_prompt(
instruction=query,
context=new_chunks,
history_pair=history)
logger.debug('prompt: {}'.format(prompt))
response = await self.response_by_common(prompt, history=history, output_format=False, stream=stream)
return ErrorCode.SUCCESS, response, references
else:
logger.debug('Response by common model')
response = await self.response_by_common(query, history=history, output_format=output_format, stream=stream)
return ErrorCode.SUCCESS, response, references