make_table_tasks.py 1.31 KB
Newer Older
bzantium's avatar
bzantium committed
1
2
3
4
5
6
"""
Usage:
   python make_table_tasks.py --output <markdown_filename>
"""
import argparse
import logging
Leo Gao's avatar
Leo Gao committed
7
8
9
10
from lm_eval import tasks
from pytablewriter import MarkdownTableWriter


bzantium's avatar
bzantium committed
11
12
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Leo Gao's avatar
Leo Gao committed
13

bzantium's avatar
bzantium committed
14
15

def check(tf):
Leo Gao's avatar
Leo Gao committed
16
    if tf:
bzantium's avatar
bzantium committed
17
        return "✓"
Leo Gao's avatar
Leo Gao committed
18
    else:
bzantium's avatar
bzantium committed
19
        return " "
Leo Gao's avatar
Leo Gao committed
20
21


bzantium's avatar
bzantium committed
22
23
24
25
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--output", type=str, default="task_table.md")
    args = parser.parse_args()
Leo Gao's avatar
Leo Gao committed
26

bzantium's avatar
bzantium committed
27
28
29
    writer = MarkdownTableWriter()
    writer.headers = ["Task Name", "Train", "Val", "Test", "Val/Test Docs", "Metrics"]
    values = []
Leo Gao's avatar
Leo Gao committed
30

bzantium's avatar
bzantium committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    tasks = tasks.TASK_REGISTRY.items()
    tasks = sorted(tasks, key=lambda x: x[0])
    for tname, Task in tasks:
        task = Task()
        v = [
            tname,
            check(task.has_training_docs()),
            check(task.has_validation_docs()),
            check(task.has_test_docs()),
            len(
                list(
                    task.test_docs() if task.has_test_docs() else task.validation_docs()
                )
            ),
            ", ".join(task.aggregation().keys()),
        ]
        logger.info(v)
        values.append(v)
    writer.value_matrix = values
    table = writer.dumps()
    with open(args.output, "w") as f:
        f.write(table)