import argparse def use_hf(model_path): import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer # from transformers tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForSequenceClassification.from_pretrained(model_path) model.eval() pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']] with torch.no_grad(): inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512) scores = model(**inputs, return_dict=True).logits.view(-1, ).float() print(scores) def use_flagE(model_path): # from FlagEmbedding from FlagEmbedding import FlagReranker reranker = FlagReranker(model_path, use_fp16=True) #use fp16 can speed up computing score = reranker.compute_score(['query', 'passage'], normalize=True) print(score) scores = reranker.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']], normalize=True) print(scores) def get_args(): parse = argparse.ArgumentParser('Testing reranker in FlagEmbedding or Transformers.') parse.add_argument('--model_path', default='BAAI/bge-reranker-base') parse.add_argument('--use_hf', action='store_true') args = parse.parse_args() return args if __name__ == '__main__': args = get_args() model_path = args.model_path if args.use_hf: use_hf(model_path) else: use_flagE(model_path)