_generate_config.py 1.13 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from pathlib import Path

import datasets


if __name__ == "__main__":
    subsets = [
        x
        for x in datasets.get_dataset_config_names(
            "mrlbenchmarks/global-piqa-nonparallel"
        )
        if not x.startswith("dev")
    ]
    PARENT = Path(__file__).parent
    for s in subsets:
        with open(PARENT / f"{s}.yaml", "w") as f:
            f.write("include: '_template_mc'\n")
            f.write(f"task: mrl_{s}\n")
            f.write(f"dataset_name: {s}\n")
Baber's avatar
Baber committed
20
21
22
23
24

with open(PARENT / "_global_piqa.yaml", "w") as f:
    f.write("group: global_piqa\n")
    f.write("task:\n")
    for s in subsets:
Baber's avatar
Baber committed
25
26
        f.write(f"  - task: mrl_{s}\n")
        f.write(f"    task_alias: {s}\n")
Baber's avatar
Baber committed
27
28
29
30
    f.write("aggregate_metric_list:\n")
    f.write("  - metric: acc\n")
    f.write("    aggregation: mean\n")
    f.write("    weight_by_size: true\n")
Baber's avatar
Baber committed
31
32
33
34
35
36
    f.write("  - metric: acc_norm\n")
    f.write("    aggregation: mean\n")
    f.write("    weight_by_size: true\n")
    f.write("  - metric: acc_bytes\n")
    f.write("    aggregation: mean\n")
    f.write("    weight_by_size: true\n")
Baber's avatar
Baber committed
37
38
    f.write("metadata:\n")
    f.write("  version: 1.0\n")