Unverified Commit 933b052d authored by Rui Xu's avatar Rui Xu Committed by GitHub
Browse files

[Feature] Add cuda ops: UpFirDn2d and fused_bias_leakyrelu (#900)



* add upfirdn2d op

* fix bug in pybind

* add fused bias leakyrelu

* fix bug in fused-bias-leakyrelu

* fix lint error

* fix bug in build cpu version

* fix bug in build cpu version

* fix lint

* fix comment from zww
Co-authored-by: default avatarzhangshilong <zhangshilong@sensetime.com>
parent 371a2172
...@@ -12,6 +12,7 @@ from .deprecated_wrappers import Linear_deprecated as Linear ...@@ -12,6 +12,7 @@ from .deprecated_wrappers import Linear_deprecated as Linear
from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d
from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss, from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss,
sigmoid_focal_loss, softmax_focal_loss) sigmoid_focal_loss, softmax_focal_loss)
from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu
from .info import (get_compiler_version, get_compiling_cuda_version, from .info import (get_compiler_version, get_compiling_cuda_version,
get_onnxruntime_op_path) get_onnxruntime_op_path)
from .masked_conv import MaskedConv2d, masked_conv2d from .masked_conv import MaskedConv2d, masked_conv2d
...@@ -27,6 +28,7 @@ from .roi_pool import RoIPool, roi_pool ...@@ -27,6 +28,7 @@ from .roi_pool import RoIPool, roi_pool
from .saconv import SAConv2d from .saconv import SAConv2d
from .sync_bn import SyncBatchNorm from .sync_bn import SyncBatchNorm
from .tin_shift import TINShift, tin_shift from .tin_shift import TINShift, tin_shift
from .upfirdn2d import upfirdn2d
__all__ = [ __all__ = [
'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe', 'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe',
...@@ -41,5 +43,6 @@ __all__ = [ ...@@ -41,5 +43,6 @@ __all__ = [
'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask', 'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign', 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated' 'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated',
'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu'
] ]
// Modified from
// from
// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
torch::Tensor fused_bias_leakyrelu_op(const torch::Tensor& input,
const torch::Tensor& bias,
const torch::Tensor& refer, int act,
int grad, float alpha, float scale);
#endif
torch::Tensor fused_bias_leakyrelu(const torch::Tensor& input,
const torch::Tensor& bias,
const torch::Tensor& refer, int act,
int grad, float alpha, float scale) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA(input);
CHECK_CUDA(bias);
return fused_bias_leakyrelu_op(input, bias, refer, act, grad, alpha, scale);
#else
AT_ERROR("Fused bias leakyrelu is not compiled with GPU support");
#endif
}
// from
// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/types.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
template <typename scalar_t>
static __global__ void fused_bias_act_kernel(
scalar_t* out, const scalar_t* p_x, const scalar_t* p_b,
const scalar_t* p_ref, int act, int grad, scalar_t alpha, scalar_t scale,
int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
scalar_t zero = 0.0;
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
loop_idx++, xi += blockDim.x) {
scalar_t x = p_x[xi];
if (use_bias) {
x += p_b[(xi / step_b) % size_b];
}
scalar_t ref = use_ref ? p_ref[xi] : zero;
scalar_t y;
// act = 1: linear layer
// act = 3: leaky relu layer
// grad = 0: direct forward path
// grad = 1: first order deviation
// grad = 2: second order deviation
switch (act * 10 + grad) {
default:
case 10:
y = x;
break;
case 11:
y = x;
break;
case 12:
y = 0.0;
break;
case 30:
y = (x > 0.0) ? x : x * alpha;
break;
case 31:
y = (ref > 0.0) ? x : x * alpha;
break;
case 32:
y = 0.0;
break;
}
out[xi] = y * scale;
}
}
torch::Tensor fused_bias_leakyrelu_op(const torch::Tensor& input,
const torch::Tensor& bias,
const torch::Tensor& refer, int act,
int grad, float alpha, float scale) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
auto x = input.contiguous();
auto b = bias.contiguous();
auto ref = refer.contiguous();
int use_bias = b.numel() ? 1 : 0;
int use_ref = ref.numel() ? 1 : 0;
int size_x = x.numel();
int size_b = b.numel();
int step_b = 1;
for (int i = 1 + 1; i < x.dim(); i++) {
step_b *= x.size(i);
}
int loop_x = 4;
int block_size = 4 * 32;
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
auto y = torch::empty_like(x);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x.scalar_type(), "fused_bias_act_kernel", [&] {
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha,
scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
});
return y;
}
...@@ -182,7 +182,18 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order, ...@@ -182,7 +182,18 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
const Tensor dets_sorted, const float iou_threshold, const Tensor dets_sorted, const float iou_threshold,
const int multi_label); const int multi_label);
Tensor upfirdn2d(const Tensor& input, const Tensor& kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0,
int pad_y1);
Tensor fused_bias_leakyrelu(const Tensor& input, const Tensor& bias,
const Tensor& refer, int act, int grad, float alpha,
float scale);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
m.def("fused_bias_leakyrelu", &fused_bias_leakyrelu,
"fused_bias_leakyrelu (CUDA)");
m.def("get_compiler_version", &get_compiler_version, "get_compiler_version"); m.def("get_compiler_version", &get_compiler_version, "get_compiler_version");
m.def("get_compiling_cuda_version", &get_compiling_cuda_version, m.def("get_compiling_cuda_version", &get_compiling_cuda_version,
"get_compiling_cuda_version"); "get_compiling_cuda_version");
......
// from
// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
torch::Tensor upfirdn2d_op(const torch::Tensor& input,
const torch::Tensor& kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1,
int pad_y0, int pad_y1);
#endif
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y, int pad_x0,
int pad_x1, int pad_y0, int pad_y1) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA(input);
CHECK_CUDA(kernel);
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
pad_y0, pad_y1);
#else
AT_ERROR("UpFirDn2d is not compiled with GPU support");
#endif
}
// from
// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/types.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
int c = a / b;
if (c * b > a) {
c--;
}
return c;
}
struct UpFirDn2DKernelParams {
int up_x;
int up_y;
int down_x;
int down_y;
int pad_x0;
int pad_x1;
int pad_y0;
int pad_y1;
int major_dim;
int in_h;
int in_w;
int minor_dim;
int kernel_h;
int kernel_w;
int out_h;
int out_w;
int loop_major;
int loop_x;
};
template <typename scalar_t>
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
const scalar_t *kernel,
const UpFirDn2DKernelParams p) {
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
int out_y = minor_idx / p.minor_dim;
minor_idx -= out_y * p.minor_dim;
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
int major_idx_base = blockIdx.z * p.loop_major;
if (out_x_base >= p.out_w || out_y >= p.out_h ||
major_idx_base >= p.major_dim) {
return;
}
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
for (int loop_major = 0, major_idx = major_idx_base;
loop_major < p.loop_major && major_idx < p.major_dim;
loop_major++, major_idx++) {
for (int loop_x = 0, out_x = out_x_base;
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
const scalar_t *x_p =
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
minor_idx];
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
int x_px = p.minor_dim;
int k_px = -p.up_x;
int x_py = p.in_w * p.minor_dim;
int k_py = -p.up_y * p.kernel_w;
scalar_t v = 0.0f;
for (int y = 0; y < h; y++) {
for (int x = 0; x < w; x++) {
v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
x_p += x_px;
k_p += k_px;
}
x_p += x_py - w * x_px;
k_p += k_py - w * k_px;
}
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
minor_idx] = v;
}
}
}
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
const scalar_t *kernel,
const UpFirDn2DKernelParams p) {
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
__shared__ volatile float sk[kernel_h][kernel_w];
__shared__ volatile float sx[tile_in_h][tile_in_w];
int minor_idx = blockIdx.x;
int tile_out_y = minor_idx / p.minor_dim;
minor_idx -= tile_out_y * p.minor_dim;
tile_out_y *= tile_out_h;
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
int major_idx_base = blockIdx.z * p.loop_major;
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
major_idx_base >= p.major_dim) {
return;
}
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
tap_idx += blockDim.x) {
int ky = tap_idx / kernel_w;
int kx = tap_idx - ky * kernel_w;
scalar_t v = 0.0;
if (kx < p.kernel_w & ky < p.kernel_h) {
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
}
sk[ky][kx] = v;
}
for (int loop_major = 0, major_idx = major_idx_base;
loop_major < p.loop_major & major_idx < p.major_dim;
loop_major++, major_idx++) {
for (int loop_x = 0, tile_out_x = tile_out_x_base;
loop_x < p.loop_x & tile_out_x < p.out_w;
loop_x++, tile_out_x += tile_out_w) {
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
int tile_in_x = floor_div(tile_mid_x, up_x);
int tile_in_y = floor_div(tile_mid_y, up_y);
__syncthreads();
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
in_idx += blockDim.x) {
int rel_in_y = in_idx / tile_in_w;
int rel_in_x = in_idx - rel_in_y * tile_in_w;
int in_x = rel_in_x + tile_in_x;
int in_y = rel_in_y + tile_in_y;
scalar_t v = 0.0;
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
p.minor_dim +
minor_idx];
}
sx[rel_in_y][rel_in_x] = v;
}
__syncthreads();
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
out_idx += blockDim.x) {
int rel_out_y = out_idx / tile_out_w;
int rel_out_x = out_idx - rel_out_y * tile_out_w;
int out_x = rel_out_x + tile_out_x;
int out_y = rel_out_y + tile_out_y;
int mid_x = tile_mid_x + rel_out_x * down_x;
int mid_y = tile_mid_y + rel_out_y * down_y;
int in_x = floor_div(mid_x, up_x);
int in_y = floor_div(mid_y, up_y);
int rel_in_x = in_x - tile_in_x;
int rel_in_y = in_y - tile_in_y;
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
scalar_t v = 0.0;
#pragma unroll
for (int y = 0; y < kernel_h / up_y; y++)
#pragma unroll
for (int x = 0; x < kernel_w / up_x; x++)
v += sx[rel_in_y + y][rel_in_x + x] *
sk[kernel_y + y * up_y][kernel_x + x * up_x];
if (out_x < p.out_w & out_y < p.out_h) {
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
minor_idx] = v;
}
}
}
}
}
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
const torch::Tensor &kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1,
int pad_y0, int pad_y1) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
UpFirDn2DKernelParams p;
auto x = input.contiguous();
auto k = kernel.contiguous();
p.major_dim = x.size(0);
p.in_h = x.size(1);
p.in_w = x.size(2);
p.minor_dim = x.size(3);
p.kernel_h = k.size(0);
p.kernel_w = k.size(1);
p.up_x = up_x;
p.up_y = up_y;
p.down_x = down_x;
p.down_y = down_y;
p.pad_x0 = pad_x0;
p.pad_x1 = pad_x1;
p.pad_y0 = pad_y0;
p.pad_y1 = pad_y1;
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
p.down_y;
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
p.down_x;
auto out =
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
int mode = -1;
int tile_out_h = -1;
int tile_out_w = -1;
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 1;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 3 && p.kernel_w <= 3) {
mode = 2;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 3;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 4;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 5;
tile_out_h = 8;
tile_out_w = 32;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 6;
tile_out_h = 8;
tile_out_w = 32;
}
dim3 block_size;
dim3 grid_size;
if (tile_out_h > 0 && tile_out_w > 0) {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 1;
block_size = dim3(32 * 8, 1, 1);
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
(p.major_dim - 1) / p.loop_major + 1);
} else {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 4;
block_size = dim3(4, 32, 1);
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
(p.major_dim - 1) / p.loop_major + 1);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
switch (mode) {
case 1:
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 2:
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 3:
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 4:
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 5:
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 6:
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
default:
upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
}
});
return out;
}
# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Function
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['fused_bias_leakyrelu'])
class FusedBiasLeakyReLUFunctionBackward(Function):
"""Calculate second order deviation.
This function is to compute the second order deviation for the fused leaky
relu operation.
"""
@staticmethod
def forward(ctx, grad_output, out, negative_slope, scale):
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale
empty = grad_output.new_empty(0)
grad_input = ext_module.fused_bias_leakyrelu(grad_output, empty, out,
3, 1, negative_slope,
scale)
dim = [0]
if grad_input.ndim > 2:
dim += list(range(2, grad_input.ndim))
grad_bias = grad_input.sum(dim).detach()
return grad_input, grad_bias
@staticmethod
def backward(ctx, gradgrad_input, gradgrad_bias):
out, = ctx.saved_tensors
# The second order deviation, in fact, contains two parts, while the
# the first part is zero. Thus, we direct consider the second part
# which is similar with the first order deviation in implementation.
gradgrad_out = ext_module.fused_bias_leakyrelu(gradgrad_input,
gradgrad_bias, out, 3,
1, ctx.negative_slope,
ctx.scale)
return gradgrad_out, None, None, None
class FusedBiasLeakyReLUFunction(Function):
@staticmethod
def forward(ctx, input, bias, negative_slope, scale):
empty = input.new_empty(0)
out = ext_module.fused_bias_leakyrelu(input, bias, empty, 3, 0,
negative_slope, scale)
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale
return out
@staticmethod
def backward(ctx, grad_output):
out, = ctx.saved_tensors
grad_input, grad_bias = FusedBiasLeakyReLUFunctionBackward.apply(
grad_output, out, ctx.negative_slope, ctx.scale)
return grad_input, grad_bias, None, None
class FusedBiasLeakyReLU(nn.Module):
"""Fused bias leaky ReLU.
This function is introduced in the StyleGAN2:
http://arxiv.org/abs/1912.04958
The bias term comes from the convolution operation. In addition, to keep
the variance of the feature map or gradients unchanged, they also adopt a
scale similarly with Kaiming initalization. However, since the
:math:`1 + \alpha^2` : is too small, we can just ignore it. Therefore, the
final sacle is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501
your own scale.
TODO: Implement the CPU version.
Args:
channel (int): The channnel number of the feature map.
negative_slope (float, optional): Same as nn.LeakyRelu.
Defaults to 0.2.
scale (float, optional): A scalar to adjust the variance of the feature
map. Defaults to 2**0.5.
"""
def __init__(self, num_channels, negative_slope=0.2, scale=2**0.5):
super(FusedBiasLeakyReLU, self).__init__()
self.bias = nn.Parameter(torch.zeros(num_channels))
self.negative_slope = negative_slope
self.scale = scale
def forward(self, input):
return fused_bias_leakyrelu(input, self.bias, self.negative_slope,
self.scale)
def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5):
"""Fused bias leaky ReLU function.
This function is introduced in the StyleGAN2:
http://arxiv.org/abs/1912.04958
The bias term comes from the convolution operation. In addition, to keep
the variance of the feature map or gradients unchanged, they also adopt a
scale similarly with Kaiming initalization. However, since the
:math:`1 + \alpha^2` : is too small, we can just ignore it. Therefore, the
final sacle is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501
your own scale.
Args:
input (torch.Tensor): Input feature map.
bias (nn.Parameter): The bias from convolution operation.
negative_slope (float, optional): Same as nn.LeakyRelu.
Defaults to 0.2.
scale (float, optional): A scalar to adjust the variance of the feature
map. Defaults to 2**0.5.
Returns:
torch.Tensor: Feature map after non-linear activation.
"""
if not input.is_cuda:
return bias_leakyrelu_ref(input, bias, negative_slope, scale)
return FusedBiasLeakyReLUFunction.apply(input, bias, negative_slope, scale)
def bias_leakyrelu_ref(x, bias, negative_slope=0.2, scale=2**0.5):
if bias is not None:
assert bias.ndim == 1
assert bias.shape[0] == x.shape[1]
x = x + bias.reshape([-1 if i == 1 else 1 for i in range(x.ndim)])
x = F.leaky_relu(x, negative_slope)
if scale != 1:
x = x * scale
return x
# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501
import torch
from torch.autograd import Function
from torch.nn import functional as F
from ..utils import ext_loader
upfirdn2d_ext = ext_loader.load_ext('_ext', ['upfirdn2d'])
class UpFirDn2dBackward(Function):
@staticmethod
def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad,
in_size, out_size):
up_x, up_y = up
down_x, down_y = down
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
grad_input = upfirdn2d_ext.upfirdn2d(
grad_output,
grad_kernel,
down_x,
down_y,
up_x,
up_y,
g_pad_x0,
g_pad_x1,
g_pad_y0,
g_pad_y1,
)
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2],
in_size[3])
ctx.save_for_backward(kernel)
pad_x0, pad_x1, pad_y0, pad_y1 = pad
ctx.up_x = up_x
ctx.up_y = up_y
ctx.down_x = down_x
ctx.down_y = down_y
ctx.pad_x0 = pad_x0
ctx.pad_x1 = pad_x1
ctx.pad_y0 = pad_y0
ctx.pad_y1 = pad_y1
ctx.in_size = in_size
ctx.out_size = out_size
return grad_input
@staticmethod
def backward(ctx, gradgrad_input):
kernel, = ctx.saved_tensors
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2],
ctx.in_size[3], 1)
gradgrad_out = upfirdn2d_ext.upfirdn2d(
gradgrad_input,
kernel,
ctx.up_x,
ctx.up_y,
ctx.down_x,
ctx.down_y,
ctx.pad_x0,
ctx.pad_x1,
ctx.pad_y0,
ctx.pad_y1,
)
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
# ctx.out_size[1], ctx.in_size[3])
gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1],
ctx.out_size[0], ctx.out_size[1])
return gradgrad_out, None, None, None, None, None, None, None, None
class UpFirDn2d(Function):
@staticmethod
def forward(ctx, input, kernel, up, down, pad):
up_x, up_y = up
down_x, down_y = down
pad_x0, pad_x1, pad_y0, pad_y1 = pad
kernel_h, kernel_w = kernel.shape
batch, channel, in_h, in_w = input.shape
ctx.in_size = input.shape
input = input.reshape(-1, in_h, in_w, 1)
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
ctx.out_size = (out_h, out_w)
ctx.up = (up_x, up_y)
ctx.down = (down_x, down_y)
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
g_pad_x0 = kernel_w - pad_x0 - 1
g_pad_y0 = kernel_h - pad_y0 - 1
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x,
down_y, pad_x0, pad_x1, pad_y0, pad_y1)
# out = out.view(major, out_h, out_w, minor)
out = out.view(-1, channel, out_h, out_w)
return out
@staticmethod
def backward(ctx, grad_output):
kernel, grad_kernel = ctx.saved_tensors
grad_input = UpFirDn2dBackward.apply(
grad_output,
kernel,
grad_kernel,
ctx.up,
ctx.down,
ctx.pad,
ctx.g_pad,
ctx.in_size,
ctx.out_size,
)
return grad_input, None, None, None, None
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
"""UpFRIDn for 2d features.
UpFIRDn is short for upsample, apply FIR filter and downsample. More
details can be found in:
https://www.mathworks.com/help/signal/ref/upfirdn.html
Args:
input (Tensor): Tensor with shape of (n, c, h, w).
kernel (Tensor): Filter kernel.
up (int, optional): Upsampling factor. Defaults to 1.
down (int, optional): Downsampling factor. Defaults to 1.
pad (tuple[int], optional): Padding for tensors, (x_pad, y_pad).
Defaults to (0, 0).
Returns:
Tensor: Tensor after UpFIRDn.
"""
if input.device.type == 'cpu':
out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0],
pad[1], pad[0], pad[1])
else:
out = UpFirDn2d.apply(input, kernel, (up, up), (down, down),
(pad[0], pad[1], pad[0], pad[1]))
return out
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
pad_y0, pad_y1):
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(
out,
[0, 0,
max(pad_x0, 0),
max(pad_x1, 0),
max(pad_y0, 0),
max(pad_y1, 0)])
out = out[:,
max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
out = out.permute(0, 3, 1, 2)
out = out.reshape(
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)
import pytest
import torch
from torch.autograd import gradcheck, gradgradcheck
class TestFusedBiasLeakyReLU(object):
@classmethod
def setup_class(cls):
if not torch.cuda.is_available():
return
cls.input_tensor = torch.randn((2, 2, 2, 2), requires_grad=True).cuda()
cls.bias = torch.zeros(2, requires_grad=True).cuda()
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_gradient(self):
from mmcv.ops import FusedBiasLeakyReLU
gradcheck(
FusedBiasLeakyReLU(2).cuda(),
self.input_tensor,
eps=1e-4,
atol=1e-3)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_gradgradient(self):
from mmcv.ops import FusedBiasLeakyReLU
gradgradcheck(
FusedBiasLeakyReLU(2).cuda(),
self.input_tensor,
eps=1e-4,
atol=1e-3)
import pytest
import torch
from torch.autograd import gradcheck, gradgradcheck
class TestUpFirDn2d(object):
"""Unit test for UpFirDn2d.
Here, we just test the basic case of upsample version. More gerneal tests
will be included in other unit test for UpFirDnUpsample and
UpFirDnDownSample modules.
"""
@classmethod
def setup_class(cls):
kernel_1d = torch.tensor([1., 3., 3., 1.])
cls.kernel = kernel_1d[:, None] * kernel_1d[None, :]
cls.kernel = cls.kernel / cls.kernel.sum()
cls.factor = 2
pad = cls.kernel.shape[0] - cls.factor
cls.pad = ((pad + 1) // 2 + cls.factor - 1, pad // 2)
cls.input_tensor = torch.randn((2, 3, 4, 4), requires_grad=True)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_upfirdn2d(self):
from mmcv.ops import upfirdn2d
gradcheck(
upfirdn2d,
(self.input_tensor.cuda(), self.kernel.type_as(
self.input_tensor).cuda(), self.factor, 1, self.pad),
eps=1e-4,
atol=1e-3)
gradgradcheck(
upfirdn2d,
(self.input_tensor.cuda(), self.kernel.type_as(
self.input_tensor).cuda(), self.factor, 1, self.pad),
eps=1e-4,
atol=1e-3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment