Commit 4bba7abb authored by lintangsutawika's avatar lintangsutawika
Browse files

no need to iterate over results for greedy_until

parent a38a0450
...@@ -1017,37 +1017,37 @@ class ConfigurableTask(Task): ...@@ -1017,37 +1017,37 @@ class ConfigurableTask(Task):
else: else:
gold = str(gold) gold = str(gold)
result, _ = results
for metric in self._metric_fn_list.keys(): for metric in self._metric_fn_list.keys():
for result in results: if self.multiple_target:
if self.multiple_target: # in the case where we have multiple targets,
# in the case where we have multiple targets, # return true if any are true
# return true if any are true # TODO: this may break for multipLe_target, non zero-or-1 metrics
# TODO: this may break for multipLe_target, non zero-or-1 metrics scores = []
scores = [] for gold_option in gold:
for gold_option in gold: res = self._metric_fn_list[metric](
res = self._metric_fn_list[metric]( references=[gold_option],
references=[gold_option],
predictions=[result],
**self._metric_fn_kwargs[metric],
)
if isinstance(res, dict):
# TODO: this handles the case where HF evaluate returns a dict.
res = res[metric]
scores.append(res)
if any(scores):
result_score = 1.0
else:
result_score = 0.0
else:
result_score = self._metric_fn_list[metric](
references=[gold],
predictions=[result], predictions=[result],
**self._metric_fn_kwargs[metric], **self._metric_fn_kwargs[metric],
) )
if isinstance(result_score, dict): if isinstance(res, dict):
# TODO: this handles the case where HF evaluate returns a dict. # TODO: this handles the case where HF evaluate returns a dict.
result_score = result_score[metric] res = res[metric]
result_dict[metric] = result_score scores.append(res)
if any(scores):
result_score = 1.0
else:
result_score = 0.0
else:
result_score = self._metric_fn_list[metric](
references=[gold],
predictions=[result],
**self._metric_fn_kwargs[metric],
)
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
result_score = result_score[metric]
result_dict[metric] = result_score
else: else:
raise ValueError( raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ", f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment