Commit 1d7d3de5 authored by haileyschoelkopf's avatar haileyschoelkopf Committed by lintangsutawika
Browse files

fix small TaskConfig bugs

parent fd33fb97
......@@ -98,13 +98,16 @@ class TaskConfig(dict):
if type(self.gold_alias) == str:
self.gold_alias = self.template_aliases + self.gold_alias
if self.generation_kwargs or self.output_type == "greedy_until":
if self.generation_kwargs:
assert (
self.output_type == "greedy_until"
), "passed `generation_kwargs`, but not using a generation request type!"
elif self.output_type == "greedy_until":
# ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {"do_sample": False, "temperature": 0.0}
# TODO: how to make TaskConfigs be de- and re-serializable, even when using the !function constructor?
def __getitem__(self, item):
return getattr(self, item)
......@@ -123,6 +126,9 @@ class TaskConfig(dict):
for k, v in list(cfg_dict.items()):
if v is None:
cfg_dict.pop(k)
elif isinstance(v, Callable):
# TODO: this should handle Promptsource template objects as a separate case?
cfg_dict[k] = str(v)
return cfg_dict
......@@ -877,7 +883,7 @@ class ConfigurableTask(Task):
for key, result in zip(self._metric_fn_list.keys(), results):
_dict = self._metric_fn_list[key].compute(
references=[gold], predictions=[result], **self._metric_kwargs[key]
references=[gold], predictions=[result], **self._metric_fn_kwargs[key]
)
result_dict = {**result_dict, **_dict}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment