Commit 98d75af0 authored by researcher2's avatar researcher2 Committed by researcher2
Browse files

Changes for PR

Remove arguments from main.py.

Add "decontamination" prefix to ngrams arguments.
parent dcef7c8f
......@@ -14,7 +14,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
num_fewshot=0, batch_size=None, device=None,
no_cache=False, limit=None, bootstrap_iters=100000,
description_dict=None, decontaminate=False,
ngrams_path=None, ngrams_n_size=None):
decontaminate_ngrams_path=None, decontaminate_ngrams_n_size=None):
"""Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM]
......@@ -69,8 +69,8 @@ def simple_evaluate(model, model_args=None, tasks=[],
limit=limit,
description_dict=description_dict,
decontaminate=decontaminate,
ngrams_path=ngrams_path,
ngrams_n_size=ngrams_n_size
decontaminate_ngrams_path=decontaminate_ngrams_path,
decontaminate_ngrams_n_size=decontaminate_ngrams_n_size
)
# add info about the model and few shot config
......@@ -92,7 +92,7 @@ decontaminate_suffix = "_decontaminate"
@positional_deprecated
def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None,
decontaminate=False, ngrams_path=None, ngrams_n_size=None):
decontaminate=False, decontaminate_ngrams_path=None, decontaminate_ngrams_n_size=None):
"""Instantiate and evaluate a model on a list of tasks.
:param lm: obj
......@@ -121,7 +121,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
if decontaminate:
assert ngrams_path and ngrams_n_size
assert decontaminate_ngrams_path and decontaminate_ngrams_n_size
task_dict_items = [
(name, task)
......@@ -193,7 +193,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# Compare all tasks/sets at once to ensure a single training set scan
if decontaminate:
print("Finding train/test overlap, please wait...")
overlaps = get_train_overlap(docs_for_decontamination, ngrams_path, ngrams_n_size, limit)
overlaps = get_train_overlap(docs_for_decontamination, decontaminate_ngrams_path, decontaminate_ngrams_n_size, limit)
# all responses for each (task, doc)
process_res_queue = collections.defaultdict(list)
......
......@@ -23,16 +23,11 @@ class MultiChoice:
for choice in self.choices:
yield choice
# Get task base classes for filtering
task_types = list(set([task.__bases__[0].__name__ for task in tasks.TASK_REGISTRY.values()]))
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True)
parser.add_argument('--model_args', default="")
parser.add_argument('--tasks', default=None, choices=MultiChoice(tasks.ALL_TASKS))
parser.add_argument('--task_type', default=None, choices=MultiChoice(task_types))
parser.add_argument('--exclude_tasks', default=None, choices=MultiChoice(tasks.ALL_TASKS))
parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--num_fewshot', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=None)
......@@ -41,8 +36,8 @@ def parse_args():
parser.add_argument('--limit', type=int, default=None)
parser.add_argument('--no_cache', action="store_true")
parser.add_argument('--decontaminate', action="store_true")
parser.add_argument('--ngrams_path', default=None)
parser.add_argument('--ngrams_n_size', type=int, default=None)
parser.add_argument('--decontaminate_ngrams_path', default=None)
parser.add_argument('--decontaminate_ngrams_n_size', type=int, default=None)
parser.add_argument('--description_dict_path', default=None)
return parser.parse_args()
......@@ -50,10 +45,10 @@ def parse_args():
def ensure_correct_decontamination_params(args):
valid = True
if args.decontaminate:
if not args.ngrams_n_size:
if not args.decontaminate_ngrams_n_size:
print("Please specify n size of training set n-grams. (--ngrams_n_size)")
valid = False
if not args.ngrams_path:
if not args.decontaminate_ngrams_path:
print("Please specify path containing training set n-grams. (--ngrams_path)")
valid = False
......@@ -78,26 +73,11 @@ def main():
if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
if args.task_type:
task_types = args.task_type.split(",")
task_names = list(dict(filter(lambda x: x[1].__bases__[0].__name__ in task_types,
tasks.TASK_REGISTRY.items())
).keys())
if args.tasks is None:
if args.task_type is None:
task_names = tasks.ALL_TASKS
task_names = tasks.ALL_TASKS
else:
task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
if args.exclude_tasks:
exclude_tasks = pattern_match(args.exclude_tasks.split(","), task_names)
task_names = list(filter(lambda x: x not in exclude_tasks, task_names))
if len(task_names) == 0:
print("You must have excluded the tasks you specified, exiting.")
return
print(f"Selected Tasks: {task_names}")
description_dict = {}
......@@ -116,8 +96,8 @@ def main():
limit=args.limit,
description_dict=description_dict,
decontaminate=args.decontaminate,
ngrams_path=args.ngrams_path,
ngrams_n_size=args.ngrams_n_size
decontaminate_ngrams_path=args.decontaminate_ngrams_path,
decontaminate_ngrams_n_size=args.decontaminate_ngrams_n_size
)
dumped = json.dumps(results, indent=2)
......
......@@ -107,18 +107,18 @@ def get_train_overlap(docs_by_task_set, ngrams_path, ngrams_n_size, limit):
non_matching_unique = 0
current_ngram = ""
for line in reader.read_tqdm():
for line in reader.read_tqdm(): # Scan training set ngrams file
total_ngrams += 1
[ngram, document_id] = line.rsplit(" ", 1)
if ngram != current_ngram:
if ngram != current_ngram: # Only need to match the ngram once in training set
unique_ngrams += 1
current_ngram = ngram
if ngram in merged_lookup:
matched_ngrams.append(ngram)
matched_ngrams.append(ngram) # For logging
matching_unique += 1
for task_name, task_set, doc_ids in merged_lookup[ngram]:
task_doc_set = duplicates[(task_name, task_set)]
for doc_id in doc_ids:
for doc_id in doc_ids: # Record contamination across all relevant task/set combos
task_doc_set.add(doc_id)
del merged_lookup[ngram] # No point matching again
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment