Commit a8ada82f authored by chenych's avatar chenych
Browse files

First commit

parent 537691da
#include <torch/extension.h>
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y,
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y,
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
CHECK_CUDA(input);
CHECK_CUDA(kernel);
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
}
\ No newline at end of file
import os
import torch
from torch.autograd import Function
from torch.utils.cpp_extension import load, _import_module_from_library
module_path = os.path.dirname(__file__)
upfirdn2d_op = load(
'upfirdn2d',
sources=[
os.path.join(module_path, 'upfirdn2d.cpp'),
os.path.join(module_path, 'upfirdn2d_kernel.cu'),
],
)
#upfirdn2d_op = _import_module_from_library('upfirdn2d', '/tmp/torch_extensions/upfirdn2d', True)
class UpFirDn2dBackward(Function):
@staticmethod
def forward(
ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
):
up_x, up_y = up
down_x, down_y = down
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
grad_input = upfirdn2d_op.upfirdn2d(
grad_output,
grad_kernel,
down_x,
down_y,
up_x,
up_y,
g_pad_x0,
g_pad_x1,
g_pad_y0,
g_pad_y1,
)
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
ctx.save_for_backward(kernel)
pad_x0, pad_x1, pad_y0, pad_y1 = pad
ctx.up_x = up_x
ctx.up_y = up_y
ctx.down_x = down_x
ctx.down_y = down_y
ctx.pad_x0 = pad_x0
ctx.pad_x1 = pad_x1
ctx.pad_y0 = pad_y0
ctx.pad_y1 = pad_y1
ctx.in_size = in_size
ctx.out_size = out_size
return grad_input
@staticmethod
def backward(ctx, gradgrad_input):
kernel, = ctx.saved_tensors
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
gradgrad_out = upfirdn2d_op.upfirdn2d(
gradgrad_input,
kernel,
ctx.up_x,
ctx.up_y,
ctx.down_x,
ctx.down_y,
ctx.pad_x0,
ctx.pad_x1,
ctx.pad_y0,
ctx.pad_y1,
)
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
gradgrad_out = gradgrad_out.view(
ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
)
return gradgrad_out, None, None, None, None, None, None, None, None
class UpFirDn2d(Function):
@staticmethod
def forward(ctx, input, kernel, up, down, pad):
up_x, up_y = up
down_x, down_y = down
pad_x0, pad_x1, pad_y0, pad_y1 = pad
kernel_h, kernel_w = kernel.shape
batch, channel, in_h, in_w = input.shape
ctx.in_size = input.shape
input = input.reshape(-1, in_h, in_w, 1)
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
ctx.out_size = (out_h, out_w)
ctx.up = (up_x, up_y)
ctx.down = (down_x, down_y)
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
g_pad_x0 = kernel_w - pad_x0 - 1
g_pad_y0 = kernel_h - pad_y0 - 1
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
out = upfirdn2d_op.upfirdn2d(
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
)
# out = out.view(major, out_h, out_w, minor)
out = out.view(-1, channel, out_h, out_w)
return out
@staticmethod
def backward(ctx, grad_output):
kernel, grad_kernel = ctx.saved_tensors
grad_input = UpFirDn2dBackward.apply(
grad_output,
kernel,
grad_kernel,
ctx.up,
ctx.down,
ctx.pad,
ctx.g_pad,
ctx.in_size,
ctx.out_size,
)
return grad_input, None, None, None, None
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
out = UpFirDn2d.apply(
input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
)
return out
def upfirdn2d_native(
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
):
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
)
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape(
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
)
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
return out[:, ::down_y, ::down_x, :]
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <torch/types.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <cuda.h>
#include <cuda_runtime.h>
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
int c = a / b;
if (c * b > a) {
c--;
}
return c;
}
struct UpFirDn2DKernelParams {
int up_x;
int up_y;
int down_x;
int down_y;
int pad_x0;
int pad_x1;
int pad_y0;
int pad_y1;
int major_dim;
int in_h;
int in_w;
int minor_dim;
int kernel_h;
int kernel_w;
int out_h;
int out_w;
int loop_major;
int loop_x;
};
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y, int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
__global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) {
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
__shared__ volatile float sk[kernel_h][kernel_w];
__shared__ volatile float sx[tile_in_h][tile_in_w];
int minor_idx = blockIdx.x;
int tile_out_y = minor_idx / p.minor_dim;
minor_idx -= tile_out_y * p.minor_dim;
tile_out_y *= tile_out_h;
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
int major_idx_base = blockIdx.z * p.loop_major;
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) {
return;
}
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) {
int ky = tap_idx / kernel_w;
int kx = tap_idx - ky * kernel_w;
scalar_t v = 0.0;
if (kx < p.kernel_w & ky < p.kernel_h) {
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
}
sk[ky][kx] = v;
}
for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) {
for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) {
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
int tile_in_x = floor_div(tile_mid_x, up_x);
int tile_in_y = floor_div(tile_mid_y, up_y);
__syncthreads();
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) {
int rel_in_y = in_idx / tile_in_w;
int rel_in_x = in_idx - rel_in_y * tile_in_w;
int in_x = rel_in_x + tile_in_x;
int in_y = rel_in_y + tile_in_y;
scalar_t v = 0.0;
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx];
}
sx[rel_in_y][rel_in_x] = v;
}
__syncthreads();
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) {
int rel_out_y = out_idx / tile_out_w;
int rel_out_x = out_idx - rel_out_y * tile_out_w;
int out_x = rel_out_x + tile_out_x;
int out_y = rel_out_y + tile_out_y;
int mid_x = tile_mid_x + rel_out_x * down_x;
int mid_y = tile_mid_y + rel_out_y * down_y;
int in_x = floor_div(mid_x, up_x);
int in_y = floor_div(mid_y, up_y);
int rel_in_x = in_x - tile_in_x;
int rel_in_y = in_y - tile_in_y;
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
scalar_t v = 0.0;
#pragma unroll
for (int y = 0; y < kernel_h / up_y; y++)
#pragma unroll
for (int x = 0; x < kernel_w / up_x; x++)
v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x];
if (out_x < p.out_w & out_y < p.out_h) {
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v;
}
}
}
}
}
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y,
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
UpFirDn2DKernelParams p;
auto x = input.contiguous();
auto k = kernel.contiguous();
p.major_dim = x.size(0);
p.in_h = x.size(1);
p.in_w = x.size(2);
p.minor_dim = x.size(3);
p.kernel_h = k.size(0);
p.kernel_w = k.size(1);
p.up_x = up_x;
p.up_y = up_y;
p.down_x = down_x;
p.down_y = down_y;
p.pad_x0 = pad_x0;
p.pad_x1 = pad_x1;
p.pad_y0 = pad_y0;
p.pad_y1 = pad_y1;
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y;
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x;
auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
int mode = -1;
int tile_out_h;
int tile_out_w;
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 1;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) {
mode = 2;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 3;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 4;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 5;
tile_out_h = 8;
tile_out_w = 32;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 6;
tile_out_h = 8;
tile_out_w = 32;
}
dim3 block_size;
dim3 grid_size;
if (tile_out_h > 0 && tile_out_w) {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 1;
block_size = dim3(32 * 8, 1, 1);
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
(p.major_dim - 1) / p.loop_major + 1);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
switch (mode) {
case 1:
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
);
break;
case 2:
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64><<<grid_size, block_size, 0, stream>>>(
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
);
break;
case 3:
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
);
break;
case 4:
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64><<<grid_size, block_size, 0, stream>>>(
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
);
break;
case 5:
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
);
break;
case 6:
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
);
break;
}
});
return out;
}
\ No newline at end of file
"""
# --------------------------------------------
# define training model
# --------------------------------------------
"""
def define_Model(opt):
model = opt['model'] # one input: L
if model == 'plain':
from models.model_plain import ModelPlain as M
elif model == 'plain2': # two inputs: L, C
from models.model_plain2 import ModelPlain2 as M
elif model == 'plain4': # four inputs: L, k, sf, sigma
from models.model_plain4 import ModelPlain4 as M
elif model == 'gan': # one input: L
from models.model_gan import ModelGAN as M
elif model == 'vrt': # one video input L, for VRT
from models.model_vrt import ModelVRT as M
else:
raise NotImplementedError('Model [{:s}] is not defined.'.format(model))
m = M(opt)
print('Training model [{:s}] is created.'.format(m.__class__.__name__))
return m
import functools
import torch
from torch.nn import init
"""
# --------------------------------------------
# select the network of G, D and F
# --------------------------------------------
"""
# --------------------------------------------
# Generator, netG, G
# --------------------------------------------
def define_G(opt):
opt_net = opt['netG']
net_type = opt_net['net_type']
# ----------------------------------------
# denoising task
# ----------------------------------------
# ----------------------------------------
# DnCNN
# ----------------------------------------
if net_type == 'dncnn':
from models.network_dncnn import DnCNN as net
netG = net(in_nc=opt_net['in_nc'],
out_nc=opt_net['out_nc'],
nc=opt_net['nc'],
nb=opt_net['nb'], # total number of conv layers
act_mode=opt_net['act_mode'])
# ----------------------------------------
# RNAN
# ----------------------------------------
if net_type == 'rnan':
from models.network_rnan import RNAN as net
netG = net()
# ----------------------------------------
# MIRNet
# ----------------------------------------
if net_type == 'mirnet':
from models.network_mirnet import MIRNet as net
netG = net()
# ----------------------------------------
# RIDNet
# ----------------------------------------
if net_type == 'ridnet':
from models.network_ridnet import RIDNET as net
netG = net()
if net_type == 'cnn':
from models.network_cnn import CNN5Layer as net
netG = net()
# ----------------------------------------
# Flexible DnCNN
# ----------------------------------------
elif net_type == 'fdncnn':
from models.network_dncnn import FDnCNN as net
netG = net(in_nc=opt_net['in_nc'],
out_nc=opt_net['out_nc'],
nc=opt_net['nc'],
nb=opt_net['nb'], # total number of conv layers
act_mode=opt_net['act_mode'])
# ----------------------------------------
# FFDNet
# ----------------------------------------
elif net_type == 'ffdnet':
from models.network_ffdnet import FFDNet as net
netG = net(in_nc=opt_net['in_nc'],
out_nc=opt_net['out_nc'],
nc=opt_net['nc'],
nb=opt_net['nb'],
act_mode=opt_net['act_mode'])
# ----------------------------------------
# others
# ----------------------------------------
# ----------------------------------------
# super-resolution task
# ----------------------------------------
# ----------------------------------------
# SRMD
# ----------------------------------------
elif net_type == 'srmd':
from models.network_srmd import SRMD as net
netG = net(in_nc=opt_net['in_nc'],
out_nc=opt_net['out_nc'],
nc=opt_net['nc'],
nb=opt_net['nb'],
upscale=opt_net['scale'],
act_mode=opt_net['act_mode'],
upsample_mode=opt_net['upsample_mode'])
# ----------------------------------------
# super-resolver prior of DPSR
# ----------------------------------------
elif net_type == 'dpsr':
from models.network_dpsr import MSRResNet_prior as net
netG = net(in_nc=opt_net['in_nc'],
out_nc=opt_net['out_nc'],
nc=opt_net['nc'],
nb=opt_net['nb'],
upscale=opt_net['scale'],
act_mode=opt_net['act_mode'],
upsample_mode=opt_net['upsample_mode'])
# ----------------------------------------
# modified SRResNet v0.0
# ----------------------------------------
elif net_type == 'msrresnet0':
from models.network_msrresnet import MSRResNet0 as net
netG = net(in_nc=opt_net['in_nc'],
out_nc=opt_net['out_nc'],
nc=opt_net['nc'],
nb=opt_net['nb'],
upscale=opt_net['scale'],
act_mode=opt_net['act_mode'],
upsample_mode=opt_net['upsample_mode'])
# ----------------------------------------
# modified SRResNet v0.1
# ----------------------------------------
elif net_type == 'msrresnet1':
from models.network_msrresnet import MSRResNet1 as net
netG = net(in_nc=opt_net['in_nc'],
out_nc=opt_net['out_nc'],
nc=opt_net['nc'],
nb=opt_net['nb'],
upscale=opt_net['scale'],
act_mode=opt_net['act_mode'],
upsample_mode=opt_net['upsample_mode'])
# ----------------------------------------
# RRDB
# ----------------------------------------
elif net_type == 'rrdb': # RRDB
from models.network_rrdb import RRDB as net
netG = net(in_nc=opt_net['in_nc'],
out_nc=opt_net['out_nc'],
nc=opt_net['nc'],
nb=opt_net['nb'],
gc=opt_net['gc'],
upscale=opt_net['scale'],
act_mode=opt_net['act_mode'],
upsample_mode=opt_net['upsample_mode'])
# ----------------------------------------
# RRDBNet
# ----------------------------------------
elif net_type == 'rrdbnet': # RRDBNet
from models.network_rrdbnet import RRDBNet as net
netG = net(in_nc=opt_net['in_nc'],
out_nc=opt_net['out_nc'],
nf=opt_net['nf'],
nb=opt_net['nb'],
gc=opt_net['gc'],
sf=opt_net['scale'])
# ----------------------------------------
# IMDB
# ----------------------------------------
elif net_type == 'imdn': # IMDB
from models.network_imdn import IMDN as net
netG = net(in_nc=opt_net['in_nc'],
out_nc=opt_net['out_nc'],
nc=opt_net['nc'],
nb=opt_net['nb'],
upscale=opt_net['scale'],
act_mode=opt_net['act_mode'],
upsample_mode=opt_net['upsample_mode'])
# ----------------------------------------
# USRNet
# ----------------------------------------
elif net_type == 'usrnet': # USRNet
from models.network_usrnet import USRNet as net
netG = net(n_iter=opt_net['n_iter'],
h_nc=opt_net['h_nc'],
in_nc=opt_net['in_nc'],
out_nc=opt_net['out_nc'],
nc=opt_net['nc'],
nb=opt_net['nb'],
act_mode=opt_net['act_mode'],
downsample_mode=opt_net['downsample_mode'],
upsample_mode=opt_net['upsample_mode']
)
# ----------------------------------------
# Deep Residual U-Net (drunet)
# ----------------------------------------
elif net_type == 'drunet':
from models.network_unet import UNetRes as net
netG = net(in_nc=opt_net['in_nc'],
out_nc=opt_net['out_nc'],
nc=opt_net['nc'],
nb=opt_net['nb'],
act_mode=opt_net['act_mode'],
downsample_mode=opt_net['downsample_mode'],
upsample_mode=opt_net['upsample_mode'],
bias=opt_net['bias'])
# ----------------------------------------
# SwinIR
# ----------------------------------------
elif net_type == 'swinir':
from models.network_swinir import SwinIR as net
netG = net(upscale=opt_net['upscale'],
in_chans=opt_net['in_chans'],
img_size=opt_net['img_size'],
window_size=opt_net['window_size'],
img_range=opt_net['img_range'],
depths=opt_net['depths'],
embed_dim=opt_net['embed_dim'],
num_heads=opt_net['num_heads'],
mlp_ratio=opt_net['mlp_ratio'],
upsampler=opt_net['upsampler'],
resi_connection=opt_net['resi_connection'],
talking_heads=opt_net['talking_heads'],
use_attn_fn=opt_net['attn_fn'],
head_scale=opt_net['head_scale'],
on_attn=opt_net['on_attn'],
use_mask=opt_net['use_mask'],
mask_ratio1=opt_net['mask_ratio1'],
mask_ratio2=opt_net['mask_ratio2'],
mask_is_diff=opt_net['mask_is_diff'],
type=opt_net['type'],
resi_scale=opt_net['resi_scale'],
opt=opt_net,
)
# ----------------------------------------
# SwinIR diff
# ----------------------------------------
elif net_type == 'swinir_diff':
from models.network_swinir_diff import SwinIR as net
netG = net(upscale=opt_net['upscale'],
in_chans=opt_net['in_chans'],
img_size=opt_net['img_size'],
window_size=opt_net['window_size'],
img_range=opt_net['img_range'],
depths=opt_net['depths'],
embed_dim=opt_net['embed_dim'],
num_heads=opt_net['num_heads'],
mlp_ratio=opt_net['mlp_ratio'],
upsampler=opt_net['upsampler'],
resi_connection=opt_net['resi_connection'],
talking_heads=opt_net['talking_heads'],
use_attn_fn=opt_net['attn_fn'],
head_scale=opt_net['head_scale'],
on_attn=opt_net['on_attn'],
use_mask=opt_net['use_mask'],
mask_ratio1=opt_net['mask_ratio1'],
mask_ratio2=opt_net['mask_ratio2'],
mask_is_diff=opt_net['mask_is_diff'],
type=opt_net['type'],
resi_scale=opt_net['resi_scale'],
opt=opt_net,
)
# ----------------------------------------
# SwinIR Dropout
# ----------------------------------------
elif net_type == 'swinir_dropout':
from models.network_swinir_dropout import SwinIR as net
netG = net(upscale=opt_net['upscale'],
in_chans=opt_net['in_chans'],
img_size=opt_net['img_size'],
window_size=opt_net['window_size'],
img_range=opt_net['img_range'],
depths=opt_net['depths'],
embed_dim=opt_net['embed_dim'],
num_heads=opt_net['num_heads'],
mlp_ratio=opt_net['mlp_ratio'],
upsampler=opt_net['upsampler'],
resi_connection=opt_net['resi_connection'],
talking_heads=opt_net['talking_heads'],
use_attn_fn=opt_net['attn_fn'],
head_scale=opt_net['head_scale'],
on_attn=opt_net['on_attn'],
use_mask=opt_net['use_mask'],
mask_ratio1=opt_net['mask_ratio1'],
mask_ratio2=opt_net['mask_ratio2'],
mask_is_diff=opt_net['mask_is_diff'],
type=opt_net['type'],
resi_scale=opt_net['resi_scale'],
opt=opt_net,
)
# ----------------------------------------
# SwinIR Dropout residual
# ----------------------------------------
elif net_type == 'swinir_dropout':
from models.network_swinir_dropout_residual import SwinIR as net
netG = net(upscale=opt_net['upscale'],
in_chans=opt_net['in_chans'],
img_size=opt_net['img_size'],
window_size=opt_net['window_size'],
img_range=opt_net['img_range'],
depths=opt_net['depths'],
embed_dim=opt_net['embed_dim'],
num_heads=opt_net['num_heads'],
mlp_ratio=opt_net['mlp_ratio'],
upsampler=opt_net['upsampler'],
resi_connection=opt_net['resi_connection'],
talking_heads=opt_net['talking_heads'],
use_attn_fn=opt_net['attn_fn'],
head_scale=opt_net['head_scale'],
on_attn=opt_net['on_attn'],
use_mask=opt_net['use_mask'],
mask_ratio1=opt_net['mask_ratio1'],
mask_ratio2=opt_net['mask_ratio2'],
mask_is_diff=opt_net['mask_is_diff'],
type=opt_net['type'],
resi_scale=opt_net['resi_scale'],
opt=opt_net,
)
# ----------------------------------------
# SwinIR residual
# ----------------------------------------
elif net_type == 'swinir_residual':
from models.network_swinir_residual import SwinIR as net
netG = net(upscale=opt_net['upscale'],
in_chans=opt_net['in_chans'],
img_size=opt_net['img_size'],
window_size=opt_net['window_size'],
img_range=opt_net['img_range'],
depths=opt_net['depths'],
embed_dim=opt_net['embed_dim'],
num_heads=opt_net['num_heads'],
mlp_ratio=opt_net['mlp_ratio'],
upsampler=opt_net['upsampler'],
resi_connection=opt_net['resi_connection'],
talking_heads=opt_net['talking_heads'],
use_attn_fn=opt_net['attn_fn'],
head_scale=opt_net['head_scale'],
on_attn=opt_net['on_attn'],
use_mask=opt_net['use_mask'],
mask_ratio1=opt_net['mask_ratio1'],
mask_ratio2=opt_net['mask_ratio2'],
mask_is_diff=opt_net['mask_is_diff'],
type=opt_net['type'],
resi_scale=opt_net['resi_scale'],
opt=opt_net,
)
# ----------------------------------------
# VRTopt_net
# ----------------------------------------
elif net_type == 'vrt':
from models.network_vrt import VRT as net
netG = net(upscale=opt_net['upscale'],
img_size=opt_net['img_size'],
window_size=opt_net['window_size'],
depths=opt_net['depths'],
indep_reconsts=opt_net['indep_reconsts'],
embed_dims=opt_net['embed_dims'],
num_heads=opt_net['num_heads'],
spynet_path=opt_net['spynet_path'],
pa_frames=opt_net['pa_frames'],
deformable_groups=opt_net['deformable_groups'],
nonblind_denoising=opt_net['nonblind_denoising'],
use_checkpoint_attn=opt_net['use_checkpoint_attn'],
use_checkpoint_ffn=opt_net['use_checkpoint_ffn'],
no_checkpoint_attn_blocks=opt_net['no_checkpoint_attn_blocks'],
no_checkpoint_ffn_blocks=opt_net['no_checkpoint_ffn_blocks'])
# ----------------------------------------
# others
# ----------------------------------------
# TODO
else:
raise NotImplementedError('netG [{:s}] is not found.'.format(net_type))
# ----------------------------------------
# initialize weights
# ----------------------------------------
if opt['is_train']:
init_weights(netG,
init_type=opt_net['init_type'],
init_bn_type=opt_net['init_bn_type'],
gain=opt_net['init_gain'])
return netG
# --------------------------------------------
# Discriminator, netD, D
# --------------------------------------------
def define_D(opt):
opt_net = opt['netD']
net_type = opt_net['net_type']
# ----------------------------------------
# discriminator_vgg_96
# ----------------------------------------
if net_type == 'discriminator_vgg_96':
from models.network_discriminator import Discriminator_VGG_96 as discriminator
netD = discriminator(in_nc=opt_net['in_nc'],
base_nc=opt_net['base_nc'],
ac_type=opt_net['act_mode'])
# ----------------------------------------
# discriminator_vgg_128
# ----------------------------------------
elif net_type == 'discriminator_vgg_128':
from models.network_discriminator import Discriminator_VGG_128 as discriminator
netD = discriminator(in_nc=opt_net['in_nc'],
base_nc=opt_net['base_nc'],
ac_type=opt_net['act_mode'])
# ----------------------------------------
# discriminator_vgg_192
# ----------------------------------------
elif net_type == 'discriminator_vgg_192':
from models.network_discriminator import Discriminator_VGG_192 as discriminator
netD = discriminator(in_nc=opt_net['in_nc'],
base_nc=opt_net['base_nc'],
ac_type=opt_net['act_mode'])
# ----------------------------------------
# discriminator_vgg_128_SN
# ----------------------------------------
elif net_type == 'discriminator_vgg_128_SN':
from models.network_discriminator import Discriminator_VGG_128_SN as discriminator
netD = discriminator()
elif net_type == 'discriminator_patchgan':
from models.network_discriminator import Discriminator_PatchGAN as discriminator
netD = discriminator(input_nc=opt_net['in_nc'],
ndf=opt_net['base_nc'],
n_layers=opt_net['n_layers'],
norm_type=opt_net['norm_type'])
elif net_type == 'discriminator_unet':
from models.network_discriminator import Discriminator_UNet as discriminator
netD = discriminator(input_nc=opt_net['in_nc'],
ndf=opt_net['base_nc'])
else:
raise NotImplementedError('netD [{:s}] is not found.'.format(net_type))
# ----------------------------------------
# initialize weights
# ----------------------------------------
init_weights(netD,
init_type=opt_net['init_type'],
init_bn_type=opt_net['init_bn_type'],
gain=opt_net['init_gain'])
return netD
# --------------------------------------------
# VGGfeature, netF, F
# --------------------------------------------
def define_F(opt, use_bn=False):
device = torch.device('cuda' if opt['gpu_ids'] else 'cpu')
from models.network_feature import VGGFeatureExtractor
# pytorch pretrained VGG19-54, before ReLU.
if use_bn:
feature_layer = 49
else:
feature_layer = 34
netF = VGGFeatureExtractor(feature_layer=feature_layer,
use_bn=use_bn,
use_input_norm=True,
device=device)
netF.eval() # No need to train, but need BP to input
return netF
"""
# --------------------------------------------
# weights initialization
# --------------------------------------------
"""
def init_weights(net, init_type='xavier_uniform', init_bn_type='uniform', gain=1):
"""
# Kai Zhang, https://github.com/cszn/KAIR
#
# Args:
# init_type:
# default, none: pass init_weights
# normal; normal; xavier_normal; xavier_uniform;
# kaiming_normal; kaiming_uniform; orthogonal
# init_bn_type:
# uniform; constant
# gain:
# 0.2
"""
def init_fn(m, init_type='xavier_uniform', init_bn_type='uniform', gain=1):
classname = m.__class__.__name__
if classname.find('Conv') != -1 or classname.find('Linear') != -1:
if init_type == 'normal':
init.normal_(m.weight.data, 0, 0.1)
m.weight.data.clamp_(-1, 1).mul_(gain)
elif init_type == 'uniform':
init.uniform_(m.weight.data, -0.2, 0.2)
m.weight.data.mul_(gain)
elif init_type == 'xavier_normal':
init.xavier_normal_(m.weight.data, gain=gain)
m.weight.data.clamp_(-1, 1)
elif init_type == 'xavier_uniform':
init.xavier_uniform_(m.weight.data, gain=gain)
elif init_type == 'kaiming_normal':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in', nonlinearity='relu')
m.weight.data.clamp_(-1, 1).mul_(gain)
elif init_type == 'kaiming_uniform':
init.kaiming_uniform_(m.weight.data, a=0, mode='fan_in', nonlinearity='relu')
m.weight.data.mul_(gain)
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('Initialization method [{:s}] is not implemented'.format(init_type))
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('BatchNorm2d') != -1:
if init_bn_type == 'uniform': # preferred
if m.affine:
init.uniform_(m.weight.data, 0.1, 1.0)
init.constant_(m.bias.data, 0.0)
elif init_bn_type == 'constant':
if m.affine:
init.constant_(m.weight.data, 1.0)
init.constant_(m.bias.data, 0.0)
else:
raise NotImplementedError('Initialization method [{:s}] is not implemented'.format(init_bn_type))
if init_type not in ['default', 'none']:
print('Initialization method [{:s} + {:s}], gain is [{:.2f}]'.format(init_type, init_bn_type, gain))
fn = functools.partial(init_fn, init_type=init_type, init_bn_type=init_bn_type, gain=gain)
net.apply(fn)
else:
print('Pass this initialization! Initialization was done during network definition!')
{
"task": "input_80_90" // real-world image sr. root/task/images-models-options
, "model": "plain" // "plain" | "plain2" if two inputs
, "gpu_ids": [0,1,2,3]
, "dist": false
, "scale": 1 // broadcast to "datasets"
, "n_channels": 3 // broadcast to "datasets", 1 for grayscale, 3 for color
, "path": {
"root": "masked_denoising" // "denoising" | "superresolution" | "masked_denoising"
, "pretrained_netG": null // path of pretrained model
, "pretrained_netE": null // path of pretrained model
},
"datasets": {
"train": {
"name": "train_dataset" // just name
, "dataset_type": "masked_denoising" // "dncnn" | "dnpatch" | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch" | "jpeg" | "masked_denoising"
, "dataroot_H": "trainsets/trainH" // path of H training dataset. DIV2K + Flickr2K + BSD500 + WED
, "dataroot_L": null // path of L training dataset
, "H_size": 64
, "lq_patchsize": 64
, "dataloader_shuffle": true
, "dataloader_num_workers": 16
, "dataloader_batch_size": 64 // batch size, bigger is better
, "noise_level": 15 // training noise level
, "if_mask": true // if use input mask
, "mask1": 80 // input mask ratio,
, "mask2": 90 // randomly sampling from [mask1, mask2]
},
"test": {
"name": "test_dataset" // just name
, "dataset_type": "plain" // "dncnn" | "dnpatch" | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch" | "jpeg"
, "dataroot_H": "testset/McM/HR" // path of H testing dataset
, "dataroot_L": "testset/McM/McM_poisson_20" // path of L testing dataset, Poisson noise with alpha=2
}
},
"netG": {
"net_type": "swinir"
, "upscale": 1
, "in_chans": 3
, "img_size": 64
, "window_size": 8
, "img_range": 1.0
, "depths": [6, 6, 6, 6]
, "embed_dim": 60
, "num_heads": [6, 6, 6, 6]
, "mlp_ratio": 2
, "upsampler": null // "pixelshuffle" | "pixelshuffledirect" | "nearest+conv" | null
, "resi_connection": "3conv" // "1conv" | "3conv"
, "init_type": "default"
, "talking_heads": false
, "attn_fn": "softmax"
, "head_scale": false
, "on_attn": false
, "use_mask": true // if use attention mask
, "mask_ratio1": 75 // attention mask ratio,
, "mask_ratio2": 75 // randomly sampling from [mask_ratio1, mask_ratio2]
, "mask_is_diff": false
, "type": "stand"
},
"train": {
"manual_seed": 1
, "G_lossfn_type": "l1" // "l1" preferred | "l2sum" | "l2" | "ssim" | "charbonnier"
, "G_lossfn_weight": 1.0 // default
, "E_decay": 0.999 // Exponential Moving Average for netG: set 0 to disable; default setting 0.999
, "G_optimizer_type": "adam" // fixed, adam is enough
, "G_optimizer_lr": 1e-4 // learning rate
, "G_optimizer_wd": 0 // weight decay, default 0
, "G_optimizer_clipgrad": null // unused
, "G_optimizer_reuse": true
, "G_scheduler_type": "MultiStepLR" // "MultiStepLR" is enough
, "G_scheduler_milestones": [] // [250000, 400000, 450000, 475000, 500000]
, "G_scheduler_gamma": 0.5
, "G_regularizer_orthstep": null // unused
, "G_regularizer_clipstep": null // unused
, "G_param_strict": true
, "E_param_strict": true
, "checkpoint_test": 5000 // for testing
, "checkpoint_save": 5000 // for saving model
, "checkpoint_print": 100 // for print
, "save_image": ["img_043_x1", "img_021_x1", "img_024_x1", "img_031_x1", "img_041_x1", "img_032_x1"] // image names to be saved (tensorboard) during testing
}
}
opencv-python
scikit-image
pillow
torchvision
hdf5storage
ninja
lmdb
requests
timms
einops
lpips
entmax
tensorboardX
tqdm
matplotlib
\ No newline at end of file
#!/bin/bash
echo "Testing start ..."
python main_test_swinir.py \
--model_path model_zoo/input_mask_80_90.pth \
--name input_mask_80_90/McM_poisson_20 \
--opt model_zoo/input_mask_80_90.json \
--folder_gt testset/McM/HR \
--folder_lq testset/McM/McM_poisson_20
# python main_test_swinir.py \
# --model_path model_zoo/baseline.pth \
# --name baseline/McM_poisson_20 \
# --opt model_zoo/baseline.json \
# --folder_gt testset/McM/HR \
# --folder_lq testset/McM/McM_poisson_20
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment