utils.py 2.71 KB
Newer Older
Geun, Lim's avatar
Geun, Lim committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from typing import List

from datasets import Dataset


def get_context(doc) -> str:
    ctx = doc["paragraph"]
    q = doc["question"]
    opt = doc["choices"]
    if ctx:
        res = f"주어진 맥락을 천천히 읽고, 질문에 대한 적절한 정답을 A, B, C, D 중에 골라 알파벳 하나로 답하시오.\n\n맥락: {ctx}\n질문: {q}\n보기:\nA:{opt[0]}, B: {opt[1]}, C: {opt[2]}, D: {opt[3]}\n정답:"
    else:
        res = f"주어진 질문을 천천히 읽고, 적절한 정답을 A, B, C, D 중에 골라 알파벳 하나로 답하시오.\n\n질문: {q}\n보기:\nA:{opt[0]}, B: {opt[1]}, C: {opt[2]}, D: {opt[3]}\n정답:"

    return res


def get_target(doc) -> str:
    ans = doc["answer"]
    if "CSAT" in doc["id"]:
        return ["A", "B", "C", "D", "E"][doc["choices"].index(ans)]
    return ["A", "B", "C", "D"][doc["choices"].index(ans)]


def get_choices(doc) -> List[str]:
    if "CSAT" in doc["id"]:
        return ["A", "B", "C", "D", "E"]
    return ["A", "B", "C", "D"]


def extract_text(dataset: Dataset) -> Dataset:
    return dataset.filter(
        lambda example: "CSAT_korean_22" in example["id"]
        or (
            "CSAT_korean_23" in example["id"] and int(example["id"].split("_")[-1]) < 35
        )
        or ("TK" in example["id"] and int(example["id"].split("_")[-1]) > 4)
    )


def extract_grammar(dataset: Dataset) -> Dataset:
    return dataset.filter(
        lambda example: (
            "CSAT_korean" in example["id"]
            and (
                int(example["id"].split("_")[2]) < 21
                and int(example["id"].split("_")[3]) > 10
            )
        )
        or (
            "Kedu_1" in example["id"]
            and (
                example["id"].split("_")[1] != "16"
                or not (
                    "대화" in example["question"]
                    or "발화" in example["question"]
                    or "질의" in example["question"]
                )
            )
        )
        or ("TK" in example["id"] and int(example["id"].split("_")[-1]) < 5)
    )


def extract_function(dataset: Dataset) -> Dataset:
    return dataset.filter(
        lambda example: (
            "CSAT_korean" in example["id"]
            and (
                int(example["id"].split("_")[-1]) > 34
                or (
                    int(example["id"].split("_")[2]) < 21
                    and int(example["id"].split("_")[3]) < 11
                )
            )
        )
        or (
            "Kedu_16" in example["id"]
            and (
                "대화" in example["question"]
                or "발화" in example["question"]
                or "질의" in example["question"]
            )
        )
        or "PSE_korean" in example["id"]
    )