infer.py 2.79 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import json
import argparse

from sentence_transformers import SentenceTransformer, util

parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, help='txt path')
parser.add_argument('--threshold_score', type=float, default=0.8)
parser.add_argument('--model_name_or_path', type=str, default="all-MiniLM-L6-v2")
parser.add_argument('--save_path', type=str, default='./results')
args = parser.parse_args()


def write_txt(infos, save_root_path='./results', save_name='pos'):
    """
    将信息列表写入到TXT文件中。

    :param infos: 要写入的信息列表,每个信息项应是可被json.dumps序列化的对象。
    :param save_root_path: 保存根路径,默认为 './results'。
    :param save_name: 文件名前缀,默认为 'pos'。最终文件名为 save_name + '.txt'。
    """
    # 检查并创建保存路径
    if not os.path.exists(save_root_path):
        os.makedirs(save_root_path)

    # 拼接完整的文件保存路径
    save_path = os.path.join(save_root_path, save_name+'.txt')

    # 打开文件,准备写入
    try:
        with open(save_path, 'w', encoding='utf-8') as wfile:
            # 先将所有信息项序列化为字符串,然后一次性写入,减少文件I/O操作
            lines = [json.dumps(info, ensure_ascii=False)+'\n' for info in infos]
            wfile.writelines(lines)
    except IOError as e:
        print(f"Error: Failed to write to file '{save_path}'.")
        return
    except TypeError as e:
        print("Error: Failed to serialize 'infos' to JSON. Ensure each 'info' item is JSON serializable.")
        return


if __name__ == "__main__":
    txt_path = args.data_path
    model_name_or_path = args.model_name_or_path
    threshold_score = args.threshold_score

    model = SentenceTransformer(model_name_or_path)

    neg_sentence = []
    pos_sentence = []

    with open(txt_path, 'r', encoding='utf-8') as rfile:
        for line in rfile.readlines():
            print('dealing with:', line.strip())
            json_info = json.loads(line)
            # Sentences are encoded by calling model.encode()
            label_emb = model.encode(json_info.get("labels"))
            pred_emb = model.encode(json_info.get("predict"))
            cos_sim = util.cos_sim(label_emb, pred_emb)
            json_info["score"] = cos_sim.item()
            print("Cosine-Similarity:", cos_sim.item())
            if cos_sim >= threshold_score:
                pos_sentence.append(json_info)
            else:
                neg_sentence.append(json_info)

    save_root_path = args.save_path
    # save results and score in txt
    write_txt(pos_sentence, save_root_path, 'pos')
    write_txt(neg_sentence, save_root_path, 'neg')
    print('dealing end. the acc is {}'.format(len(pos_sentence)/(len(pos_sentence)+len(neg_sentence))))