Commit 46654b3d authored by Baber's avatar Baber
Browse files

move multi_target to `exact_match`

parent d58d67b3
from __future__ import annotations
import logging
import math
import os
......@@ -198,13 +200,48 @@ def acc_mutual_info_fn(items): # This is a passthrough function
# See the License for the specific language governing permissions and
# limitations under the License.
def exact_match_hf_evaluate(
predictions,
references,
regexes_to_ignore=None,
ignore_case=False,
ignore_punctuation=False,
ignore_numbers=False,
predictions: Iterable[str] | str,
references: Iterable[str] | str,
regexes_to_ignore: list[str] | None = None,
ignore_case: bool = False,
ignore_punctuation: bool = False,
ignore_numbers: bool = False,
multi_target: bool = False,
):
"""
Compute exact match scores between predictions and references.
This function computes the exact match score by comparing predictions
and references. It supports optional preprocessing steps such as ignoring
case, punctuation, numbers, and specific regex patterns.
Note:
predictions and references can have different lengths.
numpy broadcasting rule applies
Args:
predictions (Iterable[str] | str): The predicted strings to evaluate.
references (Iterable[str] | str): The reference strings to compare against.
regexes_to_ignore (list[str], optional): A list of regex patterns to remove
from both predictions and references before comparison. Defaults to None.
ignore_case (bool, optional): If True, ignores case differences during comparison.
Defaults to False.
ignore_punctuation (bool, optional): If True, removes punctuation from strings
before comparison. Defaults to False.
ignore_numbers (bool, optional): If True, removes numeric characters from strings
before comparison. Defaults to False.
multi_target (bool, optional): If True, returns 1.0 if any prediction matches any
reference, otherwise 0.0. Defaults to False.
Returns:
dict: A dictionary containing the exact match score:
- "exact_match" (float): The mean exact match score or 1.0/0.0 if `multi_target` is True.
"""
predictions, references = list(predictions), list(references)
assert len(predictions) == len(references) if not multi_target else True, (
"predictions and references must have the same length unless `multi_target` is True"
)
if regexes_to_ignore is not None:
for s in regexes_to_ignore:
predictions = np.array([re.sub(s, "", x) for x in predictions])
......@@ -229,7 +266,11 @@ def exact_match_hf_evaluate(
score_list = predictions == references
return {"exact_match": np.mean(score_list)}
return {
"exact_match": np.mean(score_list)
if not multi_target
else float(np.any(score_list))
}
###
......@@ -241,8 +282,8 @@ def exact_match_hf_evaluate(
output_type="generate_until",
aggregation="mean",
)
def exact_match_fn(**kwargs):
return exact_match_hf_evaluate(**kwargs)
def exact_match_fn(references: list[str], predictions: list[str], **kwargs):
return exact_match_hf_evaluate(predictions, references, **kwargs)
@register_metric(
......
......@@ -1666,65 +1666,15 @@ class ConfigurableTask(Task):
# it assumes that doc_to_target returns a number.
choices = self.doc_to_choice(doc)
gold = choices[gold]
# we expect multiple_targets to be a list.
elif self.multiple_target:
gold = list(gold)
# TODO: handle this better
elif type(gold) is not type(result) and not (
"bypass" in self._metric_fn_list.keys() or isinstance(result, list)
):
# cast gold to the same type as result
gold = type(result)(gold)
for metric in self._metric_fn_list.keys():
if self.multiple_target:
# in the case where we have multiple targets,
# return true if any are true
# TODO: this may break for multipLe_target, non zero-or-1 metrics
scores = []
if not isinstance(gold, list):
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
# print(gold)
gold = [gold]
if metric == "exact_match":
result = [result for _ in range(len(gold))]
scores = self._metric_fn_list[metric](
references=gold,
predictions=result,
**self._metric_fn_kwargs[metric],
)[metric]
result_score = 1.0 if scores > 0.0 else 0.0
else:
for gold_option in gold:
try:
result_score = self._metric_fn_list[metric](
references=[gold_option],
predictions=[result],
**self._metric_fn_kwargs[metric],
)
except (
TypeError
): # TODO: this is hacky and I don't want to do it
result_score = self._metric_fn_list[metric](
[gold_option, result]
)
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
result_score = result_score[metric]
scores.append(result_score)
if any(scores):
result_score = 1.0
else:
result_score = 0.0
else:
try:
result_score = self._metric_fn_list[metric](
references=[gold],
predictions=[result],
**self._metric_fn_kwargs[metric],
)
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = self._metric_fn_list[metric]([gold, result])
try:
result_score = self._metric_fn_list[metric](
references=[gold] if not isinstance(gold, list) else gold,
predictions=[result],
**self._metric_fn_kwargs[metric],
)
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = self._metric_fn_list[metric]([gold, result])
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
# This allows for multiple metrics to be returned from the same function
......
......@@ -24,3 +24,6 @@ journal = {Transactions of the Association of Computational Linguistics}}
### Tasks
* `nq_open`
### Changelog
* 2025-07-21: Added `multi_target` to `exact_match`. Scores should not change.
task: nq_open
dataset_path: nq_open
dataset_path: google-research-datasets/nq_open
output_type: generate_until
training_split: train
validation_split: validation
description: "Answer these questions:\n\n"
doc_to_text: "Q: {{question}}?\nA:"
doc_to_target: "{{answer}}" # TODO: should be multi-target
doc_to_target: "{{answer}}"
fewshot_delimiter: "\n"
generation_kwargs:
until:
......@@ -28,5 +28,6 @@ metric_list:
ignore_punctuation: true
regexes_to_ignore:
- "\\b(?:The |the |An |A |The |a |an )"
multi_target: true
metadata:
version: 4.0
......@@ -49,3 +49,6 @@ If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
### Changelog
* 2025-07-21: Added `multi_target` to `exact_match`. Scores should not change.
task: triviaqa
dataset_path: trivia_qa
dataset_path: mandarjoshi/trivia_qa
dataset_name: rc.nocontext
output_type: generate_until
training_split: train
......@@ -27,5 +27,6 @@ metric_list:
higher_is_better: true
ignore_case: true
ignore_punctuation: true
multi_target: true
metadata:
version: 3.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment