Unverified Commit e57e1aef authored by Jess's avatar Jess Committed by GitHub
Browse files

Merge pull request #16 from JessicaOjo/africamgsm

add squad to afrimgsm
parents ddc2674e 03ebd203
...@@ -58,6 +58,19 @@ def f1_score(items): ...@@ -58,6 +58,19 @@ def f1_score(items):
return np.max(fscore) return np.max(fscore)
@register_aggregation("squad_f1")
def squad_f1_score(items):
gold_squad, pred_squad = [], []
for index, (ref, pred) in enumerate(items):
pred_dict = {'prediction_text': str(pred), 'id': str(index)}
ref_dict = {'answers': {'answer_start': [0], 'text': [str(ref)]}, 'id': str(index)}
gold_squad.append(ref_dict)
pred_squad.append(pred_dict)
squad_metric = hf_evaluate.load("squad")
results_squad = squad_metric.compute(predictions=pred_squad, references=gold_squad)
return results_squad['f1']/100
@register_aggregation("matthews_corrcoef") @register_aggregation("matthews_corrcoef")
def matthews_corrcoef(items): def matthews_corrcoef(items):
unzipped_list = list(zip(*items)) unzipped_list = list(zip(*items))
...@@ -178,6 +191,15 @@ def exact_match_fn(**kwargs): ...@@ -178,6 +191,15 @@ def exact_match_fn(**kwargs):
return exact_match.compute(**kwargs) return exact_match.compute(**kwargs)
@register_metric(
metric="squad",
higher_is_better=True,
output_type="generate_until",
aggregation="squad_f1"
)
def squad_fn(items):
return items
@register_metric( @register_metric(
metric="perplexity", metric="perplexity",
higher_is_better=False, higher_is_better=False,
......
...@@ -1293,6 +1293,7 @@ class ConfigurableTask(Task): ...@@ -1293,6 +1293,7 @@ class ConfigurableTask(Task):
**({"acc": acc} if "acc" in use_metric else {}), **({"acc": acc} if "acc" in use_metric else {}),
**({"f1": (gold, pred)} if "f1" in use_metric else {}), **({"f1": (gold, pred)} if "f1" in use_metric else {}),
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}), **({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"squad": (gold, pred)} if "squad" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}), **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
**({"exact_match": exact_match} if "exact_match" in use_metric else {}), **({"exact_match": exact_match} if "exact_match" in use_metric else {}),
**( **(
......
...@@ -49,6 +49,48 @@ class RegexFilter(Filter): ...@@ -49,6 +49,48 @@ class RegexFilter(Filter):
return filtered_resps return filtered_resps
@register_filter("regex-numbers")
class RegexNumberFilter(Filter):
""" """
def __init__(
self,
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select=0,
fallback: str = 0,
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
self.regex_pattern = regex_pattern
self.regex = re.compile(regex_pattern)
self.group_select = group_select
self.fallback = fallback
def apply(self, resps, docs):
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
# independently (and keep them a list.)
def filter_set(inst):
filtered = []
for resp in inst:
match = self.regex.findall(resp)
if match:
match = match[self.group_select]
if isinstance(match, tuple):
match = [m for m in match if m][0]
match = match.strip().replace(',', '').replace('.', '').lower()
else:
match = self.fallback
filtered.append(match)
return filtered
filtered_resps = list(map(lambda x: filter_set(x), resps))
return filtered_resps
@register_filter("remove_whitespace") @register_filter("remove_whitespace")
class WhitespaceFilter(Filter): class WhitespaceFilter(Filter):
""" """ """ """
......
...@@ -16,16 +16,18 @@ generation_kwargs: ...@@ -16,16 +16,18 @@ generation_kwargs:
- </s> - </s>
- <|im_end|> - <|im_end|>
filter_list: filter_list:
- name: remove_whitespace
filter:
- function: remove_whitespace
- function: take_first
- filter: - filter:
- function: regex - function: regex
group_select: -1 group_select: -1
regex_pattern: (-?[$0-9.,]{2,})|(-?[0-9]+) regex_pattern: (-?[0-9.,]{2,})|(-?[0-9]+)
- function: take_first - function: take_first
name: flexible-extract name: flexible-extract
- filter:
- function: regex-numbers
group_select: -1
regex_pattern: (\d{1,10}(?:,\d{3})*(?:[.,]\d{3})?)([^\d()]*)
- function: take_first
name: flexible-extract-new
metric_list: metric_list:
- metric: exact_match - metric: exact_match
aggregation: mean aggregation: mean
...@@ -38,5 +40,11 @@ metric_list: ...@@ -38,5 +40,11 @@ metric_list:
higher_is_better: True higher_is_better: True
ignore_case: true ignore_case: true
ignore_punctuation: true ignore_punctuation: true
- metric: squad
aggregation: squad_f1
average: weighted
higher_is_better: True
ignore_case: true
ignore_punctuation: true
metadata: metadata:
version: 2.0 version: 2.0
dataset_name: eng dataset_name: ewe
include: afrimmlu_common_yaml include: afrimmlu_common_yaml
task: afrimmlu_ewe task: afrimmlu_ewe
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment