import argparse import configparser import os import requests import time import uvicorn from fastapi import FastAPI, Request from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from sklearn.metrics import precision_recall_curve from loguru import logger from BCEmbedding.tools.langchain import BCERerank from langchain.retrievers import ContextualCompressionRetriever from langchain_community.vectorstores.utils import DistanceStrategy from requests.exceptions import RequestException from elastic_keywords_search import ElasticKeywordsSearch app = FastAPI() class Retriever: def __init__(self, config) -> None: self.mix, self.es, self.vector = None, None, None work_dir = config['default']['work_dir'] self.es_top_k = int(config['rag']['es_top_k']) self.vector_top_k = int(config['rag']['vector_top_k']) embedding_model_path = config.get('rag', 'embedding_model_path') or None reranker_model_path = config.get('rag', 'reranker_model_path') or None es_url = config.get('rag', 'es_url') or None index_name = config.get('rag', 'index_name') or None # Mix if embedding_model_path and reranker_model_path and es_url and index_name: self.init_mix_retriever(work_dir, embedding_model_path, reranker_model_path, es_url, index_name) # ES elif not embedding_model_path or not reranker_model_path: if self.is_es_available(es_url, index_name): self.es_retriever = ElasticKeywordsSearch(es_url, index_name, drop_old=False) self.weights = [0.5, 0.5] self.es = True logger.info('Initializing ES retriever alone!') # Vector elif not es_url or not index_name: self.init_vector_retriever(work_dir, embedding_model_path, reranker_model_path) self.vector = True logger.info('Initializing Vector retriever alone!') else: raise ValueError( "Incomplete configuration. Please specify all required parameters for either vector or ES retrieval.") def init_vector_retriever(self, work_dir, embedding_model_path, reranker_model_path): logger.info('loading test2vec and rerank models') self.embeddings = HuggingFaceEmbeddings( model_name=embedding_model_path, model_kwargs={'device': 'cuda'}, encode_kwargs={ 'batch_size': 1, 'normalize_embeddings': True }) # half self.embeddings.client = self.embeddings.client.half() reranker_args = { 'model': reranker_model_path, 'top_n': self.vector_top_k, 'device': 'cuda', 'use_fp16': True } self.reranker = BCERerank(**reranker_args) self.vector_store = FAISS.load_local( os.path.join(work_dir, 'db_response'), embeddings=self.embeddings, allow_dangerous_deserialization=True, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT) retriever = self.vector_store.as_retriever( search_type='similarity', search_kwargs={ 'score_threshold': 0.15, 'k': 30 } ) self.compression_retriever = ContextualCompressionRetriever( base_compressor=self.reranker, base_retriever=retriever) def init_mix_retriever(self, work_dir, embedding_model_path, reranker_model_path, es_url, index_name): if self.is_es_available(es_url, index_name): self.es_retriever = ElasticKeywordsSearch(es_url, index_name, drop_old=False) self.weights = [0.5, 0.5] self.init_vector_retriever(work_dir, embedding_model_path, reranker_model_path) self.mix = True logger.info('Initializing Mix retriever!') else: self.init_vector_retriever(work_dir, embedding_model_path, reranker_model_path) self.vector = True logger.info('Initializing Vector retriever alone!') def is_es_available(self, url, index_name, timeout=5): try: response = requests.get(f"{url}/_cluster/health", timeout=timeout) if response.status_code == 200: index_response = requests.head(f"{url}/{index_name}", timeout=timeout) if index_response.status_code == 200: logger.info(f"The index:'{index_name}' exist!") return True elif index_response.status_code == 404: logger.warning(f"The index:'{index_name}' not exist!") else: logger.error(f"Unexpected status code when checking index: {index_response.status_code}") else: logger.error(f"Elasticsearch service returned non-200 status code: {response.status_code}") except RequestException as e: logger.error(f"Error connecting to Elasticsearch service: {e}") return False def weighted_reciprocal_rank(self, es_docs, vector_docs): # Create a union of all unique documents in the input doc_lists all_documents = set() for vector_doc in vector_docs: all_documents.add(vector_doc.page_content) for es_doc in es_docs: all_documents.add(es_doc.page_content) rrf_score_dic = {doc: 0.0 for doc in all_documents} for rank, vector_doc in enumerate(vector_docs, start=1): rrf_score = self.weights[1] * (1 / (rank + 60)) rrf_score_dic[vector_doc.page_content] += rrf_score for rank, es_doc in enumerate(es_docs, start=1): rrf_score = self.weights[0] * (1 / (rank + 60)) rrf_score_dic[es_doc.page_content] += rrf_score sorted_documents = sorted(rrf_score_dic.keys(), key=lambda x: rrf_score_dic[x], reverse=True) # Map the sorted page_content back to the original document objects page_content_to_doc_map = {} for doc in es_docs: page_content_to_doc_map[doc.page_content] = doc for doc in vector_docs: page_content_to_doc_map[doc.page_content] = doc sorted_docs = [page_content_to_doc_map[page_content] for page_content in sorted_documents] return sorted_docs def remove_duplicates(self, sorted_docs): seen = set() unique_docs = [] for doc in sorted_docs: identifier = ( doc.metadata.get('source', ''), doc.metadata.get('read', ''), # doc.page_content # Need further testing ) if identifier not in seen: seen.add(identifier) unique_docs.append(doc) return unique_docs def hybrid_retrieval(self, query): es_docs = self.es_retriever.similarity_search_with_score(query, k=self.es_top_k) vector_docs = self.query(query) sorted_docs = self.weighted_reciprocal_rank(es_docs, vector_docs) unique_docs = self.remove_duplicates(sorted_docs) return unique_docs def rag_workflow(self, query): chunks = [] time_1 = time.time() if self.mix: chunks = self.hybrid_retrieval(query) elif self.es: chunks = self.es_retriever.similarity_search_with_score(query, k=self.es_top_k) else: chunks = self.query(query) time_2 = time.time() logger.debug(f'query:{query} \nchunks:{chunks} \ntimecost:{time_2 - time_1}') return chunks def query(self, question: str): if question is None or len(question) < 1: return None if len(question) > 512: logger.warning('input too long, truncate to 512') question = question[0:512] docs = self.compression_retriever.get_relevant_documents(question) return docs retriever = None @app.post("/retrieve") async def retrieve(request: Request): data = await request.json() query = data.get("query") chunks = retriever.rag_workflow(query) return {"chunks": chunks} def rag_retrieve(args: str): """ 启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 rag 检索服务. """ global retriever config = configparser.ConfigParser() config.read(args.config_path) bind_port = int(config['default']['bind_port']) os.environ["CUDA_VISIBLE_DEVICES"] = args.dcu_id retriever = Retriever(config) uvicorn.run(app, host="0.0.0.0", port=bind_port) def parse_args(): parser = argparse.ArgumentParser( description='Feature store for processing directories.') parser.add_argument( '--config_path', default='/ai/rag/config.ini', help='config目录') parser.add_argument( '--dcu_id', default=None, help='设置DCU') args = parser.parse_args() return args def main(): args = parse_args() rag_retrieve(args) if __name__ == '__main__': main()