import argparse def use_hf(model_path): from transformers import AutoTokenizer, AutoModel import torch # Sentences we want sentence embeddings for sentences = ["怎么优化网卡性能及丢包问题", "按电源键启动几秒后机器自动掉电"] # Load model from HuggingFace Hub tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModel.from_pretrained(model_path) model.eval() # Tokenize sentences encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') # for s2p(short query to long passage) retrieval task, add an instruction to query (not add instruction for passages) # encoded_input = tokenizer([instruction + q for q in queries], padding=True, truncation=True, return_tensors='pt') # Compute token embeddings with torch.no_grad(): model_output = model(**encoded_input) # Perform pooling. In this case, cls pooling. sentence_embeddings = model_output[0][:, 0] # normalize embeddings sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) print("Sentence embeddings:", sentence_embeddings) def use_stf(model_path): from sentence_transformers import SentenceTransformer sentences_1 = ["怎么优化网卡性能及丢包问题", "按电源键启动几秒后机器自动掉电"] sentences_2 = ["存储设备开机关机的顺序有哪些要求", "按电源键启动几秒自动掉电"] model = SentenceTransformer(model_path) embeddings_1 = model.encode(sentences_1, normalize_embeddings=True) embeddings_2 = model.encode(sentences_2, normalize_embeddings=True) similarity = embeddings_1 @ embeddings_2.T print(similarity) queries = ['售前咨询', '存储设备开机关机的顺序有哪些要求'] passages = ["售前咨询需要联系您所在地平台,平台联系方式可以通过以下途径进行获取:\n<平台联系方式>", "存储设备开机关机顺序:\n开机流程:启动交换机和存储扩展柜--启动存储控制器--启动物理服务器--启动物理服务器上的业务\n关机流程:停止物理服务器上的业务--关闭物理服务器--关闭存储控制器--关闭交换机和存储扩展柜"] instruction = "为这个句子生成表示以用于检索相关文章:" model = SentenceTransformer(model_path) q_embeddings = model.encode([instruction+q for q in queries], normalize_embeddings=True) p_embeddings = model.encode(passages, normalize_embeddings=True) scores = q_embeddings @ p_embeddings.T print(scores) def use_flagE(model_path): # from FlagEmbedding from FlagEmbedding import FlagModel sentences_1 = ["怎么优化网卡性能及丢包问题", "按电源键启动几秒后机器自动掉电"] sentences_2 = ["存储设备开机关机的顺序有哪些要求", "按电源键启动几秒自动掉电"] model = FlagModel(model_path, query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation embeddings_1 = model.encode(sentences_1) embeddings_2 = model.encode(sentences_2) similarity = embeddings_1 @ embeddings_2.T print(similarity) # for s2p(short query to long passage) retrieval task, suggest to use encode_queries() which will automatically add the instruction to each query # corpus in retrieval task can still use encode() or encode_corpus(), since they don't need instruction queries = ['售前咨询', '存储设备开机关机的顺序有哪些要求'] passages = ["售前咨询需要联系您所在地平台,平台联系方式可以通过以下途径进行获取:\n<平台联系方式>", "存储设备开机关机顺序:\n开机流程:启动交换机和存储扩展柜--启动存储控制器--启动物理服务器--启动物理服务器上的业务\n关机流程:停止物理服务器上的业务--关闭物理服务器--关闭存储控制器--关闭交换机和存储扩展柜"] q_embeddings = model.encode_queries(queries) p_embeddings = model.encode(passages) scores = q_embeddings @ p_embeddings.T print(scores) def get_args(): parse = argparse.ArgumentParser('Testing reranker in FlagEmbedding or Transformers.') parse.add_argument('--model_path', default='BAAI/bge-large-zh-v1.5') parse.add_argument('--use_hf', action='store_true') parse.add_argument('--use_stf', action='store_true') args = parse.parse_args() return args if __name__ == '__main__': args = get_args() model_path = args.model_path if args.use_hf: use_hf(model_path) elif args.use_stf: use_stf(model_path) else: use_flagE(model_path)