import os import time import json import httpx import configparser import torch import numpy as np from loguru import logger from BCEmbedding.tools.langchain import BCERerank from langchain_community.embeddings import HuggingFaceEmbeddings from langchain.retrievers import ContextualCompressionRetriever from langchain_community.vectorstores import FAISS from langchain_community.vectorstores.utils import DistanceStrategy from sklearn.metrics import precision_recall_curve from transformers import BertForSequenceClassification, BertTokenizer def build_history_messages(prompt, history, system: str = None): history_messages = [] if system is not None and len(system) > 0: history_messages.append({'role': 'system', 'content': system}) for item in history: history_messages.append({'role': 'user', 'content': item[0]}) history_messages.append({'role': 'assistant', 'content': item[1]}) history_messages.append({'role': 'user', 'content': prompt}) return history_messages class OpenAPIClient: def __init__(self, url: str, model_name): self.url = '%s/v1/chat/completions' % url self.model_name = model_name async def get_streaming_response(self, headers, data): async with httpx.AsyncClient() as client: async with client.stream("POST", self.url, json=data, headers=headers, timeout=300) as response: async for line in response.aiter_lines(): if not line or 'DONE' in line: continue try: result = json.loads(line.split('data:')[1]) output = result['choices'][0]['delta'].get('content') except Exception as e: logger.error('Model response parse failed:', e) raise Exception('Model response parse failed.') if not output: continue yield output async def get_response(self, headers, data): async with httpx.AsyncClient() as client: resp = await client.post(self.url, json=data, headers=headers, timeout=300) try: result = json.loads(resp.content.decode("utf-8")) output = result['choices'][0]['message']['content'] except Exception as e: logger.error('Model response parse failed:', e) raise Exception('Model response parse failed.') return output async def chat(self, prompt: str, history=[], stream=False): header = {'Content-Type': 'application/json'} # Add history to data data = { "model": self.model_name, "messages": build_history_messages(prompt, history), "stream": stream } logger.info("Request openapi param: {}".format(data)) if stream: return self.get_streaming_response(header, data) else: return await self.get_response(header, data) class ClassifyModel: def __init__(self, model_path, dcu_id): logger.info("Starting initial bert class model") self.cls_model = BertForSequenceClassification.from_pretrained(model_path).float().cuda() self.cls_model.load_state_dict(torch.load(os.path.join(model_path, 'bert_cls_model.pth'))) self.cls_model.eval() self.cls_tokenizer = BertTokenizer.from_pretrained(model_path) os.environ["CUDA_VISIBLE_DEVICES"] = dcu_id logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {dcu_id}") def classfication(self, sentence): inputs = self.cls_tokenizer( sentence, max_length=512, truncation="longest_first", return_tensors="pt") inputs = inputs.to('cuda') with torch.no_grad(): outputs = self.cls_model(**inputs) logits = outputs[0] score = torch.max(logits.data, 1)[1].tolist() logger.info("分类结果: {}, {}".format(score[0], sentence)) return float(score[0]) class CacheRetriever: def __init__(self, embedding_model_path: str, reranker_model_path: str, max_len: int = 4): self.cache = dict() self.max_len = max_len # load text2vec and rerank model logger.info('loading test2vec and rerank models') self.embeddings = HuggingFaceEmbeddings( model_name=embedding_model_path, model_kwargs={'device': 'cuda'}, encode_kwargs={ 'batch_size': 1, 'normalize_embeddings': True }) # half self.embeddings.client = self.embeddings.client.half() reranker_args = { 'model': reranker_model_path, 'top_n': 3, 'device': 'cuda', 'use_fp16': True } self.reranker = BCERerank(**reranker_args) def get(self, reject_throttle: float, fs_id: str = 'default', work_dir='workdir' ): if fs_id in self.cache: self.cache[fs_id]['time'] = time.time() return self.cache[fs_id]['retriever'] if len(self.cache) >= self.max_len: # drop the oldest one del_key = None min_time = time.time() for key, value in self.cache.items(): cur_time = value['time'] if cur_time < min_time: min_time = cur_time del_key = key if del_key is not None: del_value = self.cache[del_key] self.cache.pop(del_key) del del_value['retriever'] retriever = Retriever(embeddings=self.embeddings, reranker=self.reranker, work_dir=work_dir, reject_throttle=reject_throttle) self.cache[fs_id] = {'retriever': retriever, 'time': time.time()} return retriever def pop(self, fs_id: str): if fs_id not in self.cache: return del_value = self.cache[fs_id] self.cache.pop(fs_id) # manually free memory del del_value class Retriever: def __init__(self, embeddings, reranker, work_dir: str, reject_throttle: float) -> None: self.reject_throttle = reject_throttle self.rejecter = None self.retriever = None self.compression_retriever = None self.embeddings = embeddings self.reranker = reranker self.rejecter = FAISS.load_local( os.path.join(work_dir, 'db_response'), embeddings=embeddings, allow_dangerous_deserialization=True) self.vector_store = FAISS.load_local( os.path.join(work_dir, 'db_response'), embeddings=embeddings, allow_dangerous_deserialization=True, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT) self.retriever = self.vector_store.as_retriever( search_type='similarity', search_kwargs={ 'score_threshold': self.reject_throttle, 'k': 30 } ) self.compression_retriever = ContextualCompressionRetriever( base_compressor=reranker, base_retriever=self.retriever) if self.rejecter is None: logger.warning('rejecter is None') if self.retriever is None: logger.warning('retriever is None') def is_relative(self, sample, k=30, disable_throttle=False): """If no search results below the threshold can be found from the database, reject this query.""" if self.rejecter is None: return False, [] if disable_throttle: # for searching throttle during update sample docs_with_score = self.rejecter.similarity_search_with_relevance_scores( sample, k=1) if len(docs_with_score) < 1: return False, docs_with_score return True, docs_with_score else: # for retrieve result # if no chunk passed the throttle, give the max docs_with_score = self.rejecter.similarity_search_with_relevance_scores( sample, k=k) ret = [] max_score = -1 top1 = None for (doc, score) in docs_with_score: if score >= self.reject_throttle: ret.append(doc) if score > max_score: max_score = score top1 = (doc, score) relative = True if len(ret) > 0 else False return relative, [top1] def update_throttle(self, work_dir: str, config_path: str = 'config.ini', positive_sample=[], negative_sample=[]): import matplotlib.pyplot as plt """Update reject throttle based on positive and negative examples.""" if len(positive_sample) == 0 or len(negative_sample) == 0: raise Exception('positive and negative samples cat not be empty.') all_samples = positive_sample + negative_sample predictions = [] for sample in all_samples: self.reject_throttle = -1 _, docs = self.is_relative(sample=sample, disable_throttle=True) score = docs[0][1] predictions.append(max(0, score)) labels = [1 for _ in range(len(positive_sample)) ] + [0 for _ in range(len(negative_sample))] precision, recall, thresholds = precision_recall_curve( labels, predictions) plt.figure(figsize=(10, 8)) plt.plot(recall, precision, label='Precision-Recall curve') plt.xlabel('Recall') plt.ylabel('Precision') plt.title('Precision-Recall Curve') plt.legend(loc='best') plt.grid(True) plt.savefig(os.path.join(work_dir, 'precision_recall_curve.png'), format='png') plt.close() logger.debug("Figure have been saved!") thresholds = np.append(thresholds, 1) max_precision = np.max(precision) indices_with_max_precision = np.where(precision == max_precision) optimal_recall = recall[indices_with_max_precision[0][0]] optimal_threshold = thresholds[indices_with_max_precision[0][0]] logger.debug(f"Optimal threshold with the highest recall under the highest precision is: {optimal_threshold}") logger.debug(f"The corresponding precision is: {max_precision}") logger.debug(f"The corresponding recall is: {optimal_recall}") config = configparser.ConfigParser() config.read(config_path) config['feature_database']['reject_throttle'] = str(optimal_threshold) with open(config_path, 'w') as configfile: config.write(configfile) logger.info( f'Update optimal threshold: {optimal_threshold} to {config_path}' # noqa E501 ) return optimal_threshold def query(self, question: str, ): time_1 = time.time() if question is None or len(question) < 1: return None, None, [] if len(question) > 512: logger.warning('input too long, truncate to 512') question = question[0:512] chunks = [] references = [] relative, docs = self.is_relative(sample=question) if relative: docs = self.compression_retriever.get_relevant_documents(question) for doc in docs: doc = [doc.page_content] chunks.append(doc) # chunks = [doc.page_content for doc in docs] references = [doc.metadata['source'] for doc in docs] time_2 = time.time() logger.debug('query:{} \nchunks:{} \nreferences:{} \ntimecost:{}' .format(question, chunks, references, time_2 - time_1)) return chunks, [os.path.basename(r) for r in references] else: if len(docs) > 0: references.append(docs[0][0].metadata['source']) logger.info('feature database rejected!') return chunks, references