__init__.py 1.83 KB
Newer Older
muyangli's avatar
muyangli committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import random

import datasets
import yaml
from nunchaku.utils import fetch_or_download

__all__ = ["get_dataset"]


def load_dataset_yaml(meta_path: str, max_dataset_size: int = -1, repeat: int = 4) -> dict:
    meta = yaml.safe_load(open(meta_path, "r"))
    names = list(meta.keys())
    if max_dataset_size > 0:
        random.Random(0).shuffle(names)
        names = names[:max_dataset_size]
        names = sorted(names)

    ret = {"filename": [], "prompt": [], "meta_path": []}
    idx = 0
    for name in names:
        prompt = meta[name]
        for j in range(repeat):
            ret["filename"].append(f"{name}-{j}")
            ret["prompt"].append(prompt)
            ret["meta_path"].append(meta_path)
            idx += 1
    return ret


def get_dataset(
    name: str,
    config_name: str | None = None,
    split: str = "train",
    return_gt: bool = False,
    max_dataset_size: int = 5000,
) -> datasets.Dataset:
    prefix = os.path.dirname(__file__)
    kwargs = {
        "name": config_name,
        "split": split,
        "trust_remote_code": True,
        "token": True,
        "max_dataset_size": max_dataset_size,
    }
    path = os.path.join(prefix, f"{name}")
    if name == "MJHQ":
        dataset = datasets.load_dataset(path, return_gt=return_gt, **kwargs)
    else:
        dataset = datasets.Dataset.from_dict(
            load_dataset_yaml(
                fetch_or_download(f"mit-han-lab/nunchaku-test/{name}.yaml", repo_type="dataset"),
                max_dataset_size=max_dataset_size,
                repeat=1,
            ),
            features=datasets.Features(
                {
                    "filename": datasets.Value("string"),
                    "prompt": datasets.Value("string"),
                    "meta_path": datasets.Value("string"),
                }
            ),
        )
    return dataset