Commit 1d7d3de5 authored by haileyschoelkopf's avatar haileyschoelkopf Committed by lintangsutawika
Browse files

fix small TaskConfig bugs

parent fd33fb97
...@@ -98,13 +98,16 @@ class TaskConfig(dict): ...@@ -98,13 +98,16 @@ class TaskConfig(dict):
if type(self.gold_alias) == str: if type(self.gold_alias) == str:
self.gold_alias = self.template_aliases + self.gold_alias self.gold_alias = self.template_aliases + self.gold_alias
if self.generation_kwargs or self.output_type == "greedy_until": if self.generation_kwargs:
assert ( assert (
self.output_type == "greedy_until" self.output_type == "greedy_until"
), "passed `generation_kwargs`, but not using a generation request type!" ), "passed `generation_kwargs`, but not using a generation request type!"
elif self.output_type == "greedy_until":
# ensure that we greedily generate in absence of explicit arguments otherwise # ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {"do_sample": False, "temperature": 0.0} self.generation_kwargs = {"do_sample": False, "temperature": 0.0}
# TODO: how to make TaskConfigs be de- and re-serializable, even when using the !function constructor?
def __getitem__(self, item): def __getitem__(self, item):
return getattr(self, item) return getattr(self, item)
...@@ -123,6 +126,9 @@ class TaskConfig(dict): ...@@ -123,6 +126,9 @@ class TaskConfig(dict):
for k, v in list(cfg_dict.items()): for k, v in list(cfg_dict.items()):
if v is None: if v is None:
cfg_dict.pop(k) cfg_dict.pop(k)
elif isinstance(v, Callable):
# TODO: this should handle Promptsource template objects as a separate case?
cfg_dict[k] = str(v)
return cfg_dict return cfg_dict
...@@ -877,7 +883,7 @@ class ConfigurableTask(Task): ...@@ -877,7 +883,7 @@ class ConfigurableTask(Task):
for key, result in zip(self._metric_fn_list.keys(), results): for key, result in zip(self._metric_fn_list.keys(), results):
_dict = self._metric_fn_list[key].compute( _dict = self._metric_fn_list[key].compute(
references=[gold], predictions=[result], **self._metric_kwargs[key] references=[gold], predictions=[result], **self._metric_fn_kwargs[key]
) )
result_dict = {**result_dict, **_dict} result_dict = {**result_dict, **_dict}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment