Commit c2990e29 authored by lintangsutawika's avatar lintangsutawika
Browse files

allow doc_to_target to be list so that multiple_target works

parent c86fd1a7
......@@ -766,7 +766,7 @@ class ConfigurableTask(Task):
print(type(doc_to_text))
raise TypeError
def doc_to_target(self, doc: dict) -> Union[int, str]:
def doc_to_target(self, doc: dict) -> Union[int, str, list]:
if self.prompt is not None:
doc_to_target = self.prompt
......@@ -785,8 +785,12 @@ class ConfigurableTask(Task):
target_string = utils.apply_template(doc_to_target, doc)
if target_string.isdigit():
return ast.literal_eval(target_string)
elif (target_string[0] == "[") and (target_string[-1] == "]"):
return ast.literal_eval(target_string)
else:
return target_string
elif type(doc_to_target) == list:
return doc_to_target
elif callable(doc_to_target):
return doc_to_target(doc)
# Used when applying a Promptsource template
......@@ -993,6 +997,7 @@ class ConfigurableTask(Task):
gold = choices[gold]
for key, result in zip(self._metric_fn_list.keys(), results):
print(result)
if self.multiple_target:
# in the case where we have multiple targets,
# return true if any are true
......@@ -1009,20 +1014,20 @@ class ConfigurableTask(Task):
res = res[key]
scores.append(res)
if any(scores):
result = 1.0
result_score = 1.0
else:
result = 0.0
result_score = 0.0
else:
result = self._metric_fn_list[key](
result_score = self._metric_fn_list[key](
references=[gold],
predictions=[result],
**self._metric_fn_kwargs[key],
)
if isinstance(result, dict):
result_dict.update(result)
if isinstance(result_score, dict):
result_dict.update(result_score)
else:
result_dict[key] = result
result_dict[key] = result_score
else:
raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment