infer.py 1.99 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import json
import argparse

from sentence_transformers import SentenceTransformer, util

parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, help='txt path')
parser.add_argument('--threshold_score', type=float, default=0.8)
parser.add_argument('--model_name_or_path', type=str, default="all-MiniLM-L6-v2")
parser.add_argument('--save_path', type=str, default='./results')
args = parser.parse_args()


def write_txt(infos, save_root_path='./results',save_name='pos'):
    if not os.path.exists(save_root_path):
        os.makedirs(save_root_path)

    save_path = os.path.join(save_root_path, save_name+'.txt')

    with open(save_path, 'w', encoding='utf-8') as wfile:
        for info in infos:
            wfile.write(json.dumps(info, ensure_ascii=False)+'\n')
    wfile.close()


if __name__ == "__main__":
    txt_path = args.data_path
    model_name_or_path = args.model_name_or_path
    threshold_score = args.threshold_score

    model = SentenceTransformer(model_name_or_path)

    neg_sentence = []
    pos_sentence = []

    with open(txt_path, 'r', encoding='utf-8') as rfile:
        for line in rfile.readlines():
            print('dealing with:', line.strip())
            json_info = json.loads(line)
            # Sentences are encoded by calling model.encode()
            label_emb = model.encode(json_info.get("labels"))
            pred_emb = model.encode(json_info.get("predict"))
            cos_sim = util.cos_sim(label_emb, pred_emb)
            json_info["score"] = cos_sim.item()
            print("Cosine-Similarity:", cos_sim.item())
            if cos_sim >= threshold_score:
                pos_sentence.append(json_info)
            else:
                neg_sentence.append(json_info)

    save_root_path = args.save_path
    # save results and score in txt
    write_txt(pos_sentence, save_root_path, 'pos')
    write_txt(neg_sentence, save_root_path, 'neg')
    print('dealing end. the acc is {}'.format(len(pos_sentence)/(len(pos_sentence)+len(neg_sentence))))