Commit 194a806d authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

add description to config, remove from cmdline args

parent 41677741
......@@ -63,10 +63,11 @@ class TaskConfig(dict):
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
template_aliases: str = None
aliases: Union[str, list] = None
doc_to_text: Union[Callable, str] = None
doc_to_target: Union[Callable, str] = None
use_prompt: str = None
delimiter: str = "\n\n"
description: str = ""
num_fewshot: int = 0
batch_size: int = 1
......@@ -76,7 +77,6 @@ class TaskConfig(dict):
gold_alias: Union[Callable, str] = None
output_type: str = "greedy_until"
generation_kwargs: dict = None
delimiter: str = "\n\n"
filter_list: Union[str, list] = None
should_decontaminate: bool = False
doc_to_decontamination_query: str = None
......@@ -433,35 +433,10 @@ class Task(abc.ABC):
), "A `random.Random` generator argument must be provided to `rnd`"
if num_fewshot == 0:
labeled_examples = ""
# always prepend the (possibly empty) task description
labeled_examples = self._config.description
else:
labeled_examples = self.sampler.get_context(doc, num_fewshot)
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
# if self.has_training_docs():
# fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
# else:
# if self._fewshot_docs is None:
# self._fewshot_docs = list(
# self.validation_docs()
# if self.has_validation_docs()
# else self.test_docs()
# )
# fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
# # get rid of the doc that's the one we're evaluating, if it's in the fewshot
# fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
# labeled_examples = (
# "\n\n".join(
# [
# self.doc_to_text(doc) + self.doc_to_target(doc)
# for doc in fewshotex
# ]
# )
# + "\n\n"
# )
labeled_examples = self._config.description + self.sampler.get_context(doc, num_fewshot)
example = self.doc_to_text(doc)
return labeled_examples + example
......
......@@ -41,7 +41,6 @@ def parse_args():
parser.add_argument("--data_sampling", type=float, default=None)
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument("--description_dict_path", default=None)
parser.add_argument("--check_integrity", action="store_true")
parser.add_argument("--write_out", action="store_true", default=False)
parser.add_argument("--output_base_path", type=str, default=None)
......@@ -78,12 +77,6 @@ def main():
eval_logger.info(f"Selected Tasks: {task_names}")
# TODO: description_dict?
# description_dict = {}
# if args.description_dict_path:
# with open(args.description_dict_path, "r") as f:
# description_dict = json.load(f)
results = evaluator.simple_evaluate(
model=args.model,
model_args=args.model_args,
......@@ -94,7 +87,6 @@ def main():
device=args.device,
no_cache=args.no_cache,
limit=args.limit,
# description_dict=description_dict,
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
write_out=args.write_out,
......
......@@ -13,12 +13,10 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--output_base_path", required=True)
parser.add_argument("--tasks", default="all_tasks")
parser.add_argument("--provide_description", action="store_true")
parser.add_argument("--sets", type=str, default="val") # example: val,test
parser.add_argument("--num_fewshot", type=int, default=1)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--num_examples", type=int, default=1)
parser.add_argument("--description_dict_path", default=None)
return parser.parse_args()
......@@ -32,11 +30,6 @@ def main():
task_names = args.tasks.split(",")
task_dict = tasks.get_task_dict(task_names)
# description_dict = {}
# if args.description_dict_path:
# with open(args.description_dict_path, "r") as f:
# description_dict = json.load(f)
os.makedirs(args.output_base_path, exist_ok=True)
for task_name, task in task_dict.items():
rnd = random.Random()
......@@ -55,12 +48,6 @@ def main():
docs = join_iters(iters)
# description = (
# description_dict[task_name]
# if description_dict and task_name in description_dict
# else ""
# )
with open(os.path.join(args.output_base_path, task_name), "w") as f:
for i, doc in (
zip(range(args.num_examples), docs)
......@@ -72,7 +59,6 @@ def main():
doc=doc,
num_fewshot=args.num_fewshot,
rnd=rnd,
# description=description,
)
f.write(ctx + "\n")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment