utils.py 5.46 KB
Newer Older
Baber's avatar
Baber committed
1
2
from typing import List

Baber's avatar
Baber committed
3
4
import datasets

Baber's avatar
Baber committed
5
6
7
from lm_eval.api.metrics import exact_match_fn


Baber's avatar
Baber committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
PROMPTS = [
    {
        "rep": 'Solve this math problem. Give the reasoning steps before giving the final answer on the last line by itself in the format of "Answer:". Do not add anything other than the integer answer after "Answer:".',
        "subtask_name": "en",
    },
    {
        "rep": 'Решите эту математическую задачу. Объясните шаги рассуждения перед тем, как дать окончательный ответ в последней строке сам по себе в формате "Ответ:". Не добавляйте ничего, кроме целочисленного ответа после "Ответ:".',
        "subtask_name": "ru",
    },
    {
        "rep": 'Suluhisha tatizo hili la hesabu. Toa hatua za mantiki kabla ya kutoa jibu la mwisho kwenye mstari wa mwisho peke yake katika muundo wa "Jibu:". Usiongeze chochote kingine isipokuwa jibu la integer baada ya "Jibu:".',
        "subtask_name": "sw",
    },
    {
        "rep": 'Résolvez ce problème de mathématiques. Donnez les étapes de raisonnement avant de fournir la réponse finale sur la dernière ligne elle-même dans le format de "Réponse:". N\'ajoutez rien d\'autre que la réponse entière après "Réponse:".',
        "subtask_name": "fr",
    },
    {
        "rep": "ఈ గణిత సమస్యను పరిష్కరించండి. చివరి సమాధానాన్ని ఇవ్వదానికి ముందు తర్కాత్మక అదుగులను ఇవ్వండి. చివరి పంక్తిలో మాత్రమే 'సమాధానం:' అనే ఆకారంలో చివరి సమాధానాద్ని ఇవ్వండి సమాధానం: తర్వాత పూర్ణాంక సమాధానానికి తప్పించి ఎదేనా చేర్చవద్దు.",
        "subtask_name": "te",
    },
    {
        "rep": 'แก้ปัญหาคณิตศาสตร์นี้ ให้ให้ขั้นตอนการใช้เหตุผลก่อนที่จะให้คำตอบสุดท้ายในบรรทัดสุดท้ายโดยอยู่ในรูปแบบ "คำตอบ:" ไม่ควรเพิ่มอะไรนอกจากคำตอบที่เป็นจำนวนเต็มหลังจาก "คำตอบ:',
        "subtask_name": "th",
    },
    {
        "rep": 'の数学の問題を解いてください。最終的な答えを出す前に、解答の推論過程を記述してください。そして最後の行には "答え:" の形式で答えを記述し、その後には整数の答え以外何も追加しないでください。',
        "subtask_name": "ja",
    },
    {
        "rep": 'Löse dieses Mathematikproblem. Gib die Schritte zur Begründung an, bevor du die endgültige Antwort in der letzten Zeile alleine im Format "Antwort:" gibst. Füge nichts anderes als die ganzzahlige Antwort nach "Antwort:" hinzu.',
        "subtask_name": "de",
    },
    {
        "rep": 'এই গণিতের সমস্যাটি সমাধান করুন। চূড়ান্ত উত্তর দেওয়ার আগে যুক্তিসম্পন্ন পদক্ষেপ প্রদান করুন। চূড়ান্ত উত্তরটি একক সংখ্যা হিসাবে "উত্তর:" এর পরে শেষ লাইনে দিন। "উত্তর:" এর পরে অন্য কিছু যুক্ত করবেন না।.',
        "subtask_name": "bn",
    },
    {
        "rep": '解决这个数学问题。在最后一行给出答案前,请提供推理步骤。最后一行应该以 "答案: " 的形式独立给出答案。在 "答案:" 后不要添加除整数答案之外的任何内容。',
        "subtask_name": "zh",
    },
    {
        "rep": 'Resuelve este problema matemático. Proporciona los pasos de razonamiento antes de dar la respuesta final en la última línea por sí misma en el formato de "Respuesta:". No añadas nada más que la respuesta entera después de "Respuesta:".',
        "subtask_name": "es",
    },
]


def number_variations(n):
    formats = []
    # Generate each pattern twice
    for _ in range(2):
        # Basic string representation
        formats.append(str(n))
        formats.append(f"{n}.")

        # With one decimal place
        formats.append(f"{n}.0")
        formats.append(f"{n}.0.")

        # With two decimal places
        formats.append(f"{n}.00")
        formats.append(f"{n}.00.")

    return formats


def process_docs(lang: str, df: datasets.Dataset) -> datasets.Dataset:
    def map_(doc: dict):
        suffix = [x for x in PROMPTS if x["subtask_name"] == lang][0]["rep"]
        doc["question"] = suffix + r"\n\n" + doc["question"].split(":", 1)[-1]
        doc["answers"] = number_variations(doc["answer_number"])
        return doc

    return df.map(map_)


Baber's avatar
Baber committed
85
86
87
88
89
def process_results_mgsm(doc, prediction):
    gold: List = doc["input_correct_responses"]
    return {
        "exact_match": int(
            exact_match_fn(
Baber's avatar
Baber committed
90
91
92
                predictions=[x.strip() for x in prediction] * len(gold),
                references=gold,
                ignore_case=True,
Baber's avatar
Baber committed
93
94
95
96
            )["exact_match"]
            > 0
        )
    }