Commit 0f283a9c authored by Leo Gao's avatar Leo Gao
Browse files

further interface changes

parent b2460099
...@@ -14,7 +14,7 @@ try: ...@@ -14,7 +14,7 @@ try:
import janitor_util import janitor_util
JANITOR_CPP = True JANITOR_CPP = True
except Exception as e: except Exception as e:
print("WARNING: C++ module could not be loaded. Janitor running in python mode") # print("WARNING: C++ module could not be loaded. Janitor running in python mode")
JANITOR_CPP = False JANITOR_CPP = False
# Was used for testing the evaluator decoupled from the full logic below # Was used for testing the evaluator decoupled from the full logic below
...@@ -41,9 +41,11 @@ def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size): ...@@ -41,9 +41,11 @@ def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
# We cache the task+set lookups as well as the overlaps. # We cache the task+set lookups as well as the overlaps.
# #
# Currently calculating some per file ngram stats for interest, might remove before merging into main # Currently calculating some per file ngram stats for interest, might remove before merging into main
def get_train_overlap(docs_by_task_set, ngrams_path, ngrams_n_size, limit): def get_train_overlap(docs_by_task_set, ngrams_path, limit):
# return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size) # return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)
# TODO: infer ngrams_n_size from ngrams_path
janitor = Janitor() janitor = Janitor()
# Build lookup for each dataset first in case we use different task combinations later # Build lookup for each dataset first in case we use different task combinations later
......
...@@ -13,8 +13,7 @@ from lm_eval.utils import positional_deprecated ...@@ -13,8 +13,7 @@ from lm_eval.utils import positional_deprecated
def simple_evaluate(model, model_args=None, tasks=[], def simple_evaluate(model, model_args=None, tasks=[],
num_fewshot=0, batch_size=None, device=None, num_fewshot=0, batch_size=None, device=None,
no_cache=False, limit=None, bootstrap_iters=100000, no_cache=False, limit=None, bootstrap_iters=100000,
description_dict=None, decontaminate=False, description_dict=None, decontamination_ngrams_path=None):
decontaminate_ngrams_path=None, decontaminate_ngrams_n_size=None):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM] :param model: Union[str, LM]
...@@ -68,9 +67,7 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -68,9 +67,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
num_fewshot=num_fewshot, num_fewshot=num_fewshot,
limit=limit, limit=limit,
description_dict=description_dict, description_dict=description_dict,
decontaminate=decontaminate, decontamination_ngrams_path=decontamination_ngrams_path,
decontaminate_ngrams_path=decontaminate_ngrams_path,
decontaminate_ngrams_n_size=decontaminate_ngrams_n_size
) )
# add info about the model and few shot config # add info about the model and few shot config
...@@ -92,7 +89,7 @@ decontaminate_suffix = "_decontaminate" ...@@ -92,7 +89,7 @@ decontaminate_suffix = "_decontaminate"
@positional_deprecated @positional_deprecated
def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None, def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None,
decontaminate=False, decontaminate_ngrams_path=None, decontaminate_ngrams_n_size=None): decontamination_ngrams_path=None):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
:param lm: obj :param lm: obj
...@@ -120,8 +117,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -120,8 +117,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# nudge people to not specify it at all # nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
if decontaminate: decontaminate = decontamination_ngrams_path is not None
assert decontaminate_ngrams_path and decontaminate_ngrams_n_size
task_dict_items = [ task_dict_items = [
(name, task) (name, task)
...@@ -193,7 +189,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -193,7 +189,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# Compare all tasks/sets at once to ensure a single training set scan # Compare all tasks/sets at once to ensure a single training set scan
if decontaminate: if decontaminate:
print("Finding train/test overlap, please wait...") print("Finding train/test overlap, please wait...")
overlaps = lm_eval.decontamination.get_train_overlap(docs_for_decontamination, decontaminate_ngrams_path, decontaminate_ngrams_n_size, limit) overlaps = lm_eval.decontamination.get_train_overlap(docs_for_decontamination, decontamination_ngrams_path, limit)
# all responses for each (task, doc) # all responses for each (task, doc)
process_res_queue = collections.defaultdict(list) process_res_queue = collections.defaultdict(list)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment