"driver/vscode:/vscode.git/clone" did not exist on "4d70c71b695876cbd949b421e728f88c62462103"
Commit 46654b3d authored by Baber's avatar Baber
Browse files

move multi_target to `exact_match`

parent d58d67b3
from __future__ import annotations
import logging import logging
import math import math
import os import os
...@@ -198,13 +200,48 @@ def acc_mutual_info_fn(items): # This is a passthrough function ...@@ -198,13 +200,48 @@ def acc_mutual_info_fn(items): # This is a passthrough function
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
def exact_match_hf_evaluate( def exact_match_hf_evaluate(
predictions, predictions: Iterable[str] | str,
references, references: Iterable[str] | str,
regexes_to_ignore=None, regexes_to_ignore: list[str] | None = None,
ignore_case=False, ignore_case: bool = False,
ignore_punctuation=False, ignore_punctuation: bool = False,
ignore_numbers=False, ignore_numbers: bool = False,
multi_target: bool = False,
): ):
"""
Compute exact match scores between predictions and references.
This function computes the exact match score by comparing predictions
and references. It supports optional preprocessing steps such as ignoring
case, punctuation, numbers, and specific regex patterns.
Note:
predictions and references can have different lengths.
numpy broadcasting rule applies
Args:
predictions (Iterable[str] | str): The predicted strings to evaluate.
references (Iterable[str] | str): The reference strings to compare against.
regexes_to_ignore (list[str], optional): A list of regex patterns to remove
from both predictions and references before comparison. Defaults to None.
ignore_case (bool, optional): If True, ignores case differences during comparison.
Defaults to False.
ignore_punctuation (bool, optional): If True, removes punctuation from strings
before comparison. Defaults to False.
ignore_numbers (bool, optional): If True, removes numeric characters from strings
before comparison. Defaults to False.
multi_target (bool, optional): If True, returns 1.0 if any prediction matches any
reference, otherwise 0.0. Defaults to False.
Returns:
dict: A dictionary containing the exact match score:
- "exact_match" (float): The mean exact match score or 1.0/0.0 if `multi_target` is True.
"""
predictions, references = list(predictions), list(references)
assert len(predictions) == len(references) if not multi_target else True, (
"predictions and references must have the same length unless `multi_target` is True"
)
if regexes_to_ignore is not None: if regexes_to_ignore is not None:
for s in regexes_to_ignore: for s in regexes_to_ignore:
predictions = np.array([re.sub(s, "", x) for x in predictions]) predictions = np.array([re.sub(s, "", x) for x in predictions])
...@@ -229,7 +266,11 @@ def exact_match_hf_evaluate( ...@@ -229,7 +266,11 @@ def exact_match_hf_evaluate(
score_list = predictions == references score_list = predictions == references
return {"exact_match": np.mean(score_list)} return {
"exact_match": np.mean(score_list)
if not multi_target
else float(np.any(score_list))
}
### ###
...@@ -241,8 +282,8 @@ def exact_match_hf_evaluate( ...@@ -241,8 +282,8 @@ def exact_match_hf_evaluate(
output_type="generate_until", output_type="generate_until",
aggregation="mean", aggregation="mean",
) )
def exact_match_fn(**kwargs): def exact_match_fn(references: list[str], predictions: list[str], **kwargs):
return exact_match_hf_evaluate(**kwargs) return exact_match_hf_evaluate(predictions, references, **kwargs)
@register_metric( @register_metric(
......
...@@ -1666,65 +1666,15 @@ class ConfigurableTask(Task): ...@@ -1666,65 +1666,15 @@ class ConfigurableTask(Task):
# it assumes that doc_to_target returns a number. # it assumes that doc_to_target returns a number.
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
gold = choices[gold] gold = choices[gold]
# we expect multiple_targets to be a list.
elif self.multiple_target:
gold = list(gold)
# TODO: handle this better
elif type(gold) is not type(result) and not (
"bypass" in self._metric_fn_list.keys() or isinstance(result, list)
):
# cast gold to the same type as result
gold = type(result)(gold)
for metric in self._metric_fn_list.keys(): for metric in self._metric_fn_list.keys():
if self.multiple_target: try:
# in the case where we have multiple targets, result_score = self._metric_fn_list[metric](
# return true if any are true references=[gold] if not isinstance(gold, list) else gold,
# TODO: this may break for multipLe_target, non zero-or-1 metrics predictions=[result],
scores = [] **self._metric_fn_kwargs[metric],
if not isinstance(gold, list): )
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
# print(gold) result_score = self._metric_fn_list[metric]([gold, result])
gold = [gold]
if metric == "exact_match":
result = [result for _ in range(len(gold))]
scores = self._metric_fn_list[metric](
references=gold,
predictions=result,
**self._metric_fn_kwargs[metric],
)[metric]
result_score = 1.0 if scores > 0.0 else 0.0
else:
for gold_option in gold:
try:
result_score = self._metric_fn_list[metric](
references=[gold_option],
predictions=[result],
**self._metric_fn_kwargs[metric],
)
except (
TypeError
): # TODO: this is hacky and I don't want to do it
result_score = self._metric_fn_list[metric](
[gold_option, result]
)
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
result_score = result_score[metric]
scores.append(result_score)
if any(scores):
result_score = 1.0
else:
result_score = 0.0
else:
try:
result_score = self._metric_fn_list[metric](
references=[gold],
predictions=[result],
**self._metric_fn_kwargs[metric],
)
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = self._metric_fn_list[metric]([gold, result])
if isinstance(result_score, dict): if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict. # TODO: this handles the case where HF evaluate returns a dict.
# This allows for multiple metrics to be returned from the same function # This allows for multiple metrics to be returned from the same function
......
...@@ -24,3 +24,6 @@ journal = {Transactions of the Association of Computational Linguistics}} ...@@ -24,3 +24,6 @@ journal = {Transactions of the Association of Computational Linguistics}}
### Tasks ### Tasks
* `nq_open` * `nq_open`
### Changelog
* 2025-07-21: Added `multi_target` to `exact_match`. Scores should not change.
task: nq_open task: nq_open
dataset_path: nq_open dataset_path: google-research-datasets/nq_open
output_type: generate_until output_type: generate_until
training_split: train training_split: train
validation_split: validation validation_split: validation
description: "Answer these questions:\n\n" description: "Answer these questions:\n\n"
doc_to_text: "Q: {{question}}?\nA:" doc_to_text: "Q: {{question}}?\nA:"
doc_to_target: "{{answer}}" # TODO: should be multi-target doc_to_target: "{{answer}}"
fewshot_delimiter: "\n" fewshot_delimiter: "\n"
generation_kwargs: generation_kwargs:
until: until:
...@@ -28,5 +28,6 @@ metric_list: ...@@ -28,5 +28,6 @@ metric_list:
ignore_punctuation: true ignore_punctuation: true
regexes_to_ignore: regexes_to_ignore:
- "\\b(?:The |the |An |A |The |a |an )" - "\\b(?:The |the |An |A |The |a |an )"
multi_target: true
metadata: metadata:
version: 4.0 version: 4.0
...@@ -49,3 +49,6 @@ If other tasks on this dataset are already supported: ...@@ -49,3 +49,6 @@ If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted? * [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates? * [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant? * [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
### Changelog
* 2025-07-21: Added `multi_target` to `exact_match`. Scores should not change.
task: triviaqa task: triviaqa
dataset_path: trivia_qa dataset_path: mandarjoshi/trivia_qa
dataset_name: rc.nocontext dataset_name: rc.nocontext
output_type: generate_until output_type: generate_until
training_split: train training_split: train
...@@ -27,5 +27,6 @@ metric_list: ...@@ -27,5 +27,6 @@ metric_list:
higher_is_better: true higher_is_better: true
ignore_case: true ignore_case: true
ignore_punctuation: true ignore_punctuation: true
multi_target: true
metadata: metadata:
version: 3.0 version: 3.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment