retriever_client.py 6.98 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import requests
import argparse
import json
import re
import time
import os
from loguru import logger


def retrieve(query, url):
    """向后端服务器发送检索请求。"""
    endpoint = f"{url}/retrieve"
    payload = {"query": query}

    try:
        response = requests.post(endpoint, json=payload)
        response.raise_for_status()  # 对于错误的状态码抛出异常
        return response.json()
    except requests.exceptions.RequestException as e:
        logger.error(f"发送请求时发生错误: {e}")
        return None


def remove_punctuation(query):
    return re.sub(r'[^\w\s]|[_]', '', query)


def check_hit_content(query, chunks):
    logger.info(f"查询: {query}")
    query_move = remove_punctuation(query)
    chunks = chunks.get('chunks', '')
    for rank, chunk in enumerate(chunks, 1):
        content = remove_punctuation(chunk['page_content'])
        if query_move.lower() in content.lower():
            return True, 1 / rank
    return False, None


def check_hit(query, chunks):
    """检查查询是否命中任何文档。"""

    json2text_path = '/home/zhangwq/data/for_test_new/json2txt'

    logger.info(f"查询: {query}")
    chunks = chunks.get('chunks', '')

    query_move = remove_punctuation(query)

    for rank, chunk in enumerate(chunks, 1):
        metadata = chunk.get('metadata', '')
        source = metadata.get('source', '')
        logger.info(f"文档: {source}")

        source_original = os.path.splitext(source)[0]
        source_cleaned = remove_punctuation(source_original)

        if re.match(r'json_obj_\d+\.txt', source):
            # 处理 json_obj_{i}.txt 文件
            json_file_path = os.path.join(json2text_path, source)
            logger.info(f"检查 JSON 文件: {json_file_path}")
            try:
                with open(json_file_path, 'r', encoding='utf-8') as f:
                    json_content = f.read()
                    json_content_cleaned = remove_punctuation(json_content)
                    if query_move.lower() in json_content_cleaned.lower():
                        logger.info(f"在 JSON 文件 {source} 中找到匹配!")
                        return True, 1 / rank
            except Exception as e:
                logger.error(f"读取或处理 {json_file_path} 时出错: {str(e)}")

        if query_move.lower() in source_cleaned.lower() or query_move in source_cleaned:
            return True, 1 / rank

    return False, None


def process_queries_with_polluted_file(file_path, url):
    import ast
    """处理文件中的每一行查询。"""
    total_queries = 0
    hits = 0
    total_inverse_rank = 0

    time1 = time.time()

    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            try:
                # data = json.loads(line.strip())
                data = ast.literal_eval(line.strip())
                polluted_query = data['Polluted']
                original_query = data['original']
            except json.JSONDecodeError:
                logger.error(f"无法解析行: {line.strip()}")
                continue
            except KeyError:
                logger.error(f"行缺少必要的键: {line.strip()}")
                continue

            total_queries += 1
            logger.debug(f"处理查询 {total_queries}: 污染查询 = '{polluted_query}', 原始查询 = '{original_query}'")
            results = retrieve(polluted_query, url)
            if results:
                hit, inverse_rank = check_hit(original_query, results)
                if hit:
                    hits += 1
                    total_inverse_rank += inverse_rank
                    logger.debug(f"命中! 倒数排名: {inverse_rank:.4f}")
                else:
                    logger.debug("未命中")

            logger.debug("-" * 50)

    hit_rate = hits / total_queries if total_queries > 0 else 0
    average_inverse_rank = total_inverse_rank / hits if hits > 0 else 0

    time2 = time.time()
    elapsed_time = time2 - time1

    logger.debug(f"总查询数: {total_queries}")
    logger.debug(f"命中数: {hits}")
    logger.debug(f"命中率: {hit_rate:.2%}")
    logger.debug(f"平均倒数排名: {average_inverse_rank:.4f}")
    logger.debug(f"花费时间: {elapsed_time:.4f} 秒")


def process_queries(file_path, url):
    """处理文件中的每一行查询。"""
    total_queries = 0
    hits = 0
    total_inverse_rank = 0

    time1 = time.time()

    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            query = line.strip()
            if not query:
                continue

            total_queries += 1

            results = retrieve(query, url)
            if results:
                hit, inverse_rank = check_hit_content(query, results)
                if hit:
                    hits += 1
                    total_inverse_rank += inverse_rank
                    logger.debug(f"命中! 倒数排名: {inverse_rank:.4f}")
                else:
                    logger.debug("未命中")

            logger.debug("-" * 50)

    hit_rate = hits / total_queries if total_queries > 0 else 0
    average_inverse_rank = total_inverse_rank / hits if hits > 0 else 0

    time2 = time.time()
    elapsed_time = time2 - time1

    logger.debug(f"总查询数: {total_queries}")
    logger.debug(f"命中数: {hits}")
    logger.debug(f"命中率: {hit_rate:.2%}")
    logger.debug(f"平均倒数排名: {average_inverse_rank:.4f}")
    logger.debug(f"花费时间: {elapsed_time:.4f} 秒")


def main():
    parser = argparse.ArgumentParser(description="RAG检索系统的客户端。")
    parser.add_argument("--url", default="http://127.0.0.1:9003", help="后端服务器的URL")
    parser.add_argument("--query", default=None, help="单个查询")
    parser.add_argument("--file", default=None, help="包含查询的文本文件路径")
    parser.add_argument("--polluted_file", default='/home/zhangwq/data/query_test/N_原句_N.txt',
                        help="包含污染查询的文本文件路径")
    args = parser.parse_args()

    logger.info(f"连接到服务器: {args.url}")

    if args.file:

        log_file_path = '/home/zhangwq/project/shu_new/ai/rag/eval/RAG_hit_rate_and_average_inverse_rank-original_query_es_alone.log'
        logger.add(log_file_path, rotation='20MB', compression='zip')
        process_queries(args.file, args.url)

    elif args.polluted_file:

        log_file_path = '/home/zhangwq/project/shu_new/ai/rag/eval/RAG_hit_rate_and_average_inverse_rank-N_原句_N_es_alone.log'
        logger.add(log_file_path, rotation='20MB', compression='zip')
        process_queries_with_polluted_file(args.polluted_file, args.url)

    elif args.query:

        results = retrieve(args.query, args.url)
        if results:
            hit, inverse_rank = check_hit_content(args.query, results)
            if hit:
                logger.info(f"命中! 倒数排名: {inverse_rank:.4f}")
            else:
                logger.info("未命中")
    else:
        logger.error("请提供查询或查询文件")


if __name__ == "__main__":
    main()