"test/vscode:/vscode.git/clone" did not exist on "75e6a7cde144fdb60683fcdcdbc3f4e0a8411ef9"
Commit 98d75af0 authored by researcher2's avatar researcher2 Committed by researcher2
Browse files

Changes for PR

Remove arguments from main.py.

Add "decontamination" prefix to ngrams arguments.
parent dcef7c8f
...@@ -14,7 +14,7 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -14,7 +14,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
num_fewshot=0, batch_size=None, device=None, num_fewshot=0, batch_size=None, device=None,
no_cache=False, limit=None, bootstrap_iters=100000, no_cache=False, limit=None, bootstrap_iters=100000,
description_dict=None, decontaminate=False, description_dict=None, decontaminate=False,
ngrams_path=None, ngrams_n_size=None): decontaminate_ngrams_path=None, decontaminate_ngrams_n_size=None):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM] :param model: Union[str, LM]
...@@ -69,8 +69,8 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -69,8 +69,8 @@ def simple_evaluate(model, model_args=None, tasks=[],
limit=limit, limit=limit,
description_dict=description_dict, description_dict=description_dict,
decontaminate=decontaminate, decontaminate=decontaminate,
ngrams_path=ngrams_path, decontaminate_ngrams_path=decontaminate_ngrams_path,
ngrams_n_size=ngrams_n_size decontaminate_ngrams_n_size=decontaminate_ngrams_n_size
) )
# add info about the model and few shot config # add info about the model and few shot config
...@@ -92,7 +92,7 @@ decontaminate_suffix = "_decontaminate" ...@@ -92,7 +92,7 @@ decontaminate_suffix = "_decontaminate"
@positional_deprecated @positional_deprecated
def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None, def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None,
decontaminate=False, ngrams_path=None, ngrams_n_size=None): decontaminate=False, decontaminate_ngrams_path=None, decontaminate_ngrams_n_size=None):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
:param lm: obj :param lm: obj
...@@ -121,7 +121,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -121,7 +121,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
if decontaminate: if decontaminate:
assert ngrams_path and ngrams_n_size assert decontaminate_ngrams_path and decontaminate_ngrams_n_size
task_dict_items = [ task_dict_items = [
(name, task) (name, task)
...@@ -193,7 +193,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -193,7 +193,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# Compare all tasks/sets at once to ensure a single training set scan # Compare all tasks/sets at once to ensure a single training set scan
if decontaminate: if decontaminate:
print("Finding train/test overlap, please wait...") print("Finding train/test overlap, please wait...")
overlaps = get_train_overlap(docs_for_decontamination, ngrams_path, ngrams_n_size, limit) overlaps = get_train_overlap(docs_for_decontamination, decontaminate_ngrams_path, decontaminate_ngrams_n_size, limit)
# all responses for each (task, doc) # all responses for each (task, doc)
process_res_queue = collections.defaultdict(list) process_res_queue = collections.defaultdict(list)
......
...@@ -23,16 +23,11 @@ class MultiChoice: ...@@ -23,16 +23,11 @@ class MultiChoice:
for choice in self.choices: for choice in self.choices:
yield choice yield choice
# Get task base classes for filtering
task_types = list(set([task.__bases__[0].__name__ for task in tasks.TASK_REGISTRY.values()]))
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True) parser.add_argument('--model', required=True)
parser.add_argument('--model_args', default="") parser.add_argument('--model_args', default="")
parser.add_argument('--tasks', default=None, choices=MultiChoice(tasks.ALL_TASKS)) parser.add_argument('--tasks', default=None, choices=MultiChoice(tasks.ALL_TASKS))
parser.add_argument('--task_type', default=None, choices=MultiChoice(task_types))
parser.add_argument('--exclude_tasks', default=None, choices=MultiChoice(tasks.ALL_TASKS))
parser.add_argument('--provide_description', action="store_true") parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--num_fewshot', type=int, default=0) parser.add_argument('--num_fewshot', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=None) parser.add_argument('--batch_size', type=int, default=None)
...@@ -41,8 +36,8 @@ def parse_args(): ...@@ -41,8 +36,8 @@ def parse_args():
parser.add_argument('--limit', type=int, default=None) parser.add_argument('--limit', type=int, default=None)
parser.add_argument('--no_cache', action="store_true") parser.add_argument('--no_cache', action="store_true")
parser.add_argument('--decontaminate', action="store_true") parser.add_argument('--decontaminate', action="store_true")
parser.add_argument('--ngrams_path', default=None) parser.add_argument('--decontaminate_ngrams_path', default=None)
parser.add_argument('--ngrams_n_size', type=int, default=None) parser.add_argument('--decontaminate_ngrams_n_size', type=int, default=None)
parser.add_argument('--description_dict_path', default=None) parser.add_argument('--description_dict_path', default=None)
return parser.parse_args() return parser.parse_args()
...@@ -50,10 +45,10 @@ def parse_args(): ...@@ -50,10 +45,10 @@ def parse_args():
def ensure_correct_decontamination_params(args): def ensure_correct_decontamination_params(args):
valid = True valid = True
if args.decontaminate: if args.decontaminate:
if not args.ngrams_n_size: if not args.decontaminate_ngrams_n_size:
print("Please specify n size of training set n-grams. (--ngrams_n_size)") print("Please specify n size of training set n-grams. (--ngrams_n_size)")
valid = False valid = False
if not args.ngrams_path: if not args.decontaminate_ngrams_path:
print("Please specify path containing training set n-grams. (--ngrams_path)") print("Please specify path containing training set n-grams. (--ngrams_path)")
valid = False valid = False
...@@ -78,26 +73,11 @@ def main(): ...@@ -78,26 +73,11 @@ def main():
if args.limit: if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.") print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
if args.task_type:
task_types = args.task_type.split(",")
task_names = list(dict(filter(lambda x: x[1].__bases__[0].__name__ in task_types,
tasks.TASK_REGISTRY.items())
).keys())
if args.tasks is None: if args.tasks is None:
if args.task_type is None: task_names = tasks.ALL_TASKS
task_names = tasks.ALL_TASKS
else: else:
task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS) task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
if args.exclude_tasks:
exclude_tasks = pattern_match(args.exclude_tasks.split(","), task_names)
task_names = list(filter(lambda x: x not in exclude_tasks, task_names))
if len(task_names) == 0:
print("You must have excluded the tasks you specified, exiting.")
return
print(f"Selected Tasks: {task_names}") print(f"Selected Tasks: {task_names}")
description_dict = {} description_dict = {}
...@@ -116,8 +96,8 @@ def main(): ...@@ -116,8 +96,8 @@ def main():
limit=args.limit, limit=args.limit,
description_dict=description_dict, description_dict=description_dict,
decontaminate=args.decontaminate, decontaminate=args.decontaminate,
ngrams_path=args.ngrams_path, decontaminate_ngrams_path=args.decontaminate_ngrams_path,
ngrams_n_size=args.ngrams_n_size decontaminate_ngrams_n_size=args.decontaminate_ngrams_n_size
) )
dumped = json.dumps(results, indent=2) dumped = json.dumps(results, indent=2)
......
...@@ -107,18 +107,18 @@ def get_train_overlap(docs_by_task_set, ngrams_path, ngrams_n_size, limit): ...@@ -107,18 +107,18 @@ def get_train_overlap(docs_by_task_set, ngrams_path, ngrams_n_size, limit):
non_matching_unique = 0 non_matching_unique = 0
current_ngram = "" current_ngram = ""
for line in reader.read_tqdm(): for line in reader.read_tqdm(): # Scan training set ngrams file
total_ngrams += 1 total_ngrams += 1
[ngram, document_id] = line.rsplit(" ", 1) [ngram, document_id] = line.rsplit(" ", 1)
if ngram != current_ngram: if ngram != current_ngram: # Only need to match the ngram once in training set
unique_ngrams += 1 unique_ngrams += 1
current_ngram = ngram current_ngram = ngram
if ngram in merged_lookup: if ngram in merged_lookup:
matched_ngrams.append(ngram) matched_ngrams.append(ngram) # For logging
matching_unique += 1 matching_unique += 1
for task_name, task_set, doc_ids in merged_lookup[ngram]: for task_name, task_set, doc_ids in merged_lookup[ngram]:
task_doc_set = duplicates[(task_name, task_set)] task_doc_set = duplicates[(task_name, task_set)]
for doc_id in doc_ids: for doc_id in doc_ids: # Record contamination across all relevant task/set combos
task_doc_set.add(doc_id) task_doc_set.add(doc_id)
del merged_lookup[ngram] # No point matching again del merged_lookup[ngram] # No point matching again
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment