Commit 2ac5586e authored by Rayyyyy's avatar Rayyyyy
Browse files

first commit

parents
Pipeline #784 canceled with stages
import cv2
import numpy as np
import os.path as osp
from torch.utils import data as data
from torchvision.transforms.functional import normalize
from basicsr.data.data_util import paths_from_lmdb, scandir
from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, imfrombytes, img2tensor
from basicsr.utils.matlab_functions import imresize, rgb2ycbcr
from basicsr.utils.registry import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class ImageNetPairedDataset(data.Dataset):
def __init__(self, opt):
super(ImageNetPairedDataset, self).__init__()
self.opt = opt
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.mean = opt['mean'] if 'mean' in opt else None
self.std = opt['std'] if 'std' in opt else None
self.gt_folder = opt['dataroot_gt']
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = [self.gt_folder]
self.io_backend_opt['client_keys'] = ['gt']
self.paths = paths_from_lmdb(self.gt_folder)
elif 'meta_info_file' in self.opt:
with open(self.opt['meta_info_file'], 'r') as fin:
self.paths = [osp.join(self.gt_folder, line.split(' ')[0]) for line in fin]
else:
self.paths = sorted(list(scandir(self.gt_folder, full_path=True)))
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
scale = self.opt['scale']
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
# image range: [0, 1], float32.
gt_path = self.paths[index]
img_bytes = self.file_client.get(gt_path, 'gt')
img_gt = imfrombytes(img_bytes, float32=True)
# modcrop
size_h, size_w, _ = img_gt.shape
size_h = size_h - size_h % scale
size_w = size_w - size_w % scale
img_gt = img_gt[0:size_h, 0:size_w, :]
# generate training pairs
size_h = max(size_h, self.opt['gt_size'])
size_w = max(size_w, self.opt['gt_size'])
img_gt = cv2.resize(img_gt, (size_w, size_h))
img_lq = imresize(img_gt, 1 / scale)
img_gt = np.ascontiguousarray(img_gt, dtype=np.float32)
img_lq = np.ascontiguousarray(img_lq, dtype=np.float32)
# augmentation for training
if self.opt['phase'] == 'train':
gt_size = self.opt['gt_size']
# random crop
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
# color space transform
if 'color' in self.opt and self.opt['color'] == 'y':
img_gt = rgb2ycbcr(img_gt, y_only=True)[..., None]
img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None]
# crop the unmatched GT images during validation or testing, especially for SR benchmark datasets
# TODO: It is better to update the datasets, rather than force to crop
if self.opt['phase'] != 'train':
img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :]
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
# normalize
if self.mean is not None or self.std is not None:
normalize(img_lq, self.mean, self.std, inplace=True)
normalize(img_gt, self.mean, self.std, inplace=True)
return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path}
def __len__(self):
return len(self.paths)
This source diff could not be displayed because it is too large. You can view the blob instead.
import cv2
import math
import numpy as np
import os
import os.path as osp
import random
import time
import torch
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
from torch.utils import data as data
from basicsr.data.data_util import scandir
@DATASET_REGISTRY.register()
class RealESRGANDataset(data.Dataset):
"""Dataset used for Real-ESRGAN model:
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
It loads gt (Ground-Truth) images, and augments them.
It also generates blur kernels and sinc kernels for generating low-quality images.
Note that the low-quality images are processed in tensors on GPUS for faster processing.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
meta_info (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
Please see more options in the codes.
"""
def __init__(self, opt):
super(RealESRGANDataset, self).__init__()
self.opt = opt
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.gt_folder = opt['dataroot_gt']
# file client (lmdb io backend)
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = [self.gt_folder]
self.io_backend_opt['client_keys'] = ['gt']
if not self.gt_folder.endswith('.lmdb'):
raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
self.paths = [line.split('.')[0] for line in fin]
elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
# disk backend with meta_info
# Each line in the meta_info describes the relative path to an image
with open(self.opt['meta_info']) as fin:
paths = [line.strip().split(' ')[0] for line in fin]
self.paths = [os.path.join(self.gt_folder, v) for v in paths]
else:
self.paths = sorted(list(scandir(self.gt_folder, full_path=True)))
# blur settings for the first degradation
self.blur_kernel_size = opt['blur_kernel_size']
self.kernel_list = opt['kernel_list']
self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
self.blur_sigma = opt['blur_sigma']
self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
# blur settings for the second degradation
self.blur_kernel_size2 = opt['blur_kernel_size2']
self.kernel_list2 = opt['kernel_list2']
self.kernel_prob2 = opt['kernel_prob2']
self.blur_sigma2 = opt['blur_sigma2']
self.betag_range2 = opt['betag_range2']
self.betap_range2 = opt['betap_range2']
self.sinc_prob2 = opt['sinc_prob2']
# a final sinc filter
self.final_sinc_prob = opt['final_sinc_prob']
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
# TODO: kernel range is now hard-coded, should be in the configure file
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
self.pulse_tensor[10, 10] = 1
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
# -------------------------------- Load gt images -------------------------------- #
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
gt_path = self.paths[index]
# avoid errors caused by high latency in reading files
retry = 3
while retry > 0:
try:
img_bytes = self.file_client.get(gt_path, 'gt')
except (IOError, OSError) as e:
logger = get_root_logger()
logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
# change another file to read
index = random.randint(0, self.__len__())
gt_path = self.paths[index]
time.sleep(1) # sleep 1s for occasional server congestion
else:
break
finally:
retry -= 1
img_gt = imfrombytes(img_bytes, float32=True)
# -------------------- Do augmentation for training: flip, rotation -------------------- #
img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
# crop or pad to 400
# TODO: 400 is hard-coded. You may change it accordingly
h, w = img_gt.shape[0:2]
crop_pad_size = 400
# pad
if h < crop_pad_size or w < crop_pad_size:
pad_h = max(0, crop_pad_size - h)
pad_w = max(0, crop_pad_size - w)
img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
# crop
if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
h, w = img_gt.shape[0:2]
# randomly choose top and left coordinates
top = random.randint(0, h - crop_pad_size)
left = random.randint(0, w - crop_pad_size)
img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
kernel_size = random.choice(self.kernel_range)
if np.random.uniform() < self.opt['sinc_prob']:
# this sinc filter setting is for kernels ranging from [7, 21]
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel = random_mixed_kernels(
self.kernel_list,
self.kernel_prob,
kernel_size,
self.blur_sigma,
self.blur_sigma, [-math.pi, math.pi],
self.betag_range,
self.betap_range,
noise_range=None)
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
kernel_size = random.choice(self.kernel_range)
if np.random.uniform() < self.opt['sinc_prob2']:
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel2 = random_mixed_kernels(
self.kernel_list2,
self.kernel_prob2,
kernel_size,
self.blur_sigma2,
self.blur_sigma2, [-math.pi, math.pi],
self.betag_range2,
self.betap_range2,
noise_range=None)
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
# ------------------------------------- the final sinc kernel ------------------------------------- #
if np.random.uniform() < self.opt['final_sinc_prob']:
kernel_size = random.choice(self.kernel_range)
omega_c = np.random.uniform(np.pi / 3, np.pi)
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
sinc_kernel = torch.FloatTensor(sinc_kernel)
else:
sinc_kernel = self.pulse_tensor
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
kernel = torch.FloatTensor(kernel)
kernel2 = torch.FloatTensor(kernel2)
return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
return return_d
def __len__(self):
return len(self.paths)
\ No newline at end of file
import importlib
from os import path as osp
from basicsr.utils import scandir
# automatically scan and import model modules for registry
# scan all the files that end with '_model.py' under the model folder
model_folder = osp.dirname(osp.abspath(__file__))
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
# import all the model modules
_model_modules = [importlib.import_module(f'hat.models.{file_name}') for file_name in model_filenames]
import torch
from torch.nn import functional as F
from basicsr.utils.registry import MODEL_REGISTRY
from basicsr.models.sr_model import SRModel
from basicsr.metrics import calculate_metric
from basicsr.utils import imwrite, tensor2img
import math
from tqdm import tqdm
from os import path as osp
@MODEL_REGISTRY.register()
class HATModel(SRModel):
def pre_process(self):
# pad to multiplication of window_size
window_size = self.opt['network_g']['window_size']
self.scale = self.opt.get('scale', 1)
self.mod_pad_h, self.mod_pad_w = 0, 0
_, _, h, w = self.lq.size()
if h % window_size != 0:
self.mod_pad_h = window_size - h % window_size
if w % window_size != 0:
self.mod_pad_w = window_size - w % window_size
self.img = F.pad(self.lq, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
def process(self):
# model inference
if hasattr(self, 'net_g_ema'):
self.net_g_ema.eval()
with torch.no_grad():
self.output = self.net_g_ema(self.img)
else:
self.net_g.eval()
with torch.no_grad():
self.output = self.net_g(self.img)
# self.net_g.train()
def tile_process(self):
"""It will first crop input images to tiles, and then process each tile.
Finally, all the processed tiles are merged into one images.
Modified from: https://github.com/ata4/esrgan-launcher
"""
batch, channel, height, width = self.img.shape
output_height = height * self.scale
output_width = width * self.scale
output_shape = (batch, channel, output_height, output_width)
# start with black image
self.output = self.img.new_zeros(output_shape)
tiles_x = math.ceil(width / self.opt['tile']['tile_size'])
tiles_y = math.ceil(height / self.opt['tile']['tile_size'])
# loop over all tiles
for y in range(tiles_y):
for x in range(tiles_x):
# extract tile from input image
ofs_x = x * self.opt['tile']['tile_size']
ofs_y = y * self.opt['tile']['tile_size']
# input tile area on total image
input_start_x = ofs_x
input_end_x = min(ofs_x + self.opt['tile']['tile_size'], width)
input_start_y = ofs_y
input_end_y = min(ofs_y + self.opt['tile']['tile_size'], height)
# input tile area on total image with padding
input_start_x_pad = max(input_start_x - self.opt['tile']['tile_pad'], 0)
input_end_x_pad = min(input_end_x + self.opt['tile']['tile_pad'], width)
input_start_y_pad = max(input_start_y - self.opt['tile']['tile_pad'], 0)
input_end_y_pad = min(input_end_y + self.opt['tile']['tile_pad'], height)
# input tile dimensions
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
tile_idx = y * tiles_x + x + 1
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
# upscale tile
try:
if hasattr(self, 'net_g_ema'):
self.net_g_ema.eval()
with torch.no_grad():
output_tile = self.net_g_ema(input_tile)
else:
self.net_g.eval()
with torch.no_grad():
output_tile = self.net_g(input_tile)
except RuntimeError as error:
print('Error', error)
print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
# output tile area on total image
output_start_x = input_start_x * self.opt['scale']
output_end_x = input_end_x * self.opt['scale']
output_start_y = input_start_y * self.opt['scale']
output_end_y = input_end_y * self.opt['scale']
# output tile area without padding
output_start_x_tile = (input_start_x - input_start_x_pad) * self.opt['scale']
output_end_x_tile = output_start_x_tile + input_tile_width * self.opt['scale']
output_start_y_tile = (input_start_y - input_start_y_pad) * self.opt['scale']
output_end_y_tile = output_start_y_tile + input_tile_height * self.opt['scale']
# put tile into output image
self.output[:, :, output_start_y:output_end_y,
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
output_start_x_tile:output_end_x_tile]
def post_process(self):
_, _, h, w = self.output.size()
self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
dataset_name = dataloader.dataset.opt['name']
with_metrics = self.opt['val'].get('metrics') is not None
use_pbar = self.opt['val'].get('pbar', False)
if with_metrics:
if not hasattr(self, 'metric_results'): # only execute in the first run
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
# initialize the best metric results for each dataset_name (supporting multiple validation datasets)
self._initialize_best_metric_results(dataset_name)
# zero self.metric_results
if with_metrics:
self.metric_results = {metric: 0 for metric in self.metric_results}
metric_data = dict()
if use_pbar:
pbar = tqdm(total=len(dataloader), unit='image')
for idx, val_data in enumerate(dataloader):
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
self.feed_data(val_data)
self.pre_process()
if 'tile' in self.opt:
self.tile_process()
else:
self.process()
self.post_process()
visuals = self.get_current_visuals()
sr_img = tensor2img([visuals['result']])
metric_data['img'] = sr_img
if 'gt' in visuals:
gt_img = tensor2img([visuals['gt']])
metric_data['img2'] = gt_img
del self.gt
# tentative for out of GPU memory
del self.lq
del self.output
torch.cuda.empty_cache()
if save_img:
if self.opt['is_train']:
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
f'{img_name}_{current_iter}.png')
else:
if self.opt['val']['suffix']:
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["val"]["suffix"]}.png')
else:
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["name"]}.png')
imwrite(sr_img, save_img_path)
if with_metrics:
# calculate metrics
for name, opt_ in self.opt['val']['metrics'].items():
self.metric_results[name] += calculate_metric(metric_data, opt_)
if use_pbar:
pbar.update(1)
pbar.set_description(f'Test {img_name}')
if use_pbar:
pbar.close()
if with_metrics:
for metric in self.metric_results.keys():
self.metric_results[metric] /= (idx + 1)
# update the best metric result
self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
import numpy as np
import random
import torch
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
from basicsr.data.transforms import paired_random_crop
from basicsr.models.srgan_model import SRGANModel
from basicsr.utils import DiffJPEG, USMSharp
from basicsr.utils.img_process_util import filter2D
from basicsr.utils.registry import MODEL_REGISTRY
from collections import OrderedDict
from torch.nn import functional as F
@MODEL_REGISTRY.register()
class RealHATGANModel(SRGANModel):
"""GAN-based Real_HAT Model.
It mainly performs:
1. randomly synthesize LQ images in GPU tensors
2. optimize the networks with GAN training.
"""
def __init__(self, opt):
super(RealHATGANModel, self).__init__(opt)
self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
self.usm_sharpener = USMSharp().cuda() # do usm sharpening
self.queue_size = opt.get('queue_size', 180)
@torch.no_grad()
def _dequeue_and_enqueue(self):
"""It is the training pair pool for increasing the diversity in a batch.
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
to increase the degradation diversity in a batch.
"""
# initialize
b, c, h, w = self.lq.size()
if not hasattr(self, 'queue_lr'):
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
_, c, h, w = self.gt.size()
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
self.queue_ptr = 0
if self.queue_ptr == self.queue_size: # the pool is full
# do dequeue and enqueue
# shuffle
idx = torch.randperm(self.queue_size)
self.queue_lr = self.queue_lr[idx]
self.queue_gt = self.queue_gt[idx]
# get first b samples
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
# update the queue
self.queue_lr[0:b, :, :, :] = self.lq.clone()
self.queue_gt[0:b, :, :, :] = self.gt.clone()
self.lq = lq_dequeue
self.gt = gt_dequeue
else:
# only do enqueue
self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
self.queue_ptr = self.queue_ptr + b
@torch.no_grad()
def feed_data(self, data):
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
"""
if self.is_train and self.opt.get('high_order_degradation', True):
# training data synthesis
self.gt = data['gt'].to(self.device)
self.gt_usm = self.usm_sharpener(self.gt)
self.kernel1 = data['kernel1'].to(self.device)
self.kernel2 = data['kernel2'].to(self.device)
self.sinc_kernel = data['sinc_kernel'].to(self.device)
ori_h, ori_w = self.gt.size()[2:4]
# ----------------------- The first degradation process ----------------------- #
# blur
out = filter2D(self.gt_usm, self.kernel1)
# random resize
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
if updown_type == 'up':
scale = np.random.uniform(1, self.opt['resize_range'][1])
elif updown_type == 'down':
scale = np.random.uniform(self.opt['resize_range'][0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, scale_factor=scale, mode=mode)
# add noise
gray_noise_prob = self.opt['gray_noise_prob']
if np.random.uniform() < self.opt['gaussian_noise_prob']:
out = random_add_gaussian_noise_pt(
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
else:
out = random_add_poisson_noise_pt(
out,
scale_range=self.opt['poisson_scale_range'],
gray_prob=gray_noise_prob,
clip=True,
rounds=False)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
out = self.jpeger(out, quality=jpeg_p)
# ----------------------- The second degradation process ----------------------- #
# blur
if np.random.uniform() < self.opt['second_blur_prob']:
out = filter2D(out, self.kernel2)
# random resize
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
if updown_type == 'up':
scale = np.random.uniform(1, self.opt['resize_range2'][1])
elif updown_type == 'down':
scale = np.random.uniform(self.opt['resize_range2'][0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
# add noise
gray_noise_prob = self.opt['gray_noise_prob2']
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
out = random_add_gaussian_noise_pt(
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
else:
out = random_add_poisson_noise_pt(
out,
scale_range=self.opt['poisson_scale_range2'],
gray_prob=gray_noise_prob,
clip=True,
rounds=False)
# JPEG compression + the final sinc filter
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
# as one operation.
# We consider two orders:
# 1. [resize back + sinc filter] + JPEG compression
# 2. JPEG compression + [resize back + sinc filter]
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
if np.random.uniform() < 0.5:
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
out = filter2D(out, self.sinc_kernel)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
else:
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
out = filter2D(out, self.sinc_kernel)
# clamp and round
self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
# random crop
gt_size = self.opt['gt_size']
(self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size,
self.opt['scale'])
# training pair pool
self._dequeue_and_enqueue()
# sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
self.gt_usm = self.usm_sharpener(self.gt)
self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
else:
# for paired training or validation
self.lq = data['lq'].to(self.device)
if 'gt' in data:
self.gt = data['gt'].to(self.device)
self.gt_usm = self.usm_sharpener(self.gt)
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
# do not use the synthetic process during validation
self.is_train = False
super(RealHATGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
self.is_train = True
def optimize_parameters(self, current_iter):
# usm sharpening
l1_gt = self.gt_usm
percep_gt = self.gt_usm
gan_gt = self.gt_usm
if self.opt['l1_gt_usm'] is False:
l1_gt = self.gt
if self.opt['percep_gt_usm'] is False:
percep_gt = self.gt
if self.opt['gan_gt_usm'] is False:
gan_gt = self.gt
# optimize net_g
for p in self.net_d.parameters():
p.requires_grad = False
self.optimizer_g.zero_grad()
self.output = self.net_g(self.lq)
l_g_total = 0
loss_dict = OrderedDict()
if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
# pixel loss
if self.cri_pix:
l_g_pix = self.cri_pix(self.output, l1_gt)
l_g_total += l_g_pix
loss_dict['l_g_pix'] = l_g_pix
# perceptual loss
if self.cri_perceptual:
l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt)
if l_g_percep is not None:
l_g_total += l_g_percep
loss_dict['l_g_percep'] = l_g_percep
if l_g_style is not None:
l_g_total += l_g_style
loss_dict['l_g_style'] = l_g_style
# gan loss
fake_g_pred = self.net_d(self.output)
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
l_g_total += l_g_gan
loss_dict['l_g_gan'] = l_g_gan
l_g_total.backward()
self.optimizer_g.step()
# optimize net_d
for p in self.net_d.parameters():
p.requires_grad = True
self.optimizer_d.zero_grad()
# real
real_d_pred = self.net_d(gan_gt)
l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
loss_dict['l_d_real'] = l_d_real
loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
l_d_real.backward()
# fake
fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9
l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
loss_dict['l_d_fake'] = l_d_fake
loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
l_d_fake.backward()
self.optimizer_d.step()
if self.ema_decay > 0:
self.model_ema(decay=self.ema_decay)
self.log_dict = self.reduce_loss_dict(loss_dict)
def test(self):
# pad to multiplication of window_size
window_size = self.opt['network_g']['window_size']
scale = self.opt.get('scale', 1)
mod_pad_h, mod_pad_w = 0, 0
_, _, h, w = self.lq.size()
if h % window_size != 0:
mod_pad_h = window_size - h % window_size
if w % window_size != 0:
mod_pad_w = window_size - w % window_size
img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
if hasattr(self, 'net_g_ema'):
self.net_g_ema.eval()
with torch.no_grad():
self.output = self.net_g_ema(img)
else:
self.net_g.eval()
with torch.no_grad():
self.output = self.net_g(img)
self.net_g.train()
_, _, h, w = self.output.size()
self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale]
\ No newline at end of file
import numpy as np
import random
import torch
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
from basicsr.data.transforms import paired_random_crop
from basicsr.models.sr_model import SRModel
from basicsr.utils import DiffJPEG, USMSharp
from basicsr.utils.img_process_util import filter2D
from basicsr.utils.registry import MODEL_REGISTRY
from torch.nn import functional as F
@MODEL_REGISTRY.register()
class RealHATMSEModel(SRModel):
"""MSE-based Real_HAT Model.
It is trained without GAN losses.
It mainly performs:
1. randomly synthesize LQ images in GPU tensors
2. optimize the networks with GAN training.
"""
def __init__(self, opt):
super(RealHATMSEModel, self).__init__(opt)
self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
self.usm_sharpener = USMSharp().cuda() # do usm sharpening
self.queue_size = opt.get('queue_size', 180)
@torch.no_grad()
def _dequeue_and_enqueue(self):
"""It is the training pair pool for increasing the diversity in a batch.
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
to increase the degradation diversity in a batch.
"""
# initialize
b, c, h, w = self.lq.size()
if not hasattr(self, 'queue_lr'):
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
_, c, h, w = self.gt.size()
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
self.queue_ptr = 0
if self.queue_ptr == self.queue_size: # the pool is full
# do dequeue and enqueue
# shuffle
idx = torch.randperm(self.queue_size)
self.queue_lr = self.queue_lr[idx]
self.queue_gt = self.queue_gt[idx]
# get first b samples
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
# update the queue
self.queue_lr[0:b, :, :, :] = self.lq.clone()
self.queue_gt[0:b, :, :, :] = self.gt.clone()
self.lq = lq_dequeue
self.gt = gt_dequeue
else:
# only do enqueue
self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
self.queue_ptr = self.queue_ptr + b
@torch.no_grad()
def feed_data(self, data):
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
"""
if self.is_train and self.opt.get('high_order_degradation', True):
# training data synthesis
self.gt = data['gt'].to(self.device)
# USM sharpen the GT images
if self.opt['gt_usm'] is True:
self.gt = self.usm_sharpener(self.gt)
self.kernel1 = data['kernel1'].to(self.device)
self.kernel2 = data['kernel2'].to(self.device)
self.sinc_kernel = data['sinc_kernel'].to(self.device)
ori_h, ori_w = self.gt.size()[2:4]
# ----------------------- The first degradation process ----------------------- #
# blur
out = filter2D(self.gt, self.kernel1)
# random resize
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
if updown_type == 'up':
scale = np.random.uniform(1, self.opt['resize_range'][1])
elif updown_type == 'down':
scale = np.random.uniform(self.opt['resize_range'][0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, scale_factor=scale, mode=mode)
# add noise
gray_noise_prob = self.opt['gray_noise_prob']
if np.random.uniform() < self.opt['gaussian_noise_prob']:
out = random_add_gaussian_noise_pt(
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
else:
out = random_add_poisson_noise_pt(
out,
scale_range=self.opt['poisson_scale_range'],
gray_prob=gray_noise_prob,
clip=True,
rounds=False)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
out = self.jpeger(out, quality=jpeg_p)
# ----------------------- The second degradation process ----------------------- #
# blur
if np.random.uniform() < self.opt['second_blur_prob']:
out = filter2D(out, self.kernel2)
# random resize
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
if updown_type == 'up':
scale = np.random.uniform(1, self.opt['resize_range2'][1])
elif updown_type == 'down':
scale = np.random.uniform(self.opt['resize_range2'][0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
# add noise
gray_noise_prob = self.opt['gray_noise_prob2']
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
out = random_add_gaussian_noise_pt(
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
else:
out = random_add_poisson_noise_pt(
out,
scale_range=self.opt['poisson_scale_range2'],
gray_prob=gray_noise_prob,
clip=True,
rounds=False)
# JPEG compression + the final sinc filter
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
# as one operation.
# We consider two orders:
# 1. [resize back + sinc filter] + JPEG compression
# 2. JPEG compression + [resize back + sinc filter]
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
if np.random.uniform() < 0.5:
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
out = filter2D(out, self.sinc_kernel)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
else:
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
out = filter2D(out, self.sinc_kernel)
# clamp and round
self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
# random crop
gt_size = self.opt['gt_size']
self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale'])
# training pair pool
self._dequeue_and_enqueue()
self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
else:
# for paired training or validation
self.lq = data['lq'].to(self.device)
if 'gt' in data:
self.gt = data['gt'].to(self.device)
self.gt_usm = self.usm_sharpener(self.gt)
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
# do not use the synthetic process during validation
self.is_train = False
super(RealHATMSEModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
self.is_train = True
def test(self):
# pad to multiplication of window_size
window_size = self.opt['network_g']['window_size']
scale = self.opt.get('scale', 1)
mod_pad_h, mod_pad_w = 0, 0
_, _, h, w = self.lq.size()
if h % window_size != 0:
mod_pad_h = window_size - h % window_size
if w % window_size != 0:
mod_pad_w = window_size - w % window_size
img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
if hasattr(self, 'net_g_ema'):
self.net_g_ema.eval()
with torch.no_grad():
self.output = self.net_g_ema(img)
else:
self.net_g.eval()
with torch.no_grad():
self.output = self.net_g(img)
self.net_g.train()
_, _, h, w = self.output.size()
self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale]
\ No newline at end of file
# flake8: noqa
import os
import sys
import os.path as osp
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import hat.archs
import hat.data
import hat.models
from basicsr.test import test_pipeline
if __name__ == '__main__':
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
test_pipeline(root_path)
# flake8: noqa
import os
import sys
import os.path as osp
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import hat.archs
import hat.data
import hat.models
from basicsr.train import train_pipeline
if __name__ == '__main__':
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
train_pipeline(root_path)
# 模型唯一标识
modelCode=532
# 模型名称
modelName=hat_pytorch
# 模型描述
modelDescription=结合了通道注意力和基于窗口的自注意力方案,引入重叠的跨注意力模块来增强相邻窗口特征之间的交互,更好地聚合跨窗口信息,激活更多的像素进行重建,从而显著提高性能.
# 应用场景
appScenario=推理,训练,图像重建,交通,公安,制造
# 框架类型
frameType=PyTorch
\ No newline at end of file
name: HAT-L_SRx2_ImageNet-pretrain
model_type: HATModel
scale: 2
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 0
datasets:
test_1: # the 1st test dataset
name: Set5
type: PairedImageDataset
dataroot_gt: ./datasets/Set5/GTmod2
dataroot_lq: ./datasets/Set5/LRbicx2
io_backend:
type: disk
# test_2: # the 2nd test dataset
# name: Set14
# type: PairedImageDataset
# dataroot_gt: ./datasets/Set14/GTmod2
# dataroot_lq: ./datasets/Set14/LRbicx2
# io_backend:
# type: disk
# test_3:
# name: Urban100
# type: PairedImageDataset
# dataroot_gt: ./datasets/urban100/GTmod2
# dataroot_lq: ./datasets/urban100/LRbicx2
# io_backend:
# type: disk
# test_4:
# name: BSDS100
# type: PairedImageDataset
# dataroot_gt: ./datasets/BSDS100/GTmod2
# dataroot_lq: ./datasets/BSDS100/LRbicx2
# io_backend:
# type: disk
# test_5:
# name: Manga109
# type: PairedImageDataset
# dataroot_gt: ./datasets/manga109/GTmod2
# dataroot_lq: ./datasets/manga109/LRbicx2
# io_backend:
# type: disk
# network structures
network_g:
type: HAT
upscale: 2
in_chans: 3
img_size: 64
window_size: 16
compress_ratio: 3
squeeze_factor: 30
conv_scale: 0.01
overlap_ratio: 0.5
img_range: 1.
depths: [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]
embed_dim: 180
num_heads: [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]
mlp_ratio: 2
upsampler: 'pixelshuffle'
resi_connection: '1conv'
# path
path:
pretrain_network_g: ./experiments/pretrained_models/HAT-L_SRx2_ImageNet-pretrain.pth
strict_load_g: true
param_key_g: 'params_ema'
# validation settings
val:
save_img: true
suffix: ~ # add suffix to saved images, if None, use exp name
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 2
test_y_channel: true
ssim:
type: calculate_ssim
crop_border: 2
test_y_channel: true
name: HAT-L_SRx3_ImageNet-pretrain
model_type: HATModel
scale: 3
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 0
datasets:
test_1: # the 1st test dataset
name: Set5
type: PairedImageDataset
dataroot_gt: ./datasets/Set5/GTmod3
dataroot_lq: ./datasets/Set5/LRbicx3
io_backend:
type: disk
# test_2: # the 2nd test dataset
# name: Set14
# type: PairedImageDataset
# dataroot_gt: ./datasets/Set14/GTmod3
# dataroot_lq: ./datasets/Set14/LRbicx3
# io_backend:
# type: disk
# test_3:
# name: Urban100
# type: PairedImageDataset
# dataroot_gt: ./datasets/urban100/GTmod3
# dataroot_lq: ./datasets/urban100/LRbicx3
# io_backend:
# type: disk
# test_4:
# name: BSDS100
# type: PairedImageDataset
# dataroot_gt: ./datasets/BSDS100/GTmod3
# dataroot_lq: ./datasets/BSDS100/LRbicx3
# io_backend:
# type: disk
# test_5:
# name: Manga109
# type: PairedImageDataset
# dataroot_gt: ./datasets/manga109/GTmod3
# dataroot_lq: ./datasets/manga109/LRbicx3
# io_backend:
# type: disk
# network structures
network_g:
type: HAT
upscale: 3
in_chans: 3
img_size: 64
window_size: 16
compress_ratio: 3
squeeze_factor: 30
conv_scale: 0.01
overlap_ratio: 0.5
img_range: 1.
depths: [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]
embed_dim: 180
num_heads: [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]
mlp_ratio: 2
upsampler: 'pixelshuffle'
resi_connection: '1conv'
# path
path:
pretrain_network_g: ./experiments/pretrained_models/HAT-L_SRx3_ImageNet-pretrain.pth
strict_load_g: true
param_key_g: 'params_ema'
# validation settings
val:
save_img: true
suffix: ~ # add suffix to saved images, if None, use exp name
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 3
test_y_channel: true
ssim:
type: calculate_ssim
crop_border: 3
test_y_channel: true
name: HAT-L_SRx4_ImageNet-pretrain
model_type: HATModel
scale: 4
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 0
datasets:
test_1: # the 1st test dataset
name: Set5
type: PairedImageDataset
dataroot_gt: ./datasets/Set5/GTmod4
dataroot_lq: ./datasets/Set5/LRbicx4
io_backend:
type: disk
# test_2: # the 2nd test dataset
# name: Set14
# type: PairedImageDataset
# dataroot_gt: ./datasets/Set14/GTmod4
# dataroot_lq: ./datasets/Set14/LRbicx4
# io_backend:
# type: disk
# test_3:
# name: Urban100
# type: PairedImageDataset
# dataroot_gt: ./datasets/urban100/GTmod4
# dataroot_lq: ./datasets/urban100/LRbicx4
# io_backend:
# type: disk
# test_4:
# name: BSDS100
# type: PairedImageDataset
# dataroot_gt: ./datasets/BSDS100/GTmod4
# dataroot_lq: ./datasets/BSDS100/LRbicx4
# io_backend:
# type: disk
# test_5:
# name: Manga109
# type: PairedImageDataset
# dataroot_gt: ./datasets/manga109/GTmod4
# dataroot_lq: ./datasets/manga109/LRbicx4
# io_backend:
# type: disk
# network structures
network_g:
type: HAT
upscale: 4
in_chans: 3
img_size: 64
window_size: 16
compress_ratio: 3
squeeze_factor: 30
conv_scale: 0.01
overlap_ratio: 0.5
img_range: 1.
depths: [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]
embed_dim: 180
num_heads: [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]
mlp_ratio: 2
upsampler: 'pixelshuffle'
resi_connection: '1conv'
# path
path:
pretrain_network_g: ./experiments/pretrained_models/HAT-L_SRx4_ImageNet-pretrain.pth
strict_load_g: true
param_key_g: 'params_ema'
# validation settings
val:
save_img: true
suffix: ~ # add suffix to saved images, if None, use exp name
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 4
test_y_channel: true
ssim:
type: calculate_ssim
crop_border: 4
test_y_channel: true
name: HAT-S_SRx2
model_type: HATModel
scale: 2
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 0
datasets:
test_1: # the 1st test dataset
name: Set5
type: PairedImageDataset
dataroot_gt: ./datasets/Set5/GTmod2
dataroot_lq: ./datasets/Set5/LRbicx2
io_backend:
type: disk
test_2: # the 2nd test dataset
name: Set14
type: PairedImageDataset
dataroot_gt: ./datasets/Set14/GTmod2
dataroot_lq: ./datasets/Set14/LRbicx2
io_backend:
type: disk
test_3:
name: Urban100
type: PairedImageDataset
dataroot_gt: ./datasets/urban100/GTmod2
dataroot_lq: ./datasets/urban100/LRbicx2
io_backend:
type: disk
test_4:
name: BSDS100
type: PairedImageDataset
dataroot_gt: ./datasets/BSDS100/GTmod2
dataroot_lq: ./datasets/BSDS100/LRbicx2
io_backend:
type: disk
test_5:
name: Manga109
type: PairedImageDataset
dataroot_gt: ./datasets/manga109/GTmod2
dataroot_lq: ./datasets/manga109/LRbicx2
io_backend:
type: disk
# network structures
network_g:
type: HAT
upscale: 2
in_chans: 3
img_size: 64
window_size: 16
compress_ratio: 24
squeeze_factor: 24
conv_scale: 0.01
overlap_ratio: 0.5
img_range: 1.
depths: [6, 6, 6, 6, 6, 6]
embed_dim: 144
num_heads: [6, 6, 6, 6, 6, 6]
mlp_ratio: 2
upsampler: 'pixelshuffle'
resi_connection: '1conv'
# path
path:
pretrain_network_g: ./experiments/pretrained_models/HAT-S_SRx2.pth
strict_load_g: true
param_key_g: 'params_ema'
# validation settings
val:
save_img: true
suffix: ~ # add suffix to saved images, if None, use exp name
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 2
test_y_channel: true
ssim:
type: calculate_ssim
crop_border: 2
test_y_channel: true
name: HAT-S_SRx3
model_type: HATModel
scale: 3
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 0
datasets:
test_1: # the 1st test dataset
name: Set5
type: PairedImageDataset
dataroot_gt: ./datasets/Set5/GTmod3
dataroot_lq: ./datasets/Set5/LRbicx3
io_backend:
type: disk
test_2: # the 2nd test dataset
name: Set14
type: PairedImageDataset
dataroot_gt: ./datasets/Set14/GTmod3
dataroot_lq: ./datasets/Set14/LRbicx3
io_backend:
type: disk
test_3:
name: Urban100
type: PairedImageDataset
dataroot_gt: ./datasets/urban100/GTmod3
dataroot_lq: ./datasets/urban100/LRbicx3
io_backend:
type: disk
test_4:
name: BSDS100
type: PairedImageDataset
dataroot_gt: ./datasets/BSDS100/GTmod3
dataroot_lq: ./datasets/BSDS100/LRbicx3
io_backend:
type: disk
test_5:
name: Manga109
type: PairedImageDataset
dataroot_gt: ./datasets/manga109/GTmod3
dataroot_lq: ./datasets/manga109/LRbicx3
io_backend:
type: disk
# network structures
network_g:
type: HAT
upscale: 3
in_chans: 3
img_size: 64
window_size: 16
compress_ratio: 24
squeeze_factor: 24
conv_scale: 0.01
overlap_ratio: 0.5
img_range: 1.
depths: [6, 6, 6, 6, 6, 6]
embed_dim: 144
num_heads: [6, 6, 6, 6, 6, 6]
mlp_ratio: 2
upsampler: 'pixelshuffle'
resi_connection: '1conv'
# path
path:
pretrain_network_g: ./experiments/pretrained_models/HAT-S_SRx3.pth
strict_load_g: true
param_key_g: 'params_ema'
# validation settings
val:
save_img: true
suffix: ~ # add suffix to saved images, if None, use exp name
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 3
test_y_channel: true
ssim:
type: calculate_ssim
crop_border: 3
test_y_channel: true
name: HAT-S_SRx4
model_type: HATModel
scale: 4
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 0
datasets:
test_1: # the 1st test dataset
name: Set5
type: PairedImageDataset
dataroot_gt: ./datasets/Set5/GTmod4
dataroot_lq: ./datasets/Set5/LRbicx4
io_backend:
type: disk
test_2: # the 2nd test dataset
name: Set14
type: PairedImageDataset
dataroot_gt: ./datasets/Set14/GTmod4
dataroot_lq: ./datasets/Set14/LRbicx4
io_backend:
type: disk
test_3:
name: Urban100
type: PairedImageDataset
dataroot_gt: ./datasets/urban100/GTmod4
dataroot_lq: ./datasets/urban100/LRbicx4
io_backend:
type: disk
test_4:
name: BSDS100
type: PairedImageDataset
dataroot_gt: ./datasets/BSDS100/GTmod4
dataroot_lq: ./datasets/BSDS100/LRbicx4
io_backend:
type: disk
test_5:
name: Manga109
type: PairedImageDataset
dataroot_gt: ./datasets/manga109/GTmod4
dataroot_lq: ./datasets/manga109/LRbicx4
io_backend:
type: disk
# network structures
network_g:
type: HAT
upscale: 4
in_chans: 3
img_size: 64
window_size: 16
compress_ratio: 24
squeeze_factor: 24
conv_scale: 0.01
overlap_ratio: 0.5
img_range: 1.
depths: [6, 6, 6, 6, 6, 6]
embed_dim: 144
num_heads: [6, 6, 6, 6, 6, 6]
mlp_ratio: 2
upsampler: 'pixelshuffle'
resi_connection: '1conv'
# path
path:
pretrain_network_g: ./experiments/pretrained_models/HAT-S_SRx4.pth
strict_load_g: true
param_key_g: 'params_ema'
# validation settings
val:
save_img: true
suffix: ~ # add suffix to saved images, if None, use exp name
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 4
test_y_channel: true
ssim:
type: calculate_ssim
crop_border: 4
test_y_channel: true
name: HAT_GAN_Real_SRx4
model_type: HATModel
scale: 4
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 0
tile: # use the tile mode for limited GPU memory when testing.
tile_size: 512 # the higher, the more utilized GPU memory and the less performance change against the full image. must be an integer multiple of the window size.
tile_pad: 32 # overlapping between adjacency patches.must be an integer multiple of the window size.
datasets:
test_1: # the 1st test dataset
name: custom
type: SingleImageDataset
dataroot_lq: input_dir
io_backend:
type: disk
# network structures
network_g:
type: HAT
upscale: 4
in_chans: 3
img_size: 64
window_size: 16
compress_ratio: 3
squeeze_factor: 30
conv_scale: 0.01
overlap_ratio: 0.5
img_range: 1.
depths: [6, 6, 6, 6, 6, 6]
embed_dim: 180
num_heads: [6, 6, 6, 6, 6, 6]
mlp_ratio: 2
upsampler: 'pixelshuffle'
resi_connection: '1conv'
# path
path:
pretrain_network_g: ./experiments/pretrained_models/Real_HAT_GAN_SRx4.pth
strict_load_g: true
param_key_g: 'params_ema'
# validation settings
val:
save_img: true
suffix: ~ # add suffix to saved images, if None, use exp name
\ No newline at end of file
name: HAT_SRx2
model_type: HATModel
scale: 2
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 0
datasets:
test_1: # the 1st test dataset
name: Set5
type: PairedImageDataset
dataroot_gt: ./datasets/Set5/GTmod2
dataroot_lq: ./datasets/Set5/LRbicx2
io_backend:
type: disk
# test_2: # the 2nd test dataset
# name: Set14
# type: PairedImageDataset
# dataroot_gt: ./datasets/Set14/GTmod2
# dataroot_lq: ./datasets/Set14/LRbicx2
# io_backend:
# type: disk
# test_3:
# name: Urban100
# type: PairedImageDataset
# dataroot_gt: ./datasets/urban100/GTmod2
# dataroot_lq: ./datasets/urban100/LRbicx2
# io_backend:
# type: disk
# test_4:
# name: BSDS100
# type: PairedImageDataset
# dataroot_gt: ./datasets/BSDS100/GTmod2
# dataroot_lq: ./datasets/BSDS100/LRbicx2
# io_backend:
# type: disk
# test_5:
# name: Manga109
# type: PairedImageDataset
# dataroot_gt: ./datasets/manga109/GTmod2
# dataroot_lq: ./datasets/manga109/LRbicx2
# io_backend:
# type: disk
# network structures
network_g:
type: HAT
upscale: 2
in_chans: 3
img_size: 64
window_size: 16
compress_ratio: 3
squeeze_factor: 30
conv_scale: 0.01
overlap_ratio: 0.5
img_range: 1.
depths: [6, 6, 6, 6, 6, 6]
embed_dim: 180
num_heads: [6, 6, 6, 6, 6, 6]
mlp_ratio: 2
upsampler: 'pixelshuffle'
resi_connection: '1conv'
# path
path:
pretrain_network_g: ./experiments/pretrained_models/HAT_SRx2.pth
strict_load_g: true
param_key_g: 'params_ema'
# validation settings
val:
save_img: true
suffix: ~ # add suffix to saved images, if None, use exp name
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 2
test_y_channel: true
ssim:
type: calculate_ssim
crop_border: 2
test_y_channel: true
name: HAT_SRx2_ImageNet-pretrain
model_type: HATModel
scale: 2
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 0
datasets:
test_1: # the 1st test dataset
name: Set5
type: PairedImageDataset
dataroot_gt: ./datasets/Set5/GTmod2
dataroot_lq: ./datasets/Set5/LRbicx2
io_backend:
type: disk
# test_2: # the 2nd test dataset
# name: Set14
# type: PairedImageDataset
# dataroot_gt: ./datasets/Set14/GTmod2
# dataroot_lq: ./datasets/Set14/LRbicx2
# io_backend:
# type: disk
# test_3:
# name: Urban100
# type: PairedImageDataset
# dataroot_gt: ./datasets/urban100/GTmod2
# dataroot_lq: ./datasets/urban100/LRbicx2
# io_backend:
# type: disk
# test_4:
# name: BSDS100
# type: PairedImageDataset
# dataroot_gt: ./datasets/BSDS100/GTmod2
# dataroot_lq: ./datasets/BSDS100/LRbicx2
# io_backend:
# type: disk
# test_5:
# name: Manga109
# type: PairedImageDataset
# dataroot_gt: ./datasets/manga109/GTmod2
# dataroot_lq: ./datasets/manga109/LRbicx2
# io_backend:
# type: disk
# network structures
network_g:
type: HAT
upscale: 2
in_chans: 3
img_size: 64
window_size: 16
compress_ratio: 3
squeeze_factor: 30
conv_scale: 0.01
overlap_ratio: 0.5
img_range: 1.
depths: [6, 6, 6, 6, 6, 6]
embed_dim: 180
num_heads: [6, 6, 6, 6, 6, 6]
mlp_ratio: 2
upsampler: 'pixelshuffle'
resi_connection: '1conv'
# path
path:
pretrain_network_g: ./experiments/pretrained_models/HAT_SRx2_ImageNet-pretrain.pth
strict_load_g: true
param_key_g: 'params_ema'
# validation settings
val:
save_img: true
suffix: ~ # add suffix to saved images, if None, use exp name
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 2
test_y_channel: true
ssim:
type: calculate_ssim
crop_border: 2
test_y_channel: true
name: HAT_SRx3
model_type: HATModel
scale: 3
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 0
datasets:
test_1: # the 1st test dataset
name: Set5
type: PairedImageDataset
dataroot_gt: ./datasets/Set5/GTmod3
dataroot_lq: ./datasets/Set5/LRbicx3
io_backend:
type: disk
# test_2: # the 2nd test dataset
# name: Set14
# type: PairedImageDataset
# dataroot_gt: ./datasets/Set14/GTmod3
# dataroot_lq: ./datasets/Set14/LRbicx3
# io_backend:
# type: disk
# test_3:
# name: Urban100
# type: PairedImageDataset
# dataroot_gt: ./datasets/urban100/GTmod3
# dataroot_lq: ./datasets/urban100/LRbicx3
# io_backend:
# type: disk
# test_4:
# name: BSDS100
# type: PairedImageDataset
# dataroot_gt: ./datasets/BSDS100/GTmod3
# dataroot_lq: ./datasets/BSDS100/LRbicx3
# io_backend:
# type: disk
# test_5:
# name: Manga109
# type: PairedImageDataset
# dataroot_gt: ./datasets/manga109/GTmod3
# dataroot_lq: ./datasets/manga109/LRbicx3
# io_backend:
# type: disk
# network structures
network_g:
type: HAT
upscale: 3
in_chans: 3
img_size: 64
window_size: 16
compress_ratio: 3
squeeze_factor: 30
conv_scale: 0.01
overlap_ratio: 0.5
img_range: 1.
depths: [6, 6, 6, 6, 6, 6]
embed_dim: 180
num_heads: [6, 6, 6, 6, 6, 6]
mlp_ratio: 2
upsampler: 'pixelshuffle'
resi_connection: '1conv'
# path
path:
pretrain_network_g: ./experiments/pretrained_models/HAT_SRx3.pth
strict_load_g: true
param_key_g: 'params_ema'
# validation settings
val:
save_img: true
suffix: ~ # add suffix to saved images, if None, use exp name
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 3
test_y_channel: true
ssim:
type: calculate_ssim
crop_border: 3
test_y_channel: true
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment