Unverified Commit a8ac0446 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

ship with exact_match function already used ; don't call evaluate.load() on import (#2045)

parent 2a6acc88
import logging import logging
import math import math
import random import random
import re
import string
from collections.abc import Iterable from collections.abc import Iterable
from typing import List from typing import List
import evaluate as hf_evaluate
import numpy as np import numpy as np
import sacrebleu import sacrebleu
import sklearn.metrics import sklearn.metrics
...@@ -166,7 +167,60 @@ def acc_mutual_info_fn(items): # This is a passthrough function ...@@ -166,7 +167,60 @@ def acc_mutual_info_fn(items): # This is a passthrough function
return items return items
exact_match = hf_evaluate.load("exact_match") ### the code used in the `exact_match_hf_evaluate` function is ported from
### https://github.com/huggingface/evaluate/blob/main/metrics/exact_match/exact_match.py
### which is under the apache license.
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def exact_match_hf_evaluate(
predictions,
references,
regexes_to_ignore=None,
ignore_case=False,
ignore_punctuation=False,
ignore_numbers=False,
):
if regexes_to_ignore is not None:
for s in regexes_to_ignore:
predictions = np.array([re.sub(s, "", x) for x in predictions])
references = np.array([re.sub(s, "", x) for x in references])
else:
predictions = np.asarray(predictions)
references = np.asarray(references)
if ignore_case:
predictions = np.char.lower(predictions)
references = np.char.lower(references)
if ignore_punctuation:
repl_table = string.punctuation.maketrans("", "", string.punctuation)
predictions = np.char.translate(predictions, table=repl_table)
references = np.char.translate(references, table=repl_table)
if ignore_numbers:
repl_table = string.digits.maketrans("", "", string.digits)
predictions = np.char.translate(predictions, table=repl_table)
references = np.char.translate(references, table=repl_table)
score_list = predictions == references
return {"exact_match": np.mean(score_list)}
###
@register_metric( @register_metric(
...@@ -176,7 +230,7 @@ exact_match = hf_evaluate.load("exact_match") ...@@ -176,7 +230,7 @@ exact_match = hf_evaluate.load("exact_match")
aggregation="mean", aggregation="mean",
) )
def exact_match_fn(**kwargs): def exact_match_fn(**kwargs):
return exact_match.compute(**kwargs) return exact_match_hf_evaluate(**kwargs)
@register_metric( @register_metric(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment