"vscode:/vscode.git/clone" did not exist on "35f709a210feaebfc7c1ce02c6564e290ab08c7a"
make_table_tasks.py 3.94 KB
Newer Older
jon-tow's avatar
jon-tow committed
1
2
"""
Usage:
baberabb's avatar
baberabb committed
3
   Writes csv and Markdown table to csv_file, md_file (below).
jon-tow's avatar
jon-tow committed
4
5
"""
import logging
baberabb's avatar
test  
baberabb committed
6
import os
baberabb's avatar
baberabb committed
7
from pathlib import Path
baberabb's avatar
test  
baberabb committed
8
from typing import List, Union
9

baberabb's avatar
baberabb committed
10
11
import datasets
import pandas as pd
Leo Gao's avatar
Leo Gao committed
12

13
from lm_eval import tasks
baberabb's avatar
baberabb committed
14
from lm_eval.tasks import TASK_REGISTRY
baberabb's avatar
test  
baberabb committed
15
from lm_eval.utils import load_yaml_config
16

Leo Gao's avatar
Leo Gao committed
17

jon-tow's avatar
jon-tow committed
18
19
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
baberabb's avatar
baberabb committed
20
21
datasets.disable_caching()
tasks.initialize_tasks()
Leo Gao's avatar
Leo Gao committed
22

Fabrizio Milo's avatar
Fabrizio Milo committed
23

baberabb's avatar
test  
baberabb committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def load_changed_files(file_path: str) -> List[str]:
    with open(file_path, "r") as f:
        content = f.read()
        words_list = [x for x in content.split()]
    return words_list


def parser(full_path: List[str]) -> List[str]:
    _output = set()
    for x in full_path:
        if x.endswith(".yaml"):
            _output.add(load_yaml_config(x)["task"])
        elif x.endswith(".py"):
            path = [str(x) for x in (list(Path(x).parent.glob("*.yaml")))]
            _output |= {load_yaml_config(x)["task"] for x in path}
    return list(_output)


def new_tasks() -> Union[List[str], None]:
    FILENAME = ".github/outputs/tasks_all_changed_and_modified_files.txt"
    if os.path.exists(FILENAME):
        # If tasks folder has changed then we get the list of files from FILENAME
        # and parse the yaml files to get the task names.
        return parser(load_changed_files(FILENAME))
    elif os.getenv("API") is not None:
        # Or if API has changed then we set the ENV variable API to True
        # and run  given tasks.
        return ["arc_easy", "hellaswag", "piqa", "wikitext"]
    # if both not true just do arc_easy
    else:
        return None


jon-tow's avatar
jon-tow committed
57
def check(tf):
Leo Gao's avatar
Leo Gao committed
58
    if tf:
Fabrizio Milo's avatar
Fabrizio Milo committed
59
        return "✓"
Leo Gao's avatar
Leo Gao committed
60
    else:
Fabrizio Milo's avatar
Fabrizio Milo committed
61
62
        return " "

Leo Gao's avatar
Leo Gao committed
63

baberabb's avatar
baberabb committed
64
65
66
67
68
69
70
71
72
73
74
def maketable(df):
    headers = [
        "Task Name",
        "Group",
        "Train",
        "Val",
        "Test",
        "Val/Test Docs",
        "Request Type,",
        "Metrics",
    ]
jon-tow's avatar
jon-tow committed
75
    values = []
baberabb's avatar
baberabb committed
76
77
78
79
80
81
    if not df:
        _tasks = tasks.TASK_REGISTRY.items()
        _tasks = sorted(_tasks, key=lambda x: x[0])
    else:
        task_classes = new_tasks()
        _tasks = [(x, TASK_REGISTRY.get(x)) for x in task_classes]
baberabb's avatar
test  
baberabb committed
82
    count = 0
baberabb's avatar
baberabb committed
83
    for tname, Task in _tasks:
jon-tow's avatar
jon-tow committed
84
85
86
        task = Task()
        v = [
            tname,
baberabb's avatar
baberabb committed
87
            task.config.group,
jon-tow's avatar
jon-tow committed
88
89
90
91
92
            check(task.has_training_docs()),
            check(task.has_validation_docs()),
            check(task.has_test_docs()),
            len(
                list(
baberabb's avatar
baberabb committed
93
94
95
96
97
                    task.test_docs()
                    if task.has_test_docs()
                    else task.validation_docs()
                    if task.has_validation_docs()
                    else task.training_docs()
jon-tow's avatar
jon-tow committed
98
99
                )
            ),
baberabb's avatar
baberabb committed
100
            task.config.output_type,
jon-tow's avatar
jon-tow committed
101
102
103
104
            ", ".join(task.aggregation().keys()),
        ]
        logger.info(v)
        values.append(v)
baberabb's avatar
test  
baberabb committed
105
106
107
        count += 1
        if count == 10:
            break
baberabb's avatar
baberabb committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    if not df:
        df = pd.DataFrame(values, columns=headers)
        table = df.to_markdown(index=False)
    else:
        for new_row in values:
            tname = new_row[0]
            if tname in df["Task Name"].values:
                # If task name exists, update the row
                df.loc[df["Task Name"] == tname] = new_row
            else:
                # If task name doesn't exist, append a new row
                series = pd.Series(new_row, index=df.columns)
                df = pd.concat([df, series.to_frame().T], ignore_index=True)
        df = df.sort_values(by=["Task Name"])
        table = df.to_markdown(index=False)
    return df, table


if __name__ == "__main__":
baberabb's avatar
test  
baberabb committed
127
128
    csv_file = Path(f"{Path(__file__).parent.parent.resolve()}/docs/task_table.csv")
    md_file = Path(f"{Path(__file__).parent.parent.resolve()}/docs/task_table.md")
baberabb's avatar
baberabb committed
129
130
131
132
133
134
135
136
137

    try:
        df = pd.read_csv(csv_file)
    except FileNotFoundError:
        df = None

    df, table = maketable(df=df)

    with open(md_file, "w") as f:
jon-tow's avatar
jon-tow committed
138
        f.write(table)
baberabb's avatar
baberabb committed
139
140
    with open(csv_file, "w") as f:
        df.to_csv(f, index=False)