Commit 6449ab1a authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

add use_cache arg

parent 1c0ff968
...@@ -39,7 +39,7 @@ def simple_evaluate( ...@@ -39,7 +39,7 @@ def simple_evaluate(
batch_size=None, batch_size=None,
max_batch_size=None, max_batch_size=None,
device=None, device=None,
no_cache=False, use_cache=None,
limit=None, limit=None,
bootstrap_iters=100000, bootstrap_iters=100000,
check_integrity=False, check_integrity=False,
...@@ -64,8 +64,8 @@ def simple_evaluate( ...@@ -64,8 +64,8 @@ def simple_evaluate(
Maximal batch size to try with automatic batch size detection Maximal batch size to try with automatic batch size detection
:param device: str, optional :param device: str, optional
PyTorch device (e.g. "cpu" or "cuda:0") for running models PyTorch device (e.g. "cpu" or "cuda:0") for running models
:param no_cache: bool :param use_cache: str, optional
Whether or not to cache A path to a sqlite db file for caching model responses. `None` if not caching.
:param limit: int or float, optional :param limit: int or float, optional
Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples. Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
:param bootstrap_iters: :param bootstrap_iters:
...@@ -99,6 +99,16 @@ def simple_evaluate( ...@@ -99,6 +99,16 @@ def simple_evaluate(
assert isinstance(model, lm_eval.api.model.LM) assert isinstance(model, lm_eval.api.model.LM)
lm = model lm = model
if use_cache is not None:
print(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
lm = lm_eval.api.model.CachingLM(
lm,
use_cache
# each rank receives a different cache db.
# necessary to avoid multiple writes to cache at once
+ "_rank" + str(lm.rank) + ".db",
)
task_dict = lm_eval.tasks.get_task_dict(tasks, num_fewshot=num_fewshot) task_dict = lm_eval.tasks.get_task_dict(tasks, num_fewshot=num_fewshot)
if check_integrity: if check_integrity:
...@@ -127,7 +137,7 @@ def simple_evaluate( ...@@ -127,7 +137,7 @@ def simple_evaluate(
if hasattr(lm, "batch_sizes") if hasattr(lm, "batch_sizes")
else [], else [],
"device": device, "device": device,
"no_cache": no_cache, "use_cache": use_cache,
"limit": limit, "limit": limit,
"bootstrap_iters": bootstrap_iters, "bootstrap_iters": bootstrap_iters,
} }
......
...@@ -39,7 +39,7 @@ def parse_args(): ...@@ -39,7 +39,7 @@ def parse_args():
"If <1, limit is a percentage of the total number of examples.", "If <1, limit is a percentage of the total number of examples.",
) )
parser.add_argument("--data_sampling", type=float, default=None) parser.add_argument("--data_sampling", type=float, default=None)
parser.add_argument("--no_cache", action="store_true") parser.add_argument("--use_cache", type=str, default=None)
parser.add_argument("--decontamination_ngrams_path", default=None) parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument("--check_integrity", action="store_true") parser.add_argument("--check_integrity", action="store_true")
parser.add_argument("--write_out", action="store_true", default=False) parser.add_argument("--write_out", action="store_true", default=False)
...@@ -85,7 +85,7 @@ def main(): ...@@ -85,7 +85,7 @@ def main():
batch_size=args.batch_size, batch_size=args.batch_size,
max_batch_size=args.max_batch_size, max_batch_size=args.max_batch_size,
device=args.device, device=args.device,
no_cache=args.no_cache, use_cache=args.use_cache,
limit=args.limit, limit=args.limit,
decontamination_ngrams_path=args.decontamination_ngrams_path, decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity, check_integrity=args.check_integrity,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment