import os import argparse import time import configparser import numpy as np from aiohttp import web from multiprocessing import Value from torch.cuda import empty_cache from BCEmbedding.tools.langchain import BCERerank from langchain_community.embeddings import HuggingFaceEmbeddings from langchain.retrievers import ContextualCompressionRetriever from langchain_community.vectorstores import FAISS from langchain_community.vectorstores.utils import DistanceStrategy from sklearn.metrics import precision_recall_curve from loguru import logger class Retriever: def __init__(self, embeddings, reranker, work_dir: str, reject_throttle: float) -> None: self.reject_throttle = reject_throttle self.rejecter = None self.retriever = None self.compression_retriever = None self.embeddings = embeddings self.reranker = reranker self.rejecter = FAISS.load_local( os.path.join(work_dir, 'db_response'), embeddings=embeddings, allow_dangerous_deserialization=True) self.vector_store = FAISS.load_local( os.path.join(work_dir, 'db_response'), embeddings=embeddings, allow_dangerous_deserialization=True, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT) self.retriever = self.vector_store.as_retriever( search_type='similarity', search_kwargs={ 'score_threshold': self.reject_throttle, 'k': 30 } ) # retriever = self.vector_store.as_retriever( # search_type='similarity', # search_kwargs={ # 'score_threshold': 0.15, # 'k': 30 # } # ) self.compression_retriever = ContextualCompressionRetriever( base_compressor=reranker, base_retriever=self.retriever) if self.rejecter is None: logger.warning('rejecter is None') if self.retriever is None: logger.warning('retriever is None') def is_relative(self, sample, k=30, disable_throttle=False): """If no search results below the threshold can be found from the database, reject this query.""" if self.rejecter is None: return False, [] if disable_throttle: # for searching throttle during update sample docs_with_score = self.rejecter.similarity_search_with_relevance_scores( sample, k=1) if len(docs_with_score) < 1: return False, docs_with_score return True, docs_with_score else: # for retrieve result # if no chunk passed the throttle, give the max docs_with_score = self.rejecter.similarity_search_with_relevance_scores( sample, k=k) ret = [] max_score = -1 top1 = None for (doc, score) in docs_with_score: if score >= self.reject_throttle: ret.append(doc) if score > max_score: max_score = score top1 = (doc, score) relative = True if len(ret) > 0 else False return relative, [top1] def update_throttle(self, work_dir: str, config_path: str = 'config.ini', positive_sample=[], negative_sample=[]): import matplotlib.pyplot as plt """Update reject throttle based on positive and negative examples.""" if len(positive_sample) == 0 or len(negative_sample) == 0: raise Exception('positive and negative samples cat not be empty.') all_samples = positive_sample + negative_sample predictions = [] for sample in all_samples: self.reject_throttle = -1 _, docs = self.is_relative(sample=sample, disable_throttle=True) score = docs[0][1] predictions.append(max(0, score)) labels = [1 for _ in range(len(positive_sample)) ] + [0 for _ in range(len(negative_sample))] precision, recall, thresholds = precision_recall_curve( labels, predictions) plt.figure(figsize=(10, 8)) plt.plot(recall, precision, label='Precision-Recall curve') plt.xlabel('Recall') plt.ylabel('Precision') plt.title('Precision-Recall Curve') plt.legend(loc='best') plt.grid(True) plt.savefig(os.path.join(work_dir, 'precision_recall_curve.png'), format='png') plt.close() logger.debug("Figure have been saved!") thresholds = np.append(thresholds, 1) max_precision = np.max(precision) indices_with_max_precision = np.where(precision == max_precision) optimal_recall = recall[indices_with_max_precision[0][0]] optimal_threshold = thresholds[indices_with_max_precision[0][0]] logger.debug(f"Optimal threshold with the highest recall under the highest precision is: {optimal_threshold}") logger.debug(f"The corresponding precision is: {max_precision}") logger.debug(f"The corresponding recall is: {optimal_recall}") config = configparser.ConfigParser() config.read(config_path) config['feature_database']['reject_throttle'] = str(optimal_threshold) with open(config_path, 'w') as configfile: config.write(configfile) logger.info( f'Update optimal threshold: {optimal_threshold} to {config_path}' # noqa E501 ) return optimal_threshold def query(self, question: str, ): time_1 = time.time() if question is None or len(question) < 1: return None, None, [] if len(question) > 512: logger.warning('input too long, truncate to 512') question = question[0:512] chunks = [] references = [] relative, docs = self.is_relative(sample=question) if relative: docs = self.compression_retriever.get_relevant_documents(question) for doc in docs: doc = [doc.page_content] chunks.append(doc) # chunks = [doc.page_content for doc in docs] references = [doc.metadata['source'] for doc in docs] time_2 = time.time() logger.debug('query:{} \nchunks:{} \nreferences:{} \ntimecost:{}' .format(question, chunks, references, time_2 - time_1)) return chunks, [os.path.basename(r) for r in references] else: if len(docs) > 0: references.append(docs[0][0].metadata['source']) logger.info('feature database rejected!') return chunks, references class CacheRetriever: def __init__(self, embedding_model_path: str, reranker_model_path: str, max_len: int = 4): self.cache = dict() self.max_len = max_len # load text2vec and rerank model logger.info('loading test2vec and rerank models') self.embeddings = HuggingFaceEmbeddings( model_name=embedding_model_path, model_kwargs={'device': 'cuda'}, encode_kwargs={ 'batch_size': 1, 'normalize_embeddings': True }) # half self.embeddings.client = self.embeddings.client.half() reranker_args = { 'model': reranker_model_path, 'top_n': 3, 'device': 'cuda', 'use_fp16': True } self.reranker = BCERerank(**reranker_args) def get(self, reject_throttle: float, fs_id: str = 'default', work_dir='workdir' ): if fs_id in self.cache: self.cache[fs_id]['time'] = time.time() return self.cache[fs_id]['retriever'] if len(self.cache) >= self.max_len: # drop the oldest one del_key = None min_time = time.time() for key, value in self.cache.items(): cur_time = value['time'] if cur_time < min_time: min_time = cur_time del_key = key if del_key is not None: del_value = self.cache[del_key] self.cache.pop(del_key) del del_value['retriever'] retriever = Retriever(embeddings=self.embeddings, reranker=self.reranker, work_dir=work_dir, reject_throttle=reject_throttle) self.cache[fs_id] = {'retriever': retriever, 'time': time.time()} return retriever def pop(self, fs_id: str): if fs_id not in self.cache: return del_value = self.cache[fs_id] self.cache.pop(fs_id) # manually free memory del del_value def rag_retrieve(config_path: str, server_ready): """ 启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 rag 检索服务. """ config = configparser.ConfigParser() config.read(config_path) bind_port = int(config['default']['bind_port']) work_dir = config['default']['work_dir'] try: retriever = CacheRetriever(config_path=config_path).get(config_path=config_path, work_dir=work_dir) server_ready.value = 1 except Exception as e: server_ready.value = -1 raise (e) async def retrieve(request): input_json = await request.json() query = input_json['query'] chunks, ref = retriever.query(query) return web.json_response({'chunks': chunks, 'ref': ref}) app = web.Application() app.add_routes([web.post('/retrieve', retrieve)]) web.run_app(app, host='0.0.0.0', port=bind_port) def test_query(retriever: Retriever, real_questions): """Simple test response pipeline.""" if real_questions is None or not real_questions: logger.error("No questions provided or real_questions is empty.") return None else: logger.add('logs/feature_store_query.log', rotation='4MB') for example in real_questions: example = example[0:400] retriever.query(example) empty_cache() empty_cache() def set_envs(dcu_ids): try: os.environ["CUDA_VISIBLE_DEVICES"] = dcu_ids logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {dcu_ids}") except Exception as e: logger.error(f"{e}, but got {dcu_ids}") raise ValueError(f"{e}") def parse_args(): parser = argparse.ArgumentParser( description='Feature store for processing directories.') parser.add_argument( '--config_path', default='/path/of/config.ini', help='config目录') parser.add_argument( '--query', default=['先有鸡还是先有蛋?', '写一首五言律诗?'], help='提问的问题.') parser.add_argument( '--DCU_ID', type=str, default='0', help='设置DCU卡号,卡号之间用英文逗号隔开,输入样例:"0,1,2"') args = parser.parse_args() return args def main(): args = parse_args() set_envs(args) config = configparser.ConfigParser() config.read(args.config_path) embedding_model_path = config['feature_database']['embedding_model_path'] reranker_model_path = config['feature_database']['reranker_model_path'] cache = CacheRetriever(embedding_model_path=embedding_model_path, reranker_model_path=reranker_model_path) retriever = cache.get(reject_throttle=float(config['feature_database']['reject_throttle']), work_dir=config['default']['work_dir']) test_query(retriever, args.query) # server_ready = Value('i', 0) # rag_retrieve(config_path=args.config_path, # server_ready=server_ready) if __name__ == '__main__': main()