import os import json import argparse from sentence_transformers import SentenceTransformer, util parser = argparse.ArgumentParser() parser.add_argument('--data_path', type=str, help='txt path') parser.add_argument('--threshold_score', type=float, default=0.8) parser.add_argument('--model_name_or_path', type=str, default="all-MiniLM-L6-v2") parser.add_argument('--save_path', type=str, default='./results') args = parser.parse_args() def write_txt(infos, save_root_path='./results',save_name='pos'): if not os.path.exists(save_root_path): os.makedirs(save_root_path) save_path = os.path.join(save_root_path, save_name+'.txt') with open(save_path, 'w', encoding='utf-8') as wfile: for info in infos: wfile.write(json.dumps(info, ensure_ascii=False)+'\n') wfile.close() if __name__ == "__main__": txt_path = args.data_path model_name_or_path = args.model_name_or_path threshold_score = args.threshold_score model = SentenceTransformer(model_name_or_path) neg_sentence = [] pos_sentence = [] with open(txt_path, 'r', encoding='utf-8') as rfile: for line in rfile.readlines(): print('dealing with:', line.strip()) json_info = json.loads(line) # Sentences are encoded by calling model.encode() label_emb = model.encode(json_info.get("labels")) pred_emb = model.encode(json_info.get("predict")) cos_sim = util.cos_sim(label_emb, pred_emb) json_info["score"] = cos_sim.item() print("Cosine-Similarity:", cos_sim.item()) if cos_sim >= threshold_score: pos_sentence.append(json_info) else: neg_sentence.append(json_info) save_root_path = args.save_path # save results and score in txt write_txt(pos_sentence, save_root_path, 'pos') write_txt(neg_sentence, save_root_path, 'neg') print('dealing end. the acc is {}'.format(len(pos_sentence)/(len(pos_sentence)+len(neg_sentence))))