Commit 6449ab1a authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

add use_cache arg

parent 1c0ff968
......@@ -39,7 +39,7 @@ def simple_evaluate(
batch_size=None,
max_batch_size=None,
device=None,
no_cache=False,
use_cache=None,
limit=None,
bootstrap_iters=100000,
check_integrity=False,
......@@ -64,8 +64,8 @@ def simple_evaluate(
Maximal batch size to try with automatic batch size detection
:param device: str, optional
PyTorch device (e.g. "cpu" or "cuda:0") for running models
:param no_cache: bool
Whether or not to cache
:param use_cache: str, optional
A path to a sqlite db file for caching model responses. `None` if not caching.
:param limit: int or float, optional
Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
:param bootstrap_iters:
......@@ -99,6 +99,16 @@ def simple_evaluate(
assert isinstance(model, lm_eval.api.model.LM)
lm = model
if use_cache is not None:
print(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
lm = lm_eval.api.model.CachingLM(
lm,
use_cache
# each rank receives a different cache db.
# necessary to avoid multiple writes to cache at once
+ "_rank" + str(lm.rank) + ".db",
)
task_dict = lm_eval.tasks.get_task_dict(tasks, num_fewshot=num_fewshot)
if check_integrity:
......@@ -127,7 +137,7 @@ def simple_evaluate(
if hasattr(lm, "batch_sizes")
else [],
"device": device,
"no_cache": no_cache,
"use_cache": use_cache,
"limit": limit,
"bootstrap_iters": bootstrap_iters,
}
......
......@@ -39,7 +39,7 @@ def parse_args():
"If <1, limit is a percentage of the total number of examples.",
)
parser.add_argument("--data_sampling", type=float, default=None)
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--use_cache", type=str, default=None)
parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument("--check_integrity", action="store_true")
parser.add_argument("--write_out", action="store_true", default=False)
......@@ -85,7 +85,7 @@ def main():
batch_size=args.batch_size,
max_batch_size=args.max_batch_size,
device=args.device,
no_cache=args.no_cache,
use_cache=args.use_cache,
limit=args.limit,
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment