worker.py 7.48 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from loguru import logger
from .helper import ErrorCode, LogManager
from .retriever import CacheRetriever
from .inferencer import LLMInference
from .feature_database import DocumentProcessor, FeatureDataBase


class ChatAgent:
    def __init__(self, config, tensor_parallel_size) -> None:
        self.work_dir = config['default']['work_dir']
        self.embedding_model_path = config['feature_database']['embedding_model_path']
        self.reranker_model_path = config['feature_database']['reranker_model_path']
        reject_throttle = float(config['feature_database']['reject_throttle'])
        local_llm_path = config['llm']['local_llm_path']
        accelerate = config.getboolean('llm', 'accelerate')

        self.retriever = CacheRetriever(self.embedding_model_path,
                                        self.reranker_model_path).get(reject_throttle=reject_throttle,
                                                                      work_dir=self.work_dir)
        self.llm_server = LLMInference(local_llm_path, tensor_parallel_size, accelerate=accelerate)

    def generate_prompt(self,
                        history_pair,
                        instruction: str,
                        template: str,
                        context: str = ''):

        if context is not None and len(context) > 0:
            instruction = template.format(context, instruction)

        real_history = []
        for pair in history_pair:
            if pair[0] is None or pair[1] is None:
                continue
            if len(pair[0]) < 1 or len(pair[1]) < 1:
                continue
            real_history.append(pair)

        return instruction, real_history

    def call_rag_retrieve(self, query):
        return self.retriever.query(query)

    def call_llm_response(self, prompt, history=None):
        text, error = self.llm_server.generate_response(prompt=prompt, history=history)
        return text

    def parse_file_and_merge(self, file_dir):
        file_opr = DocumentProcessor()
        files = file_opr.scan_directory(repo_dir=file_dir)
        file_handler = FeatureDataBase(embeddings=self.retriever.embeddings, reranker=self.retriever.reranker)
        file_handler.preprocess(files=files, work_dir=self.work_dir, file_opr=file_opr)
        file_handler.merge_db_response(self.retriever.vector_store, files=files, work_dir=self.work_dir, file_opr=file_opr)
        file_opr.summarize(files)
        self.retriever = CacheRetriever(self.embedding_model_path, self.reranker_model_path).get(work_dir=self.work_dir)



class Worker:
    def __init__(self, config, tensor_parallel_size):

        self.agent = ChatAgent(config, tensor_parallel_size)

        self.TOPIC_TEMPLATE = '告诉我这句话的主题,直接说主题不要解释:“{}”'
        self.SCORING_RELAVANCE_TEMPLATE = '问题:“{}”\n材料:“{}”\n请仔细阅读以上内容,材料里为一个列表,列表里面有若干子列表,请判断每个子列表的内容和问题的相关度,不要解释直接给出相关度得分列表并以空格分隔,用0~10表示。判断标准:非常相关得 10 分;完全没关联得 0 分。\n'  # noqa E501
        self.KEYWORDS_TEMPLATE = '谷歌搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。搜索参数类型 string, 内容是短语或关键字,以空格分隔。\n你现在是搜搜小助手,用户提问“{}”,你打算通过谷歌搜索查询相关资料,请提供用于搜索的关键字或短语,不要解释直接给出关键字或短语。'  # noqa E501
        self.SECURITY_TEMAPLTE = '判断以下句子是否涉及政治、辱骂、色情、恐暴、宗教、网络暴力、种族歧视等违禁内容,结果用 0~10 表示,不要解释直接给出得分。判断标准:涉其中任一问题直接得 10 分;完全不涉及得 0 分。直接给得分不要解释:“{}”'  # noqa E501
        self.PERPLESITY_TEMPLATE = '“question:{} answer:{}”\n阅读以上对话,answer 是否在表达自己不知道,回答越全面得分越少,用0~10表示,不要解释直接给出得分。\n判断标准:准确回答问题得 0 分;答案详尽得 1 分;知道部分答案但有不确定信息得 8 分;知道小部分答案但推荐求助其他人得 9 分;不知道任何答案直接推荐求助别人得 10 分。直接打分不要解释。'  # noqa E501
        self.SUMMARIZE_TEMPLATE = '{} \n 仔细阅读以上内容,总结得简短有力点'  # noqa E501
        self.GENERATE_TEMPLATE = '“{}” \n问题:“{}”  \n请仔细阅读上述文字, 并使用markdown格式回答问题,直接给出回答不做任何解释。'  # noqa E501
        self.MARKDOWN_TEMPLATE = '问题:“{}”  \n请使用markdown格式回答此问题'

    def judgment_results(self, query, chunks, throttle):

        relation_score = self.agent.call_llm_response(
            prompt=self.SCORING_RELAVANCE_TEMPLATE.format(query, chunks))
        logger.info('score: %s' % [relation_score, throttle])

        # 过滤操作
        filtered_chunks = []

        for chunk, score in zip(chunks, relation_score.split()):
            if float(score) >= float(throttle):
                filtered_chunks.append(chunk)

        return filtered_chunks

    def extract_topic(self, query):

        topic = self.agent.call_llm_response(self.TOPIC_TEMPLATE.format(query))
        return topic

    def response_direct_by_llm(self, query):
        # Compliant check
        prompt = self.SECURITY_TEMAPLTE.format(query)
        score = self.agent.call_llm_response(prompt=prompt)
        logger.debug("score:{}, prompt:{}".format(score, prompt))
        if int(score) > 5:
            return ErrorCode.NON_COMPLIANCE_QUESTION, "您的问题中涉及敏感话题,请重新提问。", None
        logger.info('LLM direct response and prompt is: {}'.format(query))
        prompt = self.MARKDOWN_TEMPLATE.format(query)
        response_direct = self.agent.call_llm_response(prompt=prompt)
        return ErrorCode.NOT_FIND_RELATED_DOCS, response_direct, None


    def produce_response(self, query,
                 history,
                 judgment,
                 topic=False,
                 rag=True):

        response = ''
        references = []

        if query is None:
            return ErrorCode.NOT_A_QUESTION, response, references

        logger.info('input: %s' % [query, history])
        if rag:
            if topic:
                query = self.extract_topic(query)
                logger.info('topic: %s' % query)

                if len(query) <= 0:
                    return ErrorCode.NO_TOPIC, response, references

            chunks, references = self.agent.call_rag_retrieve(query)

            if len(chunks) == 0:
                return self.response_direct_by_llm(query)

            if judgment:
                chunks = self.judgment_results(
                    query, chunks,
                    throttle=5,
                )
            # 如果DataBase检索到了,就用检索到的块去回答
            if len(chunks) > 0:
                prompt, history = self.agent.generate_prompt(
                    instruction=query,
                    context=chunks,
                    history_pair=history,
                    template=self.GENERATE_TEMPLATE)

                logger.debug('prompt: {}'.format(prompt))
                response = self.agent.call_llm_response(prompt=prompt, history=history)
                return ErrorCode.SUCCESS, response, references

        else:
            return self.response_direct_by_llm(query)