utils.py 381 Bytes
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from typing import List

from lm_eval.api.metrics import exact_match_fn


def process_results_mgsm(doc, prediction):
    gold: List = doc["input_correct_responses"]
    return {
        "exact_match": int(
            exact_match_fn(
                predictions=prediction * len(gold), references=gold, ignore_case=True
            )["exact_match"]
            > 0
        )
    }