Unverified Commit fa0ba222 authored by Jess's avatar Jess Committed by GitHub
Browse files

Merge pull request #9 from JessicaOjo/africamgsm

remove added metrics -afrimgsm
parents 6bb95bbe 58692fb5
......@@ -58,20 +58,6 @@ def f1_score(items):
return np.max(fscore)
@register_aggregation("squad_f1")
def squad_f1_score(items):
gold_squad, pred_squad = [], []
for index, (ref, pred) in enumerate(items):
pred_dict = {'prediction_text': str(pred), 'id': str(index)}
ref_dict = {'answers': {'answer_start': [0], 'text': str(ref)}, 'id': str(index)}
gold_squad.append(ref_dict)
pred_squad.append(pred_dict)
squad_metric = hf_evaluate.load("squad")
results_squad = squad_metric.compute(predictions=pred_squad, references=gold_squad)
return results_squad['f1']
@register_aggregation("matthews_corrcoef")
def matthews_corrcoef(items):
unzipped_list = list(zip(*items))
......@@ -192,16 +178,6 @@ def exact_match_fn(**kwargs):
return exact_match.compute(**kwargs)
@register_metric(
metric="squad",
higher_is_better=True,
output_type="generate_until",
aggregation="squad_f1"
)
def squad_fn(items):
return items
@register_metric(
metric="perplexity",
higher_is_better=False,
......
......@@ -1294,7 +1294,6 @@ class ConfigurableTask(Task):
**({"f1": (gold, pred)} if "f1" in use_metric else {}),
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
**({"squad": (gold, pred)} if "squad" in use_metric else {}),
**({"exact_match": exact_match} if "exact_match" in use_metric else {}),
**(
{"brier_score": (gold, prob_norm)}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment