test_embedding.py 4.72 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import argparse

def use_hf(model_path):
    from transformers import AutoTokenizer, AutoModel
    import torch
    # Sentences we want sentence embeddings for
    sentences = ["怎么优化网卡性能及丢包问题", "按电源键启动几秒后机器自动掉电"]

    # Load model from HuggingFace Hub
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModel.from_pretrained(model_path)
    model.eval()

    # Tokenize sentences
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
    # for s2p(short query to long passage) retrieval task, add an instruction to query (not add instruction for passages)
    # encoded_input = tokenizer([instruction + q for q in queries], padding=True, truncation=True, return_tensors='pt')

    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)
        # Perform pooling. In this case, cls pooling.
        sentence_embeddings = model_output[0][:, 0]
    # normalize embeddings
    sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
    print("Sentence embeddings:", sentence_embeddings)



def use_stf(model_path):
    from sentence_transformers import SentenceTransformer
    sentences_1 = ["怎么优化网卡性能及丢包问题", "按电源键启动几秒后机器自动掉电"]
    sentences_2 = ["存储设备开机关机的顺序有哪些要求", "按电源键启动几秒自动掉电"]
    model = SentenceTransformer(model_path)
    embeddings_1 = model.encode(sentences_1, normalize_embeddings=True)
    embeddings_2 = model.encode(sentences_2, normalize_embeddings=True)
    similarity = embeddings_1 @ embeddings_2.T
    print(similarity)

    queries = ['售前咨询', '存储设备开机关机的顺序有哪些要求']
    passages = ["售前咨询需要联系您所在地平台,平台联系方式可以通过以下途径进行获取:\n<平台联系方式>", "存储设备开机关机顺序:\n开机流程:启动交换机和存储扩展柜--启动存储控制器--启动物理服务器--启动物理服务器上的业务\n关机流程:停止物理服务器上的业务--关闭物理服务器--关闭存储控制器--关闭交换机和存储扩展柜"]

    instruction = "为这个句子生成表示以用于检索相关文章:"

    model = SentenceTransformer(model_path)
    q_embeddings = model.encode([instruction+q for q in queries], normalize_embeddings=True)
    p_embeddings = model.encode(passages, normalize_embeddings=True)
    scores = q_embeddings @ p_embeddings.T
    print(scores)


def use_flagE(model_path):
    # from FlagEmbedding
    from FlagEmbedding import FlagModel
    sentences_1 = ["怎么优化网卡性能及丢包问题", "按电源键启动几秒后机器自动掉电"]
    sentences_2 = ["存储设备开机关机的顺序有哪些要求", "按电源键启动几秒自动掉电"]
    model = FlagModel(model_path,
                    query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
                    use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
    embeddings_1 = model.encode(sentences_1)
    embeddings_2 = model.encode(sentences_2)
    similarity = embeddings_1 @ embeddings_2.T
    print(similarity)

    # for s2p(short query to long passage) retrieval task, suggest to use encode_queries() which will automatically add the instruction to each query
    # corpus in retrieval task can still use encode() or encode_corpus(), since they don't need instruction
    queries = ['售前咨询', '存储设备开机关机的顺序有哪些要求']
    passages = ["售前咨询需要联系您所在地平台,平台联系方式可以通过以下途径进行获取:\n<平台联系方式>", "存储设备开机关机顺序:\n开机流程:启动交换机和存储扩展柜--启动存储控制器--启动物理服务器--启动物理服务器上的业务\n关机流程:停止物理服务器上的业务--关闭物理服务器--关闭存储控制器--关闭交换机和存储扩展柜"]
    q_embeddings = model.encode_queries(queries)
    p_embeddings = model.encode(passages)
    scores = q_embeddings @ p_embeddings.T
    print(scores)


def get_args():
    parse = argparse.ArgumentParser('Testing reranker in FlagEmbedding or Transformers.')
    parse.add_argument('--model_path', default='BAAI/bge-large-zh-v1.5')
    parse.add_argument('--use_hf', action='store_true')
    parse.add_argument('--use_stf', action='store_true')
    args = parse.parse_args()
    return args


if __name__ == '__main__':
    args = get_args()
    model_path = args.model_path

    if args.use_hf:
        use_hf(model_path)
    elif args.use_stf:
        use_stf(model_path)
    else:
        use_flagE(model_path)