from typing import Sequence, Dict, Union import math import time import numpy as np import cv2 from PIL import Image import torch.utils.data as data from utils.file import load_file_list from utils.image import center_crop_arr, augment, random_crop_arr from utils.degradation import ( random_mixed_kernels, random_add_gaussian_noise, random_add_jpg_compression ) class CodeformerDataset(data.Dataset): def __init__( self, file_list: str, out_size: int, crop_type: str, use_hflip: bool, blur_kernel_size: int, kernel_list: Sequence[str], kernel_prob: Sequence[float], blur_sigma: Sequence[float], downsample_range: Sequence[float], noise_range: Sequence[float], jpeg_range: Sequence[int] ) -> "CodeformerDataset": super(CodeformerDataset, self).__init__() self.file_list = file_list self.paths = load_file_list(file_list) self.out_size = out_size self.crop_type = crop_type assert self.crop_type in ["none", "center", "random"] self.use_hflip = use_hflip # degradation configurations self.blur_kernel_size = blur_kernel_size self.kernel_list = kernel_list self.kernel_prob = kernel_prob self.blur_sigma = blur_sigma self.downsample_range = downsample_range self.noise_range = noise_range self.jpeg_range = jpeg_range def __getitem__(self, index: int) -> Dict[str, Union[np.ndarray, str]]: # load gt image # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. gt_path = self.paths[index] success = False for _ in range(3): try: pil_img = Image.open(gt_path).convert("RGB") success = True break except: time.sleep(1) assert success, f"failed to load image {gt_path}" if self.crop_type == "center": pil_img_gt = center_crop_arr(pil_img, self.out_size) elif self.crop_type == "random": pil_img_gt = random_crop_arr(pil_img, self.out_size) else: pil_img_gt = np.array(pil_img) assert pil_img_gt.shape[:2] == (self.out_size, self.out_size) img_gt = (pil_img_gt[..., ::-1] / 255.0).astype(np.float32) # random horizontal flip img_gt = augment(img_gt, hflip=self.use_hflip, rotation=False, return_status=False) h, w, _ = img_gt.shape # ------------------------ generate lq image ------------------------ # # blur kernel = random_mixed_kernels( self.kernel_list, self.kernel_prob, self.blur_kernel_size, self.blur_sigma, self.blur_sigma, [-math.pi, math.pi], noise_range=None ) img_lq = cv2.filter2D(img_gt, -1, kernel) # downsample scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR) # noise if self.noise_range is not None: img_lq = random_add_gaussian_noise(img_lq, self.noise_range) # jpeg compression if self.jpeg_range is not None: img_lq = random_add_jpg_compression(img_lq, self.jpeg_range) # resize to original size img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR) # BGR to RGB, [-1, 1] target = (img_gt[..., ::-1] * 2 - 1).astype(np.float32) # BGR to RGB, [0, 1] source = img_lq[..., ::-1].astype(np.float32) return dict(jpg=target, txt="", hint=source) def __len__(self) -> int: return len(self.paths)