MJHQ.py 3.78 KB
Newer Older
muyangli's avatar
muyangli committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import json
import os
import random

import datasets
from PIL import Image

_CITATION = """\
@misc{li2024playground,
      title={Playground v2.5: Three Insights towards Enhancing Aesthetic Quality in Text-to-Image Generation}, 
      author={Daiqing Li and Aleks Kamko and Ehsan Akhgari and Ali Sabet and Linmiao Xu and Suhail Doshi},
      year={2024},
      eprint={2402.17245},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
"""

_DESCRIPTION = """\
We introduce a new benchmark, MJHQ-30K, for automatic evaluation of a model’s aesthetic quality. 
The benchmark computes FID on a high-quality dataset to gauge aesthetic quality.
"""

_HOMEPAGE = "https://huggingface.co/datasets/playgroundai/MJHQ-30K"

_LICENSE = (
    "Playground v2.5 Community License "
    "(https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/LICENSE.md)"
)

IMAGE_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/mjhq30k_imgs.zip"

META_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/meta_data.json"


class MJHQConfig(datasets.BuilderConfig):
    def __init__(self, max_dataset_size: int = -1, return_gt: bool = False, **kwargs):
        super(MJHQConfig, self).__init__(
            name=kwargs.get("name", "default"),
            version=kwargs.get("version", "0.0.0"),
            data_dir=kwargs.get("data_dir", None),
            data_files=kwargs.get("data_files", None),
            description=kwargs.get("description", None),
        )
        self.max_dataset_size = max_dataset_size
        self.return_gt = return_gt


class DCI(datasets.GeneratorBasedBuilder):
    VERSION = datasets.Version("0.0.0")

    BUILDER_CONFIG_CLASS = MJHQConfig
    BUILDER_CONFIGS = [MJHQConfig(name="MJHQ", version=VERSION, description="MJHQ-30K full dataset")]
    DEFAULT_CONFIG_NAME = "MJHQ"

    def _info(self):
        features = datasets.Features(
            {
                "filename": datasets.Value("string"),
                "category": datasets.Value("string"),
                "image": datasets.Image(),
                "prompt": datasets.Value("string"),
                "prompt_path": datasets.Value("string"),
                "image_root": datasets.Value("string"),
                "image_path": datasets.Value("string"),
                "split": datasets.Value("string"),
            }
        )
        return datasets.DatasetInfo(
            description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
        )

    def _split_generators(self, dl_manager: datasets.download.DownloadManager):
        meta_path = dl_manager.download(META_URL)
        image_root = dl_manager.download_and_extract(IMAGE_URL)
        return [
            datasets.SplitGenerator(
                name=datasets.Split.TRAIN, gen_kwargs={"meta_path": meta_path, "image_root": image_root}
            ),
        ]

    def _generate_examples(self, meta_path: str, image_root: str):

        with open(meta_path, "r") as f:
            meta = json.load(f)

        names = list(meta.keys())
        if self.config.max_dataset_size > 0:
            random.Random(0).shuffle(names)
            names = names[: self.config.max_dataset_size]
            names = sorted(names)

        for i, name in enumerate(names):
            category = meta[name]["category"]
            prompt = meta[name]["prompt"]
            image_path = os.path.join(image_root, category, f"{name}.jpg")
            yield i, {
                "filename": name,
                "category": category,
                "image": Image.open(image_path) if self.config.return_gt else None,
                "prompt": prompt,
                "meta_path": meta_path,
                "image_root": image_root,
                "image_path": image_path,
                "split": self.config.name,
            }