worker.py 6.92 KB
Newer Older
chenych's avatar
update  
chenych committed
1
import json
chenych's avatar
chenych committed
2
3
import os
import pickle
Rayyyyy's avatar
Rayyyyy committed
4
from loguru import logger
chenych's avatar
update  
chenych committed
5
6

from .utils import COMMON
chenych's avatar
chenych committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from .helper import ErrorCode
from .http_client import OpenAPIClient, ClassifyModel, CacheRetriever


SECURITY_TEMAPLTE = '判断以下句子是否涉及政治、辱骂、色情、恐暴、宗教、网络暴力、种族歧视等违禁内容,结果用 0~10 表示,不要解释直接给出得分。判断标准:涉其中任一问题直接得 10 分;完全不涉及得 0 分。直接给得分不要解释:“{}”'  # noqa E501
GENERATE_TEMPLATE = '<Data>{}</Data> \n 回答要求:\n如果你不清楚答案,你需要澄清。\n避免提及你是从 <Data></Data> 获取的知识。\n保持答案与 <Data></Data> 中描述的一致。\n使用 Markdown 语法优化回答格式。\n使用与问题相同的语言回答。问题:"{}"'
MARKDOWN_TEMPLATE = '问题:“{}”  \n请使用markdown格式回答此问题'


def substitution(chunks):
    # 翻译特殊字符
    import re
    new_chunks = []
    for chunk in chunks:
        matchObj = re.split('.*(<.*>).*', chunk, re.M|re.I)
        if len(matchObj) > 1:
            obj = matchObj[1]
            replace_str = COMMON.get(obj)
            if replace_str:
                chunk = chunk.replace(obj, replace_str)
                logger.info(f"{obj} be replaced {replace_str}, after {chunk}")
        new_chunks.append(chunk)
    return new_chunks
Rayyyyy's avatar
Rayyyyy committed
30
31


chenych's avatar
chenych committed
32
class Worker:
Rayyyyy's avatar
update  
Rayyyyy committed
33

chenych's avatar
chenych committed
34
    def __init__(self, config):
Rayyyyy's avatar
Rayyyyy committed
35
        self.work_dir = config['default']['work_dir']
chenych's avatar
chenych committed
36
37
38
39
40
        llm_model = config['model']['llm_model']
        local_model = config['model']['local_model']
        llm_service_address = config['model']['llm_service_address']
        cls_model_path = config['model']['cls_model_path']
        local_server_address = config['model']['local_service_address']
Rayyyyy's avatar
Rayyyyy committed
41
        reject_throttle = float(config['feature_database']['reject_throttle'])
chenych's avatar
update  
chenych committed
42
43
        self.embedding_model_path = config['feature_database']['embedding_model_path']
        self.reranker_model_path = config['feature_database']['reranker_model_path']
chenych's avatar
chenych committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
        if not llm_service_address:
            raise Exception('llm_service_address is required in config.ini')
        if not cls_model_path:
            raise Exception('cls_model_path is required in config.ini')

        self.max_input_len = int(config['model']['max_input_length'])

        self.retriever = CacheRetriever(
                self.embedding_model_path,
                self.reranker_model_path).get(reject_throttle=reject_throttle,
                                              work_dir=self.work_dir)
        self.openapi_service = OpenAPIClient(llm_service_address, llm_model)
        self.openapi_local_server = OpenAPIClient(local_server_address, local_model)
        self.classify_service = ClassifyModel(cls_model_path)

        self.tasks = {}
        if os.path.exists(self.work_dir + '/tasks_status.pkl'):
            with open(self.work_dir + '/tasks_status.pkl', 'rb') as f:
                self.tasks = pickle.load(f)
Rayyyyy's avatar
Rayyyyy committed
63
64
65
66
67
68
69

    def generate_prompt(self,
                        history_pair,
                        instruction: str,
                        context: str = ''):

        if context is not None and len(context) > 0:
chenych's avatar
chenych committed
70
71
72
73
            str_context = str(context)
            if len(str_context) > self.max_input_len:
                str_context = str_context[:self.max_input_len]
            instruction = GENERATE_TEMPLATE.format(str_context, instruction)
Rayyyyy's avatar
Rayyyyy committed
74
75
76
77
78
79
80
81
82
83
84

        real_history = []
        for pair in history_pair:
            if pair[0] is None or pair[1] is None:
                continue
            if len(pair[0]) < 1 or len(pair[1]) < 1:
                continue
            real_history.append(pair)

        return instruction, real_history

chenych's avatar
chenych committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    async def generater(self, content):
        for word in content:
            yield word
            #await asyncio.sleep(0.1)

    async def response_by_common(self, query, history, output_format=False, stream=False):
        if output_format:
            query = MARKDOWN_TEMPLATE.format(query)
        logger.info('Prompt is: {}, History is: {}'.format(query, history))
        response_direct = await self.openapi_service.chat(query, history, stream=stream)
        return response_direct

    def format_rag_result(self, chunks, references, stream=False):
        result = "针对您的问题,我们找到了如下解决方案:\n%s"
        content = ""
        for i, item in enumerate(references):
            if item.endswith(".json"):
                content += " - %s.%s\n" % (i + 1, chunks[i])
            else:
                line = chunks[i]
                if len(line) > 300:
                    line = line[:300] + "..." + '\n'
                    line += "详细内容参见:%s" % item
                content += " - %s.%s\n" % (i + 1, line)
        if stream:
            return self.generater((result % content))
        return result % content

    def response_by_finetune(self, query, history=[]):
        '''微调模型回答'''
        logger.info('Prompt is: {}, History is: {}'.format(query, history))
        response_direct = self.openapi_local_server.chat(query, history)
chenych's avatar
update  
chenych committed
117
118
119
        data = json.loads(response_direct.content.decode("utf-8"))
        output = data["text"]
        return output
chenych's avatar
chenych committed
120
121

    async def produce_response(self, config, query, history, stream=False):
Rayyyyy's avatar
Rayyyyy committed
122
123
        response = ''
        references = []
chenych's avatar
chenych committed
124
125
        use_template = config.getboolean('default', 'use_template')
        output_format = config.getboolean('default', 'output_format')
Rayyyyy's avatar
Rayyyyy committed
126
127
128
129
130
131

        if query is None:
            return ErrorCode.NOT_A_QUESTION, response, references

        logger.info('input: %s' % [query, history])

chenych's avatar
chenych committed
132
133
134
135
136
        # classify
        score = self.classify_service.classfication(query)
        if score > 0.8:
            logger.debug('Start RAG search')
            chunks, references = self.retriever.query(query)
Rayyyyy's avatar
Rayyyyy committed
137
138

            if len(chunks) == 0:
chenych's avatar
chenych committed
139
                logger.debug('Response by finetune model')
chenych's avatar
chenych committed
140
141
                response = await self.response_by_finetune(query, history=history)
                chunks = [response]
chenych's avatar
chenych committed
142
143
144
            elif use_template:
                logger.debug('Response by template')
                response = self.format_rag_result(chunks, references, stream=stream)
Rayyyyy's avatar
Rayyyyy committed
145
146
                return ErrorCode.SUCCESS, response, references

chenych's avatar
chenych committed
147
148
149
150
151
152
153
154
155
156
            logger.debug('Response with common model')
            new_chunks = substitution(chunks)
            prompt, history = self.generate_prompt(
                instruction=query,
                context=new_chunks,
                history_pair=history)

            logger.debug('prompt: {}'.format(prompt))
            response = await self.response_by_common(prompt, history=history, output_format=False, stream=stream)
            return ErrorCode.SUCCESS, response, references
Rayyyyy's avatar
Rayyyyy committed
157
        else:
chenych's avatar
chenych committed
158
159
160
            logger.debug('Response by common model')
            response = await self.response_by_common(query, history=history, output_format=output_format, stream=stream)
            return ErrorCode.SUCCESS, response, references