Commit 76b9024b authored by yangzhong's avatar yangzhong
Browse files

git init

parents
Pipeline #3145 failed with stages
in 0 seconds
import math
from collections import Counter
from torch.optim.lr_scheduler import _LRScheduler
class MultiStepRestartLR(_LRScheduler):
""" MultiStep with restarts learning rate scheme.
Args:
optimizer (torch.nn.optimizer): Torch optimizer.
milestones (list): Iterations that will decrease learning rate.
gamma (float): Decrease ratio. Default: 0.1.
restarts (list): Restart iterations. Default: [0].
restart_weights (list): Restart weights at each restart iteration.
Default: [1].
last_epoch (int): Used in _LRScheduler. Default: -1.
"""
def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1):
self.milestones = Counter(milestones)
self.gamma = gamma
self.restarts = restarts
self.restart_weights = restart_weights
assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.'
super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch in self.restarts:
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
if self.last_epoch not in self.milestones:
return [group['lr'] for group in self.optimizer.param_groups]
return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups]
def get_position_from_periods(iteration, cumulative_period):
"""Get the position from a period list.
It will return the index of the right-closest number in the period list.
For example, the cumulative_period = [100, 200, 300, 400],
if iteration == 50, return 0;
if iteration == 210, return 2;
if iteration == 300, return 2.
Args:
iteration (int): Current iteration.
cumulative_period (list[int]): Cumulative period list.
Returns:
int: The position of the right-closest number in the period list.
"""
for i, period in enumerate(cumulative_period):
if iteration <= period:
return i
class CosineAnnealingRestartLR(_LRScheduler):
""" Cosine annealing with restarts learning rate scheme.
An example of config:
periods = [10, 10, 10, 10]
restart_weights = [1, 0.5, 0.5, 0.5]
eta_min=1e-7
It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
scheduler will restart with the weights in restart_weights.
Args:
optimizer (torch.nn.optimizer): Torch optimizer.
periods (list): Period for each cosine anneling cycle.
restart_weights (list): Restart weights at each restart iteration.
Default: [1].
eta_min (float): The minimum lr. Default: 0.
last_epoch (int): Used in _LRScheduler. Default: -1.
"""
def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1):
self.periods = periods
self.restart_weights = restart_weights
self.eta_min = eta_min
assert (len(self.periods) == len(
self.restart_weights)), 'periods and restart_weights should have the same length.'
self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))]
super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
idx = get_position_from_periods(self.last_epoch, self.cumulative_period)
current_weight = self.restart_weights[idx]
nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
current_period = self.periods[idx]
return [
self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
(1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period)))
for base_lr in self.base_lrs
]
import numpy as np
import random
import torch
from collections import OrderedDict
from torch.nn import functional as F
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
from basicsr.data.transforms import paired_random_crop
from basicsr.losses.loss_util import get_refined_artifact_map
from basicsr.models.srgan_model import SRGANModel
from basicsr.utils import DiffJPEG, USMSharp
from basicsr.utils.img_process_util import filter2D
from basicsr.utils.registry import MODEL_REGISTRY
@MODEL_REGISTRY.register(suffix='basicsr')
class RealESRGANModel(SRGANModel):
"""RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
It mainly performs:
1. randomly synthesize LQ images in GPU tensors
2. optimize the networks with GAN training.
"""
def __init__(self, opt):
super(RealESRGANModel, self).__init__(opt)
self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
self.usm_sharpener = USMSharp().cuda() # do usm sharpening
self.queue_size = opt.get('queue_size', 180)
@torch.no_grad()
def _dequeue_and_enqueue(self):
"""It is the training pair pool for increasing the diversity in a batch.
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
to increase the degradation diversity in a batch.
"""
# initialize
b, c, h, w = self.lq.size()
if not hasattr(self, 'queue_lr'):
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
_, c, h, w = self.gt.size()
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
self.queue_ptr = 0
if self.queue_ptr == self.queue_size: # the pool is full
# do dequeue and enqueue
# shuffle
idx = torch.randperm(self.queue_size)
self.queue_lr = self.queue_lr[idx]
self.queue_gt = self.queue_gt[idx]
# get first b samples
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
# update the queue
self.queue_lr[0:b, :, :, :] = self.lq.clone()
self.queue_gt[0:b, :, :, :] = self.gt.clone()
self.lq = lq_dequeue
self.gt = gt_dequeue
else:
# only do enqueue
self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
self.queue_ptr = self.queue_ptr + b
@torch.no_grad()
def feed_data(self, data):
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
"""
if self.is_train and self.opt.get('high_order_degradation', True):
# training data synthesis
self.gt = data['gt'].to(self.device)
self.gt_usm = self.usm_sharpener(self.gt)
self.kernel1 = data['kernel1'].to(self.device)
self.kernel2 = data['kernel2'].to(self.device)
self.sinc_kernel = data['sinc_kernel'].to(self.device)
ori_h, ori_w = self.gt.size()[2:4]
# ----------------------- The first degradation process ----------------------- #
# blur
out = filter2D(self.gt_usm, self.kernel1)
# random resize
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
if updown_type == 'up':
scale = np.random.uniform(1, self.opt['resize_range'][1])
elif updown_type == 'down':
scale = np.random.uniform(self.opt['resize_range'][0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, scale_factor=scale, mode=mode)
# add noise
gray_noise_prob = self.opt['gray_noise_prob']
if np.random.uniform() < self.opt['gaussian_noise_prob']:
out = random_add_gaussian_noise_pt(
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
else:
out = random_add_poisson_noise_pt(
out,
scale_range=self.opt['poisson_scale_range'],
gray_prob=gray_noise_prob,
clip=True,
rounds=False)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
out = self.jpeger(out, quality=jpeg_p)
# ----------------------- The second degradation process ----------------------- #
# blur
if np.random.uniform() < self.opt['second_blur_prob']:
out = filter2D(out, self.kernel2)
# random resize
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
if updown_type == 'up':
scale = np.random.uniform(1, self.opt['resize_range2'][1])
elif updown_type == 'down':
scale = np.random.uniform(self.opt['resize_range2'][0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
# add noise
gray_noise_prob = self.opt['gray_noise_prob2']
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
out = random_add_gaussian_noise_pt(
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
else:
out = random_add_poisson_noise_pt(
out,
scale_range=self.opt['poisson_scale_range2'],
gray_prob=gray_noise_prob,
clip=True,
rounds=False)
# JPEG compression + the final sinc filter
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
# as one operation.
# We consider two orders:
# 1. [resize back + sinc filter] + JPEG compression
# 2. JPEG compression + [resize back + sinc filter]
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
if np.random.uniform() < 0.5:
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
out = filter2D(out, self.sinc_kernel)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
else:
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
out = filter2D(out, self.sinc_kernel)
# clamp and round
self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
# random crop
gt_size = self.opt['gt_size']
(self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size,
self.opt['scale'])
# training pair pool
self._dequeue_and_enqueue()
# sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
self.gt_usm = self.usm_sharpener(self.gt)
self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
else:
# for paired training or validation
self.lq = data['lq'].to(self.device)
if 'gt' in data:
self.gt = data['gt'].to(self.device)
self.gt_usm = self.usm_sharpener(self.gt)
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
# do not use the synthetic process during validation
self.is_train = False
super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
self.is_train = True
def optimize_parameters(self, current_iter):
# usm sharpening
l1_gt = self.gt_usm
percep_gt = self.gt_usm
gan_gt = self.gt_usm
if self.opt['l1_gt_usm'] is False:
l1_gt = self.gt
if self.opt['percep_gt_usm'] is False:
percep_gt = self.gt
if self.opt['gan_gt_usm'] is False:
gan_gt = self.gt
# optimize net_g
for p in self.net_d.parameters():
p.requires_grad = False
self.optimizer_g.zero_grad()
self.output = self.net_g(self.lq)
if self.cri_ldl:
self.output_ema = self.net_g_ema(self.lq)
l_g_total = 0
loss_dict = OrderedDict()
if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
# pixel loss
if self.cri_pix:
l_g_pix = self.cri_pix(self.output, l1_gt)
l_g_total += l_g_pix
loss_dict['l_g_pix'] = l_g_pix
if self.cri_ldl:
pixel_weight = get_refined_artifact_map(self.gt, self.output, self.output_ema, 7)
l_g_ldl = self.cri_ldl(torch.mul(pixel_weight, self.output), torch.mul(pixel_weight, self.gt))
l_g_total += l_g_ldl
loss_dict['l_g_ldl'] = l_g_ldl
# perceptual loss
if self.cri_perceptual:
l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt)
if l_g_percep is not None:
l_g_total += l_g_percep
loss_dict['l_g_percep'] = l_g_percep
if l_g_style is not None:
l_g_total += l_g_style
loss_dict['l_g_style'] = l_g_style
# gan loss
fake_g_pred = self.net_d(self.output)
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
l_g_total += l_g_gan
loss_dict['l_g_gan'] = l_g_gan
l_g_total.backward()
self.optimizer_g.step()
# optimize net_d
for p in self.net_d.parameters():
p.requires_grad = True
self.optimizer_d.zero_grad()
# real
real_d_pred = self.net_d(gan_gt)
l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
loss_dict['l_d_real'] = l_d_real
loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
l_d_real.backward()
# fake
fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9
l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
loss_dict['l_d_fake'] = l_d_fake
loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
l_d_fake.backward()
self.optimizer_d.step()
if self.ema_decay > 0:
self.model_ema(decay=self.ema_decay)
self.log_dict = self.reduce_loss_dict(loss_dict)
import numpy as np
import random
import torch
from torch.nn import functional as F
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
from basicsr.data.transforms import paired_random_crop
from basicsr.models.sr_model import SRModel
from basicsr.utils import DiffJPEG, USMSharp
from basicsr.utils.img_process_util import filter2D
from basicsr.utils.registry import MODEL_REGISTRY
@MODEL_REGISTRY.register(suffix='basicsr')
class RealESRNetModel(SRModel):
"""RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
It is trained without GAN losses.
It mainly performs:
1. randomly synthesize LQ images in GPU tensors
2. optimize the networks with GAN training.
"""
def __init__(self, opt):
super(RealESRNetModel, self).__init__(opt)
self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
self.usm_sharpener = USMSharp().cuda() # do usm sharpening
self.queue_size = opt.get('queue_size', 180)
@torch.no_grad()
def _dequeue_and_enqueue(self):
"""It is the training pair pool for increasing the diversity in a batch.
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
to increase the degradation diversity in a batch.
"""
# initialize
b, c, h, w = self.lq.size()
if not hasattr(self, 'queue_lr'):
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
_, c, h, w = self.gt.size()
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
self.queue_ptr = 0
if self.queue_ptr == self.queue_size: # the pool is full
# do dequeue and enqueue
# shuffle
idx = torch.randperm(self.queue_size)
self.queue_lr = self.queue_lr[idx]
self.queue_gt = self.queue_gt[idx]
# get first b samples
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
# update the queue
self.queue_lr[0:b, :, :, :] = self.lq.clone()
self.queue_gt[0:b, :, :, :] = self.gt.clone()
self.lq = lq_dequeue
self.gt = gt_dequeue
else:
# only do enqueue
self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
self.queue_ptr = self.queue_ptr + b
@torch.no_grad()
def feed_data(self, data):
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
"""
if self.is_train and self.opt.get('high_order_degradation', True):
# training data synthesis
self.gt = data['gt'].to(self.device)
# USM sharpen the GT images
if self.opt['gt_usm'] is True:
self.gt = self.usm_sharpener(self.gt)
self.kernel1 = data['kernel1'].to(self.device)
self.kernel2 = data['kernel2'].to(self.device)
self.sinc_kernel = data['sinc_kernel'].to(self.device)
ori_h, ori_w = self.gt.size()[2:4]
# ----------------------- The first degradation process ----------------------- #
# blur
out = filter2D(self.gt, self.kernel1)
# random resize
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
if updown_type == 'up':
scale = np.random.uniform(1, self.opt['resize_range'][1])
elif updown_type == 'down':
scale = np.random.uniform(self.opt['resize_range'][0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, scale_factor=scale, mode=mode)
# add noise
gray_noise_prob = self.opt['gray_noise_prob']
if np.random.uniform() < self.opt['gaussian_noise_prob']:
out = random_add_gaussian_noise_pt(
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
else:
out = random_add_poisson_noise_pt(
out,
scale_range=self.opt['poisson_scale_range'],
gray_prob=gray_noise_prob,
clip=True,
rounds=False)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
out = self.jpeger(out, quality=jpeg_p)
# ----------------------- The second degradation process ----------------------- #
# blur
if np.random.uniform() < self.opt['second_blur_prob']:
out = filter2D(out, self.kernel2)
# random resize
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
if updown_type == 'up':
scale = np.random.uniform(1, self.opt['resize_range2'][1])
elif updown_type == 'down':
scale = np.random.uniform(self.opt['resize_range2'][0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
# add noise
gray_noise_prob = self.opt['gray_noise_prob2']
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
out = random_add_gaussian_noise_pt(
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
else:
out = random_add_poisson_noise_pt(
out,
scale_range=self.opt['poisson_scale_range2'],
gray_prob=gray_noise_prob,
clip=True,
rounds=False)
# JPEG compression + the final sinc filter
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
# as one operation.
# We consider two orders:
# 1. [resize back + sinc filter] + JPEG compression
# 2. JPEG compression + [resize back + sinc filter]
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
if np.random.uniform() < 0.5:
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
out = filter2D(out, self.sinc_kernel)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
else:
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
out = filter2D(out, self.sinc_kernel)
# clamp and round
self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
# random crop
gt_size = self.opt['gt_size']
self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale'])
# training pair pool
self._dequeue_and_enqueue()
self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
else:
# for paired training or validation
self.lq = data['lq'].to(self.device)
if 'gt' in data:
self.gt = data['gt'].to(self.device)
self.gt_usm = self.usm_sharpener(self.gt)
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
# do not use the synthetic process during validation
self.is_train = False
super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
self.is_train = True
import torch
from collections import OrderedDict
from os import path as osp
from tqdm import tqdm
from basicsr.archs import build_network
from basicsr.losses import build_loss
from basicsr.metrics import calculate_metric
from basicsr.utils import get_root_logger, imwrite, tensor2img
from basicsr.utils.registry import MODEL_REGISTRY
from .base_model import BaseModel
@MODEL_REGISTRY.register()
class SRModel(BaseModel):
"""Base SR model for single image super-resolution."""
def __init__(self, opt):
super(SRModel, self).__init__(opt)
# define network
self.net_g = build_network(opt['network_g'])
self.net_g = self.model_to_device(self.net_g)
self.print_network(self.net_g)
# load pretrained models
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
param_key = self.opt['path'].get('param_key_g', 'params')
self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
if self.is_train:
self.init_training_settings()
def init_training_settings(self):
self.net_g.train()
train_opt = self.opt['train']
self.ema_decay = train_opt.get('ema_decay', 0)
if self.ema_decay > 0:
logger = get_root_logger()
logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
# define network net_g with Exponential Moving Average (EMA)
# net_g_ema is used only for testing on one GPU and saving
# There is no need to wrap with DistributedDataParallel
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
# load pretrained model
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
else:
self.model_ema(0) # copy net_g weight
self.net_g_ema.eval()
# define losses
if train_opt.get('pixel_opt'):
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
else:
self.cri_pix = None
if train_opt.get('perceptual_opt'):
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
else:
self.cri_perceptual = None
if self.cri_pix is None and self.cri_perceptual is None:
raise ValueError('Both pixel and perceptual losses are None.')
# set up optimizers and schedulers
self.setup_optimizers()
self.setup_schedulers()
def setup_optimizers(self):
train_opt = self.opt['train']
optim_params = []
for k, v in self.net_g.named_parameters():
if v.requires_grad:
optim_params.append(v)
else:
logger = get_root_logger()
logger.warning(f'Params {k} will not be optimized.')
optim_type = train_opt['optim_g'].pop('type')
self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
self.optimizers.append(self.optimizer_g)
def feed_data(self, data):
self.lq = data['lq'].to(self.device)
if 'gt' in data:
self.gt = data['gt'].to(self.device)
def optimize_parameters(self, current_iter):
self.optimizer_g.zero_grad()
self.output = self.net_g(self.lq)
l_total = 0
loss_dict = OrderedDict()
# pixel loss
if self.cri_pix:
l_pix = self.cri_pix(self.output, self.gt)
l_total += l_pix
loss_dict['l_pix'] = l_pix
# perceptual loss
if self.cri_perceptual:
l_percep, l_style = self.cri_perceptual(self.output, self.gt)
if l_percep is not None:
l_total += l_percep
loss_dict['l_percep'] = l_percep
if l_style is not None:
l_total += l_style
loss_dict['l_style'] = l_style
l_total.backward()
self.optimizer_g.step()
self.log_dict = self.reduce_loss_dict(loss_dict)
if self.ema_decay > 0:
self.model_ema(decay=self.ema_decay)
def test(self):
if hasattr(self, 'net_g_ema'):
self.net_g_ema.eval()
with torch.no_grad():
self.output = self.net_g_ema(self.lq)
else:
self.net_g.eval()
with torch.no_grad():
self.output = self.net_g(self.lq)
self.net_g.train()
def test_selfensemble(self):
# TODO: to be tested
# 8 augmentations
# modified from https://github.com/thstkdgus35/EDSR-PyTorch
def _transform(v, op):
# if self.precision != 'single': v = v.float()
v2np = v.data.cpu().numpy()
if op == 'v':
tfnp = v2np[:, :, :, ::-1].copy()
elif op == 'h':
tfnp = v2np[:, :, ::-1, :].copy()
elif op == 't':
tfnp = v2np.transpose((0, 1, 3, 2)).copy()
ret = torch.Tensor(tfnp).to(self.device)
# if self.precision == 'half': ret = ret.half()
return ret
# prepare augmented data
lq_list = [self.lq]
for tf in 'v', 'h', 't':
lq_list.extend([_transform(t, tf) for t in lq_list])
# inference
if hasattr(self, 'net_g_ema'):
self.net_g_ema.eval()
with torch.no_grad():
out_list = [self.net_g_ema(aug) for aug in lq_list]
else:
self.net_g.eval()
with torch.no_grad():
out_list = [self.net_g_ema(aug) for aug in lq_list]
self.net_g.train()
# merge results
for i in range(len(out_list)):
if i > 3:
out_list[i] = _transform(out_list[i], 't')
if i % 4 > 1:
out_list[i] = _transform(out_list[i], 'h')
if (i % 4) % 2 == 1:
out_list[i] = _transform(out_list[i], 'v')
output = torch.cat(out_list, dim=0)
self.output = output.mean(dim=0, keepdim=True)
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
if self.opt['rank'] == 0:
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
dataset_name = dataloader.dataset.opt['name']
with_metrics = self.opt['val'].get('metrics') is not None
use_pbar = self.opt['val'].get('pbar', False)
if with_metrics:
if not hasattr(self, 'metric_results'): # only execute in the first run
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
# initialize the best metric results for each dataset_name (supporting multiple validation datasets)
self._initialize_best_metric_results(dataset_name)
# zero self.metric_results
if with_metrics:
self.metric_results = {metric: 0 for metric in self.metric_results}
metric_data = dict()
if use_pbar:
pbar = tqdm(total=len(dataloader), unit='image')
for idx, val_data in enumerate(dataloader):
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
self.feed_data(val_data)
self.test()
visuals = self.get_current_visuals()
sr_img = tensor2img([visuals['result']])
metric_data['img'] = sr_img
if 'gt' in visuals:
gt_img = tensor2img([visuals['gt']])
metric_data['img2'] = gt_img
del self.gt
# tentative for out of GPU memory
del self.lq
del self.output
torch.cuda.empty_cache()
if save_img:
if self.opt['is_train']:
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
f'{img_name}_{current_iter}.png')
else:
if self.opt['val']['suffix']:
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["val"]["suffix"]}.png')
else:
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["name"]}.png')
imwrite(sr_img, save_img_path)
if with_metrics:
# calculate metrics
for name, opt_ in self.opt['val']['metrics'].items():
self.metric_results[name] += calculate_metric(metric_data, opt_)
if use_pbar:
pbar.update(1)
pbar.set_description(f'Test {img_name}')
if use_pbar:
pbar.close()
if with_metrics:
for metric in self.metric_results.keys():
self.metric_results[metric] /= (idx + 1)
# update the best metric result
self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
log_str = f'Validation {dataset_name}\n'
for metric, value in self.metric_results.items():
log_str += f'\t # {metric}: {value:.4f}'
if hasattr(self, 'best_metric_results'):
log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
log_str += '\n'
logger = get_root_logger()
logger.info(log_str)
if tb_logger:
for metric, value in self.metric_results.items():
tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
def get_current_visuals(self):
out_dict = OrderedDict()
out_dict['lq'] = self.lq.detach().cpu()
out_dict['result'] = self.output.detach().cpu()
if hasattr(self, 'gt'):
out_dict['gt'] = self.gt.detach().cpu()
return out_dict
def save(self, epoch, current_iter):
if hasattr(self, 'net_g_ema'):
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
else:
self.save_network(self.net_g, 'net_g', current_iter)
self.save_training_state(epoch, current_iter)
import torch
from collections import OrderedDict
from basicsr.archs import build_network
from basicsr.losses import build_loss
from basicsr.utils import get_root_logger
from basicsr.utils.registry import MODEL_REGISTRY
from .sr_model import SRModel
@MODEL_REGISTRY.register()
class SRGANModel(SRModel):
"""SRGAN model for single image super-resolution."""
def init_training_settings(self):
train_opt = self.opt['train']
self.ema_decay = train_opt.get('ema_decay', 0)
if self.ema_decay > 0:
logger = get_root_logger()
logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
# define network net_g with Exponential Moving Average (EMA)
# net_g_ema is used only for testing on one GPU and saving
# There is no need to wrap with DistributedDataParallel
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
# load pretrained model
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
else:
self.model_ema(0) # copy net_g weight
self.net_g_ema.eval()
# define network net_d
self.net_d = build_network(self.opt['network_d'])
self.net_d = self.model_to_device(self.net_d)
self.print_network(self.net_d)
# load pretrained models
load_path = self.opt['path'].get('pretrain_network_d', None)
if load_path is not None:
param_key = self.opt['path'].get('param_key_d', 'params')
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key)
self.net_g.train()
self.net_d.train()
# define losses
if train_opt.get('pixel_opt'):
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
else:
self.cri_pix = None
if train_opt.get('ldl_opt'):
self.cri_ldl = build_loss(train_opt['ldl_opt']).to(self.device)
else:
self.cri_ldl = None
if train_opt.get('perceptual_opt'):
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
else:
self.cri_perceptual = None
if train_opt.get('gan_opt'):
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
self.net_d_iters = train_opt.get('net_d_iters', 1)
self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
# set up optimizers and schedulers
self.setup_optimizers()
self.setup_schedulers()
def setup_optimizers(self):
train_opt = self.opt['train']
# optimizer g
optim_type = train_opt['optim_g'].pop('type')
self.optimizer_g = self.get_optimizer(optim_type, self.net_g.parameters(), **train_opt['optim_g'])
self.optimizers.append(self.optimizer_g)
# optimizer d
optim_type = train_opt['optim_d'].pop('type')
self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
self.optimizers.append(self.optimizer_d)
def optimize_parameters(self, current_iter):
# optimize net_g
for p in self.net_d.parameters():
p.requires_grad = False
self.optimizer_g.zero_grad()
self.output = self.net_g(self.lq)
l_g_total = 0
loss_dict = OrderedDict()
if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
# pixel loss
if self.cri_pix:
l_g_pix = self.cri_pix(self.output, self.gt)
l_g_total += l_g_pix
loss_dict['l_g_pix'] = l_g_pix
# perceptual loss
if self.cri_perceptual:
l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
if l_g_percep is not None:
l_g_total += l_g_percep
loss_dict['l_g_percep'] = l_g_percep
if l_g_style is not None:
l_g_total += l_g_style
loss_dict['l_g_style'] = l_g_style
# gan loss
fake_g_pred = self.net_d(self.output)
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
l_g_total += l_g_gan
loss_dict['l_g_gan'] = l_g_gan
l_g_total.backward()
self.optimizer_g.step()
# optimize net_d
for p in self.net_d.parameters():
p.requires_grad = True
self.optimizer_d.zero_grad()
# real
real_d_pred = self.net_d(self.gt)
l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
loss_dict['l_d_real'] = l_d_real
loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
l_d_real.backward()
# fake
fake_d_pred = self.net_d(self.output.detach())
l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
loss_dict['l_d_fake'] = l_d_fake
loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
l_d_fake.backward()
self.optimizer_d.step()
self.log_dict = self.reduce_loss_dict(loss_dict)
if self.ema_decay > 0:
self.model_ema(decay=self.ema_decay)
def save(self, epoch, current_iter):
if hasattr(self, 'net_g_ema'):
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
else:
self.save_network(self.net_g, 'net_g', current_iter)
self.save_network(self.net_d, 'net_d', current_iter)
self.save_training_state(epoch, current_iter)
import cv2
import math
import numpy as np
import random
import torch
from collections import OrderedDict
from os import path as osp
from basicsr.archs import build_network
from basicsr.losses import build_loss
from basicsr.losses.gan_loss import g_path_regularize, r1_penalty
from basicsr.utils import imwrite, tensor2img
from basicsr.utils.registry import MODEL_REGISTRY
from .base_model import BaseModel
@MODEL_REGISTRY.register()
class StyleGAN2Model(BaseModel):
"""StyleGAN2 model."""
def __init__(self, opt):
super(StyleGAN2Model, self).__init__(opt)
# define network net_g
self.net_g = build_network(opt['network_g'])
self.net_g = self.model_to_device(self.net_g)
self.print_network(self.net_g)
# load pretrained model
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
param_key = self.opt['path'].get('param_key_g', 'params')
self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
# latent dimension: self.num_style_feat
self.num_style_feat = opt['network_g']['num_style_feat']
num_val_samples = self.opt['val'].get('num_val_samples', 16)
self.fixed_sample = torch.randn(num_val_samples, self.num_style_feat, device=self.device)
if self.is_train:
self.init_training_settings()
def init_training_settings(self):
train_opt = self.opt['train']
# define network net_d
self.net_d = build_network(self.opt['network_d'])
self.net_d = self.model_to_device(self.net_d)
self.print_network(self.net_d)
# load pretrained model
load_path = self.opt['path'].get('pretrain_network_d', None)
if load_path is not None:
param_key = self.opt['path'].get('param_key_d', 'params')
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key)
# define network net_g with Exponential Moving Average (EMA)
# net_g_ema only used for testing on one GPU and saving, do not need to
# wrap with DistributedDataParallel
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
# load pretrained model
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
else:
self.model_ema(0) # copy net_g weight
self.net_g.train()
self.net_d.train()
self.net_g_ema.eval()
# define losses
# gan loss (wgan)
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
# regularization weights
self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator
self.path_reg_weight = train_opt['path_reg_weight'] # for generator
self.net_g_reg_every = train_opt['net_g_reg_every']
self.net_d_reg_every = train_opt['net_d_reg_every']
self.mixing_prob = train_opt['mixing_prob']
self.mean_path_length = 0
# set up optimizers and schedulers
self.setup_optimizers()
self.setup_schedulers()
def setup_optimizers(self):
train_opt = self.opt['train']
# optimizer g
net_g_reg_ratio = self.net_g_reg_every / (self.net_g_reg_every + 1)
if self.opt['network_g']['type'] == 'StyleGAN2GeneratorC':
normal_params = []
style_mlp_params = []
modulation_conv_params = []
for name, param in self.net_g.named_parameters():
if 'modulation' in name:
normal_params.append(param)
elif 'style_mlp' in name:
style_mlp_params.append(param)
elif 'modulated_conv' in name:
modulation_conv_params.append(param)
else:
normal_params.append(param)
optim_params_g = [
{ # add normal params first
'params': normal_params,
'lr': train_opt['optim_g']['lr']
},
{
'params': style_mlp_params,
'lr': train_opt['optim_g']['lr'] * 0.01
},
{
'params': modulation_conv_params,
'lr': train_opt['optim_g']['lr'] / 3
}
]
else:
normal_params = []
for name, param in self.net_g.named_parameters():
normal_params.append(param)
optim_params_g = [{ # add normal params first
'params': normal_params,
'lr': train_opt['optim_g']['lr']
}]
optim_type = train_opt['optim_g'].pop('type')
lr = train_opt['optim_g']['lr'] * net_g_reg_ratio
betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio)
self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas)
self.optimizers.append(self.optimizer_g)
# optimizer d
net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1)
if self.opt['network_d']['type'] == 'StyleGAN2DiscriminatorC':
normal_params = []
linear_params = []
for name, param in self.net_d.named_parameters():
if 'final_linear' in name:
linear_params.append(param)
else:
normal_params.append(param)
optim_params_d = [
{ # add normal params first
'params': normal_params,
'lr': train_opt['optim_d']['lr']
},
{
'params': linear_params,
'lr': train_opt['optim_d']['lr'] * (1 / math.sqrt(512))
}
]
else:
normal_params = []
for name, param in self.net_d.named_parameters():
normal_params.append(param)
optim_params_d = [{ # add normal params first
'params': normal_params,
'lr': train_opt['optim_d']['lr']
}]
optim_type = train_opt['optim_d'].pop('type')
lr = train_opt['optim_d']['lr'] * net_d_reg_ratio
betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio)
self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
self.optimizers.append(self.optimizer_d)
def feed_data(self, data):
self.real_img = data['gt'].to(self.device)
def make_noise(self, batch, num_noise):
if num_noise == 1:
noises = torch.randn(batch, self.num_style_feat, device=self.device)
else:
noises = torch.randn(num_noise, batch, self.num_style_feat, device=self.device).unbind(0)
return noises
def mixing_noise(self, batch, prob):
if random.random() < prob:
return self.make_noise(batch, 2)
else:
return [self.make_noise(batch, 1)]
def optimize_parameters(self, current_iter):
loss_dict = OrderedDict()
# optimize net_d
for p in self.net_d.parameters():
p.requires_grad = True
self.optimizer_d.zero_grad()
batch = self.real_img.size(0)
noise = self.mixing_noise(batch, self.mixing_prob)
fake_img, _ = self.net_g(noise)
fake_pred = self.net_d(fake_img.detach())
real_pred = self.net_d(self.real_img)
# wgan loss with softplus (logistic loss) for discriminator
l_d = self.cri_gan(real_pred, True, is_disc=True) + self.cri_gan(fake_pred, False, is_disc=True)
loss_dict['l_d'] = l_d
# In wgan, real_score should be positive and fake_score should be
# negative
loss_dict['real_score'] = real_pred.detach().mean()
loss_dict['fake_score'] = fake_pred.detach().mean()
l_d.backward()
if current_iter % self.net_d_reg_every == 0:
self.real_img.requires_grad = True
real_pred = self.net_d(self.real_img)
l_d_r1 = r1_penalty(real_pred, self.real_img)
l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0])
# TODO: why do we need to add 0 * real_pred, otherwise, a runtime
# error will arise: RuntimeError: Expected to have finished
# reduction in the prior iteration before starting a new one.
# This error indicates that your module has parameters that were
# not used in producing loss.
loss_dict['l_d_r1'] = l_d_r1.detach().mean()
l_d_r1.backward()
self.optimizer_d.step()
# optimize net_g
for p in self.net_d.parameters():
p.requires_grad = False
self.optimizer_g.zero_grad()
noise = self.mixing_noise(batch, self.mixing_prob)
fake_img, _ = self.net_g(noise)
fake_pred = self.net_d(fake_img)
# wgan loss with softplus (non-saturating loss) for generator
l_g = self.cri_gan(fake_pred, True, is_disc=False)
loss_dict['l_g'] = l_g
l_g.backward()
if current_iter % self.net_g_reg_every == 0:
path_batch_size = max(1, batch // self.opt['train']['path_batch_shrink'])
noise = self.mixing_noise(path_batch_size, self.mixing_prob)
fake_img, latents = self.net_g(noise, return_latents=True)
l_g_path, path_lengths, self.mean_path_length = g_path_regularize(fake_img, latents, self.mean_path_length)
l_g_path = (self.path_reg_weight * self.net_g_reg_every * l_g_path + 0 * fake_img[0, 0, 0, 0])
# TODO: why do we need to add 0 * fake_img[0, 0, 0, 0]
l_g_path.backward()
loss_dict['l_g_path'] = l_g_path.detach().mean()
loss_dict['path_length'] = path_lengths
self.optimizer_g.step()
self.log_dict = self.reduce_loss_dict(loss_dict)
# EMA
self.model_ema(decay=0.5**(32 / (10 * 1000)))
def test(self):
with torch.no_grad():
self.net_g_ema.eval()
self.output, _ = self.net_g_ema([self.fixed_sample])
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
if self.opt['rank'] == 0:
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
assert dataloader is None, 'Validation dataloader should be None.'
self.test()
result = tensor2img(self.output, min_max=(-1, 1))
if self.opt['is_train']:
save_img_path = osp.join(self.opt['path']['visualization'], 'train', f'train_{current_iter}.png')
else:
save_img_path = osp.join(self.opt['path']['visualization'], 'test', f'test_{self.opt["name"]}.png')
imwrite(result, save_img_path)
# add sample images to tb_logger
result = (result / 255.).astype(np.float32)
result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
if tb_logger is not None:
tb_logger.add_image('samples', result, global_step=current_iter, dataformats='HWC')
def save(self, epoch, current_iter):
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
self.save_network(self.net_d, 'net_d', current_iter)
self.save_training_state(epoch, current_iter)
import torch
from torch.nn import functional as F
from basicsr.utils.registry import MODEL_REGISTRY
from .sr_model import SRModel
@MODEL_REGISTRY.register()
class SwinIRModel(SRModel):
def test(self):
# pad to multiplication of window_size
window_size = self.opt['network_g']['window_size']
scale = self.opt.get('scale', 1)
mod_pad_h, mod_pad_w = 0, 0
_, _, h, w = self.lq.size()
if h % window_size != 0:
mod_pad_h = window_size - h % window_size
if w % window_size != 0:
mod_pad_w = window_size - w % window_size
img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
if hasattr(self, 'net_g_ema'):
self.net_g_ema.eval()
with torch.no_grad():
self.output = self.net_g_ema(img)
else:
self.net_g.eval()
with torch.no_grad():
self.output = self.net_g(img)
self.net_g.train()
_, _, h, w = self.output.size()
self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale]
import torch
from collections import Counter
from os import path as osp
from torch import distributed as dist
from tqdm import tqdm
from basicsr.metrics import calculate_metric
from basicsr.utils import get_root_logger, imwrite, tensor2img
from basicsr.utils.dist_util import get_dist_info
from basicsr.utils.registry import MODEL_REGISTRY
from .sr_model import SRModel
@MODEL_REGISTRY.register()
class VideoBaseModel(SRModel):
"""Base video SR model."""
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
dataset = dataloader.dataset
dataset_name = dataset.opt['name']
with_metrics = self.opt['val']['metrics'] is not None
# initialize self.metric_results
# It is a dict: {
# 'folder1': tensor (num_frame x len(metrics)),
# 'folder2': tensor (num_frame x len(metrics))
# }
if with_metrics:
if not hasattr(self, 'metric_results'): # only execute in the first run
self.metric_results = {}
num_frame_each_folder = Counter(dataset.data_info['folder'])
for folder, num_frame in num_frame_each_folder.items():
self.metric_results[folder] = torch.zeros(
num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda')
# initialize the best metric results
self._initialize_best_metric_results(dataset_name)
# zero self.metric_results
rank, world_size = get_dist_info()
if with_metrics:
for _, tensor in self.metric_results.items():
tensor.zero_()
metric_data = dict()
# record all frames (border and center frames)
if rank == 0:
pbar = tqdm(total=len(dataset), unit='frame')
for idx in range(rank, len(dataset), world_size):
val_data = dataset[idx]
val_data['lq'].unsqueeze_(0)
val_data['gt'].unsqueeze_(0)
folder = val_data['folder']
frame_idx, max_idx = val_data['idx'].split('/')
lq_path = val_data['lq_path']
self.feed_data(val_data)
self.test()
visuals = self.get_current_visuals()
result_img = tensor2img([visuals['result']])
metric_data['img'] = result_img
if 'gt' in visuals:
gt_img = tensor2img([visuals['gt']])
metric_data['img2'] = gt_img
del self.gt
# tentative for out of GPU memory
del self.lq
del self.output
torch.cuda.empty_cache()
if save_img:
if self.opt['is_train']:
raise NotImplementedError('saving image is not supported during training.')
else:
if 'vimeo' in dataset_name.lower(): # vimeo90k dataset
split_result = lq_path.split('/')
img_name = f'{split_result[-3]}_{split_result[-2]}_{split_result[-1].split(".")[0]}'
else: # other datasets, e.g., REDS, Vid4
img_name = osp.splitext(osp.basename(lq_path))[0]
if self.opt['val']['suffix']:
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
f'{img_name}_{self.opt["val"]["suffix"]}.png')
else:
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
f'{img_name}_{self.opt["name"]}.png')
imwrite(result_img, save_img_path)
if with_metrics:
# calculate metrics
for metric_idx, opt_ in enumerate(self.opt['val']['metrics'].values()):
result = calculate_metric(metric_data, opt_)
self.metric_results[folder][int(frame_idx), metric_idx] += result
# progress bar
if rank == 0:
for _ in range(world_size):
pbar.update(1)
pbar.set_description(f'Test {folder}: {int(frame_idx) + world_size}/{max_idx}')
if rank == 0:
pbar.close()
if with_metrics:
if self.opt['dist']:
# collect data among GPUs
for _, tensor in self.metric_results.items():
dist.reduce(tensor, 0)
dist.barrier()
else:
pass # assume use one gpu in non-dist testing
if rank == 0:
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
logger = get_root_logger()
logger.warning('nondist_validation is not implemented. Run dist_validation.')
self.dist_validation(dataloader, current_iter, tb_logger, save_img)
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
# ----------------- calculate the average values for each folder, and for each metric ----------------- #
# average all frames for each sub-folder
# metric_results_avg is a dict:{
# 'folder1': tensor (len(metrics)),
# 'folder2': tensor (len(metrics))
# }
metric_results_avg = {
folder: torch.mean(tensor, dim=0).cpu()
for (folder, tensor) in self.metric_results.items()
}
# total_avg_results is a dict: {
# 'metric1': float,
# 'metric2': float
# }
total_avg_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
for folder, tensor in metric_results_avg.items():
for idx, metric in enumerate(total_avg_results.keys()):
total_avg_results[metric] += metric_results_avg[folder][idx].item()
# average among folders
for metric in total_avg_results.keys():
total_avg_results[metric] /= len(metric_results_avg)
# update the best metric result
self._update_best_metric_result(dataset_name, metric, total_avg_results[metric], current_iter)
# ------------------------------------------ log the metric ------------------------------------------ #
log_str = f'Validation {dataset_name}\n'
for metric_idx, (metric, value) in enumerate(total_avg_results.items()):
log_str += f'\t # {metric}: {value:.4f}'
for folder, tensor in metric_results_avg.items():
log_str += f'\t # {folder}: {tensor[metric_idx].item():.4f}'
if hasattr(self, 'best_metric_results'):
log_str += (f'\n\t Best: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
log_str += '\n'
logger = get_root_logger()
logger.info(log_str)
if tb_logger:
for metric_idx, (metric, value) in enumerate(total_avg_results.items()):
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
for folder, tensor in metric_results_avg.items():
tb_logger.add_scalar(f'metrics/{metric}/{folder}', tensor[metric_idx].item(), current_iter)
from basicsr.utils.registry import MODEL_REGISTRY
from .srgan_model import SRGANModel
from .video_base_model import VideoBaseModel
@MODEL_REGISTRY.register()
class VideoGANModel(SRGANModel, VideoBaseModel):
"""Video GAN model.
Use multiple inheritance.
It will first use the functions of :class:`SRGANModel`:
- :func:`init_training_settings`
- :func:`setup_optimizers`
- :func:`optimize_parameters`
- :func:`save`
Then find functions in :class:`VideoBaseModel`.
"""
import torch
from collections import OrderedDict
from basicsr.archs import build_network
from basicsr.losses import build_loss
from basicsr.utils import get_root_logger
from basicsr.utils.registry import MODEL_REGISTRY
from .video_recurrent_model import VideoRecurrentModel
@MODEL_REGISTRY.register()
class VideoRecurrentGANModel(VideoRecurrentModel):
def init_training_settings(self):
train_opt = self.opt['train']
self.ema_decay = train_opt.get('ema_decay', 0)
if self.ema_decay > 0:
logger = get_root_logger()
logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
# build network net_g with Exponential Moving Average (EMA)
# net_g_ema only used for testing on one GPU and saving.
# There is no need to wrap with DistributedDataParallel
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
# load pretrained model
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
else:
self.model_ema(0) # copy net_g weight
self.net_g_ema.eval()
# define network net_d
self.net_d = build_network(self.opt['network_d'])
self.net_d = self.model_to_device(self.net_d)
self.print_network(self.net_d)
# load pretrained models
load_path = self.opt['path'].get('pretrain_network_d', None)
if load_path is not None:
param_key = self.opt['path'].get('param_key_d', 'params')
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key)
self.net_g.train()
self.net_d.train()
# define losses
if train_opt.get('pixel_opt'):
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
else:
self.cri_pix = None
if train_opt.get('perceptual_opt'):
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
else:
self.cri_perceptual = None
if train_opt.get('gan_opt'):
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
self.net_d_iters = train_opt.get('net_d_iters', 1)
self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
# set up optimizers and schedulers
self.setup_optimizers()
self.setup_schedulers()
def setup_optimizers(self):
train_opt = self.opt['train']
if train_opt['fix_flow']:
normal_params = []
flow_params = []
for name, param in self.net_g.named_parameters():
if 'spynet' in name: # The fix_flow now only works for spynet.
flow_params.append(param)
else:
normal_params.append(param)
optim_params = [
{ # add flow params first
'params': flow_params,
'lr': train_opt['lr_flow']
},
{
'params': normal_params,
'lr': train_opt['optim_g']['lr']
},
]
else:
optim_params = self.net_g.parameters()
# optimizer g
optim_type = train_opt['optim_g'].pop('type')
self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
self.optimizers.append(self.optimizer_g)
# optimizer d
optim_type = train_opt['optim_d'].pop('type')
self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
self.optimizers.append(self.optimizer_d)
def optimize_parameters(self, current_iter):
logger = get_root_logger()
# optimize net_g
for p in self.net_d.parameters():
p.requires_grad = False
if self.fix_flow_iter:
if current_iter == 1:
logger.info(f'Fix flow network and feature extractor for {self.fix_flow_iter} iters.')
for name, param in self.net_g.named_parameters():
if 'spynet' in name or 'edvr' in name:
param.requires_grad_(False)
elif current_iter == self.fix_flow_iter:
logger.warning('Train all the parameters.')
self.net_g.requires_grad_(True)
self.optimizer_g.zero_grad()
self.output = self.net_g(self.lq)
_, _, c, h, w = self.output.size()
l_g_total = 0
loss_dict = OrderedDict()
if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
# pixel loss
if self.cri_pix:
l_g_pix = self.cri_pix(self.output, self.gt)
l_g_total += l_g_pix
loss_dict['l_g_pix'] = l_g_pix
# perceptual loss
if self.cri_perceptual:
l_g_percep, l_g_style = self.cri_perceptual(self.output.view(-1, c, h, w), self.gt.view(-1, c, h, w))
if l_g_percep is not None:
l_g_total += l_g_percep
loss_dict['l_g_percep'] = l_g_percep
if l_g_style is not None:
l_g_total += l_g_style
loss_dict['l_g_style'] = l_g_style
# gan loss
fake_g_pred = self.net_d(self.output.view(-1, c, h, w))
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
l_g_total += l_g_gan
loss_dict['l_g_gan'] = l_g_gan
l_g_total.backward()
self.optimizer_g.step()
# optimize net_d
for p in self.net_d.parameters():
p.requires_grad = True
self.optimizer_d.zero_grad()
# real
# reshape to (b*n, c, h, w)
real_d_pred = self.net_d(self.gt.view(-1, c, h, w))
l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
loss_dict['l_d_real'] = l_d_real
loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
l_d_real.backward()
# fake
# reshape to (b*n, c, h, w)
fake_d_pred = self.net_d(self.output.view(-1, c, h, w).detach())
l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
loss_dict['l_d_fake'] = l_d_fake
loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
l_d_fake.backward()
self.optimizer_d.step()
self.log_dict = self.reduce_loss_dict(loss_dict)
if self.ema_decay > 0:
self.model_ema(decay=self.ema_decay)
def save(self, epoch, current_iter):
if self.ema_decay > 0:
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
else:
self.save_network(self.net_g, 'net_g', current_iter)
self.save_network(self.net_d, 'net_d', current_iter)
self.save_training_state(epoch, current_iter)
import torch
from collections import Counter
from os import path as osp
from torch import distributed as dist
from tqdm import tqdm
from basicsr.metrics import calculate_metric
from basicsr.utils import get_root_logger, imwrite, tensor2img
from basicsr.utils.dist_util import get_dist_info
from basicsr.utils.registry import MODEL_REGISTRY
from .video_base_model import VideoBaseModel
@MODEL_REGISTRY.register()
class VideoRecurrentModel(VideoBaseModel):
def __init__(self, opt):
super(VideoRecurrentModel, self).__init__(opt)
if self.is_train:
self.fix_flow_iter = opt['train'].get('fix_flow')
def setup_optimizers(self):
train_opt = self.opt['train']
flow_lr_mul = train_opt.get('flow_lr_mul', 1)
logger = get_root_logger()
logger.info(f'Multiple the learning rate for flow network with {flow_lr_mul}.')
if flow_lr_mul == 1:
optim_params = self.net_g.parameters()
else: # separate flow params and normal params for different lr
normal_params = []
flow_params = []
for name, param in self.net_g.named_parameters():
if 'spynet' in name:
flow_params.append(param)
else:
normal_params.append(param)
optim_params = [
{ # add normal params first
'params': normal_params,
'lr': train_opt['optim_g']['lr']
},
{
'params': flow_params,
'lr': train_opt['optim_g']['lr'] * flow_lr_mul
},
]
optim_type = train_opt['optim_g'].pop('type')
self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
self.optimizers.append(self.optimizer_g)
def optimize_parameters(self, current_iter):
if self.fix_flow_iter:
logger = get_root_logger()
if current_iter == 1:
logger.info(f'Fix flow network and feature extractor for {self.fix_flow_iter} iters.')
for name, param in self.net_g.named_parameters():
if 'spynet' in name or 'edvr' in name:
param.requires_grad_(False)
elif current_iter == self.fix_flow_iter:
logger.warning('Train all the parameters.')
self.net_g.requires_grad_(True)
super(VideoRecurrentModel, self).optimize_parameters(current_iter)
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
dataset = dataloader.dataset
dataset_name = dataset.opt['name']
with_metrics = self.opt['val']['metrics'] is not None
# initialize self.metric_results
# It is a dict: {
# 'folder1': tensor (num_frame x len(metrics)),
# 'folder2': tensor (num_frame x len(metrics))
# }
if with_metrics:
if not hasattr(self, 'metric_results'): # only execute in the first run
self.metric_results = {}
num_frame_each_folder = Counter(dataset.data_info['folder'])
for folder, num_frame in num_frame_each_folder.items():
self.metric_results[folder] = torch.zeros(
num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda')
# initialize the best metric results
self._initialize_best_metric_results(dataset_name)
# zero self.metric_results
rank, world_size = get_dist_info()
if with_metrics:
for _, tensor in self.metric_results.items():
tensor.zero_()
metric_data = dict()
num_folders = len(dataset)
num_pad = (world_size - (num_folders % world_size)) % world_size
if rank == 0:
pbar = tqdm(total=len(dataset), unit='folder')
# Will evaluate (num_folders + num_pad) times, but only the first num_folders results will be recorded.
# (To avoid wait-dead)
for i in range(rank, num_folders + num_pad, world_size):
idx = min(i, num_folders - 1)
val_data = dataset[idx]
folder = val_data['folder']
# compute outputs
val_data['lq'].unsqueeze_(0)
val_data['gt'].unsqueeze_(0)
self.feed_data(val_data)
val_data['lq'].squeeze_(0)
val_data['gt'].squeeze_(0)
self.test()
visuals = self.get_current_visuals()
# tentative for out of GPU memory
del self.lq
del self.output
if 'gt' in visuals:
del self.gt
torch.cuda.empty_cache()
if self.center_frame_only:
visuals['result'] = visuals['result'].unsqueeze(1)
if 'gt' in visuals:
visuals['gt'] = visuals['gt'].unsqueeze(1)
# evaluate
if i < num_folders:
for idx in range(visuals['result'].size(1)):
result = visuals['result'][0, idx, :, :, :]
result_img = tensor2img([result]) # uint8, bgr
metric_data['img'] = result_img
if 'gt' in visuals:
gt = visuals['gt'][0, idx, :, :, :]
gt_img = tensor2img([gt]) # uint8, bgr
metric_data['img2'] = gt_img
if save_img:
if self.opt['is_train']:
raise NotImplementedError('saving image is not supported during training.')
else:
if self.center_frame_only: # vimeo-90k
clip_ = val_data['lq_path'].split('/')[-3]
seq_ = val_data['lq_path'].split('/')[-2]
name_ = f'{clip_}_{seq_}'
img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
f"{name_}_{self.opt['name']}.png")
else: # others
img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
f"{idx:08d}_{self.opt['name']}.png")
# image name only for REDS dataset
imwrite(result_img, img_path)
# calculate metrics
if with_metrics:
for metric_idx, opt_ in enumerate(self.opt['val']['metrics'].values()):
result = calculate_metric(metric_data, opt_)
self.metric_results[folder][idx, metric_idx] += result
# progress bar
if rank == 0:
for _ in range(world_size):
pbar.update(1)
pbar.set_description(f'Folder: {folder}')
if rank == 0:
pbar.close()
if with_metrics:
if self.opt['dist']:
# collect data among GPUs
for _, tensor in self.metric_results.items():
dist.reduce(tensor, 0)
dist.barrier()
if rank == 0:
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
def test(self):
n = self.lq.size(1)
self.net_g.eval()
flip_seq = self.opt['val'].get('flip_seq', False)
self.center_frame_only = self.opt['val'].get('center_frame_only', False)
if flip_seq:
self.lq = torch.cat([self.lq, self.lq.flip(1)], dim=1)
with torch.no_grad():
self.output = self.net_g(self.lq)
if flip_seq:
output_1 = self.output[:, :n, :, :, :]
output_2 = self.output[:, n:, :, :, :].flip(1)
self.output = 0.5 * (output_1 + output_2)
if self.center_frame_only:
self.output = self.output[:, n // 2, :, :, :]
self.net_g.train()
from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,
modulated_deform_conv)
__all__ = [
'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
'modulated_deform_conv'
]
import math
import os
import torch
from torch import nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn import functional as F
from torch.nn.modules.utils import _pair, _single
BASICSR_JIT = os.getenv('BASICSR_JIT')
if BASICSR_JIT == 'True':
from torch.utils.cpp_extension import load
module_path = os.path.dirname(__file__)
deform_conv_ext = load(
'deform_conv',
sources=[
os.path.join(module_path, 'src', 'deform_conv_ext.cpp'),
os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'),
os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'),
],
)
else:
try:
from . import deform_conv_ext
except ImportError:
pass
# avoid annoying print output
# print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n '
# '1. compile with BASICSR_EXT=True. or\n '
# '2. set BASICSR_JIT=True during running')
class DeformConvFunction(Function):
@staticmethod
def forward(ctx,
input,
offset,
weight,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
im2col_step=64):
if input is not None and input.dim() != 4:
raise ValueError(f'Expected 4D tensor as input, got {input.dim()}D tensor instead.')
ctx.stride = _pair(stride)
ctx.padding = _pair(padding)
ctx.dilation = _pair(dilation)
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.im2col_step = im2col_step
ctx.save_for_backward(input, offset, weight)
output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
if not input.is_cuda:
raise NotImplementedError
else:
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
deform_conv_ext.deform_conv_forward(input, weight,
offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
ctx.deformable_groups, cur_im2col_step)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
input, offset, weight = ctx.saved_tensors
grad_input = grad_offset = grad_weight = None
if not grad_output.is_cuda:
raise NotImplementedError
else:
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input,
grad_offset, weight, ctx.bufs_[0], weight.size(3),
weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
ctx.deformable_groups, cur_im2col_step)
if ctx.needs_input_grad[2]:
grad_weight = torch.zeros_like(weight)
deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight,
ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
weight.size(2), ctx.stride[1], ctx.stride[0],
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
cur_im2col_step)
return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
@staticmethod
def _output_size(input, weight, padding, dilation, stride):
channels = weight.size(0)
output_size = (input.size(0), channels)
for d in range(input.dim() - 2):
in_size = input.size(d + 2)
pad = padding[d]
kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
stride_ = stride[d]
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
if not all(map(lambda s: s > 0, output_size)):
raise ValueError(f'convolution input is too small (output would be {"x".join(map(str, output_size))})')
return output_size
class ModulatedDeformConvFunction(Function):
@staticmethod
def forward(ctx,
input,
offset,
mask,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1):
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.with_bias = bias is not None
if not ctx.with_bias:
bias = input.new_empty(1) # fake tensor
if not input.is_cuda:
raise NotImplementedError
if weight.requires_grad or mask.requires_grad or offset.requires_grad or input.requires_grad:
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output,
ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.groups, ctx.deformable_groups, ctx.with_bias)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
input, offset, mask, weight, bias = ctx.saved_tensors
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
grad_mask = torch.zeros_like(mask)
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias)
deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
grad_output, weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.groups, ctx.deformable_groups, ctx.with_bias)
if not ctx.with_bias:
grad_bias = None
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None)
@staticmethod
def _infer_shape(ctx, input, weight):
n = input.size(0)
channels_out = weight.size(0)
height, width = input.shape[2:4]
kernel_h, kernel_w = weight.shape[2:4]
height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
return n, channels_out, height_out, width_out
deform_conv = DeformConvFunction.apply
modulated_deform_conv = ModulatedDeformConvFunction.apply
class DeformConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=False):
super(DeformConv, self).__init__()
assert not bias
assert in_channels % groups == 0, f'in_channels {in_channels} is not divisible by groups {groups}'
assert out_channels % groups == 0, f'out_channels {out_channels} is not divisible by groups {groups}'
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.deformable_groups = deformable_groups
# enable compatibility with nn.Conv2d
self.transposed = False
self.output_padding = _single(0)
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
def forward(self, x, offset):
# To fix an assert error in deform_conv_cuda.cpp:128
# input image is smaller than kernel
input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1])
if input_pad:
pad_h = max(self.kernel_size[0] - x.size(2), 0)
pad_w = max(self.kernel_size[1] - x.size(3), 0)
x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
self.deformable_groups)
if input_pad:
out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous()
return out
class DeformConvPack(DeformConv):
"""A Deformable Conv Encapsulation that acts as normal Conv layers.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d.
padding (int or tuple[int]): Same as nn.Conv2d.
dilation (int or tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
"""
_version = 2
def __init__(self, *args, **kwargs):
super(DeformConvPack, self).__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
kernel_size=self.kernel_size,
stride=_pair(self.stride),
padding=_pair(self.padding),
dilation=_pair(self.dilation),
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x):
offset = self.conv_offset(x)
return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
self.deformable_groups)
class ModulatedDeformConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=True):
super(ModulatedDeformConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.deformable_groups = deformable_groups
self.with_bias = bias
# enable compatibility with nn.Conv2d
self.transposed = False
self.output_padding = _single(0)
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.init_weights()
def init_weights(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.zero_()
def forward(self, x, offset, mask):
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups)
class ModulatedDeformConvPack(ModulatedDeformConv):
"""A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d.
padding (int or tuple[int]): Same as nn.Conv2d.
dilation (int or tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
"""
_version = 2
def __init__(self, *args, **kwargs):
super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
kernel_size=self.kernel_size,
stride=_pair(self.stride),
padding=_pair(self.padding),
dilation=_pair(self.dilation),
bias=True)
self.init_weights()
def init_weights(self):
super(ModulatedDeformConvPack, self).init_weights()
if hasattr(self, 'conv_offset'):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x):
out = self.conv_offset(x)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups)
from .fused_act import FusedLeakyReLU, fused_leaky_relu
__all__ = ['FusedLeakyReLU', 'fused_leaky_relu']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment