import os import json import argparse from sentence_transformers import SentenceTransformer, util parser = argparse.ArgumentParser() parser.add_argument('--data_path', type=str, help='txt path') parser.add_argument('--threshold_score', type=float, default=0.8) parser.add_argument('--model_name_or_path', type=str, default="all-MiniLM-L6-v2") parser.add_argument('--save_path', type=str, default='./results') args = parser.parse_args() def write_txt(infos, save_root_path='./results', save_name='pos'): """ 将信息列表写入到TXT文件中。 :param infos: 要写入的信息列表,每个信息项应是可被json.dumps序列化的对象。 :param save_root_path: 保存根路径,默认为 './results'。 :param save_name: 文件名前缀,默认为 'pos'。最终文件名为 save_name + '.txt'。 """ # 检查并创建保存路径 if not os.path.exists(save_root_path): os.makedirs(save_root_path) # 拼接完整的文件保存路径 save_path = os.path.join(save_root_path, save_name+'.txt') # 打开文件,准备写入 try: with open(save_path, 'w', encoding='utf-8') as wfile: # 先将所有信息项序列化为字符串,然后一次性写入,减少文件I/O操作 lines = [json.dumps(info, ensure_ascii=False)+'\n' for info in infos] wfile.writelines(lines) except IOError as e: print(f"Error: Failed to write to file '{save_path}'.") return except TypeError as e: print("Error: Failed to serialize 'infos' to JSON. Ensure each 'info' item is JSON serializable.") return if __name__ == "__main__": txt_path = args.data_path model_name_or_path = args.model_name_or_path threshold_score = args.threshold_score model = SentenceTransformer(model_name_or_path) neg_sentence = [] pos_sentence = [] with open(txt_path, 'r', encoding='utf-8') as rfile: for line in rfile.readlines(): print('dealing with:', line.strip()) json_info = json.loads(line) # Sentences are encoded by calling model.encode() label_emb = model.encode(json_info.get("labels")) pred_emb = model.encode(json_info.get("predict")) cos_sim = util.cos_sim(label_emb, pred_emb) json_info["score"] = cos_sim.item() print("Cosine-Similarity:", cos_sim.item()) if cos_sim >= threshold_score: pos_sentence.append(json_info) else: neg_sentence.append(json_info) save_root_path = args.save_path # save results and score in txt write_txt(pos_sentence, save_root_path, 'pos') write_txt(neg_sentence, save_root_path, 'neg') print('dealing end. the acc is {}'.format(len(pos_sentence)/(len(pos_sentence)+len(neg_sentence))))