""" This scripts runs the evaluation (dev & test) for the AskUbuntu dataset Usage: python eval_askubuntu.py [sbert_model_name_or_path] """ from sentence_transformers import SentenceTransformer, LoggingHandler from sentence_transformers import util, evaluation import logging import os import gzip import sys #### Just some code to print debug information to stdout logging.basicConfig( format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] ) #### /print debug information to stdout model = SentenceTransformer(sys.argv[1]) ################# Download AskUbuntu and extract training corpus ################# askubuntu_folder = "askubuntu" training_corpus = os.path.join(askubuntu_folder, "train.unsupervised.txt") ## Download the AskUbuntu dataset from https://github.com/taolei87/askubuntu for filename in ["text_tokenized.txt.gz", "dev.txt", "test.txt", "train_random.txt"]: filepath = os.path.join(askubuntu_folder, filename) if not os.path.exists(filepath): util.http_get("https://github.com/taolei87/askubuntu/raw/master/" + filename, filepath) # Read the corpus corpus = {} dev_test_ids = set() with gzip.open(os.path.join(askubuntu_folder, "text_tokenized.txt.gz"), "rt", encoding="utf8") as fIn: for line in fIn: splits = line.strip().split("\t") id = splits[0] title = splits[1] corpus[id] = title # Read dev & test dataset def read_eval_dataset(filepath): dataset = [] with open(filepath) as fIn: for line in fIn: query_id, relevant_id, candidate_ids, bm25_scores = line.strip().split("\t") if len(relevant_id) == 0: # Skip examples without relevant entries continue relevant_id = relevant_id.split(" ") candidate_ids = candidate_ids.split(" ") negative_ids = set(candidate_ids) - set(relevant_id) dataset.append( { "query": corpus[query_id], "positive": [corpus[pid] for pid in relevant_id], "negative": [corpus[pid] for pid in negative_ids], } ) dev_test_ids.add(query_id) dev_test_ids.update(candidate_ids) return dataset dev_dataset = read_eval_dataset(os.path.join(askubuntu_folder, "dev.txt")) test_dataset = read_eval_dataset(os.path.join(askubuntu_folder, "test.txt")) # Create a dev evaluator dev_evaluator = evaluation.RerankingEvaluator(dev_dataset, name="AskUbuntu dev") logging.info("Dev performance before training") dev_evaluator(model) test_evaluator = evaluation.RerankingEvaluator(test_dataset, name="AskUbuntu test") logging.info("Test performance before training") test_evaluator(model)