Commit 98d75af0 authored by researcher2's avatar researcher2 Committed by researcher2
Browse files

Changes for PR

Remove arguments from main.py.

Add "decontamination" prefix to ngrams arguments.
parent dcef7c8f
...@@ -14,7 +14,7 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -14,7 +14,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
num_fewshot=0, batch_size=None, device=None, num_fewshot=0, batch_size=None, device=None,
no_cache=False, limit=None, bootstrap_iters=100000, no_cache=False, limit=None, bootstrap_iters=100000,
description_dict=None, decontaminate=False, description_dict=None, decontaminate=False,
ngrams_path=None, ngrams_n_size=None): decontaminate_ngrams_path=None, decontaminate_ngrams_n_size=None):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM] :param model: Union[str, LM]
...@@ -69,8 +69,8 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -69,8 +69,8 @@ def simple_evaluate(model, model_args=None, tasks=[],
limit=limit, limit=limit,
description_dict=description_dict, description_dict=description_dict,
decontaminate=decontaminate, decontaminate=decontaminate,
ngrams_path=ngrams_path, decontaminate_ngrams_path=decontaminate_ngrams_path,
ngrams_n_size=ngrams_n_size decontaminate_ngrams_n_size=decontaminate_ngrams_n_size
) )
# add info about the model and few shot config # add info about the model and few shot config
...@@ -92,7 +92,7 @@ decontaminate_suffix = "_decontaminate" ...@@ -92,7 +92,7 @@ decontaminate_suffix = "_decontaminate"
@positional_deprecated @positional_deprecated
def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None, def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None,
decontaminate=False, ngrams_path=None, ngrams_n_size=None): decontaminate=False, decontaminate_ngrams_path=None, decontaminate_ngrams_n_size=None):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
:param lm: obj :param lm: obj
...@@ -121,7 +121,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -121,7 +121,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
if decontaminate: if decontaminate:
assert ngrams_path and ngrams_n_size assert decontaminate_ngrams_path and decontaminate_ngrams_n_size
task_dict_items = [ task_dict_items = [
(name, task) (name, task)
...@@ -193,7 +193,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -193,7 +193,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# Compare all tasks/sets at once to ensure a single training set scan # Compare all tasks/sets at once to ensure a single training set scan
if decontaminate: if decontaminate:
print("Finding train/test overlap, please wait...") print("Finding train/test overlap, please wait...")
overlaps = get_train_overlap(docs_for_decontamination, ngrams_path, ngrams_n_size, limit) overlaps = get_train_overlap(docs_for_decontamination, decontaminate_ngrams_path, decontaminate_ngrams_n_size, limit)
# all responses for each (task, doc) # all responses for each (task, doc)
process_res_queue = collections.defaultdict(list) process_res_queue = collections.defaultdict(list)
......
...@@ -23,16 +23,11 @@ class MultiChoice: ...@@ -23,16 +23,11 @@ class MultiChoice:
for choice in self.choices: for choice in self.choices:
yield choice yield choice
# Get task base classes for filtering
task_types = list(set([task.__bases__[0].__name__ for task in tasks.TASK_REGISTRY.values()]))
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True) parser.add_argument('--model', required=True)
parser.add_argument('--model_args', default="") parser.add_argument('--model_args', default="")
parser.add_argument('--tasks', default=None, choices=MultiChoice(tasks.ALL_TASKS)) parser.add_argument('--tasks', default=None, choices=MultiChoice(tasks.ALL_TASKS))
parser.add_argument('--task_type', default=None, choices=MultiChoice(task_types))
parser.add_argument('--exclude_tasks', default=None, choices=MultiChoice(tasks.ALL_TASKS))
parser.add_argument('--provide_description', action="store_true") parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--num_fewshot', type=int, default=0) parser.add_argument('--num_fewshot', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=None) parser.add_argument('--batch_size', type=int, default=None)
...@@ -41,8 +36,8 @@ def parse_args(): ...@@ -41,8 +36,8 @@ def parse_args():
parser.add_argument('--limit', type=int, default=None) parser.add_argument('--limit', type=int, default=None)
parser.add_argument('--no_cache', action="store_true") parser.add_argument('--no_cache', action="store_true")
parser.add_argument('--decontaminate', action="store_true") parser.add_argument('--decontaminate', action="store_true")
parser.add_argument('--ngrams_path', default=None) parser.add_argument('--decontaminate_ngrams_path', default=None)
parser.add_argument('--ngrams_n_size', type=int, default=None) parser.add_argument('--decontaminate_ngrams_n_size', type=int, default=None)
parser.add_argument('--description_dict_path', default=None) parser.add_argument('--description_dict_path', default=None)
return parser.parse_args() return parser.parse_args()
...@@ -50,10 +45,10 @@ def parse_args(): ...@@ -50,10 +45,10 @@ def parse_args():
def ensure_correct_decontamination_params(args): def ensure_correct_decontamination_params(args):
valid = True valid = True
if args.decontaminate: if args.decontaminate:
if not args.ngrams_n_size: if not args.decontaminate_ngrams_n_size:
print("Please specify n size of training set n-grams. (--ngrams_n_size)") print("Please specify n size of training set n-grams. (--ngrams_n_size)")
valid = False valid = False
if not args.ngrams_path: if not args.decontaminate_ngrams_path:
print("Please specify path containing training set n-grams. (--ngrams_path)") print("Please specify path containing training set n-grams. (--ngrams_path)")
valid = False valid = False
...@@ -78,26 +73,11 @@ def main(): ...@@ -78,26 +73,11 @@ def main():
if args.limit: if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.") print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
if args.task_type:
task_types = args.task_type.split(",")
task_names = list(dict(filter(lambda x: x[1].__bases__[0].__name__ in task_types,
tasks.TASK_REGISTRY.items())
).keys())
if args.tasks is None: if args.tasks is None:
if args.task_type is None: task_names = tasks.ALL_TASKS
task_names = tasks.ALL_TASKS
else: else:
task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS) task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
if args.exclude_tasks:
exclude_tasks = pattern_match(args.exclude_tasks.split(","), task_names)
task_names = list(filter(lambda x: x not in exclude_tasks, task_names))
if len(task_names) == 0:
print("You must have excluded the tasks you specified, exiting.")
return
print(f"Selected Tasks: {task_names}") print(f"Selected Tasks: {task_names}")
description_dict = {} description_dict = {}
...@@ -116,8 +96,8 @@ def main(): ...@@ -116,8 +96,8 @@ def main():
limit=args.limit, limit=args.limit,
description_dict=description_dict, description_dict=description_dict,
decontaminate=args.decontaminate, decontaminate=args.decontaminate,
ngrams_path=args.ngrams_path, decontaminate_ngrams_path=args.decontaminate_ngrams_path,
ngrams_n_size=args.ngrams_n_size decontaminate_ngrams_n_size=args.decontaminate_ngrams_n_size
) )
dumped = json.dumps(results, indent=2) dumped = json.dumps(results, indent=2)
......
...@@ -107,18 +107,18 @@ def get_train_overlap(docs_by_task_set, ngrams_path, ngrams_n_size, limit): ...@@ -107,18 +107,18 @@ def get_train_overlap(docs_by_task_set, ngrams_path, ngrams_n_size, limit):
non_matching_unique = 0 non_matching_unique = 0
current_ngram = "" current_ngram = ""
for line in reader.read_tqdm(): for line in reader.read_tqdm(): # Scan training set ngrams file
total_ngrams += 1 total_ngrams += 1
[ngram, document_id] = line.rsplit(" ", 1) [ngram, document_id] = line.rsplit(" ", 1)
if ngram != current_ngram: if ngram != current_ngram: # Only need to match the ngram once in training set
unique_ngrams += 1 unique_ngrams += 1
current_ngram = ngram current_ngram = ngram
if ngram in merged_lookup: if ngram in merged_lookup:
matched_ngrams.append(ngram) matched_ngrams.append(ngram) # For logging
matching_unique += 1 matching_unique += 1
for task_name, task_set, doc_ids in merged_lookup[ngram]: for task_name, task_set, doc_ids in merged_lookup[ngram]:
task_doc_set = duplicates[(task_name, task_set)] task_doc_set = duplicates[(task_name, task_set)]
for doc_id in doc_ids: for doc_id in doc_ids: # Record contamination across all relevant task/set combos
task_doc_set.add(doc_id) task_doc_set.add(doc_id)
del merged_lookup[ngram] # No point matching again del merged_lookup[ngram] # No point matching again
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment