test_requests_caching.py 3.63 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# import importlib
# import os
# import sys
# from datetime import datetime
# from typing import List, Optional, Tuple
#
# import pytest
# import torch
#
# from lm_eval.caching.cache import PATH
#
#
# MODULE_DIR = os.path.dirname(os.path.realpath(__file__))
#
# # NOTE the script this loads uses simple evaluate
# # TODO potentially test both the helper script and the normal script
# sys.path.append(f"{MODULE_DIR}/../scripts")
# model_loader = importlib.import_module("requests_caching")
# run_model_for_task_caching = model_loader.run_model_for_task_caching
#
# os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1"
# DEFAULT_TASKS = ["lambada_openai", "sciq"]
#
#
# @pytest.fixture(autouse=True)
# def setup_and_teardown():
#     # Setup
#     torch.use_deterministic_algorithms(False)
#     clear_cache()
#     # Yields control back to the test function
#     yield
#     # Cleanup here
#
#
# def clear_cache():
#     if os.path.exists(PATH):
#         cache_files = os.listdir(PATH)
#         for file in cache_files:
#             file_path = f"{PATH}/{file}"
#             os.unlink(file_path)
#
#
# # leaving tasks here to allow for the option to select specific task files
# def get_cache_files(tasks: Optional[List[str]] = None) -> Tuple[List[str], List[str]]:
#     cache_files = os.listdir(PATH)
#
#     file_task_names = []
#
#     for file in cache_files:
#         file_without_prefix = file.split("-")[1]
#         file_without_prefix_and_suffix = file_without_prefix.split(".")[0]
#         file_task_names.extend([file_without_prefix_and_suffix])
#
#     return cache_files, file_task_names
#
#
# def assert_created(tasks: List[str], file_task_names: List[str]):
#     tasks.sort()
#     file_task_names.sort()
#
#     assert tasks == file_task_names
#
#
# @pytest.mark.parametrize("tasks", [DEFAULT_TASKS])
# def requests_caching_true(tasks: List[str]):
#     run_model_for_task_caching(tasks=tasks, cache_requests="true")
#
#     cache_files, file_task_names = get_cache_files()
#     print(file_task_names)
#     assert_created(tasks=tasks, file_task_names=file_task_names)
#
#
# @pytest.mark.parametrize("tasks", [DEFAULT_TASKS])
# def requests_caching_refresh(tasks: List[str]):
#     run_model_for_task_caching(tasks=tasks, cache_requests="true")
#
#     timestamp_before_test = datetime.now().timestamp()
#
#     run_model_for_task_caching(tasks=tasks, cache_requests="refresh")
#
#     cache_files, file_task_names = get_cache_files()
#
#     for file in cache_files:
#         modification_time = os.path.getmtime(f"{PATH}/{file}")
#         assert modification_time > timestamp_before_test
#
#     tasks.sort()
#     file_task_names.sort()
#
#     assert tasks == file_task_names
#
#
# @pytest.mark.parametrize("tasks", [DEFAULT_TASKS])
# def requests_caching_delete(tasks: List[str]):
#     # populate the data first, rerun this test within this test for additional confidence
#     # test_requests_caching_true(tasks=tasks)
#
#     run_model_for_task_caching(tasks=tasks, cache_requests="delete")
#
#     cache_files, file_task_names = get_cache_files()
#
#     assert len(cache_files) == 0
#
#
# # useful for locally running tests through the debugger
# if __name__ == "__main__":
#
#     def run_tests():
#         tests = [
#             # test_requests_caching_true,
#             # test_requests_caching_refresh,
#             # test_requests_caching_delete,
#         ]
#         # Lookups of global names within a loop is inefficient, so copy to a local variable outside of the loop first
#         default_tasks = DEFAULT_TASKS
#         for test_func in tests:
#             clear_cache()
#             test_func(tasks=default_tasks)
#
#         print("Tests pass")
#
#     run_tests()