Unverified Commit c8e85b28 authored by Jintao Lin's avatar Jintao Lin Committed by GitHub
Browse files

Add `tin_shift` function (#492)



* add tin shift

* add unittest

* add docstring

* add docstring

* parrots for tin_shift

* fix lint

* fix lint
Co-authored-by: default avatarjiaomenglei <jiaomenglei@sensetime.com>
parent 15537c5a
...@@ -20,6 +20,7 @@ from .roi_align import RoIAlign, roi_align ...@@ -20,6 +20,7 @@ from .roi_align import RoIAlign, roi_align
from .roi_pool import RoIPool, roi_pool from .roi_pool import RoIPool, roi_pool
from .saconv import SAConv2d from .saconv import SAConv2d
from .sync_bn import SyncBatchNorm from .sync_bn import SyncBatchNorm
from .tin_shift import TINShift, tin_shift
from .wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d from .wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d
__all__ = [ __all__ = [
...@@ -34,5 +35,5 @@ __all__ = [ ...@@ -34,5 +35,5 @@ __all__ = [
'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask', 'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign', 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d' 'SAConv2d', 'TINShift', 'tin_shift'
] ]
#include "parrots_cpp_helper.hpp"
void TINShiftForwardCUDAKernelLauncher(const DArrayLite input,
const DArrayLite shift,
DArrayLite output, cudaStream_t stream);
void TINShiftBackwardCUDAKernelLauncher(const DArrayLite grad_output,
const DArrayLite shift,
DArrayLite grad_input,
cudaStream_t stream);
void tin_shift_forward_cuda(CudaContext &ctx, const SSElement &attr,
const OperatorBase::in_list_t &ins,
OperatorBase::out_list_t &outs) {
const auto &input = ins[0];
const auto &shift = ins[1];
auto &output = outs[0];
cudaStream_t stream = getStreamNative<CudaDevice>(ctx.getStream());
TINShiftForwardCUDAKernelLauncher(input, shift, output, stream);
}
void tin_shift_backward_cuda(CudaContext &ctx, const SSElement &attr,
const OperatorBase::in_list_t &ins,
OperatorBase::out_list_t &outs) {
const auto &grad_output = ins[0];
const auto &shift = ins[1];
auto &grad_input = outs[0];
cudaStream_t stream = getStreamNative<CudaDevice>(ctx.getStream());
TINShiftBackwardCUDAKernelLauncher(grad_output, shift, grad_input, stream);
}
PARROTS_EXTENSION_REGISTER(tin_shift_forward)
.input(2)
.output(1)
.apply(tin_shift_forward_cuda)
.done();
PARROTS_EXTENSION_REGISTER(tin_shift_backward)
.input(2)
.output(1)
.apply(tin_shift_backward_cuda)
.done();
\ No newline at end of file
#include "parrots_cuda_helper.hpp"
#include "tin_shift_cuda_kernel.cuh"
void TINShiftForwardCUDAKernelLauncher(const DArrayLite input,
const DArrayLite shift,
DArrayLite output, cudaStream_t stream) {
int output_size = output.size();
int batch_size = input.dim(0);
int t_size = input.dim(1);
int channels = input.dim(2);
int hw_size = input.dim(3);
int group_size = shift.dim(1);
int group_channel = channels / group_size;
int num_kernels = batch_size * hw_size * channels;
PARROTS_DISPATCH_FLOATING_TYPES_AND_HALF(
input.elemType().prim(), ([&] {
tin_shift_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0, stream>>>(
output_size, input.ptr<scalar_t>(), shift.ptr<int>(),
output.ptr<scalar_t>(), batch_size, channels, t_size, hw_size,
group_size, group_channel);
}));
PARROTS_CUDA_CHECK(cudaGetLastError());
}
void TINShiftBackwardCUDAKernelLauncher(const DArrayLite grad_output,
const DArrayLite shift,
DArrayLite grad_input,
cudaStream_t stream) {
int output_size = grad_output.size();
int batch_size = grad_output.dim(0);
int t_size = grad_output.dim(1);
int channels = grad_output.dim(2);
int hw_size = grad_output.dim(3);
int group_size = shift.dim(1);
int group_channel = channels / group_size;
int num_kernels = batch_size * hw_size * channels;
PARROTS_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.elemType().prim(), ([&] {
tin_shift_backward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0, stream>>>(
output_size, grad_output.ptr<scalar_t>(), shift.ptr<int>(),
grad_input.ptr<scalar_t>(), batch_size, channels, t_size,
hw_size, group_size, group_channel);
}));
PARROTS_CUDA_CHECK(cudaGetLastError());
}
...@@ -155,6 +155,11 @@ void psamask_backward(Tensor grad_output, const Tensor grad_input, ...@@ -155,6 +155,11 @@ void psamask_backward(Tensor grad_output, const Tensor grad_input,
const int w_feature, const int h_mask, const int w_mask, const int w_feature, const int h_mask, const int w_mask,
const int half_h_mask, const int half_w_mask); const int half_h_mask, const int half_w_mask);
void tin_shift_forward(const Tensor input, const Tensor shift, Tensor output);
void tin_shift_backward(Tensor grad_output, const Tensor shift,
const Tensor grad_input);
Tensor bottom_pool_forward(Tensor input); Tensor bottom_pool_forward(Tensor input);
Tensor bottom_pool_backward(Tensor input, Tensor grad_output); Tensor bottom_pool_backward(Tensor input, Tensor grad_output);
...@@ -329,6 +334,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -329,6 +334,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("num_"), py::arg("h_feature"), py::arg("w_feature"), py::arg("num_"), py::arg("h_feature"), py::arg("w_feature"),
py::arg("h_mask"), py::arg("w_mask"), py::arg("half_h_mask"), py::arg("h_mask"), py::arg("w_mask"), py::arg("half_h_mask"),
py::arg("half_w_mask")); py::arg("half_w_mask"));
m.def("tin_shift_forward", &tin_shift_forward, "tin_shift forward",
py::arg("input"), py::arg("shift"), py::arg("output"));
m.def("tin_shift_backward", &tin_shift_backward, "tin_shift backward",
py::arg("grad_output"), py::arg("shift"), py::arg("grad_input"));
m.def("bottom_pool_forward", &bottom_pool_forward, "Bottom Pool Forward", m.def("bottom_pool_forward", &bottom_pool_forward, "Bottom Pool Forward",
py::arg("input"), py::call_guard<py::gil_scoped_release>()); py::arg("input"), py::call_guard<py::gil_scoped_release>());
m.def("bottom_pool_backward", &bottom_pool_backward, "Bottom Pool Backward", m.def("bottom_pool_backward", &bottom_pool_backward, "Bottom Pool Backward",
......
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
void TINShiftForwardCUDAKernelLauncher(Tensor input, Tensor shift,
Tensor output);
void TINShiftBackwardCUDAKernelLauncher(Tensor grad_output, Tensor shift,
Tensor grad_input);
void tin_shift_forward_cuda(Tensor input, Tensor shift, Tensor output) {
TINShiftForwardCUDAKernelLauncher(input, shift, output);
}
void tin_shift_backward_cuda(Tensor grad_output, Tensor shift,
Tensor grad_input) {
TINShiftBackwardCUDAKernelLauncher(grad_output, shift, grad_input);
}
#endif
void tin_shift_forward(Tensor input, Tensor shift, Tensor output) {
if (input.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(shift);
CHECK_CUDA_INPUT(output);
tin_shift_forward_cuda(input, shift, output);
#else
AT_ERROR("TINShift is not compiled with GPU support");
#endif
} else {
AT_ERROR("TINShift is not implemented on CPU");
}
}
void tin_shift_backward(Tensor grad_output, Tensor shift, Tensor grad_input) {
if (grad_output.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(grad_output);
CHECK_CUDA_INPUT(shift);
CHECK_CUDA_INPUT(grad_input);
tin_shift_backward_cuda(grad_output, shift, grad_input);
#else
AT_ERROR("TINShift is not compiled with GPU support");
#endif
} else {
AT_ERROR("TINShift is not implemented on CPU");
}
}
#include "pytorch_cuda_helper.hpp"
#include "tin_shift_cuda_kernel.cuh"
void TINShiftForwardCUDAKernelLauncher(Tensor input, Tensor shift,
Tensor output) {
int output_size = output.numel();
int batch_size = input.size(0);
int t_size = input.size(1);
int channels = input.size(2);
int hw_size = input.size(3);
int group_size = shift.size(1);
int group_channel = channels / group_size;
int num_kernels = batch_size * hw_size * channels;
at::cuda::CUDAGuard device_guard(input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "tin_shift_forward_cuda_kernel", [&] {
tin_shift_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0, stream>>>(
output_size, input.data_ptr<scalar_t>(), shift.data_ptr<int>(),
output.data_ptr<scalar_t>(), batch_size, channels, t_size,
hw_size, group_size, group_channel);
});
AT_CUDA_CHECK(cudaGetLastError());
}
void TINShiftBackwardCUDAKernelLauncher(Tensor grad_output, Tensor shift,
Tensor grad_input) {
int output_size = grad_output.numel();
int batch_size = grad_output.size(0);
int t_size = grad_output.size(1);
int channels = grad_output.size(2);
int hw_size = grad_output.size(3);
int group_size = shift.size(1);
int group_channel = channels / group_size;
int num_kernels = batch_size * hw_size * channels;
at::cuda::CUDAGuard device_guard(grad_output.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(), "tin_shift_backward_cuda_kernel", [&] {
tin_shift_backward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0, stream>>>(
output_size, grad_output.data_ptr<scalar_t>(),
shift.data_ptr<int>(), grad_input.data_ptr<scalar_t>(),
batch_size, channels, t_size, hw_size, group_size,
group_channel);
});
AT_CUDA_CHECK(cudaGetLastError());
}
#ifndef TIN_SHIFT_CUDA_KERNEL_CUH
#define TIN_SHIFT_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
template <typename T>
__global__ void tin_shift_forward_cuda_kernel(
const int nthreads, const T* input, const int* shift, T* output,
const int batch_size, const int channels, const int t_size,
const int hw_size, const int group_size, const int group_channel) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
const int hw_index = index % hw_size;
const int j = (index / hw_size) % channels;
const int n_index = (index / hw_size / channels) % batch_size;
int group_id = j / group_channel;
int t_shift = shift[n_index * group_size + group_id];
int offset = n_index * t_size * hw_size * channels + hw_size * j + hw_index;
for (int i = 0; i < t_size; i++) {
int now_t = i + t_shift;
int data_id = i * hw_size * channels + offset;
if (now_t < 0 || now_t >= t_size) {
continue;
}
int out_id = now_t * hw_size * channels + offset;
output[out_id] = input[data_id];
}
}
}
template <typename T>
__global__ void tin_shift_backward_cuda_kernel(
const int nthreads, const T* input, const int* shift, T* output,
const int batch_size, const int channels, const int t_size,
const int hw_size, const int group_size, const int group_channel) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
const int hw_index = index % hw_size;
const int j = (index / hw_size) % channels;
const int n_index = (index / hw_size / channels) % batch_size;
int group_id = j / group_channel;
int t_shift = shift[n_index * group_size + group_id];
int offset = n_index * t_size * hw_size * channels + hw_size * j + hw_index;
for (int i = 0; i < t_size; i++) {
int now_t = i + t_shift;
int data_id = i * hw_size * channels + offset;
if (now_t < 0 || now_t >= t_size) {
continue;
}
int out_id = now_t * hw_size * channels + offset;
output[out_id] = input[data_id];
}
}
}
#endif // TIN_SHIFT_CUDA_KERNEL_CUH
# Code reference from "Temporal Interlacing Network"
# https://github.com/deepcs233/TIN/blob/master/cuda_shift/rtc_wrap.py
# Hao Shao, Shengju Qian, Yu Liu
# shaoh19@mails.tsinghua.edu.cn, sjqian@cse.cuhk.edu.hk, yuliu@ee.cuhk.edu.hk
import torch
import torch.nn as nn
from torch.autograd import Function
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext',
['tin_shift_forward', 'tin_shift_backward'])
class TINShiftFunction(Function):
@staticmethod
def forward(ctx, input, shift):
ctx.save_for_backward(shift)
out = torch.zeros_like(input)
ext_module.tin_shift_forward(input, shift, out)
return out
@staticmethod
def backward(ctx, grad_output):
shift = ctx.saved_tensors[0]
data_grad_input = grad_output.new(*grad_output.size()).zero_()
shift_grad_input = shift.new(*shift.size()).zero_()
ext_module.tin_shift_backward(grad_output, shift, data_grad_input)
return data_grad_input, shift_grad_input
tin_shift = TINShiftFunction.apply
class TINShift(nn.Module):
"""Temporal Interlace Shift.
Temporal Interlace shift is a differentiable temporal-wise frame shifting
which is proposed in "Temporal Interlacing Network"
Please refer to https://arxiv.org/abs/2001.06499 for more details.
Code is modified from https://github.com/mit-han-lab/temporal-shift-module
"""
def forward(self, input, shift):
"""Perform temporal interlace shift.
Args:
input (Tensor): Feature map with shape [N, num_segments, C, H * W].
shift (Tensor): Shift tensor with shape [N, num_segments].
Returns:
Feature map after temporal interlace shift.
"""
return tin_shift(input, shift)
import os
import numpy as np
import pytest
import torch
_USING_PARROTS = True
try:
from parrots.autograd import gradcheck
except ImportError:
from torch.autograd import gradcheck
_USING_PARROTS = False
cur_dir = os.path.dirname(os.path.abspath(__file__))
inputs = ([[[[0.4369, -3.7571], [-1.1835, -1.6374], [0.9534, -0.1321]],
[[-0.4658, 0.2162], [-0.8135, -0.3903], [-0.1720, -0.0599]],
[[0.4851, 1.8224], [0.8973, 0.3779], [2.3454, 1.0319]],
[[0.0420, 0.3574], [0.7641, 0.2384], [0.2759, 0.4931]]],
[[[-0.5897, 0.7544], [1.0593, 0.8388], [-0.5732, 0.5692]],
[[-0.6766, -1.4657], [1.2362, 0.4913], [-1.1820, -1.4341]],
[[0.6476, -0.7391], [1.4314, -0.3522], [0.8401, -0.7757]],
[[1.4306, 0.9726], [1.0518, -0.8820], [-0.5129, -0.7876]]]])
shifts = [([[1, 0, 1, -2], [-2, 1, -1, 1]]), ([[2, 1, 2, -1], [-1, 2, 0, 2]])]
outputs = [([[[[0.4369, -3.7571], [-1.1835, -1.6374], [0.9534, -0.1321]],
[[-0.4658, 0.2162], [-0.8135, -0.3903], [-0.1720, -0.0599]],
[[0.4851, 1.8224], [0.8973, 0.3779], [2.3454, 1.0319]],
[[0.0420, 0.3574], [0.7641, 0.2384], [0.2759, 0.4931]]],
[[[0.6476, -0.7391], [1.4314, -0.3522], [0.8401, -0.7757]],
[[1.4306, 0.9726], [1.0518, -0.8820], [-0.5129, -0.7876]],
[[0.0000, 0.0000], [0.0000, 0.0000], [0.0000, 0.0000]],
[[0.0000, 0.0000], [0.0000, 0.0000], [0.0000, 0.0000]]]]),
([[[[0.4369, -3.7571], [-1.1835, -1.6374], [0.9534, -0.1321]],
[[-0.4658, 0.2162], [-0.8135, -0.3903], [-0.1720, -0.0599]],
[[0.4851, 1.8224], [0.8973, 0.3779], [2.3454, 1.0319]],
[[0.0420, 0.3574], [0.7641, 0.2384], [0.2759, 0.4931]]],
[[[-0.6766, -1.4657], [1.2362, 0.4913], [-1.1820, -1.4341]],
[[0.6476, -0.7391], [1.4314, -0.3522], [0.8401, -0.7757]],
[[1.4306, 0.9726], [1.0518, -0.8820], [-0.5129, -0.7876]],
[[0.0000, 0.0000], [0.0000, 0.0000], [0.0000, 0.0000]]]])]
grads = [[[[[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]],
[[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]]],
[[[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]],
[[0., 0.], [0., 0.], [0., 0.]], [[0., 0.], [0., 0.], [0., 0.]]]],
[[[[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]],
[[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]]],
[[[1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.]],
[[1., 1.], [1., 1.], [1., 1.]], [[0., 0.], [0., 0.], [0., 0.]]]]]
def _test_tinshift_gradcheck(dtype):
try:
from mmcv.ops import tin_shift
except ModuleNotFoundError:
pytest.skip('TinShift op is not successfully compiled')
if dtype == torch.half:
pytest.skip('"add_cpu/sub_cpu" not implemented for Half')
for shift in shifts:
np_input = np.array(inputs)
np_shift = np.array(shift)
x = torch.tensor(
np_input, dtype=dtype, device='cuda', requires_grad=True)
shift = torch.tensor(np_shift, device='cuda').int()
if torch.__version__ == 'parrots':
gradcheck(tin_shift, (x, shift))
else:
gradcheck(tin_shift, (x, shift), atol=1, rtol=0.1)
def _test_tinshift_allclose(dtype):
try:
from mmcv.ops import tin_shift
except ModuleNotFoundError:
pytest.skip('TinShift op is not successfully compiled')
for shift, output, grad in zip(shifts, outputs, grads):
np_input = np.array(inputs)
np_shift = np.array(shift)
np_output = np.array(output)
np_grad = np.array(grad)
x = torch.tensor(
np_input, dtype=dtype, device='cuda', requires_grad=True)
shift = torch.tensor(np_shift, device='cuda').int()
output = tin_shift(x, shift)
output.backward(torch.ones_like(output))
assert np.allclose(
output.data.type(torch.float).cpu().numpy(), np_output, 1e-3)
assert np.allclose(
x.grad.data.type(torch.float).cpu().numpy(), np_grad, 1e-3)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
@pytest.mark.parametrize('dtype', [torch.float, torch.double, torch.half])
def test_tinshift(dtype):
_test_tinshift_allclose(dtype=dtype)
_test_tinshift_gradcheck(dtype=dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment