Commit c2990e29 authored by lintangsutawika's avatar lintangsutawika
Browse files

allow doc_to_target to be list so that multiple_target works

parent c86fd1a7
...@@ -766,7 +766,7 @@ class ConfigurableTask(Task): ...@@ -766,7 +766,7 @@ class ConfigurableTask(Task):
print(type(doc_to_text)) print(type(doc_to_text))
raise TypeError raise TypeError
def doc_to_target(self, doc: dict) -> Union[int, str]: def doc_to_target(self, doc: dict) -> Union[int, str, list]:
if self.prompt is not None: if self.prompt is not None:
doc_to_target = self.prompt doc_to_target = self.prompt
...@@ -785,8 +785,12 @@ class ConfigurableTask(Task): ...@@ -785,8 +785,12 @@ class ConfigurableTask(Task):
target_string = utils.apply_template(doc_to_target, doc) target_string = utils.apply_template(doc_to_target, doc)
if target_string.isdigit(): if target_string.isdigit():
return ast.literal_eval(target_string) return ast.literal_eval(target_string)
elif (target_string[0] == "[") and (target_string[-1] == "]"):
return ast.literal_eval(target_string)
else: else:
return target_string return target_string
elif type(doc_to_target) == list:
return doc_to_target
elif callable(doc_to_target): elif callable(doc_to_target):
return doc_to_target(doc) return doc_to_target(doc)
# Used when applying a Promptsource template # Used when applying a Promptsource template
...@@ -993,6 +997,7 @@ class ConfigurableTask(Task): ...@@ -993,6 +997,7 @@ class ConfigurableTask(Task):
gold = choices[gold] gold = choices[gold]
for key, result in zip(self._metric_fn_list.keys(), results): for key, result in zip(self._metric_fn_list.keys(), results):
print(result)
if self.multiple_target: if self.multiple_target:
# in the case where we have multiple targets, # in the case where we have multiple targets,
# return true if any are true # return true if any are true
...@@ -1009,20 +1014,20 @@ class ConfigurableTask(Task): ...@@ -1009,20 +1014,20 @@ class ConfigurableTask(Task):
res = res[key] res = res[key]
scores.append(res) scores.append(res)
if any(scores): if any(scores):
result = 1.0 result_score = 1.0
else: else:
result = 0.0 result_score = 0.0
else: else:
result = self._metric_fn_list[key]( result_score = self._metric_fn_list[key](
references=[gold], references=[gold],
predictions=[result], predictions=[result],
**self._metric_fn_kwargs[key], **self._metric_fn_kwargs[key],
) )
if isinstance(result, dict): if isinstance(result_score, dict):
result_dict.update(result) result_dict.update(result_score)
else: else:
result_dict[key] = result result_dict[key] = result_score
else: else:
raise ValueError( raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ", f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment