Commit a75d2bda authored by mashun1's avatar mashun1
Browse files

evtexture

parents
Pipeline #1325 canceled with stages
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from basicsr.metrics.metric_util import reorder_image, to_y_channel
from basicsr.utils.color_util import rgb2ycbcr_pt
from basicsr.utils.registry import METRIC_REGISTRY
@METRIC_REGISTRY.register()
def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
Args:
img (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: PSNR result.
"""
assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
if input_order not in ['HWC', 'CHW']:
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"')
img = reorder_image(img, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order)
if crop_border != 0:
img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
if test_y_channel:
img = to_y_channel(img)
img2 = to_y_channel(img2)
img = img.astype(np.float64)
img2 = img2.astype(np.float64)
mse = np.mean((img - img2)**2)
if mse == 0:
return float('inf')
return 10. * np.log10(255. * 255. / mse)
@METRIC_REGISTRY.register()
def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False, **kwargs):
"""Calculate PSNR (Peak Signal-to-Noise Ratio) (PyTorch version).
Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
Args:
img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: PSNR result.
"""
assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
if crop_border != 0:
img = img[:, :, crop_border:-crop_border, crop_border:-crop_border]
img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border]
if test_y_channel:
img = rgb2ycbcr_pt(img, y_only=True)
img2 = rgb2ycbcr_pt(img2, y_only=True)
img = img.to(torch.float64)
img2 = img2.to(torch.float64)
mse = torch.mean((img - img2)**2, dim=[1, 2, 3])
return 10. * torch.log10(1. / (mse + 1e-8))
@METRIC_REGISTRY.register()
def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
"""Calculate SSIM (structural similarity).
``Paper: Image quality assessment: From error visibility to structural similarity``
The results are the same as that of the official released MATLAB code in
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
For three-channel images, SSIM is calculated for each channel and then
averaged.
Args:
img (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
Default: 'HWC'.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: SSIM result.
"""
assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
if input_order not in ['HWC', 'CHW']:
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"')
img = reorder_image(img, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order)
if crop_border != 0:
img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
if test_y_channel:
img = to_y_channel(img)
img2 = to_y_channel(img2)
img = img.astype(np.float64)
img2 = img2.astype(np.float64)
ssims = []
for i in range(img.shape[2]):
ssims.append(_ssim(img[..., i], img2[..., i]))
return np.array(ssims).mean()
@METRIC_REGISTRY.register()
def calculate_ssim_pt(img, img2, crop_border, test_y_channel=False, **kwargs):
"""Calculate SSIM (structural similarity) (PyTorch version).
``Paper: Image quality assessment: From error visibility to structural similarity``
The results are the same as that of the official released MATLAB code in
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
For three-channel images, SSIM is calculated for each channel and then
averaged.
Args:
img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: SSIM result.
"""
assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
if crop_border != 0:
img = img[:, :, crop_border:-crop_border, crop_border:-crop_border]
img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border]
if test_y_channel:
img = rgb2ycbcr_pt(img, y_only=True)
img2 = rgb2ycbcr_pt(img2, y_only=True)
img = img.to(torch.float64)
img2 = img2.to(torch.float64)
ssim = _ssim_pth(img * 255., img2 * 255.)
return ssim
def _ssim(img, img2):
"""Calculate SSIM (structural similarity) for one channel images.
It is called by func:`calculate_ssim`.
Args:
img (ndarray): Images with range [0, 255] with order 'HWC'.
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
Returns:
float: SSIM result.
"""
c1 = (0.01 * 255)**2
c2 = (0.03 * 255)**2
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5] # valid mode for window size 11
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2))
return ssim_map.mean()
def _ssim_pth(img, img2):
"""Calculate SSIM (structural similarity) (PyTorch version).
It is called by func:`calculate_ssim_pt`.
Args:
img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
Returns:
float: SSIM result.
"""
c1 = (0.01 * 255)**2
c2 = (0.03 * 255)**2
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
window = torch.from_numpy(window).view(1, 1, 11, 11).expand(img.size(1), 1, 11, 11).to(img.dtype).to(img.device)
mu1 = F.conv2d(img, window, stride=1, padding=0, groups=img.shape[1]) # valid mode
mu2 = F.conv2d(img2, window, stride=1, padding=0, groups=img2.shape[1]) # valid mode
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img * img, window, stride=1, padding=0, groups=img.shape[1]) - mu1_sq
sigma2_sq = F.conv2d(img2 * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu2_sq
sigma12 = F.conv2d(img * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu1_mu2
cs_map = (2 * sigma12 + c2) / (sigma1_sq + sigma2_sq + c2)
ssim_map = ((2 * mu1_mu2 + c1) / (mu1_sq + mu2_sq + c1)) * cs_map
return ssim_map.mean([1, 2, 3])
import importlib
from copy import deepcopy
from os import path as osp
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.registry import MODEL_REGISTRY
__all__ = ['build_model']
# automatically scan and import model modules for registry
# scan all the files under the 'models' folder and collect files ending with '_model.py'
model_folder = osp.dirname(osp.abspath(__file__))
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
# import all the model modules
_model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames]
def build_model(opt):
"""Build model from options.
Args:
opt (dict): Configuration. It must contain:
model_type (str): Model type.
"""
opt = deepcopy(opt)
model = MODEL_REGISTRY.get(opt['model_type'])(opt)
logger = get_root_logger()
logger.info(f'Model [{model.__class__.__name__}] is created.')
return model
import os
import time
import torch
from collections import OrderedDict
from copy import deepcopy
from torch.nn.parallel import DataParallel, DistributedDataParallel
from basicsr.models import lr_scheduler as lr_scheduler
from basicsr.utils import get_root_logger
from basicsr.utils.dist_util import master_only
class BaseModel():
"""Base model."""
def __init__(self, opt):
self.opt = opt
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
self.is_train = opt['is_train']
self.schedulers = []
self.optimizers = []
def feed_data(self, data):
pass
def optimize_parameters(self):
pass
def get_current_visuals(self):
pass
def save(self, epoch, current_iter):
"""Save networks and training state."""
pass
def validation(self, dataloader, current_iter, tb_logger, save_img=False):
"""Validation function.
Args:
dataloader (torch.utils.data.DataLoader): Validation dataloader.
current_iter (int): Current iteration.
tb_logger (tensorboard logger): Tensorboard logger.
save_img (bool): Whether to save images. Default: False.
"""
if self.opt['dist']:
self.dist_validation(dataloader, current_iter, tb_logger, save_img)
else:
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
def _initialize_best_metric_results(self, dataset_name):
"""Initialize the best metric results dict for recording the best metric value and iteration."""
if hasattr(self, 'best_metric_results') and dataset_name in self.best_metric_results:
return
elif not hasattr(self, 'best_metric_results'):
self.best_metric_results = dict()
# add a dataset record
record = dict()
for metric, content in self.opt['val']['metrics'].items():
better = content.get('better', 'higher')
init_val = float('-inf') if better == 'higher' else float('inf')
record[metric] = dict(better=better, val=init_val, iter=-1)
self.best_metric_results[dataset_name] = record
def _update_best_metric_result(self, dataset_name, metric, val, current_iter):
if self.best_metric_results[dataset_name][metric]['better'] == 'higher':
if val >= self.best_metric_results[dataset_name][metric]['val']:
self.best_metric_results[dataset_name][metric]['val'] = val
self.best_metric_results[dataset_name][metric]['iter'] = current_iter
else:
if val <= self.best_metric_results[dataset_name][metric]['val']:
self.best_metric_results[dataset_name][metric]['val'] = val
self.best_metric_results[dataset_name][metric]['iter'] = current_iter
def model_ema(self, decay=0.999):
net_g = self.get_bare_model(self.net_g)
net_g_params = dict(net_g.named_parameters())
net_g_ema_params = dict(self.net_g_ema.named_parameters())
for k in net_g_ema_params.keys():
net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay)
def get_current_log(self):
return self.log_dict
def model_to_device(self, net):
"""Model to device. It also warps models with DistributedDataParallel
or DataParallel.
Args:
net (nn.Module)
"""
net = net.to(self.device)
if self.opt['dist']:
find_unused_parameters = self.opt.get('find_unused_parameters', False)
net = DistributedDataParallel(
net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters)
elif self.opt['num_gpu'] > 1:
net = DataParallel(net)
return net
def get_optimizer(self, optim_type, params, lr, **kwargs):
if optim_type == 'Adam':
optimizer = torch.optim.Adam(params, lr, **kwargs)
elif optim_type == 'AdamW':
optimizer = torch.optim.AdamW(params, lr, **kwargs)
elif optim_type == 'Adamax':
optimizer = torch.optim.Adamax(params, lr, **kwargs)
elif optim_type == 'SGD':
optimizer = torch.optim.SGD(params, lr, **kwargs)
elif optim_type == 'ASGD':
optimizer = torch.optim.ASGD(params, lr, **kwargs)
elif optim_type == 'RMSprop':
optimizer = torch.optim.RMSprop(params, lr, **kwargs)
elif optim_type == 'Rprop':
optimizer = torch.optim.Rprop(params, lr, **kwargs)
else:
raise NotImplementedError(f'optimizer {optim_type} is not supported yet.')
return optimizer
def setup_schedulers(self):
"""Set up schedulers."""
train_opt = self.opt['train']
scheduler_type = train_opt['scheduler'].pop('type')
if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
for optimizer in self.optimizers:
self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler']))
elif scheduler_type == 'CosineAnnealingRestartLR':
for optimizer in self.optimizers:
self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler']))
else:
raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.')
def get_bare_model(self, net):
"""Get bare model, especially under wrapping with
DistributedDataParallel or DataParallel.
"""
if isinstance(net, (DataParallel, DistributedDataParallel)):
net = net.module
return net
@master_only
def print_network(self, net):
"""Print the str and parameter number of a network.
Args:
net (nn.Module)
"""
if isinstance(net, (DataParallel, DistributedDataParallel)):
net_cls_str = f'{net.__class__.__name__} - {net.module.__class__.__name__}'
else:
net_cls_str = f'{net.__class__.__name__}'
net = self.get_bare_model(net)
net_str = str(net)
net_params = sum(map(lambda x: x.numel(), net.parameters()))
logger = get_root_logger()
logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}')
logger.info(net_str)
def _set_lr(self, lr_groups_l):
"""Set learning rate for warm-up.
Args:
lr_groups_l (list): List for lr_groups, each for an optimizer.
"""
for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
for param_group, lr in zip(optimizer.param_groups, lr_groups):
param_group['lr'] = lr
def _get_init_lr(self):
"""Get the initial lr, which is set by the scheduler.
"""
init_lr_groups_l = []
for optimizer in self.optimizers:
init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
return init_lr_groups_l
def update_learning_rate(self, current_iter, warmup_iter=-1):
"""Update learning rate.
Args:
current_iter (int): Current iteration.
warmup_iter (int): Warm-up iter numbers. -1 for no warm-up.
Default: -1.
"""
if current_iter > 1:
for scheduler in self.schedulers:
scheduler.step()
# set up warm-up learning rate
if current_iter < warmup_iter:
# get initial lr for each group
init_lr_g_l = self._get_init_lr()
# modify warming-up learning rates
# currently only support linearly warm up
warm_up_lr_l = []
for init_lr_g in init_lr_g_l:
warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g])
# set learning rate
self._set_lr(warm_up_lr_l)
def get_current_learning_rate(self):
return [param_group['lr'] for param_group in self.optimizers[0].param_groups]
@master_only
def save_network(self, net, net_label, current_iter, param_key='params'):
"""Save networks.
Args:
net (nn.Module | list[nn.Module]): Network(s) to be saved.
net_label (str): Network label.
current_iter (int): Current iter number.
param_key (str | list[str]): The parameter key(s) to save network.
Default: 'params'.
"""
if current_iter == -1:
current_iter = 'latest'
save_filename = f'{net_label}_{current_iter}.pth'
save_path = os.path.join(self.opt['path']['models'], save_filename)
net = net if isinstance(net, list) else [net]
param_key = param_key if isinstance(param_key, list) else [param_key]
assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.'
save_dict = {}
for net_, param_key_ in zip(net, param_key):
net_ = self.get_bare_model(net_)
state_dict = net_.state_dict()
for key, param in state_dict.items():
if key.startswith('module.'): # remove unnecessary 'module.'
key = key[7:]
state_dict[key] = param.cpu()
save_dict[param_key_] = state_dict
# avoid occasional writing errors
retry = 3
while retry > 0:
try:
torch.save(save_dict, save_path)
except Exception as e:
logger = get_root_logger()
logger.warning(f'Save model error: {e}, remaining retry times: {retry - 1}')
time.sleep(1)
else:
break
finally:
retry -= 1
if retry == 0:
logger.warning(f'Still cannot save {save_path}. Just ignore it.')
# raise IOError(f'Cannot save {save_path}.')
def _print_different_keys_loading(self, crt_net, load_net, strict=True):
"""Print keys with different name or different size when loading models.
1. Print keys with different names.
2. If strict=False, print the same key but with different tensor size.
It also ignore these keys with different sizes (not load).
Args:
crt_net (torch model): Current network.
load_net (dict): Loaded network.
strict (bool): Whether strictly loaded. Default: True.
"""
crt_net = self.get_bare_model(crt_net)
crt_net = crt_net.state_dict()
crt_net_keys = set(crt_net.keys())
load_net_keys = set(load_net.keys())
logger = get_root_logger()
if crt_net_keys != load_net_keys:
logger.warning('Current net - loaded net:')
for v in sorted(list(crt_net_keys - load_net_keys)):
logger.warning(f' {v}')
logger.warning('Loaded net - current net:')
for v in sorted(list(load_net_keys - crt_net_keys)):
logger.warning(f' {v}')
# check the size for the same keys
if not strict:
common_keys = crt_net_keys & load_net_keys
for k in common_keys:
if crt_net[k].size() != load_net[k].size():
logger.warning(f'Size different, ignore [{k}]: crt_net: '
f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
load_net[k + '.ignore'] = load_net.pop(k)
def load_network(self, net, load_path, strict=True, param_key='params'):
"""Load network.
Args:
load_path (str): The path of networks to be loaded.
net (nn.Module): Network.
strict (bool): Whether strictly loaded.
param_key (str): The parameter key of loaded network. If set to
None, use the root 'path'.
Default: 'params'.
"""
logger = get_root_logger()
net = self.get_bare_model(net)
load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
if param_key is not None:
if param_key not in load_net and 'params' in load_net:
param_key = 'params'
logger.info('Loading: params_ema does not exist, use params.')
load_net = load_net[param_key]
logger.info(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].')
# remove unnecessary 'module.'
for k, v in deepcopy(load_net).items():
if k.startswith('module.'):
load_net[k[7:]] = v
load_net.pop(k)
self._print_different_keys_loading(net, load_net, strict)
net.load_state_dict(load_net, strict=strict)
@master_only
def save_training_state(self, epoch, current_iter):
"""Save training states during training, which will be used for
resuming.
Args:
epoch (int): Current epoch.
current_iter (int): Current iteration.
"""
if current_iter != -1:
state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []}
for o in self.optimizers:
state['optimizers'].append(o.state_dict())
for s in self.schedulers:
state['schedulers'].append(s.state_dict())
save_filename = f'{current_iter}.state'
save_path = os.path.join(self.opt['path']['training_states'], save_filename)
# avoid occasional writing errors
retry = 3
while retry > 0:
try:
torch.save(state, save_path)
except Exception as e:
logger = get_root_logger()
logger.warning(f'Save training state error: {e}, remaining retry times: {retry - 1}')
time.sleep(1)
else:
break
finally:
retry -= 1
if retry == 0:
logger.warning(f'Still cannot save {save_path}. Just ignore it.')
# raise IOError(f'Cannot save {save_path}.')
def resume_training(self, resume_state):
"""Reload the optimizers and schedulers for resumed training.
Args:
resume_state (dict): Resume state.
"""
resume_optimizers = resume_state['optimizers']
resume_schedulers = resume_state['schedulers']
assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
for i, o in enumerate(resume_optimizers):
self.optimizers[i].load_state_dict(o)
for i, s in enumerate(resume_schedulers):
self.schedulers[i].load_state_dict(s)
def reduce_loss_dict(self, loss_dict):
"""reduce loss dict.
In distributed training, it averages the losses among different GPUs .
Args:
loss_dict (OrderedDict): Loss dict.
"""
with torch.no_grad():
if self.opt['dist']:
keys = []
losses = []
for name, value in loss_dict.items():
keys.append(name)
losses.append(value)
losses = torch.stack(losses, 0)
torch.distributed.reduce(losses, dst=0)
if self.opt['rank'] == 0:
losses /= self.opt['world_size']
loss_dict = {key: loss for key, loss in zip(keys, losses)}
log_dict = OrderedDict()
for name, value in loss_dict.items():
log_dict[name] = value.mean().item()
return log_dict
import torch
from collections import Counter
from collections import OrderedDict
from os import path as osp
from torch import distributed as dist
from tqdm import tqdm
from basicsr.metrics import calculate_metric
from basicsr.utils import get_root_logger, imwrite, tensor2img
from basicsr.data.transforms import mod_crop
from basicsr.utils.dist_util import get_dist_info
from basicsr.utils.registry import MODEL_REGISTRY
from .video_base_model import VideoBaseModel
from torchvision import utils as vutils
@MODEL_REGISTRY.register()
class E2VSRModel(VideoBaseModel):
def __init__(self, opt):
super(E2VSRModel, self).__init__(opt)
self.scale = opt['scale']
if self.is_train:
self.fix_flow_iter = opt['train'].get('fix_flow')
def feed_data(self, data, flag='train'):
self.lq = data['lq'].to(self.device)
self.voxels_f = data['voxels_f'].to(self.device)
self.voxels_b = data['voxels_b'].to(self.device)
if 'gt' in data:
self.gt = data['gt'].to(self.device)
def setup_optimizers(self):
train_opt = self.opt['train']
flow_lr_mul = train_opt.get('flow_lr_mul', 1)
logger = get_root_logger()
logger.info(f'Multiple the learning rate for flow network with {flow_lr_mul}.')
if flow_lr_mul == 1:
optim_params = self.net_g.parameters()
else: # separate flow params and normal params for different lr
normal_params = []
flow_params = []
for name, param in self.net_g.named_parameters():
# if 'spynet' in name or (self.is_pretrain and 'unet' in name):
if 'spynet' in name:
flow_params.append(param)
else:
normal_params.append(param)
optim_params = [
{ # add normal params first
'params': normal_params,
'lr': train_opt['optim_g']['lr']
},
{
'params': flow_params,
'lr': train_opt['optim_g']['lr'] * flow_lr_mul
},
]
optim_type = train_opt['optim_g'].pop('type')
self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
self.optimizers.append(self.optimizer_g)
def optimize_parameters(self, current_iter):
if self.fix_flow_iter:
logger = get_root_logger()
if current_iter == 1:
logger.info(f'Fix flow network and feature extractor for {self.fix_flow_iter} iters.')
for name, param in self.net_g.named_parameters():
# if 'spynet' in name or 'edvr' in name or (self.is_pretrain and 'unet' in name):
if 'spynet' in name or 'edvr' in name:
param.requires_grad_(False)
elif current_iter == self.fix_flow_iter:
logger.warning('Train all the parameters.')
self.net_g.requires_grad_(True)
self.optimizer_g.zero_grad()
self.output = self.net_g(self.lq, self.voxels_f, self.voxels_b)
l_total = 0
loss_dict = OrderedDict()
# pixel loss
if self.cri_pix:
l_pix = self.cri_pix(self.output, self.gt)
l_total += l_pix
loss_dict['l_pix'] = l_pix
# perceptual loss
if self.cri_perceptual:
l_percep, l_style = self.cri_perceptual(self.output, self.gt)
if l_percep is not None:
l_total += l_percep
loss_dict['l_percep'] = l_percep
if l_style is not None:
l_total += l_style
loss_dict['l_style'] = l_style
loss_dict['l_total'] = l_total
l_total.backward()
self.optimizer_g.step()
self.log_dict = self.reduce_loss_dict(loss_dict)
if self.ema_decay > 0:
self.model_ema(decay=self.ema_decay)
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
dataset = dataloader.dataset
dataset_name = dataset.opt['name']
with_metrics = self.opt['val']['metrics'] is not None
# initialize self.metric_results
# It is a dict: {
# 'folder1': tensor (num_frame x len(metrics)),
# 'folder2': tensor (num_frame x len(metrics))
# }
if with_metrics:
# if not hasattr(self, 'metric_results'): # only execute in the first run
self.metric_results = {}
num_frame_each_folder = Counter(dataset.data_info['folder'])
for folder, num_frame in num_frame_each_folder.items():
self.metric_results[folder] = torch.zeros(
num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda')
# initialize the best metric results
self._initialize_best_metric_results(dataset_name)
# zero self.metric_results
rank, world_size = get_dist_info()
if with_metrics:
for _, tensor in self.metric_results.items():
tensor.zero_()
metric_data = dict()
num_folders = len(dataset)
num_pad = (world_size - (num_folders % world_size)) % world_size
if rank == 0:
pbar = tqdm(total=len(dataset), unit='folder')
# Will evaluate (num_folders + num_pad) times, but only the first num_folders results will be recorded.
# (To avoid wait-dead)
for i in range(rank, num_folders + num_pad, world_size):
idx = min(i, num_folders - 1)
val_data = dataset[idx]
folder = val_data['folder']
# compute outputs
val_data['lq'].unsqueeze_(0)
val_data['gt'].unsqueeze_(0)
val_data['voxels_f'].unsqueeze_(0)
val_data['voxels_b'].unsqueeze_(0)
self.feed_data(val_data, flag='val')
val_data['lq'].squeeze_(0)
val_data['gt'].squeeze_(0)
val_data['voxels_f'].squeeze_(0)
val_data['voxels_b'].squeeze_(0)
self.test()
visuals = self.get_current_visuals()
# tentative for out of GPU memory
del self.lq
del self.output
del self.voxels_f
del self.voxels_b
if 'gt' in visuals:
del self.gt
torch.cuda.empty_cache()
# evaluate
if i < num_folders:
for idx in range(visuals['result'].size(1)):
result_img = visuals['result'][0, idx, :, :, :]
result_img = tensor2img([result_img]) # uint8, bgr
# metric_data['img'] = result_img
# if 'gt' in visuals:
# gt = visuals['gt'][0, idx, :, :, :]
# gt_img = tensor2img([gt]) # uint8, bgr
# gt_img = mod_crop(gt_img, self.scale)
# metric_data['img2'] = gt_img
# # calculate metrics
# if with_metrics:
# for metric_idx, opt_ in enumerate(self.opt['val']['metrics'].values()):
# result = calculate_metric(metric_data, opt_)
# self.metric_results[folder][idx, metric_idx] += result
if save_img:
if self.opt['is_train']:
raise NotImplementedError('saving image is not supported during training.')
else:
# temp_psnr = self.metric_results[folder][idx, 0]
# img_path = osp.join(self.opt['path']['visualization'], dataset_name,
# osp.splitext(folder)[0],
# f"{idx:06d}_{temp_psnr:.4f}_{self.opt['name']}.png")
img_path = osp.join(self.opt['path']['visualization'], dataset_name,
osp.splitext(folder)[0],
f"{idx:06d}_{self.opt['name']}.png")
imwrite(result_img, img_path)
# progress bar
if rank == 0:
for _ in range(world_size):
pbar.update(1)
pbar.set_description(f'Folder: {folder}')
if rank == 0:
pbar.close()
if with_metrics:
if self.opt['dist']:
# collect data among GPUs
for _, tensor in self.metric_results.items():
dist.reduce(tensor, 0)
dist.barrier()
if rank == 0:
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
def test(self):
n = self.lq.size(1)
self.net_g.eval()
with torch.no_grad():
self.output = self.net_g(self.lq, self.voxels_f, self.voxels_b)
self.net_g.train()
def get_current_visuals(self):
out_dict = OrderedDict()
out_dict['lq'] = self.lq.detach().cpu()
out_dict['result'] = self.output.detach().cpu()
if hasattr(self, 'gt'):
out_dict['gt'] = self.gt.detach().cpu()
return out_dict
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
# ----------------- calculate the average values for each folder, and for each metric ----------------- #
# average all frames for each sub-folder
# metric_results_avg is a dict:{
# folder example: 'people_dynamic_wave/split0'
# 'folder1': tensor (len(metrics)),
# 'folder2': tensor (len(metrics))
# }
metric_results_avg = {
folder: torch.mean(tensor, dim=0).cpu()
for (folder, tensor) in self.metric_results.items()
}
# cnt_folder is a dict:{
# 'folder1': int (len(folder1))
# 'folder2': int (len(folder2))
# }
cnt_folder = {folder: tensor.size(0) for (folder, tensor) in self.metric_results.items()}
# total_avg_results is a dict: {
# 'metric1': float,
# 'metric2': float
# }
total_avg_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
for folder, tensor in metric_results_avg.items():
for metric_idx, metric in enumerate(total_avg_results.keys()):
total_avg_results[metric] += tensor[metric_idx].item() * cnt_folder[folder]
total_samples_length = 0
for folder, length in cnt_folder.items():
total_samples_length += length
# average among folders
for metric in total_avg_results.keys():
total_avg_results[metric] /= total_samples_length
# update the best metric result
self._update_best_metric_result(dataset_name, metric, total_avg_results[metric], current_iter)
# ------------------------------------------ log the metric ------------------------------------------ #
log_str = f'Validation {dataset_name}\n'
for metric_idx, (metric, value) in enumerate(total_avg_results.items()):
log_str += f'\t # {metric}: {value:.4f}'
for folder, tensor in metric_results_avg.items():
report_folder = folder.split('.')[0]
log_str += f'\t # {report_folder}: {tensor[metric_idx].item():.4f}'
if hasattr(self, 'best_metric_results'):
log_str += (f'\n\t Best: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
log_str += '\n'
logger = get_root_logger()
logger.info(log_str)
if tb_logger:
for metric_idx, (metric, value) in enumerate(total_avg_results.items()):
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
for folder, tensor in metric_results_avg.items():
tb_logger.add_scalar(f'metrics/{metric}/{folder}', tensor[metric_idx].item(), current_iter)
\ No newline at end of file
import math
from collections import Counter
from torch.optim.lr_scheduler import _LRScheduler
class MultiStepRestartLR(_LRScheduler):
""" MultiStep with restarts learning rate scheme.
Args:
optimizer (torch.nn.optimizer): Torch optimizer.
milestones (list): Iterations that will decrease learning rate.
gamma (float): Decrease ratio. Default: 0.1.
restarts (list): Restart iterations. Default: [0].
restart_weights (list): Restart weights at each restart iteration.
Default: [1].
last_epoch (int): Used in _LRScheduler. Default: -1.
"""
def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1):
self.milestones = Counter(milestones)
self.gamma = gamma
self.restarts = restarts
self.restart_weights = restart_weights
assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.'
super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch in self.restarts:
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
if self.last_epoch not in self.milestones:
return [group['lr'] for group in self.optimizer.param_groups]
return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups]
def get_position_from_periods(iteration, cumulative_period):
"""Get the position from a period list.
It will return the index of the right-closest number in the period list.
For example, the cumulative_period = [100, 200, 300, 400],
if iteration == 50, return 0;
if iteration == 210, return 2;
if iteration == 300, return 2.
Args:
iteration (int): Current iteration.
cumulative_period (list[int]): Cumulative period list.
Returns:
int: The position of the right-closest number in the period list.
"""
for i, period in enumerate(cumulative_period):
if iteration <= period:
return i
class CosineAnnealingRestartLR(_LRScheduler):
""" Cosine annealing with restarts learning rate scheme.
An example of config:
periods = [10, 10, 10, 10]
restart_weights = [1, 0.5, 0.5, 0.5]
eta_min=1e-7
It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
scheduler will restart with the weights in restart_weights.
Args:
optimizer (torch.nn.optimizer): Torch optimizer.
periods (list): Period for each cosine anneling cycle.
restart_weights (list): Restart weights at each restart iteration.
Default: [1].
eta_min (float): The minimum lr. Default: 0.
last_epoch (int): Used in _LRScheduler. Default: -1.
"""
def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1):
self.periods = periods
self.restart_weights = restart_weights
self.eta_min = eta_min
assert (len(self.periods) == len(
self.restart_weights)), 'periods and restart_weights should have the same length.'
self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))]
super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
idx = get_position_from_periods(self.last_epoch, self.cumulative_period)
current_weight = self.restart_weights[idx]
nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
current_period = self.periods[idx]
return [
self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
(1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period)))
for base_lr in self.base_lrs
]
import torch
from collections import OrderedDict
from os import path as osp
from tqdm import tqdm
from basicsr.archs import build_network
from basicsr.losses import build_loss
from basicsr.metrics import calculate_metric
from basicsr.utils import get_root_logger, imwrite, tensor2img
from basicsr.utils.registry import MODEL_REGISTRY
from .base_model import BaseModel
@MODEL_REGISTRY.register()
class SRModel(BaseModel):
"""Base SR model for single image super-resolution."""
def __init__(self, opt):
super(SRModel, self).__init__(opt)
# define network
self.net_g = build_network(opt['network_g'])
self.net_g = self.model_to_device(self.net_g)
self.print_network(self.net_g)
# load pretrained models
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
param_key = self.opt['path'].get('param_key_g', 'params')
self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
if self.is_train:
self.init_training_settings()
def init_training_settings(self):
self.net_g.train()
train_opt = self.opt['train']
self.ema_decay = train_opt.get('ema_decay', 0)
if self.ema_decay > 0:
logger = get_root_logger()
logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
# define network net_g with Exponential Moving Average (EMA)
# net_g_ema is used only for testing on one GPU and saving
# There is no need to wrap with DistributedDataParallel
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
# load pretrained model
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
else:
self.model_ema(0) # copy net_g weight
self.net_g_ema.eval()
# define losses
if train_opt.get('pixel_opt'):
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
else:
self.cri_pix = None
if train_opt.get('perceptual_opt'):
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
else:
self.cri_perceptual = None
if self.cri_pix is None and self.cri_perceptual is None:
raise ValueError('Both pixel and perceptual losses are None.')
# set up optimizers and schedulers
self.setup_optimizers()
self.setup_schedulers()
def setup_optimizers(self):
train_opt = self.opt['train']
optim_params = []
for k, v in self.net_g.named_parameters():
if v.requires_grad:
optim_params.append(v)
else:
logger = get_root_logger()
logger.warning(f'Params {k} will not be optimized.')
optim_type = train_opt['optim_g'].pop('type')
self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
self.optimizers.append(self.optimizer_g)
def feed_data(self, data):
self.lq = data['lq'].to(self.device)
if 'gt' in data:
self.gt = data['gt'].to(self.device)
def optimize_parameters(self, current_iter):
self.optimizer_g.zero_grad()
self.output = self.net_g(self.lq)
l_total = 0
loss_dict = OrderedDict()
# pixel loss
if self.cri_pix:
l_pix = self.cri_pix(self.output, self.gt)
l_total += l_pix
loss_dict['l_pix'] = l_pix
# perceptual loss
if self.cri_perceptual:
l_percep, l_style = self.cri_perceptual(self.output, self.gt)
if l_percep is not None:
l_total += l_percep
loss_dict['l_percep'] = l_percep
if l_style is not None:
l_total += l_style
loss_dict['l_style'] = l_style
l_total.backward()
self.optimizer_g.step()
self.log_dict = self.reduce_loss_dict(loss_dict)
if self.ema_decay > 0:
self.model_ema(decay=self.ema_decay)
def test(self):
if hasattr(self, 'net_g_ema'):
self.net_g_ema.eval()
with torch.no_grad():
self.output = self.net_g_ema(self.lq)
else:
self.net_g.eval()
with torch.no_grad():
self.output = self.net_g(self.lq)
self.net_g.train()
def test_selfensemble(self):
# TODO: to be tested
# 8 augmentations
# modified from https://github.com/thstkdgus35/EDSR-PyTorch
def _transform(v, op):
# if self.precision != 'single': v = v.float()
v2np = v.data.cpu().numpy()
if op == 'v':
tfnp = v2np[:, :, :, ::-1].copy()
elif op == 'h':
tfnp = v2np[:, :, ::-1, :].copy()
elif op == 't':
tfnp = v2np.transpose((0, 1, 3, 2)).copy()
ret = torch.Tensor(tfnp).to(self.device)
# if self.precision == 'half': ret = ret.half()
return ret
# prepare augmented data
lq_list = [self.lq]
for tf in 'v', 'h', 't':
lq_list.extend([_transform(t, tf) for t in lq_list])
# inference
if hasattr(self, 'net_g_ema'):
self.net_g_ema.eval()
with torch.no_grad():
out_list = [self.net_g_ema(aug) for aug in lq_list]
else:
self.net_g.eval()
with torch.no_grad():
out_list = [self.net_g_ema(aug) for aug in lq_list]
self.net_g.train()
# merge results
for i in range(len(out_list)):
if i > 3:
out_list[i] = _transform(out_list[i], 't')
if i % 4 > 1:
out_list[i] = _transform(out_list[i], 'h')
if (i % 4) % 2 == 1:
out_list[i] = _transform(out_list[i], 'v')
output = torch.cat(out_list, dim=0)
self.output = output.mean(dim=0, keepdim=True)
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
if self.opt['rank'] == 0:
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
dataset_name = dataloader.dataset.opt['name']
with_metrics = self.opt['val'].get('metrics') is not None
use_pbar = self.opt['val'].get('pbar', False)
if with_metrics:
if not hasattr(self, 'metric_results'): # only execute in the first run
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
# initialize the best metric results for each dataset_name (supporting multiple validation datasets)
self._initialize_best_metric_results(dataset_name)
# zero self.metric_results
if with_metrics:
self.metric_results = {metric: 0 for metric in self.metric_results}
metric_data = dict()
if use_pbar:
pbar = tqdm(total=len(dataloader), unit='image')
for idx, val_data in enumerate(dataloader):
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
self.feed_data(val_data)
self.test()
visuals = self.get_current_visuals()
sr_img = tensor2img([visuals['result']])
metric_data['img'] = sr_img
if 'gt' in visuals:
gt_img = tensor2img([visuals['gt']])
metric_data['img2'] = gt_img
del self.gt
# tentative for out of GPU memory
del self.lq
del self.output
torch.cuda.empty_cache()
if save_img:
if self.opt['is_train']:
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
f'{img_name}_{current_iter}.png')
else:
if self.opt['val']['suffix']:
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["val"]["suffix"]}.png')
else:
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["name"]}.png')
imwrite(sr_img, save_img_path)
if with_metrics:
# calculate metrics
for name, opt_ in self.opt['val']['metrics'].items():
self.metric_results[name] += calculate_metric(metric_data, opt_)
if use_pbar:
pbar.update(1)
pbar.set_description(f'Test {img_name}')
if use_pbar:
pbar.close()
if with_metrics:
for metric in self.metric_results.keys():
self.metric_results[metric] /= (idx + 1)
# update the best metric result
self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
log_str = f'Validation {dataset_name}\n'
for metric, value in self.metric_results.items():
log_str += f'\t # {metric}: {value:.4f}'
if hasattr(self, 'best_metric_results'):
log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
log_str += '\n'
logger = get_root_logger()
logger.info(log_str)
if tb_logger:
for metric, value in self.metric_results.items():
tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
def get_current_visuals(self):
out_dict = OrderedDict()
out_dict['lq'] = self.lq.detach().cpu()
out_dict['result'] = self.output.detach().cpu()
if hasattr(self, 'gt'):
out_dict['gt'] = self.gt.detach().cpu()
return out_dict
def save(self, epoch, current_iter):
if hasattr(self, 'net_g_ema'):
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
else:
self.save_network(self.net_g, 'net_g', current_iter)
self.save_training_state(epoch, current_iter)
import torch
from collections import Counter
from os import path as osp
from torch import distributed as dist
from tqdm import tqdm
from basicsr.metrics import calculate_metric
from basicsr.utils import get_root_logger, imwrite, tensor2img
from basicsr.utils.dist_util import get_dist_info
from basicsr.utils.registry import MODEL_REGISTRY
from .sr_model import SRModel
@MODEL_REGISTRY.register()
class VideoBaseModel(SRModel):
"""Base video SR model."""
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
dataset = dataloader.dataset
dataset_name = dataset.opt['name']
with_metrics = self.opt['val']['metrics'] is not None
# initialize self.metric_results
# It is a dict: {
# 'folder1': tensor (num_frame x len(metrics)),
# 'folder2': tensor (num_frame x len(metrics))
# }
if with_metrics:
if not hasattr(self, 'metric_results'): # only execute in the first run
self.metric_results = {}
num_frame_each_folder = Counter(dataset.data_info['folder'])
for folder, num_frame in num_frame_each_folder.items():
self.metric_results[folder] = torch.zeros(
num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda')
# initialize the best metric results
self._initialize_best_metric_results(dataset_name)
# zero self.metric_results
rank, world_size = get_dist_info()
if with_metrics:
for _, tensor in self.metric_results.items():
tensor.zero_()
metric_data = dict()
# record all frames (border and center frames)
if rank == 0:
pbar = tqdm(total=len(dataset), unit='frame')
for idx in range(rank, len(dataset), world_size):
val_data = dataset[idx]
val_data['lq'].unsqueeze_(0)
val_data['gt'].unsqueeze_(0)
folder = val_data['folder']
frame_idx, max_idx = val_data['idx'].split('/')
lq_path = val_data['lq_path']
self.feed_data(val_data)
self.test()
visuals = self.get_current_visuals()
result_img = tensor2img([visuals['result']])
metric_data['img'] = result_img
if 'gt' in visuals:
gt_img = tensor2img([visuals['gt']])
metric_data['img2'] = gt_img
del self.gt
# tentative for out of GPU memory
del self.lq
del self.output
torch.cuda.empty_cache()
if save_img:
if self.opt['is_train']:
raise NotImplementedError('saving image is not supported during training.')
else:
if 'vimeo' in dataset_name.lower(): # vimeo90k dataset
split_result = lq_path.split('/')
img_name = f'{split_result[-3]}_{split_result[-2]}_{split_result[-1].split(".")[0]}'
else: # other datasets, e.g., REDS, Vid4
img_name = osp.splitext(osp.basename(lq_path))[0]
if self.opt['val']['suffix']:
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
f'{img_name}_{self.opt["val"]["suffix"]}.png')
else:
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
f'{img_name}_{self.opt["name"]}.png')
imwrite(result_img, save_img_path)
if with_metrics:
# calculate metrics
for metric_idx, opt_ in enumerate(self.opt['val']['metrics'].values()):
result = calculate_metric(metric_data, opt_)
self.metric_results[folder][int(frame_idx), metric_idx] += result
# progress bar
if rank == 0:
for _ in range(world_size):
pbar.update(1)
pbar.set_description(f'Test {folder}: {int(frame_idx) + world_size}/{max_idx}')
if rank == 0:
pbar.close()
if with_metrics:
if self.opt['dist']:
# collect data among GPUs
for _, tensor in self.metric_results.items():
dist.reduce(tensor, 0)
dist.barrier()
else:
pass # assume use one gpu in non-dist testing
if rank == 0:
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
logger = get_root_logger()
logger.warning('nondist_validation is not implemented. Run dist_validation.')
self.dist_validation(dataloader, current_iter, tb_logger, save_img)
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
# ----------------- calculate the average values for each folder, and for each metric ----------------- #
# average all frames for each sub-folder
# metric_results_avg is a dict:{
# 'folder1': tensor (len(metrics)),
# 'folder2': tensor (len(metrics))
# }
metric_results_avg = {
folder: torch.mean(tensor, dim=0).cpu()
for (folder, tensor) in self.metric_results.items()
}
# total_avg_results is a dict: {
# 'metric1': float,
# 'metric2': float
# }
total_avg_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
for folder, tensor in metric_results_avg.items():
for idx, metric in enumerate(total_avg_results.keys()):
total_avg_results[metric] += metric_results_avg[folder][idx].item()
# average among folders
for metric in total_avg_results.keys():
total_avg_results[metric] /= len(metric_results_avg)
# update the best metric result
self._update_best_metric_result(dataset_name, metric, total_avg_results[metric], current_iter)
# ------------------------------------------ log the metric ------------------------------------------ #
log_str = f'Validation {dataset_name}\n'
for metric_idx, (metric, value) in enumerate(total_avg_results.items()):
log_str += f'\t # {metric}: {value:.4f}'
for folder, tensor in metric_results_avg.items():
log_str += f'\t # {folder}: {tensor[metric_idx].item():.4f}'
if hasattr(self, 'best_metric_results'):
log_str += (f'\n\t Best: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
log_str += '\n'
logger = get_root_logger()
logger.info(log_str)
if tb_logger:
for metric_idx, (metric, value) in enumerate(total_avg_results.items()):
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
for folder, tensor in metric_results_avg.items():
tb_logger.add_scalar(f'metrics/{metric}/{folder}', tensor[metric_idx].item(), current_iter)
import torch
from collections import Counter
from os import path as osp
from torch import distributed as dist
from tqdm import tqdm
from basicsr.metrics import calculate_metric
from basicsr.utils import get_root_logger, imwrite, tensor2img
from basicsr.utils.dist_util import get_dist_info
from basicsr.utils.registry import MODEL_REGISTRY
from .video_base_model import VideoBaseModel
@MODEL_REGISTRY.register()
class VideoRecurrentModel(VideoBaseModel):
def __init__(self, opt):
super(VideoRecurrentModel, self).__init__(opt)
if self.is_train:
self.fix_flow_iter = opt['train'].get('fix_flow')
def setup_optimizers(self):
train_opt = self.opt['train']
flow_lr_mul = train_opt.get('flow_lr_mul', 1)
logger = get_root_logger()
logger.info(f'Multiple the learning rate for flow network with {flow_lr_mul}.')
if flow_lr_mul == 1:
optim_params = self.net_g.parameters()
else: # separate flow params and normal params for different lr
normal_params = []
flow_params = []
for name, param in self.net_g.named_parameters():
if 'spynet' in name:
flow_params.append(param)
else:
normal_params.append(param)
optim_params = [
{ # add normal params first
'params': normal_params,
'lr': train_opt['optim_g']['lr']
},
{
'params': flow_params,
'lr': train_opt['optim_g']['lr'] * flow_lr_mul
},
]
optim_type = train_opt['optim_g'].pop('type')
self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
self.optimizers.append(self.optimizer_g)
def optimize_parameters(self, current_iter):
if self.fix_flow_iter:
logger = get_root_logger()
if current_iter == 1:
logger.info(f'Fix flow network and feature extractor for {self.fix_flow_iter} iters.')
for name, param in self.net_g.named_parameters():
if 'spynet' in name or 'edvr' in name:
param.requires_grad_(False)
elif current_iter == self.fix_flow_iter:
logger.warning('Train all the parameters.')
self.net_g.requires_grad_(True)
super(VideoRecurrentModel, self).optimize_parameters(current_iter)
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
dataset = dataloader.dataset
dataset_name = dataset.opt['name']
with_metrics = self.opt['val']['metrics'] is not None
# initialize self.metric_results
# It is a dict: {
# 'folder1': tensor (num_frame x len(metrics)),
# 'folder2': tensor (num_frame x len(metrics))
# }
if with_metrics:
if not hasattr(self, 'metric_results'): # only execute in the first run
self.metric_results = {}
num_frame_each_folder = Counter(dataset.data_info['folder'])
for folder, num_frame in num_frame_each_folder.items():
self.metric_results[folder] = torch.zeros(
num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda')
# initialize the best metric results
self._initialize_best_metric_results(dataset_name)
# zero self.metric_results
rank, world_size = get_dist_info()
if with_metrics:
for _, tensor in self.metric_results.items():
tensor.zero_()
metric_data = dict()
num_folders = len(dataset)
num_pad = (world_size - (num_folders % world_size)) % world_size
if rank == 0:
pbar = tqdm(total=len(dataset), unit='folder')
# Will evaluate (num_folders + num_pad) times, but only the first num_folders results will be recorded.
# (To avoid wait-dead)
for i in range(rank, num_folders + num_pad, world_size):
idx = min(i, num_folders - 1)
val_data = dataset[idx]
folder = val_data['folder']
# compute outputs
val_data['lq'].unsqueeze_(0)
val_data['gt'].unsqueeze_(0)
self.feed_data(val_data)
val_data['lq'].squeeze_(0)
val_data['gt'].squeeze_(0)
self.test()
visuals = self.get_current_visuals()
# tentative for out of GPU memory
del self.lq
del self.output
if 'gt' in visuals:
del self.gt
torch.cuda.empty_cache()
if self.center_frame_only:
visuals['result'] = visuals['result'].unsqueeze(1)
if 'gt' in visuals:
visuals['gt'] = visuals['gt'].unsqueeze(1)
# evaluate
if i < num_folders:
for idx in range(visuals['result'].size(1)):
result = visuals['result'][0, idx, :, :, :]
result_img = tensor2img([result]) # uint8, bgr
metric_data['img'] = result_img
if 'gt' in visuals:
gt = visuals['gt'][0, idx, :, :, :]
gt_img = tensor2img([gt]) # uint8, bgr
metric_data['img2'] = gt_img
if save_img:
if self.opt['is_train']:
raise NotImplementedError('saving image is not supported during training.')
else:
if self.center_frame_only: # vimeo-90k
clip_ = val_data['lq_path'].split('/')[-3]
seq_ = val_data['lq_path'].split('/')[-2]
name_ = f'{clip_}_{seq_}'
img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
f"{name_}_{self.opt['name']}.png")
else: # others
img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
f"{idx:08d}_{self.opt['name']}.png")
# image name only for REDS dataset
imwrite(result_img, img_path)
# calculate metrics
if with_metrics:
for metric_idx, opt_ in enumerate(self.opt['val']['metrics'].values()):
result = calculate_metric(metric_data, opt_)
self.metric_results[folder][idx, metric_idx] += result
# progress bar
if rank == 0:
for _ in range(world_size):
pbar.update(1)
pbar.set_description(f'Folder: {folder}')
if rank == 0:
pbar.close()
if with_metrics:
if self.opt['dist']:
# collect data among GPUs
for _, tensor in self.metric_results.items():
dist.reduce(tensor, 0)
dist.barrier()
if rank == 0:
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
def test(self):
n = self.lq.size(1)
self.net_g.eval()
flip_seq = self.opt['val'].get('flip_seq', False)
self.center_frame_only = self.opt['val'].get('center_frame_only', False)
if flip_seq:
self.lq = torch.cat([self.lq, self.lq.flip(1)], dim=1)
with torch.no_grad():
self.output = self.net_g(self.lq)
if flip_seq:
output_1 = self.output[:, :n, :, :, :]
output_2 = self.output[:, n:, :, :, :].flip(1)
self.output = 0.5 * (output_1 + output_2)
if self.center_frame_only:
self.output = self.output[:, n // 2, :, :, :]
self.net_g.train()
from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,
modulated_deform_conv)
__all__ = [
'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
'modulated_deform_conv'
]
import math
import os
import torch
from torch import nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn import functional as F
from torch.nn.modules.utils import _pair, _single
BASICSR_JIT = os.getenv('BASICSR_JIT')
if BASICSR_JIT == 'True':
from torch.utils.cpp_extension import load
module_path = os.path.dirname(__file__)
deform_conv_ext = load(
'deform_conv',
sources=[
os.path.join(module_path, 'src', 'deform_conv_ext.cpp'),
os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'),
os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'),
],
)
else:
try:
from . import deform_conv_ext
except ImportError:
pass
# avoid annoying print output
# print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n '
# '1. compile with BASICSR_EXT=True. or\n '
# '2. set BASICSR_JIT=True during running')
class DeformConvFunction(Function):
@staticmethod
def forward(ctx,
input,
offset,
weight,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
im2col_step=64):
if input is not None and input.dim() != 4:
raise ValueError(f'Expected 4D tensor as input, got {input.dim()}D tensor instead.')
ctx.stride = _pair(stride)
ctx.padding = _pair(padding)
ctx.dilation = _pair(dilation)
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.im2col_step = im2col_step
ctx.save_for_backward(input, offset, weight)
output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
if not input.is_cuda:
raise NotImplementedError
else:
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
deform_conv_ext.deform_conv_forward(input, weight,
offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
ctx.deformable_groups, cur_im2col_step)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
input, offset, weight = ctx.saved_tensors
grad_input = grad_offset = grad_weight = None
if not grad_output.is_cuda:
raise NotImplementedError
else:
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input,
grad_offset, weight, ctx.bufs_[0], weight.size(3),
weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
ctx.deformable_groups, cur_im2col_step)
if ctx.needs_input_grad[2]:
grad_weight = torch.zeros_like(weight)
deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight,
ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
weight.size(2), ctx.stride[1], ctx.stride[0],
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
cur_im2col_step)
return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
@staticmethod
def _output_size(input, weight, padding, dilation, stride):
channels = weight.size(0)
output_size = (input.size(0), channels)
for d in range(input.dim() - 2):
in_size = input.size(d + 2)
pad = padding[d]
kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
stride_ = stride[d]
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
if not all(map(lambda s: s > 0, output_size)):
raise ValueError(f'convolution input is too small (output would be {"x".join(map(str, output_size))})')
return output_size
class ModulatedDeformConvFunction(Function):
@staticmethod
def forward(ctx,
input,
offset,
mask,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1):
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.with_bias = bias is not None
if not ctx.with_bias:
bias = input.new_empty(1) # fake tensor
if not input.is_cuda:
raise NotImplementedError
if weight.requires_grad or mask.requires_grad or offset.requires_grad or input.requires_grad:
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output,
ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.groups, ctx.deformable_groups, ctx.with_bias)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
input, offset, mask, weight, bias = ctx.saved_tensors
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
grad_mask = torch.zeros_like(mask)
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias)
deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
grad_output, weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.groups, ctx.deformable_groups, ctx.with_bias)
if not ctx.with_bias:
grad_bias = None
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None)
@staticmethod
def _infer_shape(ctx, input, weight):
n = input.size(0)
channels_out = weight.size(0)
height, width = input.shape[2:4]
kernel_h, kernel_w = weight.shape[2:4]
height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
return n, channels_out, height_out, width_out
deform_conv = DeformConvFunction.apply
modulated_deform_conv = ModulatedDeformConvFunction.apply
class DeformConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=False):
super(DeformConv, self).__init__()
assert not bias
assert in_channels % groups == 0, f'in_channels {in_channels} is not divisible by groups {groups}'
assert out_channels % groups == 0, f'out_channels {out_channels} is not divisible by groups {groups}'
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.deformable_groups = deformable_groups
# enable compatibility with nn.Conv2d
self.transposed = False
self.output_padding = _single(0)
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
def forward(self, x, offset):
# To fix an assert error in deform_conv_cuda.cpp:128
# input image is smaller than kernel
input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1])
if input_pad:
pad_h = max(self.kernel_size[0] - x.size(2), 0)
pad_w = max(self.kernel_size[1] - x.size(3), 0)
x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
self.deformable_groups)
if input_pad:
out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous()
return out
class DeformConvPack(DeformConv):
"""A Deformable Conv Encapsulation that acts as normal Conv layers.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d.
padding (int or tuple[int]): Same as nn.Conv2d.
dilation (int or tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
"""
_version = 2
def __init__(self, *args, **kwargs):
super(DeformConvPack, self).__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
kernel_size=self.kernel_size,
stride=_pair(self.stride),
padding=_pair(self.padding),
dilation=_pair(self.dilation),
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x):
offset = self.conv_offset(x)
return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
self.deformable_groups)
class ModulatedDeformConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=True):
super(ModulatedDeformConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.deformable_groups = deformable_groups
self.with_bias = bias
# enable compatibility with nn.Conv2d
self.transposed = False
self.output_padding = _single(0)
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.init_weights()
def init_weights(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.zero_()
def forward(self, x, offset, mask):
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups)
class ModulatedDeformConvPack(ModulatedDeformConv):
"""A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d.
padding (int or tuple[int]): Same as nn.Conv2d.
dilation (int or tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
"""
_version = 2
def __init__(self, *args, **kwargs):
super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
kernel_size=self.kernel_size,
stride=_pair(self.stride),
padding=_pair(self.padding),
dilation=_pair(self.dilation),
bias=True)
self.init_weights()
def init_weights(self):
super(ModulatedDeformConvPack, self).init_weights()
if hasattr(self, 'conv_offset'):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x):
out = self.conv_offset(x)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups)
// modify from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
#include <torch/extension.h>
#include <ATen/DeviceGuard.h>
#include <cmath>
#include <vector>
void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
const int channels, const int height, const int width,
const int ksize_h, const int ksize_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
at::Tensor data_col);
void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
const int channels, const int height, const int width,
const int ksize_h, const int ksize_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
at::Tensor grad_im);
void deformable_col2im_coord(
const at::Tensor data_col, const at::Tensor data_im,
const at::Tensor data_offset, const int channels, const int height,
const int width, const int ksize_h, const int ksize_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int parallel_imgs,
const int deformable_group, at::Tensor grad_offset);
void modulated_deformable_im2col_cuda(
const at::Tensor data_im, const at::Tensor data_offset,
const at::Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int deformable_group,
at::Tensor data_col);
void modulated_deformable_col2im_cuda(
const at::Tensor data_col, const at::Tensor data_offset,
const at::Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int deformable_group,
at::Tensor grad_im);
void modulated_deformable_col2im_coord_cuda(
const at::Tensor data_col, const at::Tensor data_im,
const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im,
const int width_im, const int height_col, const int width_col,
const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w, const int dilation_h,
const int dilation_w, const int deformable_group, at::Tensor grad_offset,
at::Tensor grad_mask);
void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
int padW, int dilationH, int dilationW, int group,
int deformable_group) {
TORCH_CHECK(weight.ndimension() == 4,
"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
"but got: %s",
weight.ndimension());
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
TORCH_CHECK(kW > 0 && kH > 0,
"kernel size should be greater than zero, but got kH: %d kW: %d", kH,
kW);
TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
"kernel size should be consistent with weight, ",
"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
kW, weight.size(2), weight.size(3));
TORCH_CHECK(dW > 0 && dH > 0,
"stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
TORCH_CHECK(
dilationW > 0 && dilationH > 0,
"dilation should be greater than 0, but got dilationH: %d dilationW: %d",
dilationH, dilationW);
int ndim = input.ndimension();
int dimf = 0;
int dimh = 1;
int dimw = 2;
if (ndim == 4) {
dimf++;
dimh++;
dimw++;
}
TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
ndim);
long nInputPlane = weight.size(1) * group;
long inputHeight = input.size(dimh);
long inputWidth = input.size(dimw);
long nOutputPlane = weight.size(0);
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
TORCH_CHECK(nInputPlane % deformable_group == 0,
"input channels must divide deformable group size");
if (outputWidth < 1 || outputHeight < 1)
AT_ERROR(
"Given input size: (%ld x %ld x %ld). "
"Calculated output size: (%ld x %ld x %ld). Output size is too small",
nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
outputWidth);
TORCH_CHECK(input.size(1) == nInputPlane,
"invalid number of input planes, expected: %d, but got: %d",
nInputPlane, input.size(1));
TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
"input image is smaller than kernel");
TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
"invalid spatial size of offset, expected height: %d width: %d, but "
"got height: %d width: %d",
outputHeight, outputWidth, offset.size(2), offset.size(3));
TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
"invalid number of channels of offset");
if (gradOutput != NULL) {
TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane,
"invalid number of gradOutput planes, expected: %d, but got: %d",
nOutputPlane, gradOutput->size(dimf));
TORCH_CHECK((gradOutput->size(dimh) == outputHeight &&
gradOutput->size(dimw) == outputWidth),
"invalid size of gradOutput, expected height: %d width: %d , but "
"got height: %d width: %d",
outputHeight, outputWidth, gradOutput->size(dimh),
gradOutput->size(dimw));
}
}
int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
at::Tensor offset, at::Tensor output,
at::Tensor columns, at::Tensor ones, int kW,
int kH, int dW, int dH, int padW, int padH,
int dilationW, int dilationH, int group,
int deformable_group, int im2col_step) {
// todo: resize columns to include im2col: done
// todo: add im2col_step as input
// todo: add new output buffer and transpose it to output (or directly
// transpose output) todo: possibly change data indexing because of
// parallel_imgs
shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
dilationH, dilationW, group, deformable_group);
at::DeviceGuard guard(input.device());
input = input.contiguous();
offset = offset.contiguous();
weight = weight.contiguous();
int batch = 1;
if (input.ndimension() == 3) {
// Force batch
batch = 0;
input.unsqueeze_(0);
offset.unsqueeze_(0);
}
// todo: assert batchsize dividable by im2col_step
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = weight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
outputHeight, outputWidth});
columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
input.options());
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
ones = at::ones({outputHeight, outputWidth}, input.options());
}
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
offset =
offset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
at::Tensor output_buffer =
at::zeros({batchSize / im2col_step, nOutputPlane,
im2col_step * outputHeight, outputWidth},
output.options());
output_buffer = output_buffer.view(
{output_buffer.size(0), group, output_buffer.size(1) / group,
output_buffer.size(2), output_buffer.size(3)});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
dilationW, im2col_step, deformable_group, columns);
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
for (int g = 0; g < group; g++) {
output_buffer[elt][g] = output_buffer[elt][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(output_buffer[elt][g]);
}
}
output_buffer = output_buffer.view(
{output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
output_buffer.size(3), output_buffer.size(4)});
output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
im2col_step, outputHeight, outputWidth});
output_buffer.transpose_(1, 2);
output.copy_(output_buffer);
output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0) {
output = output.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
}
return 1;
}
int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
at::Tensor gradOutput, at::Tensor gradInput,
at::Tensor gradOffset, at::Tensor weight,
at::Tensor columns, int kW, int kH, int dW,
int dH, int padW, int padH, int dilationW,
int dilationH, int group,
int deformable_group, int im2col_step) {
shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
dilationH, dilationW, group, deformable_group);
at::DeviceGuard guard(input.device());
input = input.contiguous();
offset = offset.contiguous();
gradOutput = gradOutput.contiguous();
weight = weight.contiguous();
int batch = 1;
if (input.ndimension() == 3) {
// Force batch
batch = 0;
input = input.view({1, input.size(0), input.size(1), input.size(2)});
offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
gradOutput = gradOutput.view(
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
}
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = weight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
input.options());
// change order of grad output
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
nOutputPlane, outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight,
outputWidth});
offset =
offset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
// divide into groups
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
gradOutput = gradOutput.view(
{gradOutput.size(0), group, gradOutput.size(1) / group,
gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
for (int g = 0; g < group; g++) {
columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
gradOutput = gradOutput.view(
{gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
dilationH, dilationW, im2col_step, deformable_group,
gradOffset[elt]);
deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
dilationW, im2col_step, deformable_group, gradInput[elt]);
}
gradOutput.transpose_(1, 2);
gradOutput =
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
gradOffset = gradOffset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0) {
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
gradOffset =
gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
}
return 1;
}
int deform_conv_backward_parameters_cuda(
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
at::Tensor gradWeight, // at::Tensor gradBias,
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
int padW, int padH, int dilationW, int dilationH, int group,
int deformable_group, float scale, int im2col_step) {
// todo: transpose and reshape outGrad
// todo: reshape columns
// todo: add im2col_step as input
shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
padW, dilationH, dilationW, group, deformable_group);
at::DeviceGuard guard(input.device());
input = input.contiguous();
offset = offset.contiguous();
gradOutput = gradOutput.contiguous();
int batch = 1;
if (input.ndimension() == 3) {
// Force batch
batch = 0;
input = input.view(
at::IntList({1, input.size(0), input.size(1), input.size(2)}));
gradOutput = gradOutput.view(
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
}
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = gradWeight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
input.options());
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
nOutputPlane, outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
gradOutputBuffer =
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
outputHeight, outputWidth});
gradOutputBuffer.copy_(gradOutput);
gradOutputBuffer =
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
im2col_step * outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
gradOutput =
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
offset =
offset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
dilationW, im2col_step, deformable_group, columns);
// divide into group
gradOutputBuffer = gradOutputBuffer.view(
{gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
gradWeight =
gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
gradWeight.size(2), gradWeight.size(3)});
for (int g = 0; g < group; g++) {
gradWeight[g] = gradWeight[g]
.flatten(1)
.addmm_(gradOutputBuffer[elt][g].flatten(1),
columns[g].transpose(1, 0), 1.0, scale)
.view_as(gradWeight[g]);
}
gradOutputBuffer = gradOutputBuffer.view(
{gradOutputBuffer.size(0),
gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
gradWeight.size(2), gradWeight.size(3),
gradWeight.size(4)});
}
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0) {
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
}
return 1;
}
void modulated_deform_conv_cuda_forward(
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group, const int deformable_group,
const bool with_bias) {
TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
at::DeviceGuard guard(input.device());
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).",
channels, channels_kernel * group);
const int height_out =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out) {
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.options());
}
// resize output
output = output.view({batch, channels_out, height_out, width_out}).zero_();
// resize temporary columns
columns =
at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
input.options());
output = output.view({output.size(0), group, output.size(1) / group,
output.size(2), output.size(3)});
for (int b = 0; b < batch; b++) {
modulated_deformable_im2col_cuda(
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, columns);
// divide into group
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
for (int g = 0; g < group; g++) {
output[b][g] = output[b][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(output[b][g]);
}
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
weight.size(3), weight.size(4)});
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
}
output = output.view({output.size(0), output.size(1) * output.size(2),
output.size(3), output.size(4)});
if (with_bias) {
output += bias.view({1, bias.size(0), 1, 1});
}
}
void modulated_deform_conv_cuda_backward(
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor columns,
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
const bool with_bias) {
TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
at::DeviceGuard guard(input.device());
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).",
channels, channels_kernel * group);
const int height_out =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out) {
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.options());
}
grad_input = grad_input.view({batch, channels, height, width});
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
input.options());
grad_output =
grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
grad_output.size(2), grad_output.size(3)});
for (int b = 0; b < batch; b++) {
// divide int group
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
for (int g = 0; g < group; g++) {
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
grad_output[b][g].flatten(1), 0.0f, 1.0f);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
weight.size(3), weight.size(4)});
// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda(
columns, input[b], offset[b], mask[b], 1, channels, height, width,
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
grad_mask[b]);
// gradient w.r.t. input data
modulated_deformable_col2im_cuda(
columns, offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, grad_input[b]);
// gradient w.r.t. weight, dWeight should accumulate across the batch and
// group
modulated_deformable_im2col_cuda(
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, columns);
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
grad_weight.size(1), grad_weight.size(2),
grad_weight.size(3)});
if (with_bias)
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
for (int g = 0; g < group; g++) {
grad_weight[g] =
grad_weight[g]
.flatten(1)
.addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
.view_as(grad_weight[g]);
if (with_bias) {
grad_bias[g] =
grad_bias[g]
.view({-1, 1})
.addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
.view(-1);
}
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
grad_weight.size(2), grad_weight.size(3),
grad_weight.size(4)});
if (with_bias)
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
}
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
grad_output.size(2), grad_output.size(3),
grad_output.size(4)});
}
/*!
******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
*
* COPYRIGHT
*
* All contributions by the University of California:
* Copyright (c) 2014-2017 The Regents of the University of California (Regents)
* All rights reserved.
*
* All other contributions:
* Copyright (c) 2014-2017, the respective contributors
* All rights reserved.
*
* Caffe uses a shared copyright model: each contributor holds copyright over
* their contributions to Caffe. The project versioning records all such
* contribution and copyright details. If a contributor wants to further mark
* their specific copyright on a particular contribution, they should indicate
* their copyright solely in the commit message of the change when it is
* committed.
*
* LICENSE
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* CONTRIBUTION AGREEMENT
*
* By contributing to the BVLC/caffe repository through pull-request, comment,
* or otherwise, the contributor releases their content to the
* license and copyright terms herein.
*
***************** END Caffe Copyright Notice and Disclaimer ********************
*
* Copyright (c) 2018 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file modulated_deformable_im2col.cuh
* \brief Function definitions of converting an image to
* column matrix based on kernel, padding, dilation, and offset.
* These functions are mainly used in deformable convolution operators.
* \ref: https://arxiv.org/abs/1703.06211
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
*/
// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
#include <stdio.h>
#include <math.h>
#include <float.h>
using namespace at;
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
const int kMaxGridNum = 65535;
inline int GET_BLOCKS(const int N)
{
return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
}
template <typename scalar_t>
__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
const int height, const int width, scalar_t h, scalar_t w)
{
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
scalar_t lh = h - h_low;
scalar_t lw = w - w_low;
scalar_t hh = 1 - lh, hw = 1 - lw;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
v1 = bottom_data[h_low * data_width + w_low];
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = bottom_data[h_low * data_width + w_high];
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = bottom_data[h_high * data_width + w_low];
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = bottom_data[h_high * data_width + w_high];
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename scalar_t>
__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
const int h, const int w, const int height, const int width)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
scalar_t weight = 0;
if (h == argmax_h_low && w == argmax_w_low)
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
if (h == argmax_h_low && w == argmax_w_high)
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
if (h == argmax_h_high && w == argmax_w_low)
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
if (h == argmax_h_high && w == argmax_w_high)
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
return weight;
}
template <typename scalar_t>
__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
const int height, const int width, const scalar_t *im_data,
const int data_width, const int bp_dir)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
scalar_t weight = 0;
if (bp_dir == 0)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
else if (bp_dir == 1)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
return weight;
}
template <typename scalar_t>
__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int channel_per_deformable_group,
const int batch_size, const int num_channels, const int deformable_group,
const int height_col, const int width_col,
scalar_t *data_col)
{
CUDA_KERNEL_LOOP(index, n)
{
// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
const int b_col = (index / width_col / height_col) % batch_size;
const int c_im = (index / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
// compute deformable group index
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
//const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
for (int i = 0; i < kernel_h; ++i)
{
for (int j = 0; j < kernel_w; ++j)
{
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
scalar_t val = static_cast<scalar_t>(0);
const scalar_t h_im = h_in + i * dilation_h + offset_h;
const scalar_t w_im = w_in + j * dilation_w + offset_w;
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
{
//const scalar_t map_h = i * dilation_h + offset_h;
//const scalar_t map_w = j * dilation_w + offset_w;
//const int cur_height = height - h_in;
//const int cur_width = width - w_in;
//val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val;
data_col_ptr += batch_size * height_col * width_col;
}
}
}
}
void deformable_im2col(
const at::Tensor data_im, const at::Tensor data_offset, const int channels,
const int height, const int width, const int ksize_h, const int ksize_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int parallel_imgs,
const int deformable_group, at::Tensor data_col)
{
// num_axes should be smaller than block size
// todo: check parallel_imgs is correctly passed in
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels = channels * height_col * width_col * parallel_imgs;
int channel_per_deformable_group = channels / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
channel_per_deformable_group, parallel_imgs, channels, deformable_group,
height_col, width_col, data_col_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
}
}
template <typename scalar_t>
__global__ void deformable_col2im_gpu_kernel(
const int n, const scalar_t *data_col, const scalar_t *data_offset,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int deformable_group,
const int height_col, const int width_col,
scalar_t *grad_im)
{
CUDA_KERNEL_LOOP(index, n)
{
const int j = (index / width_col / height_col / batch_size) % kernel_w;
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
// compute the start and end of the output
const int deformable_group_index = c / channel_per_deformable_group;
int w_out = index % width_col;
int h_out = (index / width_col) % height_col;
int b = (index / width_col / height_col) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
2 * kernel_h * kernel_w * height_col * width_col;
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
const scalar_t cur_top_grad = data_col[index];
const int cur_h = (int)cur_inv_h_data;
const int cur_w = (int)cur_inv_w_data;
for (int dy = -2; dy <= 2; dy++)
{
for (int dx = -2; dx <= 2; dx++)
{
if (cur_h + dy >= 0 && cur_h + dy < height &&
cur_w + dx >= 0 && cur_w + dx < width &&
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
abs(cur_inv_w_data - (cur_w + dx)) < 1)
{
int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
}
}
}
}
}
void deformable_col2im(
const at::Tensor data_col, const at::Tensor data_offset, const int channels,
const int height, const int width, const int ksize_h,
const int ksize_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
at::Tensor grad_im)
{
// todo: make sure parallel_imgs is passed in correctly
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
int channel_per_deformable_group = channels / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
ksize_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
parallel_imgs, deformable_group, height_col, width_col, grad_im_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
}
}
template <typename scalar_t>
__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
const scalar_t *data_im, const scalar_t *data_offset,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int offset_channels, const int deformable_group,
const int height_col, const int width_col, scalar_t *grad_offset)
{
CUDA_KERNEL_LOOP(index, n)
{
scalar_t val = 0;
int w = index % width_col;
int h = (index / width_col) % height_col;
int c = (index / width_col / height_col) % offset_channels;
int b = (index / width_col / height_col) / offset_channels;
// compute the start and end of the output
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
const int col_step = kernel_h * kernel_w;
int cnt = 0;
const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
batch_size * width_col * height_col;
const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
channel_per_deformable_group / kernel_h / kernel_w * height * width;
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
kernel_h * kernel_w * height_col * width_col;
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
{
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
const int bp_dir = offset_c % 2;
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
int w_out = col_pos % width_col;
int h_out = (col_pos / width_col) % height_col;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
scalar_t inv_h = h_in + i * dilation_h + offset_h;
scalar_t inv_w = w_in + j * dilation_w + offset_w;
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
{
inv_h = inv_w = -2;
}
const scalar_t weight = get_coordinate_weight(
inv_h, inv_w,
height, width, data_im_ptr + cnt * height * width, width, bp_dir);
val += weight * data_col_ptr[col_pos];
cnt += 1;
}
grad_offset[index] = val;
}
}
void deformable_col2im_coord(
const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
const int channels, const int height, const int width, const int ksize_h,
const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
const int stride_w, const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
{
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;
int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,
height_col, width_col, grad_offset_);
}));
}
template <typename scalar_t>
__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
const int height, const int width, scalar_t h, scalar_t w)
{
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
scalar_t lh = h - h_low;
scalar_t lw = w - w_low;
scalar_t hh = 1 - lh, hw = 1 - lw;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
v1 = bottom_data[h_low * data_width + w_low];
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = bottom_data[h_low * data_width + w_high];
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = bottom_data[h_high * data_width + w_low];
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = bottom_data[h_high * data_width + w_high];
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename scalar_t>
__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
const int h, const int w, const int height, const int width)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
scalar_t weight = 0;
if (h == argmax_h_low && w == argmax_w_low)
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
if (h == argmax_h_low && w == argmax_w_high)
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
if (h == argmax_h_high && w == argmax_w_low)
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
if (h == argmax_h_high && w == argmax_w_high)
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
return weight;
}
template <typename scalar_t>
__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
const int height, const int width, const scalar_t *im_data,
const int data_width, const int bp_dir)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
scalar_t weight = 0;
if (bp_dir == 0)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
else if (bp_dir == 1)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
return weight;
}
template <typename scalar_t>
__global__ void modulated_deformable_im2col_gpu_kernel(const int n,
const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int num_channels, const int deformable_group,
const int height_col, const int width_col,
scalar_t *data_col)
{
CUDA_KERNEL_LOOP(index, n)
{
// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
const int b_col = (index / width_col / height_col) % batch_size;
const int c_im = (index / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
// compute deformable group index
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
//const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
for (int i = 0; i < kernel_h; ++i)
{
for (int j = 0; j < kernel_w; ++j)
{
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
scalar_t val = static_cast<scalar_t>(0);
const scalar_t h_im = h_in + i * dilation_h + offset_h;
const scalar_t w_im = w_in + j * dilation_w + offset_w;
//if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
{
//const float map_h = i * dilation_h + offset_h;
//const float map_w = j * dilation_w + offset_w;
//const int cur_height = height - h_in;
//const int cur_width = width - w_in;
//val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val * mask;
data_col_ptr += batch_size * height_col * width_col;
//data_col_ptr += height_col * width_col;
}
}
}
}
template <typename scalar_t>
__global__ void modulated_deformable_col2im_gpu_kernel(const int n,
const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int deformable_group,
const int height_col, const int width_col,
scalar_t *grad_im)
{
CUDA_KERNEL_LOOP(index, n)
{
const int j = (index / width_col / height_col / batch_size) % kernel_w;
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
// compute the start and end of the output
const int deformable_group_index = c / channel_per_deformable_group;
int w_out = index % width_col;
int h_out = (index / width_col) % height_col;
int b = (index / width_col / height_col) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
const scalar_t cur_top_grad = data_col[index] * mask;
const int cur_h = (int)cur_inv_h_data;
const int cur_w = (int)cur_inv_w_data;
for (int dy = -2; dy <= 2; dy++)
{
for (int dx = -2; dx <= 2; dx++)
{
if (cur_h + dy >= 0 && cur_h + dy < height &&
cur_w + dx >= 0 && cur_w + dx < width &&
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
abs(cur_inv_w_data - (cur_w + dx)) < 1)
{
int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
}
}
}
}
}
template <typename scalar_t>
__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
const scalar_t *data_col, const scalar_t *data_im,
const scalar_t *data_offset, const scalar_t *data_mask,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int offset_channels, const int deformable_group,
const int height_col, const int width_col,
scalar_t *grad_offset, scalar_t *grad_mask)
{
CUDA_KERNEL_LOOP(index, n)
{
scalar_t val = 0, mval = 0;
int w = index % width_col;
int h = (index / width_col) % height_col;
int c = (index / width_col / height_col) % offset_channels;
int b = (index / width_col / height_col) / offset_channels;
// compute the start and end of the output
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
const int col_step = kernel_h * kernel_w;
int cnt = 0;
const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
{
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
const int bp_dir = offset_c % 2;
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
int w_out = col_pos % width_col;
int h_out = (col_pos / width_col) % height_col;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
scalar_t inv_h = h_in + i * dilation_h + offset_h;
scalar_t inv_w = w_in + j * dilation_w + offset_w;
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
{
inv_h = inv_w = -2;
}
else
{
mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
}
const scalar_t weight = dmcn_get_coordinate_weight(
inv_h, inv_w,
height, width, data_im_ptr + cnt * height * width, width, bp_dir);
val += weight * data_col_ptr[col_pos] * mask;
cnt += 1;
}
// KERNEL_ASSIGN(grad_offset[index], offset_req, val);
grad_offset[index] = val;
if (offset_c % 2 == 0)
// KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
}
}
void modulated_deformable_im2col_cuda(
const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor data_col)
{
// num_axes should be smaller than block size
const int channel_per_deformable_group = channels / deformable_group;
const int num_kernels = channels * batch_size * height_col * width_col;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] {
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
batch_size, channels, deformable_group, height_col, width_col, data_col_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
}
}
void modulated_deformable_col2im_cuda(
const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor grad_im)
{
const int channel_per_deformable_group = channels / deformable_group;
const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] {
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
batch_size, deformable_group, height_col, width_col, grad_im_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
}
}
void modulated_deformable_col2im_coord_cuda(
const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group,
at::Tensor grad_offset, at::Tensor grad_mask)
{
const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] {
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
scalar_t *grad_mask_ = grad_mask.data_ptr<scalar_t>();
modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
grad_offset_, grad_mask_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
}
}
// modify from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
#include <torch/extension.h>
#include <ATen/DeviceGuard.h>
#include <cmath>
#include <vector>
#define WITH_CUDA // always use cuda
#ifdef WITH_CUDA
int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
at::Tensor offset, at::Tensor output,
at::Tensor columns, at::Tensor ones, int kW,
int kH, int dW, int dH, int padW, int padH,
int dilationW, int dilationH, int group,
int deformable_group, int im2col_step);
int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
at::Tensor gradOutput, at::Tensor gradInput,
at::Tensor gradOffset, at::Tensor weight,
at::Tensor columns, int kW, int kH, int dW,
int dH, int padW, int padH, int dilationW,
int dilationH, int group,
int deformable_group, int im2col_step);
int deform_conv_backward_parameters_cuda(
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
at::Tensor gradWeight, // at::Tensor gradBias,
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
int padW, int padH, int dilationW, int dilationH, int group,
int deformable_group, float scale, int im2col_step);
void modulated_deform_conv_cuda_forward(
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group, const int deformable_group,
const bool with_bias);
void modulated_deform_conv_cuda_backward(
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor columns,
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
const bool with_bias);
#endif
int deform_conv_forward(at::Tensor input, at::Tensor weight,
at::Tensor offset, at::Tensor output,
at::Tensor columns, at::Tensor ones, int kW,
int kH, int dW, int dH, int padW, int padH,
int dilationW, int dilationH, int group,
int deformable_group, int im2col_step) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
return deform_conv_forward_cuda(input, weight, offset, output, columns,
ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group,
deformable_group, im2col_step);
#else
AT_ERROR("deform conv is not compiled with GPU support");
#endif
}
AT_ERROR("deform conv is not implemented on CPU");
}
int deform_conv_backward_input(at::Tensor input, at::Tensor offset,
at::Tensor gradOutput, at::Tensor gradInput,
at::Tensor gradOffset, at::Tensor weight,
at::Tensor columns, int kW, int kH, int dW,
int dH, int padW, int padH, int dilationW,
int dilationH, int group,
int deformable_group, int im2col_step) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
return deform_conv_backward_input_cuda(input, offset, gradOutput,
gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH,
dilationW, dilationH, group, deformable_group, im2col_step);
#else
AT_ERROR("deform conv is not compiled with GPU support");
#endif
}
AT_ERROR("deform conv is not implemented on CPU");
}
int deform_conv_backward_parameters(
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
at::Tensor gradWeight, // at::Tensor gradBias,
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
int padW, int padH, int dilationW, int dilationH, int group,
int deformable_group, float scale, int im2col_step) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
return deform_conv_backward_parameters_cuda(input, offset, gradOutput,
gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW,
dilationH, group, deformable_group, scale, im2col_step);
#else
AT_ERROR("deform conv is not compiled with GPU support");
#endif
}
AT_ERROR("deform conv is not implemented on CPU");
}
void modulated_deform_conv_forward(
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group, const int deformable_group,
const bool with_bias) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
return modulated_deform_conv_cuda_forward(input, weight, bias, ones,
offset, mask, output, columns, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
deformable_group, with_bias);
#else
AT_ERROR("modulated deform conv is not compiled with GPU support");
#endif
}
AT_ERROR("modulated deform conv is not implemented on CPU");
}
void modulated_deform_conv_backward(
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor columns,
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
const bool with_bias) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
return modulated_deform_conv_cuda_backward(input, weight, bias, ones,
offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset,
grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group, deformable_group,
with_bias);
#else
AT_ERROR("modulated deform conv is not compiled with GPU support");
#endif
}
AT_ERROR("modulated deform conv is not implemented on CPU");
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("deform_conv_forward", &deform_conv_forward,
"deform forward");
m.def("deform_conv_backward_input", &deform_conv_backward_input,
"deform_conv_backward_input");
m.def("deform_conv_backward_parameters",
&deform_conv_backward_parameters,
"deform_conv_backward_parameters");
m.def("modulated_deform_conv_forward",
&modulated_deform_conv_forward,
"modulated deform conv forward");
m.def("modulated_deform_conv_backward",
&modulated_deform_conv_backward,
"modulated deform conv backward");
}
from .fused_act import FusedLeakyReLU, fused_leaky_relu
__all__ = ['FusedLeakyReLU', 'fused_leaky_relu']
# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
import os
import torch
from torch import nn
from torch.autograd import Function
BASICSR_JIT = os.getenv('BASICSR_JIT')
if BASICSR_JIT == 'True':
from torch.utils.cpp_extension import load
module_path = os.path.dirname(__file__)
fused_act_ext = load(
'fused',
sources=[
os.path.join(module_path, 'src', 'fused_bias_act.cpp'),
os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'),
],
)
else:
try:
from . import fused_act_ext
except ImportError:
pass
# avoid annoying print output
# print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n '
# '1. compile with BASICSR_EXT=True. or\n '
# '2. set BASICSR_JIT=True during running')
class FusedLeakyReLUFunctionBackward(Function):
@staticmethod
def forward(ctx, grad_output, out, negative_slope, scale):
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale
empty = grad_output.new_empty(0)
grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale)
dim = [0]
if grad_input.ndim > 2:
dim += list(range(2, grad_input.ndim))
grad_bias = grad_input.sum(dim).detach()
return grad_input, grad_bias
@staticmethod
def backward(ctx, gradgrad_input, gradgrad_bias):
out, = ctx.saved_tensors
gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope,
ctx.scale)
return gradgrad_out, None, None, None
class FusedLeakyReLUFunction(Function):
@staticmethod
def forward(ctx, input, bias, negative_slope, scale):
empty = input.new_empty(0)
out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale
return out
@staticmethod
def backward(ctx, grad_output):
out, = ctx.saved_tensors
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale)
return grad_input, grad_bias, None, None
class FusedLeakyReLU(nn.Module):
def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
super().__init__()
self.bias = nn.Parameter(torch.zeros(channel))
self.negative_slope = negative_slope
self.scale = scale
def forward(self, input):
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp
#include <torch/extension.h>
torch::Tensor fused_bias_act_op(const torch::Tensor& input,
const torch::Tensor& bias,
const torch::Tensor& refer,
int act, int grad, float alpha, float scale);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor fused_bias_act(const torch::Tensor& input,
const torch::Tensor& bias,
const torch::Tensor& refer,
int act, int grad, float alpha, float scale) {
CHECK_CUDA(input);
CHECK_CUDA(bias);
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
}
// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <torch/types.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <cuda.h>
#include <cuda_runtime.h>
template <typename scalar_t>
static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
scalar_t zero = 0.0;
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
scalar_t x = p_x[xi];
if (use_bias) {
x += p_b[(xi / step_b) % size_b];
}
scalar_t ref = use_ref ? p_ref[xi] : zero;
scalar_t y;
switch (act * 10 + grad) {
default:
case 10: y = x; break;
case 11: y = x; break;
case 12: y = 0.0; break;
case 30: y = (x > 0.0) ? x : x * alpha; break;
case 31: y = (ref > 0.0) ? x : x * alpha; break;
case 32: y = 0.0; break;
}
out[xi] = y * scale;
}
}
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
int act, int grad, float alpha, float scale) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
auto x = input.contiguous();
auto b = bias.contiguous();
auto ref = refer.contiguous();
int use_bias = b.numel() ? 1 : 0;
int use_ref = ref.numel() ? 1 : 0;
int size_x = x.numel();
int size_b = b.numel();
int step_b = 1;
for (int i = 1 + 1; i < x.dim(); i++) {
step_b *= x.size(i);
}
int loop_x = 4;
int block_size = 4 * 32;
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
auto y = torch::empty_like(x);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
y.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
b.data_ptr<scalar_t>(),
ref.data_ptr<scalar_t>(),
act,
grad,
alpha,
scale,
loop_x,
size_x,
step_b,
size_b,
use_bias,
use_ref
);
});
return y;
}
from .upfirdn2d import upfirdn2d
__all__ = ['upfirdn2d']
// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp
#include <torch/extension.h>
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y,
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y,
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
CHECK_CUDA(input);
CHECK_CUDA(kernel);
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment