Commit 892f40a9 authored by lintangsutawika's avatar lintangsutawika
Browse files

add comments

parent d2804132
...@@ -250,6 +250,11 @@ class Task(abc.ABC): ...@@ -250,6 +250,11 @@ class Task(abc.ABC):
download_mode=download_mode, download_mode=download_mode,
) )
@property
def config(self):
"""Returns the TaskConfig associated with this class."""
return self._config
@abc.abstractmethod @abc.abstractmethod
def has_training_docs(self): def has_training_docs(self):
"""Whether the task has a training set""" """Whether the task has a training set"""
...@@ -352,7 +357,7 @@ class Task(abc.ABC): ...@@ -352,7 +357,7 @@ class Task(abc.ABC):
), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!" ), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
eval_logger.info( eval_logger.info(
f"Building contexts for task '{self._config.task}' on rank {rank}..." f"Building contexts for task '{self.config.task}' on rank {rank}..."
) )
instances = [] instances = []
...@@ -362,14 +367,14 @@ class Task(abc.ABC): ...@@ -362,14 +367,14 @@ class Task(abc.ABC):
# sample fewshot context #TODO: need to offset doc_id by rank now! # sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx = self.fewshot_context( fewshot_ctx = self.fewshot_context(
doc, doc,
self._config.num_fewshot, self.config.num_fewshot,
) )
# TODO: we should override self._config.repeats if doing greedy gen so users don't waste time+compute # TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
inst = self.construct_requests( inst = self.construct_requests(
doc=doc, doc=doc,
ctx=fewshot_ctx, ctx=fewshot_ctx,
metadata=(self._config["task"], doc_id, self._config.repeats), metadata=(self.config["task"], doc_id, self.config.repeats),
) )
if not isinstance(inst, list): if not isinstance(inst, list):
...@@ -457,9 +462,9 @@ class Task(abc.ABC): ...@@ -457,9 +462,9 @@ class Task(abc.ABC):
if num_fewshot == 0: if num_fewshot == 0:
# always prepend the (possibly empty) task description # always prepend the (possibly empty) task description
labeled_examples = self._config.description labeled_examples = self.config.description
else: else:
labeled_examples = self._config.description + self.sampler.get_context( labeled_examples = self.config.description + self.sampler.get_context(
doc, num_fewshot doc, num_fewshot
) )
...@@ -469,7 +474,7 @@ class Task(abc.ABC): ...@@ -469,7 +474,7 @@ class Task(abc.ABC):
elif type(example) == list: elif type(example) == list:
return [labeled_examples + ex for ex in example] return [labeled_examples + ex for ex in example]
elif type(example) == int: elif type(example) == int:
if self._config.doc_to_choice is not None: if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
return labeled_examples + choices[example] return labeled_examples + choices[example]
else: else:
...@@ -491,7 +496,7 @@ class Task(abc.ABC): ...@@ -491,7 +496,7 @@ class Task(abc.ABC):
""" """
# TODO: this should only return the overrides applied to a non-YAML task's configuration. # TODO: this should only return the overrides applied to a non-YAML task's configuration.
# (num_fewshot) # (num_fewshot)
return self._config.to_dict() return self.config.to_dict()
class ConfigurableTask(Task): class ConfigurableTask(Task):
...@@ -506,35 +511,35 @@ class ConfigurableTask(Task): ...@@ -506,35 +511,35 @@ class ConfigurableTask(Task):
self._config = self.CONFIG self._config = self.CONFIG
# Use new configurations if there was no preconfiguration # Use new configurations if there was no preconfiguration
if self._config is None: if self.config is None:
self._config = TaskConfig(**config) self._config = TaskConfig(**config)
# Overwrite configs # Overwrite configs
else: else:
if config is not None: if config is not None:
self._config.__dict__.update(config) self._config.__dict__.update(config)
if self._config is None: if self.config is None:
raise ValueError( raise ValueError(
"Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg" "Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
) )
if self._config.output_type is not None: if self.config.output_type is not None:
assert self._config.output_type in ALL_OUTPUT_TYPES assert self.config.output_type in ALL_OUTPUT_TYPES
self.OUTPUT_TYPE = self._config.output_type self.OUTPUT_TYPE = self.config.output_type
if self._config.dataset_path is not None: if self.config.dataset_path is not None:
self.DATASET_PATH = self._config.dataset_path self.DATASET_PATH = self.config.dataset_path
if self._config.dataset_name is not None: if self.config.dataset_name is not None:
self.DATASET_NAME = self._config.dataset_name self.DATASET_NAME = self.config.dataset_name
self._metric_fn_list = {} self._metric_fn_list = {}
self._metric_fn_kwargs = {} self._metric_fn_kwargs = {}
self._aggregation_list = {} self._aggregation_list = {}
self._higher_is_better = {} self._higher_is_better = {}
_metric_list = DEFAULT_METRIC_REGISTRY[self._config.output_type] _metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
if self._config.metric_list is None: if self.config.metric_list is None:
# TODO: handle this in TaskConfig.__post_init__ ? # TODO: handle this in TaskConfig.__post_init__ ?
for metric_name in _metric_list: for metric_name in _metric_list:
self._metric_fn_list[metric_name] = get_metric(metric_name) self._metric_fn_list[metric_name] = get_metric(metric_name)
...@@ -543,7 +548,7 @@ class ConfigurableTask(Task): ...@@ -543,7 +548,7 @@ class ConfigurableTask(Task):
) )
self._higher_is_better[metric_name] = is_higher_better(metric_name) self._higher_is_better[metric_name] = is_higher_better(metric_name)
else: else:
for metric_config in self._config.metric_list: for metric_config in self.config.metric_list:
assert "metric" in metric_config assert "metric" in metric_config
metric_name = metric_config["metric"] metric_name = metric_config["metric"]
kwargs = { kwargs = {
...@@ -552,7 +557,7 @@ class ConfigurableTask(Task): ...@@ -552,7 +557,7 @@ class ConfigurableTask(Task):
if key not in ["metric", "aggregation", "higher_is_better"] if key not in ["metric", "aggregation", "higher_is_better"]
} }
if self._config.process_results is not None: if self.config.process_results is not None:
self._metric_fn_list[metric_name] = None self._metric_fn_list[metric_name] = None
self._metric_fn_kwargs[metric_name] = {} self._metric_fn_kwargs[metric_name] = {}
elif callable(metric_name): elif callable(metric_name):
...@@ -594,13 +599,13 @@ class ConfigurableTask(Task): ...@@ -594,13 +599,13 @@ class ConfigurableTask(Task):
) )
self._higher_is_better[metric_name] = is_higher_better(metric_name) self._higher_is_better[metric_name] = is_higher_better(metric_name)
self.download(self._config.dataset_kwargs) self.download(self.config.dataset_kwargs)
self._training_docs = None self._training_docs = None
self._fewshot_docs = None self._fewshot_docs = None
if self._config.filter_list is not None: if self.config.filter_list is not None:
self._filters = [] self._filters = []
for filter_config in self._config.filter_list: for filter_config in self.config.filter_list:
for filter_pipeline in filter_config: for filter_pipeline in filter_config:
filter_name = filter_config["name"] filter_name = filter_config["name"]
filter_functions = filter_config["filter"] filter_functions = filter_config["filter"]
...@@ -615,10 +620,10 @@ class ConfigurableTask(Task): ...@@ -615,10 +620,10 @@ class ConfigurableTask(Task):
else: else:
self._filters = [build_filter_ensemble("none", [["take_first", None]])] self._filters = [build_filter_ensemble("none", [["take_first", None]])]
if self._config.use_prompt is not None: if self.config.use_prompt is not None:
eval_logger.info(f"loading prompt {self._config.use_prompt}") eval_logger.info(f"loading prompt {self.config.use_prompt}")
self.prompt = get_prompt( self.prompt = get_prompt(
self._config.use_prompt, self.DATASET_PATH, self.DATASET_NAME self.config.use_prompt, self.DATASET_PATH, self.DATASET_NAME
) )
else: else:
self.prompt = None self.prompt = None
...@@ -645,7 +650,7 @@ class ConfigurableTask(Task): ...@@ -645,7 +650,7 @@ class ConfigurableTask(Task):
test_text = self.doc_to_text(test_doc) test_text = self.doc_to_text(test_doc)
test_target = self.doc_to_target(test_doc) test_target = self.doc_to_target(test_doc)
if self._config.doc_to_choice is not None: if self.config.doc_to_choice is not None:
test_choice = self.doc_to_choice(test_doc) test_choice = self.doc_to_choice(test_doc)
if type(test_choice) is not list: if type(test_choice) is not list:
eval_logger.error("doc_to_choice must return list") eval_logger.error("doc_to_choice must return list")
...@@ -671,9 +676,9 @@ class ConfigurableTask(Task): ...@@ -671,9 +676,9 @@ class ConfigurableTask(Task):
check_choices = [test_target] check_choices = [test_target]
for choice in check_choices: for choice in check_choices:
choice_has_whitespace = True if choice.startswith(" ") or choice.endswith(" ") else False choice_has_whitespace = True if " " in choice else False
delimiter_has_whitespace = ( delimiter_has_whitespace = (
True if (self._config.target_delimiter.startswith(" ") or self._config.target_delimiter.endswith(" ")) else False True if " " in self.config.target_delimiter else False
) )
if delimiter_has_whitespace and choice_has_whitespace: if delimiter_has_whitespace and choice_has_whitespace:
...@@ -693,52 +698,52 @@ class ConfigurableTask(Task): ...@@ -693,52 +698,52 @@ class ConfigurableTask(Task):
) )
def has_training_docs(self) -> bool: def has_training_docs(self) -> bool:
if self._config.training_split is not None: if self.config.training_split is not None:
return True return True
else: else:
return False return False
def has_validation_docs(self) -> bool: def has_validation_docs(self) -> bool:
if self._config.validation_split is not None: if self.config.validation_split is not None:
return True return True
else: else:
return False return False
def has_test_docs(self) -> bool: def has_test_docs(self) -> bool:
if self._config.test_split is not None: if self.config.test_split is not None:
return True return True
else: else:
return False return False
def training_docs(self) -> datasets.Dataset: def training_docs(self) -> datasets.Dataset:
if self.has_training_docs(): if self.has_training_docs():
if self._config.process_docs is not None: if self.config.process_docs is not None:
return self._config.process_docs( return self.config.process_docs(
self.dataset[self._config.training_split] self.dataset[self.config.training_split]
) )
return self.dataset[self._config.training_split] return self.dataset[self.config.training_split]
def validation_docs(self) -> datasets.Dataset: def validation_docs(self) -> datasets.Dataset:
if self.has_validation_docs(): if self.has_validation_docs():
if self._config.process_docs is not None: if self.config.process_docs is not None:
return self._config.process_docs( return self.config.process_docs(
self.dataset[self._config.validation_split] self.dataset[self.config.validation_split]
) )
return self.dataset[self._config.validation_split] return self.dataset[self.config.validation_split]
def test_docs(self) -> datasets.Dataset: def test_docs(self) -> datasets.Dataset:
if self.has_test_docs(): if self.has_test_docs():
if self._config.process_docs is not None: if self.config.process_docs is not None:
return self._config.process_docs(self.dataset[self._config.test_split]) return self.config.process_docs(self.dataset[self.config.test_split])
return self.dataset[self._config.test_split] return self.dataset[self.config.test_split]
def fewshot_docs(self): def fewshot_docs(self):
if self._config.fewshot_split is not None: if self.config.fewshot_split is not None:
return self.dataset[self._config.fewshot_split] return self.dataset[self.config.fewshot_split]
else: else:
if self._config.num_fewshot > 0: if self.config.num_fewshot > 0:
eval_logger.warning( eval_logger.warning(
f"Task '{self._config.task}': " f"Task '{self.config.task}': "
"num_fewshot > 0 but fewshot_split is None. " "num_fewshot > 0 but fewshot_split is None. "
"using preconfigured rule." "using preconfigured rule."
) )
...@@ -754,15 +759,15 @@ class ConfigurableTask(Task): ...@@ -754,15 +759,15 @@ class ConfigurableTask(Task):
return self._instances return self._instances
def should_decontaminate(self): def should_decontaminate(self):
return self._config.should_decontaminate return self.config.should_decontaminate
def doc_to_decontamination_query(self, doc): def doc_to_decontamination_query(self, doc):
if self._config.should_decontaminate: if self.config.should_decontaminate:
if self._config.doc_to_decontamination_query in self.features: if self.config.doc_to_decontamination_query in self.features:
return doc[self._config.doc_to_decontamination_query] return doc[self.config.doc_to_decontamination_query]
else: else:
return ast.literal_eval( return ast.literal_eval(
utils.apply_template(self._config.doc_to_decontamination_query, doc) utils.apply_template(self.config.doc_to_decontamination_query, doc)
) )
def _process_doc(self, doc): def _process_doc(self, doc):
...@@ -780,13 +785,13 @@ class ConfigurableTask(Task): ...@@ -780,13 +785,13 @@ class ConfigurableTask(Task):
if self.prompt is not None: if self.prompt is not None:
doc_to_text = self.prompt doc_to_text = self.prompt
else: else:
doc_to_text = self._config.doc_to_text doc_to_text = self.config.doc_to_text
if type(doc_to_text) == int: if type(doc_to_text) == int:
return doc_to_text return doc_to_text
elif type(doc_to_text) == str: elif type(doc_to_text) == str:
if doc_to_text in self.features: if doc_to_text in self.features:
# if self._config.doc_to_choice is not None: # if self.config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_text]] # return self.doc_to_choice(doc)[doc[doc_to_text]]
# else: # else:
return doc[doc_to_text] return doc[doc_to_text]
...@@ -805,7 +810,7 @@ class ConfigurableTask(Task): ...@@ -805,7 +810,7 @@ class ConfigurableTask(Task):
return applied_prompt[0] return applied_prompt[0]
else: else:
eval_logger.warning("Applied prompt returns empty string") eval_logger.warning("Applied prompt returns empty string")
return self._config.fewshot_delimiter return self.config.fewshot_delimiter
else: else:
print(type(doc_to_text)) print(type(doc_to_text))
raise TypeError raise TypeError
...@@ -814,13 +819,13 @@ class ConfigurableTask(Task): ...@@ -814,13 +819,13 @@ class ConfigurableTask(Task):
if self.prompt is not None: if self.prompt is not None:
doc_to_target = self.prompt doc_to_target = self.prompt
else: else:
doc_to_target = self._config.doc_to_target doc_to_target = self.config.doc_to_target
if type(doc_to_target) == int: if type(doc_to_target) == int:
return doc_to_target return doc_to_target
elif type(doc_to_target) == str: elif type(doc_to_target) == str:
if doc_to_target in self.features: if doc_to_target in self.features:
# if self._config.doc_to_choice is not None: # if self.config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_target]] # return self.doc_to_choice(doc)[doc[doc_to_target]]
# else: # else:
return doc[doc_to_target] return doc[doc_to_target]
...@@ -847,17 +852,17 @@ class ConfigurableTask(Task): ...@@ -847,17 +852,17 @@ class ConfigurableTask(Task):
return applied_prompt[1] return applied_prompt[1]
else: else:
eval_logger.warning("Applied prompt returns empty string") eval_logger.warning("Applied prompt returns empty string")
return self._config.fewshot_delimiter return self.config.fewshot_delimiter
else: else:
raise TypeError raise TypeError
def doc_to_choice(self, doc: Any) -> List[str]: def doc_to_choice(self, doc: Any) -> List[str]:
if self.prompt is not None: if self.prompt is not None:
doc_to_choice = self.prompt doc_to_choice = self.prompt
elif self._config.doc_to_choice is None: elif self.config.doc_to_choice is None:
eval_logger.error("doc_to_choice was called but not set in config") eval_logger.error("doc_to_choice was called but not set in config")
else: else:
doc_to_choice = self._config.doc_to_choice doc_to_choice = self.config.doc_to_choice
if type(doc_to_choice) == str: if type(doc_to_choice) == str:
return ast.literal_eval(utils.apply_template(doc_to_choice, doc)) return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
...@@ -878,8 +883,8 @@ class ConfigurableTask(Task): ...@@ -878,8 +883,8 @@ class ConfigurableTask(Task):
# in multiple_choice tasks, this should be castable to an int corresponding to the index # in multiple_choice tasks, this should be castable to an int corresponding to the index
# within the answer choices, while doc_to_target is the string version of {{answer_choices[gold]}}. # within the answer choices, while doc_to_target is the string version of {{answer_choices[gold]}}.
if self._config.gold_alias is not None: if self.config.gold_alias is not None:
doc_to_target = self._config.gold_alias doc_to_target = self.config.gold_alias
else: else:
return self.doc_to_target(doc) return self.doc_to_target(doc)
...@@ -901,7 +906,7 @@ class ConfigurableTask(Task): ...@@ -901,7 +906,7 @@ class ConfigurableTask(Task):
arguments = (self.doc_to_target(doc),) arguments = (self.doc_to_target(doc),)
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
target_delimiter = self._config.target_delimiter target_delimiter = self.config.target_delimiter
if self.multiple_input: if self.multiple_input:
# If there are multiple inputs, choices are placed in the ctx # If there are multiple inputs, choices are placed in the ctx
cont = self.doc_to_target(doc) cont = self.doc_to_target(doc)
...@@ -943,15 +948,16 @@ class ConfigurableTask(Task): ...@@ -943,15 +948,16 @@ class ConfigurableTask(Task):
return request_list return request_list
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "greedy_until":
arguments = (ctx, self._config.generation_kwargs) arguments = (ctx, self.config.generation_kwargs)
return Instance( return Instance(
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
) )
def process_results(self, doc, results): def process_results(self, doc, results):
if callable(self._config.process_results):
return self._config.process_results(doc, results) if callable(self.config.process_results):
return self.config.process_results(doc, results)
result_dict = {} result_dict = {}
use_metric = list(self._metric_fn_list.keys()) use_metric = list(self._metric_fn_list.keys())
...@@ -1056,7 +1062,7 @@ class ConfigurableTask(Task): ...@@ -1056,7 +1062,7 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "greedy_until":
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
if self._config.doc_to_choice is not None: if self.config.doc_to_choice is not None:
# If you set doc_to_choice, # If you set doc_to_choice,
# it assumes that doc_to_target returns a number. # it assumes that doc_to_target returns a number.
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
......
...@@ -218,11 +218,11 @@ def evaluate( ...@@ -218,11 +218,11 @@ def evaluate(
# stores the amount to pad out reqs per req. type so that # stores the amount to pad out reqs per req. type so that
# number of fwd passes per distributed rank is equal # number of fwd passes per distributed rank is equal
padding_requests = collections.defaultdict(int) padding_requests = collections.defaultdict(int)
# store the hierarchy to do proper ordering
task_hierarchy = collections.defaultdict(list) task_hierarchy = collections.defaultdict(list)
# store the ordering of tasks and groups
task_order = collections.defaultdict(int) task_order = collections.defaultdict(int)
# store the aggregation for aggregating across tasks in the same group
sample_agg_fn = collections.defaultdict(dict) sample_agg_fn = collections.defaultdict(dict)
# get lists of each type of request # get lists of each type of request
...@@ -437,7 +437,7 @@ def evaluate( ...@@ -437,7 +437,7 @@ def evaluate(
task_to_group[task].append(group) task_to_group[task].append(group)
else: else:
task_to_group[task] = [group] task_to_group[task] = [group]
### Aggregate results over all datapoints ### ### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs # aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items(): for (task_name, key, metric), items in vals.items():
...@@ -459,7 +459,7 @@ def evaluate( ...@@ -459,7 +459,7 @@ def evaluate(
results[grouping][metric_key].append(task_score) results[grouping][metric_key].append(task_score)
else: else:
results[grouping][metric_key] = [task_score] results[grouping][metric_key] = [task_score]
if sample_metric_key in results[grouping]: if sample_metric_key in results[grouping]:
results[grouping][sample_metric_key] += items results[grouping][sample_metric_key] += items
else: else:
...@@ -486,36 +486,33 @@ def evaluate( ...@@ -486,36 +486,33 @@ def evaluate(
for metric in results[task_or_group].keys(): for metric in results[task_or_group].keys():
if type(results[task_or_group][metric]) == list: if type(results[task_or_group][metric]) == list:
if "(sample agg)" in metric: if "(sample agg)" in metric:
results[task_or_group][metric] = sample_agg_fn[task_or_group][metric](results[task_or_group][metric]) results[task_or_group][metric] = sample_agg_fn[
task_or_group
][metric](results[task_or_group][metric])
else: else:
results[task_or_group][metric] = np.average(results[task_or_group][metric]) results[task_or_group][metric] = np.average(
results[task_or_group][metric]
)
versions[task_or_group] = "N/A" versions[task_or_group] = "N/A"
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
if type(task) == tuple: if type(task) == tuple:
group_name, task = task group_name, task = task
order = task_order[group_name] order = task_order[group_name]
tabbed_name = "-"*order+group_name tabbed_name = "-" * order + group_name
results_agg[tabbed_name] = results[group_name] results_agg[tabbed_name] = results[group_name]
versions[tabbed_name] = versions[group_name] versions[tabbed_name] = versions[group_name]
if order == 0: if order == 0:
groups_agg[group_name] = results[group_name] groups_agg[group_name] = results[group_name]
order = task_order[task_name] order = task_order[task_name]
tabbed_name = "-"*order+task_name tabbed_name = "-" * order + task_name
results_agg[tabbed_name] = results[task_name] results_agg[tabbed_name] = results[task_name]
versions[tabbed_name] = versions[task_name] versions[tabbed_name] = versions[task_name]
results_dict = { results_dict = {
"results": dict(results_agg.items()), "results": dict(results_agg.items()),
**( **({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}),
{
"groups": dict(groups_agg.items())
}
if bool(groups_agg)
else {}
),
"configs": dict(sorted(configs.items())), "configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())), "versions": dict(sorted(versions.items())),
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment