from typing import cast, List, Union, Tuple, Optional, Dict import numpy as np from collections import defaultdict import torch from tqdm import tqdm import datasets from transformers import PreTrainedTokenizerFast, BatchEncoding, DataCollatorWithPadding, XLMRobertaForMaskedLM, is_torch_npu_available from torch.utils.data import DataLoader from functools import partial from FlagEmbedding.BGE_M3 import BGEM3ForInference def _transform_func(examples: Dict[str, List], tokenizer: PreTrainedTokenizerFast, max_length: int = 8192, ) -> BatchEncoding: inputs = tokenizer(examples['text'], max_length=max_length, padding=True, return_token_type_ids=False, truncation=True, return_tensors='pt') return inputs class BGEM3FlagModel: def __init__( self, model_name_or_path: str = None, pooling_method: str = 'cls', normalize_embeddings: bool = True, use_fp16: bool = True, device: str = None ) -> None: self.model = BGEM3ForInference( model_name=model_name_or_path, normlized=normalize_embeddings, sentence_pooling_method=pooling_method, ) self.tokenizer = self.model.tokenizer if device: self.device = torch.device(device) else: if torch.cuda.is_available(): self.device = torch.device("cuda") elif torch.backends.mps.is_available(): self.device = torch.device("mps") elif is_torch_npu_available(): self.device = torch.device("npu") else: self.device = torch.device("cpu") use_fp16 = False if use_fp16: self.model.half() self.model = self.model.to(self.device) if device is None: self.num_gpus = torch.cuda.device_count() if self.num_gpus > 1: print(f"----------using {self.num_gpus}*GPUs----------") self.model.model = torch.nn.DataParallel(self.model.model) else: self.num_gpus = 1 self.model.eval() def convert_id_to_token(self, lexical_weights: List[Dict]): if isinstance(lexical_weights, dict): lexical_weights = [lexical_weights] new_lexical_weights = [] for item in lexical_weights: new_item = {} for id, weight in item.items(): token = self.tokenizer.decode([int(id)]) new_item[token] = weight new_lexical_weights.append(new_item) if len(new_lexical_weights) == 1: new_lexical_weights = new_lexical_weights[0] return new_lexical_weights def compute_lexical_matching_score(self, lexical_weights_1: Dict, lexical_weights_2: Dict): scores = 0 for token, weight in lexical_weights_1.items(): if token in lexical_weights_2: scores += weight * lexical_weights_2[token] return scores def colbert_score(self, q_reps, p_reps): q_reps, p_reps = torch.from_numpy(q_reps), torch.from_numpy(p_reps) token_scores = torch.einsum('in,jn->ij', q_reps, p_reps) scores, _ = token_scores.max(-1) scores = torch.sum(scores) / q_reps.size(0) return scores @torch.no_grad() def encode(self, sentences: Union[List[str], str], batch_size: int = 12, max_length: int = 8192, return_dense: bool = True, return_sparse: bool = False, return_colbert_vecs: bool = False) -> Dict: if self.num_gpus > 1: batch_size *= self.num_gpus self.model.eval() input_was_string = False if isinstance(sentences, str): sentences = [sentences] input_was_string = True def _process_token_weights(token_weights: np.ndarray, input_ids: list): # conver to dict result = defaultdict(int) unused_tokens = set([self.tokenizer.cls_token_id, self.tokenizer.eos_token_id, self.tokenizer.pad_token_id, self.tokenizer.unk_token_id]) # token_weights = np.ceil(token_weights * 100) for w, idx in zip(token_weights, input_ids): if idx not in unused_tokens and w > 0: idx = str(idx) # w = int(w) if w > result[idx]: result[idx] = w return result def _process_colbert_vecs(colbert_vecs: np.ndarray, attention_mask: list): # delte the vectors of padding tokens tokens_num = np.sum(attention_mask) return colbert_vecs[:tokens_num - 1] # we don't use the embedding of cls, so select tokens_num-1 all_dense_embeddings, all_lexical_weights, all_colbert_vec = [], [], [] for start_index in tqdm(range(0, len(sentences), batch_size), desc="Inference Embeddings", disable=len(sentences) < 256): sentences_batch = sentences[start_index:start_index + batch_size] batch_data = self.tokenizer( sentences_batch, padding=True, truncation=True, return_tensors='pt', max_length=max_length, ).to(self.device) output = self.model(batch_data, return_dense=return_dense, return_sparse=return_sparse, return_colbert=return_colbert_vecs) if return_dense: all_dense_embeddings.append(output['dense_vecs'].cpu().numpy()) if return_sparse: token_weights = output['sparse_vecs'].squeeze(-1) all_lexical_weights.extend(list(map(_process_token_weights, token_weights.cpu().numpy(), batch_data['input_ids'].cpu().numpy().tolist()))) if return_colbert_vecs: all_colbert_vec.extend(list(map(_process_colbert_vecs, output['colbert_vecs'].cpu().numpy(), batch_data['attention_mask'].cpu().numpy()))) if return_dense: all_dense_embeddings = np.concatenate(all_dense_embeddings, axis=0) if return_dense: if input_was_string: all_dense_embeddings = all_dense_embeddings[0] else: all_dense_embeddings = None if return_sparse: if input_was_string: all_lexical_weights = all_lexical_weights[0] else: all_lexical_weights = None if return_colbert_vecs: if input_was_string: all_colbert_vec = all_colbert_vec[0] else: all_colbert_vec = None return {"dense_vecs": all_dense_embeddings, "lexical_weights": all_lexical_weights, "colbert_vecs": all_colbert_vec} @torch.no_grad() def compute_score(self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]], batch_size: int = 256, max_query_length: int = 512, max_passage_length: int = 8192, weights_for_different_modes: List[float] = None) -> Dict[str, List[float]]: def _tokenize(texts: list, max_length: int): return self.tokenizer( texts, max_length=max_length, padding=True, return_token_type_ids=False, truncation=True, return_tensors='pt' ) if self.num_gpus > 0: batch_size *= self.num_gpus self.model.eval() if isinstance(sentence_pairs, list) and len(sentence_pairs) == 0: return [] if isinstance(sentence_pairs[0], str): one_input_pair = True sentence_pairs = [sentence_pairs] else: one_input_pair = False all_scores = { 'colbert': [], 'sparse': [], 'dense': [], 'sparse+dense': [], 'colbert+sparse+dense': [] } for start_index in tqdm(range(0, len(sentence_pairs), batch_size), desc="Compute Scores", disable=len(sentence_pairs) < 128): sentences_batch = sentence_pairs[start_index:start_index + batch_size] queries_batch = [pair[0] for pair in sentences_batch] corpus_batch = [pair[1] for pair in sentences_batch] queries_inputs = _tokenize(queries_batch, max_length=max_query_length).to(self.device) corpus_inputs = _tokenize(corpus_batch, max_length=max_passage_length).to(self.device) queries_output = self.model(queries_inputs, return_dense=True, return_sparse=True, return_colbert=True, return_sparse_embedding=True) corpus_output = self.model(corpus_inputs, return_dense=True, return_sparse=True, return_colbert=True, return_sparse_embedding=True) q_dense_vecs, q_sparse_vecs, q_colbert_vecs = queries_output['dense_vecs'], queries_output['sparse_vecs'], \ queries_output['colbert_vecs'] p_dense_vecs, p_sparse_vecs, p_colbert_vecs = corpus_output['dense_vecs'], corpus_output['sparse_vecs'], \ corpus_output['colbert_vecs'] dense_scores = self.model.dense_score(q_dense_vecs, p_dense_vecs) sparse_scores = self.model.sparse_score(q_sparse_vecs, p_sparse_vecs) colbert_scores = self.model.colbert_score(q_colbert_vecs, p_colbert_vecs, q_mask=queries_inputs['attention_mask']) if weights_for_different_modes is None: weights_for_different_modes = [1, 1., 1.] weight_sum = 3 print("default weights for dense, sparse, colbert are [1.0, 1.0, 1.0] ") else: assert len(weights_for_different_modes) == 3 weight_sum = sum(weights_for_different_modes) inx = torch.arange(0, len(sentences_batch)) dense_scores, sparse_scores, colbert_scores = dense_scores[inx, inx].float(), sparse_scores[ inx, inx].float(), colbert_scores[inx, inx].float() all_scores['colbert'].extend( colbert_scores.cpu().numpy().tolist() ) all_scores['sparse'].extend( sparse_scores.cpu().numpy().tolist() ) all_scores['dense'].extend( dense_scores.cpu().numpy().tolist() ) all_scores['sparse+dense'].extend( ((sparse_scores * weights_for_different_modes[1] + dense_scores * weights_for_different_modes[0])/(weights_for_different_modes[1]+weights_for_different_modes[0])).cpu().numpy().tolist() ) all_scores['colbert+sparse+dense'].extend( ((colbert_scores * weights_for_different_modes[2] + sparse_scores * weights_for_different_modes[1] + dense_scores * weights_for_different_modes[0])/weight_sum).cpu().numpy().tolist() ) if one_input_pair: return {k: v[0] for k, v in all_scores.items()} return all_scores