"docs_zh-CN/tutorials/index.rst" did not exist on "54538ad2df2a27be0224c4dfdc30913acac97b5e"
utils.py 3.61 KB
Newer Older
muyangli's avatar
muyangli committed
1
import os
2
3
from os import PathLike
from pathlib import Path
muyangli's avatar
muyangli committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

import datasets
import torch
import torchvision
from PIL import Image
from torch.utils import data
from torchmetrics.image import LearnedPerceptualImagePatchSimilarity
from tqdm import tqdm


def hash_str_to_int(s: str) -> int:
    """Hash a string to an integer."""
    modulus = 10**9 + 7  # Large prime modulus
    hash_int = 0
    for char in s:
        hash_int = (hash_int * 31 + ord(char)) % modulus
    return hash_int


23
24
25
26
27
28
29
def already_generate(save_dir: str | PathLike[str], num_images) -> bool:
    if isinstance(save_dir, str):
        save_dir = Path(save_dir)
    assert isinstance(save_dir, Path)
    if save_dir.exists():
        images = list(save_dir.iterdir())
        images = [_ for _ in images if _.name.endswith(".png")]
muyangli's avatar
muyangli committed
30
31
32
33
34
35
        if len(images) == num_images:
            return True
    return False


class MultiImageDataset(data.Dataset):
36
    def __init__(self, gen_dirpath_or_image_path: str, ref_dirpath_or_image_path: str | datasets.Dataset):
muyangli's avatar
muyangli committed
37
        super(data.Dataset, self).__init__()
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        if os.path.isdir(gen_dirpath_or_image_path):
            self.gen_names = sorted(
                [
                    name
                    for name in os.listdir(gen_dirpath_or_image_path)
                    if name.endswith(".png") or name.endswith(".jpg")
                ]
            )
            self.gen_dirpath = gen_dirpath_or_image_path
        else:
            self.gen_names = [os.path.basename(gen_dirpath_or_image_path)]
            self.gen_dirpath = os.path.dirname(gen_dirpath_or_image_path)
        if os.path.isdir(ref_dirpath_or_image_path):
            self.ref_names = sorted(
                [
                    name
                    for name in os.listdir(ref_dirpath_or_image_path)
                    if name.endswith(".png") or name.endswith(".jpg")
                ]
            )
            self.ref_dirpath = ref_dirpath_or_image_path
        else:
            self.ref_names = [os.path.basename(ref_dirpath_or_image_path)]
            self.ref_dirpath = os.path.dirname(ref_dirpath_or_image_path)
muyangli's avatar
muyangli committed
62
63
64
65
66
67
68
69

        assert len(self.ref_names) == len(self.gen_names)
        self.transform = torchvision.transforms.ToTensor()

    def __len__(self):
        return len(self.ref_names)

    def __getitem__(self, idx: int):
70
71
        ref_image = Image.open(os.path.join(self.ref_dirpath, self.ref_names[idx])).convert("RGB")
        gen_image = Image.open(os.path.join(self.gen_dirpath, self.gen_names[idx])).convert("RGB")
muyangli's avatar
muyangli committed
72
73
74
75
76
77
78
79
80
81
        gen_size = gen_image.size
        ref_size = ref_image.size
        if ref_size != gen_size:
            ref_image = ref_image.resize(gen_size, Image.Resampling.BICUBIC)
        gen_tensor = self.transform(gen_image)
        ref_tensor = self.transform(ref_image)
        return [gen_tensor, ref_tensor]


def compute_lpips(
82
83
84
85
86
    ref_dirpath_or_image_path: str,
    gen_dirpath_or_image_path: str,
    batch_size: int = 4,
    num_workers: int = 0,
    device: str | torch.device = "cuda",
muyangli's avatar
muyangli committed
87
88
89
) -> float:
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    metric = LearnedPerceptualImagePatchSimilarity(normalize=True).to(device)
90
    dataset = MultiImageDataset(gen_dirpath_or_image_path, ref_dirpath_or_image_path)
muyangli's avatar
muyangli committed
91
92
93
94
    dataloader = data.DataLoader(
        dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, drop_last=False
    )
    with torch.no_grad():
95
        desc = (os.path.basename(gen_dirpath_or_image_path)) + " LPIPS"
muyangli's avatar
muyangli committed
96
97
98
99
        for i, batch in enumerate(tqdm(dataloader, desc=desc)):
            batch = [tensor.to(device) for tensor in batch]
            metric.update(batch[0], batch[1])
    return metric.compute().item()