# import importlib # import os # import sys # from datetime import datetime # from typing import List, Optional, Tuple # # import pytest # import torch # # from lm_eval.caching.cache import PATH # # # MODULE_DIR = os.path.dirname(os.path.realpath(__file__)) # # # NOTE the script this loads uses simple evaluate # # TODO potentially test both the helper script and the normal script # sys.path.append(f"{MODULE_DIR}/../scripts") # model_loader = importlib.import_module("requests_caching") # run_model_for_task_caching = model_loader.run_model_for_task_caching # # os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1" # DEFAULT_TASKS = ["lambada_openai", "sciq"] # # # @pytest.fixture(autouse=True) # def setup_and_teardown(): # # Setup # torch.use_deterministic_algorithms(False) # clear_cache() # # Yields control back to the test function # yield # # Cleanup here # # # def clear_cache(): # if os.path.exists(PATH): # cache_files = os.listdir(PATH) # for file in cache_files: # file_path = f"{PATH}/{file}" # os.unlink(file_path) # # # # leaving tasks here to allow for the option to select specific task files # def get_cache_files(tasks: Optional[List[str]] = None) -> Tuple[List[str], List[str]]: # cache_files = os.listdir(PATH) # # file_task_names = [] # # for file in cache_files: # file_without_prefix = file.split("-")[1] # file_without_prefix_and_suffix = file_without_prefix.split(".")[0] # file_task_names.extend([file_without_prefix_and_suffix]) # # return cache_files, file_task_names # # # def assert_created(tasks: List[str], file_task_names: List[str]): # tasks.sort() # file_task_names.sort() # # assert tasks == file_task_names # # # @pytest.mark.parametrize("tasks", [DEFAULT_TASKS]) # def requests_caching_true(tasks: List[str]): # run_model_for_task_caching(tasks=tasks, cache_requests="true") # # cache_files, file_task_names = get_cache_files() # print(file_task_names) # assert_created(tasks=tasks, file_task_names=file_task_names) # # # @pytest.mark.parametrize("tasks", [DEFAULT_TASKS]) # def requests_caching_refresh(tasks: List[str]): # run_model_for_task_caching(tasks=tasks, cache_requests="true") # # timestamp_before_test = datetime.now().timestamp() # # run_model_for_task_caching(tasks=tasks, cache_requests="refresh") # # cache_files, file_task_names = get_cache_files() # # for file in cache_files: # modification_time = os.path.getmtime(f"{PATH}/{file}") # assert modification_time > timestamp_before_test # # tasks.sort() # file_task_names.sort() # # assert tasks == file_task_names # # # @pytest.mark.parametrize("tasks", [DEFAULT_TASKS]) # def requests_caching_delete(tasks: List[str]): # # populate the data first, rerun this test within this test for additional confidence # # test_requests_caching_true(tasks=tasks) # # run_model_for_task_caching(tasks=tasks, cache_requests="delete") # # cache_files, file_task_names = get_cache_files() # # assert len(cache_files) == 0 # # # # useful for locally running tests through the debugger # if __name__ == "__main__": # # def run_tests(): # tests = [ # # test_requests_caching_true, # # test_requests_caching_refresh, # # test_requests_caching_delete, # ] # # Lookups of global names within a loop is inefficient, so copy to a local variable outside of the loop first # default_tasks = DEFAULT_TASKS # for test_func in tests: # clear_cache() # test_func(tasks=default_tasks) # # print("Tests pass") # # run_tests()