fid.py 4.17 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
from datetime import datetime

import numpy as np
import torch
import torchvision
from cleanfid import fid
from cleanfid.resize import build_resizer
from datasets import Dataset
from tqdm import tqdm


def get_dataset_features(
    dataset: Dataset,
    model,
    mode: str = "clean",
    batch_size: int = 128,
    device: str | torch.device = "cuda",
) -> np.ndarray:
    to_tensor = torchvision.transforms.ToTensor()
    fn_resize = build_resizer(mode)
    np_feats = []
    for batch in tqdm(
        dataset.iter(batch_size=batch_size, drop_last_batch=False),
        desc=f"Extracting {dataset.config_name} features",
        total=(len(dataset) + batch_size - 1) // batch_size,
    ):
        resized_images = [fn_resize(np.array(image.convert("RGB"))) for image in batch["image"]]
        image_tensors = []
        for resized_image in resized_images:
            if resized_image.dtype == "uint8":
                image_tensor = to_tensor(resized_image) * 255
            else:
                image_tensor = to_tensor(resized_image)
            image_tensors.append(image_tensor)
        image_tensors = torch.stack(image_tensors, dim=0)
        np_feats.append(fid.get_batch_features(image_tensors, model, device))
    np_feats = np.concatenate(np_feats, axis=0)
    return np_feats


def get_fid_features(
    dataset_or_folder: str | Dataset | None = None,
    cache_path: str | None = None,
    num: int | None = None,
    mode: str = "clean",
    num_workers: int = 8,
    batch_size: int = 64,
    device: str | torch.device = "cuda",
    force_overwrite: bool = False,
    verbose: bool = True,
) -> tuple[np.ndarray, np.ndarray]:
    if cache_path is not None and os.path.exists(cache_path) and not force_overwrite:
        npz = np.load(cache_path)
        mu, sigma = npz["mu"], npz["sigma"]
    else:
        feat_model = fid.build_feature_extractor(mode, device)
        if isinstance(dataset_or_folder, str):
            np_feats = fid.get_folder_features(
                dataset_or_folder,
                feat_model,
                num_workers=num_workers,
                num=num,
                batch_size=batch_size,
                device=device,
                verbose=verbose,
                mode=mode,
                description=f"Extracting {dataset_or_folder} features",
            )
        else:
            assert isinstance(dataset_or_folder, Dataset)
            np_feats = get_dataset_features(
                dataset_or_folder, model=feat_model, mode=mode, batch_size=batch_size, device=device
            )

        mu = np.mean(np_feats, axis=0)
        sigma = np.cov(np_feats, rowvar=False)
        if cache_path is not None:
            os.makedirs(os.path.abspath(os.path.dirname(cache_path)), exist_ok=True)
            np.savez(cache_path, mu=mu, sigma=sigma)

    return mu, sigma


def compute_fid(
    ref_dirpath_or_dataset: str | Dataset,
    gen_dirpath: str,
    ref_cache_path: str | None = None,
    gen_cache_path: str | None = None,
    use_symlink: bool = True,
    timestamp: str | None = None,
) -> float:
    sym_ref_dirpath, sym_gen_dirpath = None, None
    if use_symlink:
        if timestamp is None:
            timestamp = datetime.now().strftime("%y%m%d.%H%M%S")

        os.makedirs(".tmp", exist_ok=True)

        if isinstance(ref_dirpath_or_dataset, str):
            sym_ref_dirpath = os.path.join(".tmp", f"ref-{hash(str(ref_dirpath_or_dataset))}-{timestamp}")
            os.symlink(os.path.abspath(ref_dirpath_or_dataset), os.path.abspath(sym_ref_dirpath))
            ref_dirpath_or_dataset = sym_ref_dirpath

        sym_gen_dirpath = os.path.join(".tmp", f"gen-{hash(str(gen_dirpath))}-{timestamp}")
        os.symlink(os.path.abspath(gen_dirpath), os.path.abspath(sym_gen_dirpath))
        gen_dirpath = sym_gen_dirpath
    mu1, sigma1 = get_fid_features(dataset_or_folder=ref_dirpath_or_dataset, cache_path=ref_cache_path)
    mu2, sigma2 = get_fid_features(dataset_or_folder=gen_dirpath, cache_path=gen_cache_path)
    fid_score = fid.frechet_distance(mu1, sigma1, mu2, sigma2)
    fid_score = float(fid_score)
    if use_symlink:
        if sym_ref_dirpath is not None:
            os.remove(sym_ref_dirpath)
        os.remove(sym_gen_dirpath)
    return fid_score