utils.py 3.44 KB
Newer Older
muyangli's avatar
muyangli committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import os

import datasets
import torch
import torchvision
from PIL import Image
from torch.utils import data
from torchmetrics.image import LearnedPerceptualImagePatchSimilarity
from tqdm import tqdm


def hash_str_to_int(s: str) -> int:
    """Hash a string to an integer."""
    modulus = 10**9 + 7  # Large prime modulus
    hash_int = 0
    for char in s:
        hash_int = (hash_int * 31 + ord(char)) % modulus
    return hash_int


def already_generate(save_dir: str, num_images) -> bool:
    if os.path.exists(save_dir):
        images = os.listdir(save_dir)
        images = [_ for _ in images if _.endswith(".png")]
        if len(images) == num_images:
            return True
    return False


class MultiImageDataset(data.Dataset):
31
    def __init__(self, gen_dirpath_or_image_path: str, ref_dirpath_or_image_path: str | datasets.Dataset):
muyangli's avatar
muyangli committed
32
        super(data.Dataset, self).__init__()
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
        if os.path.isdir(gen_dirpath_or_image_path):
            self.gen_names = sorted(
                [
                    name
                    for name in os.listdir(gen_dirpath_or_image_path)
                    if name.endswith(".png") or name.endswith(".jpg")
                ]
            )
            self.gen_dirpath = gen_dirpath_or_image_path
        else:
            self.gen_names = [os.path.basename(gen_dirpath_or_image_path)]
            self.gen_dirpath = os.path.dirname(gen_dirpath_or_image_path)
        if os.path.isdir(ref_dirpath_or_image_path):
            self.ref_names = sorted(
                [
                    name
                    for name in os.listdir(ref_dirpath_or_image_path)
                    if name.endswith(".png") or name.endswith(".jpg")
                ]
            )
            self.ref_dirpath = ref_dirpath_or_image_path
        else:
            self.ref_names = [os.path.basename(ref_dirpath_or_image_path)]
            self.ref_dirpath = os.path.dirname(ref_dirpath_or_image_path)
muyangli's avatar
muyangli committed
57
58
59
60
61
62
63
64

        assert len(self.ref_names) == len(self.gen_names)
        self.transform = torchvision.transforms.ToTensor()

    def __len__(self):
        return len(self.ref_names)

    def __getitem__(self, idx: int):
65
66
        ref_image = Image.open(os.path.join(self.ref_dirpath, self.ref_names[idx])).convert("RGB")
        gen_image = Image.open(os.path.join(self.gen_dirpath, self.gen_names[idx])).convert("RGB")
muyangli's avatar
muyangli committed
67
68
69
70
71
72
73
74
75
76
        gen_size = gen_image.size
        ref_size = ref_image.size
        if ref_size != gen_size:
            ref_image = ref_image.resize(gen_size, Image.Resampling.BICUBIC)
        gen_tensor = self.transform(gen_image)
        ref_tensor = self.transform(ref_image)
        return [gen_tensor, ref_tensor]


def compute_lpips(
77
78
79
80
81
    ref_dirpath_or_image_path: str,
    gen_dirpath_or_image_path: str,
    batch_size: int = 4,
    num_workers: int = 0,
    device: str | torch.device = "cuda",
muyangli's avatar
muyangli committed
82
83
84
) -> float:
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    metric = LearnedPerceptualImagePatchSimilarity(normalize=True).to(device)
85
    dataset = MultiImageDataset(gen_dirpath_or_image_path, ref_dirpath_or_image_path)
muyangli's avatar
muyangli committed
86
87
88
89
    dataloader = data.DataLoader(
        dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, drop_last=False
    )
    with torch.no_grad():
90
        desc = (os.path.basename(gen_dirpath_or_image_path)) + " LPIPS"
muyangli's avatar
muyangli committed
91
92
93
94
        for i, batch in enumerate(tqdm(dataloader, desc=desc)):
            batch = [tensor.to(device) for tensor in batch]
            metric.update(batch[0], batch[1])
    return metric.compute().item()