import requests import argparse import json import re import time import os from loguru import logger def retrieve(query, url): """向后端服务器发送检索请求。""" endpoint = f"{url}/retrieve" payload = {"query": query} try: response = requests.post(endpoint, json=payload) response.raise_for_status() # 对于错误的状态码抛出异常 return response.json() except requests.exceptions.RequestException as e: logger.error(f"发送请求时发生错误: {e}") return None def remove_punctuation(query): return re.sub(r'[^\w\s]|[_]', '', query) def check_hit_content(query, chunks): logger.info(f"查询: {query}") query_move = remove_punctuation(query) chunks = chunks.get('chunks', '') for rank, chunk in enumerate(chunks, 1): content = remove_punctuation(chunk['page_content']) if query_move.lower() in content.lower(): return True, 1 / rank return False, None def check_hit(query, chunks): """检查查询是否命中任何文档。""" json2text_path = '/home/zhangwq/data/for_test_new/json2txt' logger.info(f"查询: {query}") chunks = chunks.get('chunks', '') query_move = remove_punctuation(query) for rank, chunk in enumerate(chunks, 1): metadata = chunk.get('metadata', '') source = metadata.get('source', '') logger.info(f"文档: {source}") source_original = os.path.splitext(source)[0] source_cleaned = remove_punctuation(source_original) if re.match(r'json_obj_\d+\.txt', source): # 处理 json_obj_{i}.txt 文件 json_file_path = os.path.join(json2text_path, source) logger.info(f"检查 JSON 文件: {json_file_path}") try: with open(json_file_path, 'r', encoding='utf-8') as f: json_content = f.read() json_content_cleaned = remove_punctuation(json_content) if query_move.lower() in json_content_cleaned.lower(): logger.info(f"在 JSON 文件 {source} 中找到匹配!") return True, 1 / rank except Exception as e: logger.error(f"读取或处理 {json_file_path} 时出错: {str(e)}") if query_move.lower() in source_cleaned.lower() or query_move in source_cleaned: return True, 1 / rank return False, None def process_queries_with_polluted_file(file_path, url): import ast """处理文件中的每一行查询。""" total_queries = 0 hits = 0 total_inverse_rank = 0 time1 = time.time() with open(file_path, 'r', encoding='utf-8') as file: for line in file: try: # data = json.loads(line.strip()) data = ast.literal_eval(line.strip()) polluted_query = data['Polluted'] original_query = data['original'] except json.JSONDecodeError: logger.error(f"无法解析行: {line.strip()}") continue except KeyError: logger.error(f"行缺少必要的键: {line.strip()}") continue total_queries += 1 logger.debug(f"处理查询 {total_queries}: 污染查询 = '{polluted_query}', 原始查询 = '{original_query}'") results = retrieve(polluted_query, url) if results: hit, inverse_rank = check_hit(original_query, results) if hit: hits += 1 total_inverse_rank += inverse_rank logger.debug(f"命中! 倒数排名: {inverse_rank:.4f}") else: logger.debug("未命中") logger.debug("-" * 50) hit_rate = hits / total_queries if total_queries > 0 else 0 average_inverse_rank = total_inverse_rank / hits if hits > 0 else 0 time2 = time.time() elapsed_time = time2 - time1 logger.debug(f"总查询数: {total_queries}") logger.debug(f"命中数: {hits}") logger.debug(f"命中率: {hit_rate:.2%}") logger.debug(f"平均倒数排名: {average_inverse_rank:.4f}") logger.debug(f"花费时间: {elapsed_time:.4f} 秒") def process_queries(file_path, url): """处理文件中的每一行查询。""" total_queries = 0 hits = 0 total_inverse_rank = 0 time1 = time.time() with open(file_path, 'r', encoding='utf-8') as file: for line in file: query = line.strip() if not query: continue total_queries += 1 results = retrieve(query, url) if results: hit, inverse_rank = check_hit_content(query, results) if hit: hits += 1 total_inverse_rank += inverse_rank logger.debug(f"命中! 倒数排名: {inverse_rank:.4f}") else: logger.debug("未命中") logger.debug("-" * 50) hit_rate = hits / total_queries if total_queries > 0 else 0 average_inverse_rank = total_inverse_rank / hits if hits > 0 else 0 time2 = time.time() elapsed_time = time2 - time1 logger.debug(f"总查询数: {total_queries}") logger.debug(f"命中数: {hits}") logger.debug(f"命中率: {hit_rate:.2%}") logger.debug(f"平均倒数排名: {average_inverse_rank:.4f}") logger.debug(f"花费时间: {elapsed_time:.4f} 秒") def main(): parser = argparse.ArgumentParser(description="RAG检索系统的客户端。") parser.add_argument("--url", default="http://127.0.0.1:9003", help="后端服务器的URL") parser.add_argument("--query", default=None, help="单个查询") parser.add_argument("--file", default=None, help="包含查询的文本文件路径") parser.add_argument("--polluted_file", default='/home/zhangwq/data/query_test/N_原句_N.txt', help="包含污染查询的文本文件路径") args = parser.parse_args() logger.info(f"连接到服务器: {args.url}") if args.file: log_file_path = '/home/zhangwq/project/shu_new/ai/rag/eval/RAG_hit_rate_and_average_inverse_rank-original_query_es_alone.log' logger.add(log_file_path, rotation='20MB', compression='zip') process_queries(args.file, args.url) elif args.polluted_file: log_file_path = '/home/zhangwq/project/shu_new/ai/rag/eval/RAG_hit_rate_and_average_inverse_rank-N_原句_N_es_alone.log' logger.add(log_file_path, rotation='20MB', compression='zip') process_queries_with_polluted_file(args.polluted_file, args.url) elif args.query: results = retrieve(args.query, args.url) if results: hit, inverse_rank = check_hit_content(args.query, results) if hit: logger.info(f"命中! 倒数排名: {inverse_rank:.4f}") else: logger.info("未命中") else: logger.error("请提供查询或查询文件") if __name__ == "__main__": main()