Unverified Commit 2a3d2d4a authored by Leojc's avatar Leojc Committed by GitHub
Browse files

Reimplement cc_attention using pure pytorch (#1201)



* Reimplement cc_attention using pure pytorch

* fix: avoid BC-Breaking

* delete cc_attention related cpp and cuda files

* delete cc_attention related lines in pybind.cpp

* make out Tensor contiguous.

* remove unneeded lines.

* Update mmcv/ops/cc_attention.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update TestCrissCrossAttention

* passing pre-commit

* Update docstring of CrissCrossAttention

* Update docstring of CrissCrossAttention

* Update mmcv/ops/cc_attention.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* [docs]polish the docstring

* [Docs] Polish the docstring
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 642d2818
...@@ -2,77 +2,47 @@ ...@@ -2,77 +2,47 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd.function import once_differentiable
from mmcv.cnn import PLUGIN_LAYERS, Scale from mmcv.cnn import PLUGIN_LAYERS, Scale
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['ca_forward', 'ca_backward', 'ca_map_forward', 'ca_map_backward'])
def NEG_INF_DIAG(n, device):
"""Returns a diagonal matrix of size [n, n].
class CAWeightFunction(torch.autograd.Function): The diagonal are all "-inf". This is for avoiding calculating the
overlapped element in the Criss-Cross twice.
"""
return torch.diag(torch.tensor(float('-inf')).to(device).repeat(n), 0)
@staticmethod
def symbolic(g, t, f):
return g.op('mmcv::MMCVCAWeight', t, f)
@staticmethod @PLUGIN_LAYERS.register_module()
def forward(ctx, t, f): class CrissCrossAttention(nn.Module):
n, c, h, w = t.size() """Criss-Cross Attention Module.
weight = torch.zeros(n, h + w - 1, h, w).to(t.device)
ext_module.ca_forward(t, f, weight)
ctx.save_for_backward(t, f)
return weight
@staticmethod
@once_differentiable
def backward(ctx, dw):
t, f = ctx.saved_tensors
dt = torch.zeros_like(t)
df = torch.zeros_like(f)
ext_module.ca_backward(dw, t, f, dt, df)
return dt, df
class CAMapFunction(torch.autograd.Function):
@staticmethod
def symbolic(g, weight, v):
return g.op('mmcv::MMCVCAMap', weight, v)
@staticmethod
def forward(ctx, weight, v):
out = torch.zeros_like(v)
ext_module.ca_map_forward(weight, v, out)
ctx.save_for_backward(weight, v)
return out
@staticmethod
@once_differentiable
def backward(ctx, dout):
weight, v = ctx.saved_tensors
dw = torch.zeros_like(weight)
dv = torch.zeros_like(v)
ext_module.ca_map_backward(dout, weight, v, dw, dv)
return dw, dv .. note::
Before v1.3.13, we use a CUDA op. Since v1.3.13, we switch
to a pure PyTorch and equivalent implementation. For more
details, please refer to https://github.com/open-mmlab/mmcv/pull/1201.
Speed comparison for one forward pass
ca_weight = CAWeightFunction.apply - Input size: [2,512,97,97]
ca_map = CAMapFunction.apply - Device: 1 NVIDIA GeForce RTX 2080 Ti
+-----------------------+---------------+------------+---------------+
| |PyTorch version|CUDA version|Relative speed |
+=======================+===============+============+===============+
|with torch.no_grad() |0.00554402 s |0.0299619 s |5.4x |
+-----------------------+---------------+------------+---------------+
|no with torch.no_grad()|0.00562803 s |0.0301349 s |5.4x |
+-----------------------+---------------+------------+---------------+
@PLUGIN_LAYERS.register_module() Args:
class CrissCrossAttention(nn.Module): in_channels (int): Channels of the input feature map.
"""Criss-Cross Attention Module.""" """
def __init__(self, in_channels): def __init__(self, in_channels):
super(CrissCrossAttention, self).__init__() super().__init__()
self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1) self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1) self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
self.value_conv = nn.Conv2d(in_channels, in_channels, 1) self.value_conv = nn.Conv2d(in_channels, in_channels, 1)
...@@ -80,14 +50,30 @@ class CrissCrossAttention(nn.Module): ...@@ -80,14 +50,30 @@ class CrissCrossAttention(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
def forward(self, x): def forward(self, x):
proj_query = self.query_conv(x) """forward function of Criss-Cross Attention.
proj_key = self.key_conv(x)
proj_value = self.value_conv(x) Args:
x (Tensor): Input feature. \
shape (batch_size, in_channels, height, width)
Returns:
Tensor: Output of the layer, with shape of \
(batch_size, in_channels, height, width)
"""
B, C, H, W = x.size()
query = self.query_conv(x)
key = self.key_conv(x)
value = self.value_conv(x)
energy_H = torch.einsum('bchw,bciw->bwhi', query, key) + NEG_INF_DIAG(
H, query.device)
energy_H = energy_H.transpose(1, 2)
energy_W = torch.einsum('bchw,bchj->bhwj', query, key)
attn = F.softmax(
torch.cat([energy_H, energy_W], dim=-1), dim=-1) # [B,H,W,(H+W)]
out = torch.einsum('bciw,bhwi->bchw', value, attn[..., :H])
out += torch.einsum('bchj,bhwj->bchw', value, attn[..., H:])
energy = ca_weight(proj_query, proj_key)
attention = F.softmax(energy, 1)
out = ca_map(attention, proj_value)
out = self.gamma(out) + x out = self.gamma(out) + x
out = out.contiguous()
return out return out
......
// Copyright (c) OpenMMLab. All rights reserved
#ifndef CC_ATTENTION_CUDA_KERNEL_CUH
#define CC_ATTENTION_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
template <typename T>
__global__ void ca_forward_kernel(const T *t, const T *f, T *weight, int num,
int chn, int height, int width) {
int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width;
int len = height + width - 1;
int z = blockIdx.z % len;
int batch = blockIdx.z / len;
if (x < width && y < height) {
T *weight_ptr = weight + (batch * len + z) * sp + y * width + x;
const int t_offset = y * width + x;
const int j = (z - width < y) ? z - width : z - width + 1;
const int f_offset = z < width ? y * width + z : j * width + x;
for (int plane = 0; plane < chn; ++plane) {
const int tf_base = (batch * chn + plane) * sp;
*weight_ptr += t[tf_base + t_offset] * f[tf_base + f_offset];
}
}
}
template <typename T>
__global__ void ca_backward_kernel_t(const T *dw, const T *t, const T *f, T *dt,
int num, int chn, int height, int width) {
int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width;
int len = height + width - 1;
int plane = blockIdx.z % chn;
int batch = blockIdx.z / chn;
if (x < width && y < height) {
for (int i = 0; i < width; ++i) {
T _dw = dw[(batch * len + i) * sp + y * width + x];
T _f = f[(batch * chn + plane) * sp + y * width + i];
dt[(batch * chn + plane) * sp + y * width + x] += _dw * _f;
}
for (int i = 0; i < height; ++i) {
if (i == y) continue;
int j = i < y ? i : i - 1;
T _dw = dw[(batch * len + width + j) * sp + y * width + x];
T _f = f[(batch * chn + plane) * sp + i * width + x];
dt[(batch * chn + plane) * sp + y * width + x] += _dw * _f;
}
}
}
template <typename T>
__global__ void ca_backward_kernel_f(const T *dw, const T *t, const T *f, T *df,
int num, int chn, int height, int width) {
int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width;
int len = height + width - 1;
int plane = blockIdx.z % chn;
int batch = blockIdx.z / chn;
if (x < width && y < height) {
for (int i = 0; i < width; ++i) {
T _dw = dw[(batch * len + x) * sp + y * width + i];
T _t = t[(batch * chn + plane) * sp + y * width + i];
df[(batch * chn + plane) * sp + y * width + x] += _dw * _t;
}
for (int i = 0; i < height; ++i) {
if (i == y) continue;
int j = i > y ? y : y - 1;
T _dw = dw[(batch * len + width + j) * sp + i * width + x];
T _t = t[(batch * chn + plane) * sp + i * width + x];
df[(batch * chn + plane) * sp + y * width + x] += _dw * _t;
}
}
}
template <typename T>
__global__ void ca_map_forward_kernel(const T *weight, const T *g, T *out,
int num, int chn, int height, int width) {
int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width;
int len = height + width - 1;
int plane = blockIdx.z % chn;
int batch = blockIdx.z / chn;
if (x < width && y < height) {
for (int i = 0; i < width; ++i) {
T _g = g[(batch * chn + plane) * sp + y * width + i];
T _w = weight[(batch * len + i) * sp + y * width + x];
out[(batch * chn + plane) * sp + y * width + x] += _g * _w;
}
for (int i = 0; i < height; ++i) {
if (i == y) continue;
int j = i < y ? i : i - 1;
T _g = g[(batch * chn + plane) * sp + i * width + x];
T _w = weight[(batch * len + width + j) * sp + y * width + x];
out[(batch * chn + plane) * sp + y * width + x] += _g * _w;
}
}
}
template <typename T>
__global__ void ca_map_backward_kernel_w(const T *dout, const T *weight,
const T *g, T *dw, int num, int chn,
int height, int width) {
int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width;
int len = height + width - 1;
int z = blockIdx.z % len;
int batch = blockIdx.z / len;
if (x < width && y < height) {
int widx = (batch * len + z) * sp + y * width + x;
int dout_idx = batch * chn * sp + y * width + x;
int gidx = batch * chn * sp;
if (z < width) {
gidx += y * width + z;
} else {
int j = z - width;
j = j < y ? j : j + 1;
gidx += j * width + x;
}
for (int plane = 0; plane < chn; plane++) {
dw[widx] += dout[dout_idx + plane * sp] * g[gidx + plane * sp];
}
}
}
template <typename T>
__global__ void ca_map_backward_kernel_g(const T *dout, const T *weight,
const T *g, T *dg, int num, int chn,
int height, int width) {
int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width;
int len = height + width - 1;
int plane = blockIdx.z % chn;
int batch = blockIdx.z / chn;
int index = (batch * chn + plane) * sp + y * width + x;
if (x < width && y < height) {
for (int i = 0; i < width; ++i) {
dg[index] += dout[(batch * chn + plane) * sp + y * width + i] *
weight[(batch * len + x) * sp + y * width + i];
}
for (int i = 0; i < height; ++i) {
if (i == y) continue;
int j = i > y ? y : y - 1;
dg[index] += dout[(batch * chn + plane) * sp + i * width + x] *
weight[(batch * len + width + j) * sp + i * width + x];
}
}
}
#endif // CC_ATTENTION_CUDA_KERNEL_CUH
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
void CAForwardCUDAKernelLauncher(const Tensor t, const Tensor f, Tensor weight);
void CABackwardCUDAKernelLauncher(const Tensor dw, const Tensor t,
const Tensor f, Tensor dt, Tensor df);
void CAMapForwardCUDAKernelLauncher(const Tensor weight, const Tensor g,
Tensor out);
void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight,
const Tensor g, Tensor dw, Tensor dg);
void ca_forward_cuda(const Tensor t, const Tensor f, Tensor weight) {
CAForwardCUDAKernelLauncher(t, f, weight);
}
void ca_backward_cuda(const Tensor dw, const Tensor t, const Tensor f,
Tensor dt, Tensor df) {
CABackwardCUDAKernelLauncher(dw, t, f, dt, df);
}
void ca_map_forward_cuda(const Tensor weight, const Tensor g, Tensor out) {
CAMapForwardCUDAKernelLauncher(weight, g, out);
}
void ca_map_backward_cuda(const Tensor dout, const Tensor weight,
const Tensor g, Tensor dw, Tensor dg) {
CAMapBackwardCUDAKernelLauncher(dout, weight, g, dw, dg);
}
#endif
void ca_forward(const Tensor t, const Tensor f, Tensor weight) {
if (t.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(t);
CHECK_CUDA_INPUT(f);
CHECK_CUDA_INPUT(weight);
ca_forward_cuda(t, f, weight);
#else
AT_ERROR("ca is not compiled with GPU support");
#endif
} else {
AT_ERROR("ca is not implemented on the CPU");
}
}
void ca_backward(const Tensor dw, const Tensor t, const Tensor f, Tensor dt,
Tensor df) {
if (dw.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(dw);
CHECK_CUDA_INPUT(t);
CHECK_CUDA_INPUT(f);
CHECK_CUDA_INPUT(dt);
CHECK_CUDA_INPUT(df);
ca_backward_cuda(dw, t, f, dt, df);
#else
AT_ERROR("ca is not compiled with GPU support");
#endif
} else {
AT_ERROR("ca is not implemented on the CPU");
}
}
void ca_map_forward(const Tensor weight, const Tensor g, Tensor out) {
if (weight.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(weight);
CHECK_CUDA_INPUT(g);
CHECK_CUDA_INPUT(out);
ca_map_forward_cuda(weight, g, out);
#else
AT_ERROR("ca_map is not compiled with GPU support");
#endif
} else {
AT_ERROR("ca is not implemented on the CPU");
}
}
void ca_map_backward(const Tensor dout, const Tensor weight, const Tensor g,
Tensor dw, Tensor dg) {
if (dout.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(dout);
CHECK_CUDA_INPUT(weight);
CHECK_CUDA_INPUT(g);
CHECK_CUDA_INPUT(dw);
CHECK_CUDA_INPUT(dg);
ca_map_backward_cuda(dout, weight, g, dw, dg);
#else
AT_ERROR("ca_map is not compiled with GPU support");
#endif
} else {
AT_ERROR("ca is not implemented on the CPU");
}
}
// Copyright (c) OpenMMLab. All rights reserved
#include <parrots/compute/aten.hpp>
#include <parrots/extension.hpp>
#include <parrots/foundation/ssattrs.hpp>
#include "cc_attention_pytorch.h"
using namespace parrots;
#ifdef MMCV_WITH_CUDA
/*void ca_forward_cuda(const Tensor t, const Tensor f, Tensor weight);*/
void ca_forward_cuda_parrots(CudaContext &ctx, const SSElement &attr,
const OperatorBase::in_list_t &ins,
OperatorBase::out_list_t &outs) {
const auto &t = buildATensor(ctx, ins[0]);
const auto &f = buildATensor(ctx, ins[1]);
auto weight = buildATensor(ctx, outs[0]);
ca_forward_cuda(t, f, weight);
}
/* void ca_backward_cuda(const Tensor dw, const Tensor t, const Tensor f,
* Tensor dt, Tensor df)
*/
void ca_backward_cuda_parrots(CudaContext &ctx, const SSElement &attr,
const OperatorBase::in_list_t &ins,
OperatorBase::out_list_t &outs) {
const auto &dw = buildATensor(ctx, ins[0]);
const auto &t = buildATensor(ctx, ins[1]);
const auto &f = buildATensor(ctx, ins[2]);
auto dt = buildATensor(ctx, outs[0]);
auto df = buildATensor(ctx, outs[1]);
ca_backward_cuda(dw, t, f, dt, df);
}
/* void ca_map_forward_cuda(const Tensor weight, const Tensor g, Tensor out); */
void ca_map_forward_cuda_parrots(CudaContext &ctx, const SSElement &attr,
const OperatorBase::in_list_t &ins,
OperatorBase::out_list_t &outs) {
const auto &weight = buildATensor(ctx, ins[0]);
const auto &g = buildATensor(ctx, ins[1]);
auto out = buildATensor(ctx, outs[0]);
ca_map_forward_cuda(weight, g, out);
}
/* void ca_map_backward_cuda(const Tensor dout, const Tensor weight,
* const Tensor g, Tensor dw, Tensor dg);
*/
void ca_map_backward_cuda_parrots(CudaContext &ctx, const SSElement &attr,
const OperatorBase::in_list_t &ins,
OperatorBase::out_list_t &outs) {
const auto &dout = buildATensor(ctx, ins[0]);
const auto &weight = buildATensor(ctx, ins[1]);
const auto &g = buildATensor(ctx, ins[2]);
auto dw = buildATensor(ctx, outs[0]);
auto dg = buildATensor(ctx, outs[1]);
ca_map_backward_cuda(dout, weight, g, dw, dg);
}
PARROTS_EXTENSION_REGISTER(ca_forward)
.input(2)
.output(1)
.apply(ca_forward_cuda_parrots)
.done();
PARROTS_EXTENSION_REGISTER(ca_backward)
.input(3)
.output(2)
.apply(ca_backward_cuda_parrots)
.done();
PARROTS_EXTENSION_REGISTER(ca_map_forward)
.input(2)
.output(1)
.apply(ca_map_forward_cuda_parrots)
.done();
PARROTS_EXTENSION_REGISTER(ca_map_backward)
.input(3)
.output(2)
.apply(ca_map_backward_cuda_parrots)
.done();
#endif
// Copyright (c) OpenMMLab. All rights reserved
#ifndef CC_ATTENTION_PYTORCH_H
#define CC_ATTENTION_PYTORCH_H
#include <torch/extension.h>
using namespace at;
void ca_forward_cuda(const Tensor t, const Tensor f, Tensor weight);
void ca_backward_cuda(const Tensor dw, const Tensor t, const Tensor f,
Tensor dt, Tensor df);
void ca_map_forward_cuda(const Tensor weight, const Tensor g, Tensor out);
void ca_map_backward_cuda(const Tensor dout, const Tensor weight,
const Tensor g, Tensor dw, Tensor dg);
#endif // CC_ATTENTION_PYTORCH_H
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
void CAForwardCUDAKernelLauncher(const Tensor t, const Tensor f, Tensor weight);
void CABackwardCUDAKernelLauncher(const Tensor dw, const Tensor t,
const Tensor f, Tensor dt, Tensor df);
void CAMapForwardCUDAKernelLauncher(const Tensor weight, const Tensor g,
Tensor out);
void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight,
const Tensor g, Tensor dw, Tensor dg);
void ca_forward_cuda(const Tensor t, const Tensor f, Tensor weight) {
CAForwardCUDAKernelLauncher(t, f, weight);
}
void ca_backward_cuda(const Tensor dw, const Tensor t, const Tensor f,
Tensor dt, Tensor df) {
CABackwardCUDAKernelLauncher(dw, t, f, dt, df);
}
void ca_map_forward_cuda(const Tensor weight, const Tensor g, Tensor out) {
CAMapForwardCUDAKernelLauncher(weight, g, out);
}
void ca_map_backward_cuda(const Tensor dout, const Tensor weight,
const Tensor g, Tensor dw, Tensor dg) {
CAMapBackwardCUDAKernelLauncher(dout, weight, g, dw, dg);
}
#endif
void ca_forward(const Tensor t, const Tensor f, Tensor weight) {
if (t.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(t);
CHECK_CUDA_INPUT(f);
CHECK_CUDA_INPUT(weight);
ca_forward_cuda(t, f, weight);
#else
AT_ERROR("ca is not compiled with GPU support");
#endif
} else {
AT_ERROR("ca is not implemented on the CPU");
}
}
void ca_backward(const Tensor dw, const Tensor t, const Tensor f, Tensor dt,
Tensor df) {
if (dw.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(dw);
CHECK_CUDA_INPUT(t);
CHECK_CUDA_INPUT(f);
CHECK_CUDA_INPUT(dt);
CHECK_CUDA_INPUT(df);
ca_backward_cuda(dw, t, f, dt, df);
#else
AT_ERROR("ca is not compiled with GPU support");
#endif
} else {
AT_ERROR("ca is not implemented on the CPU");
}
}
void ca_map_forward(const Tensor weight, const Tensor g, Tensor out) {
if (weight.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(weight);
CHECK_CUDA_INPUT(g);
CHECK_CUDA_INPUT(out);
ca_map_forward_cuda(weight, g, out);
#else
AT_ERROR("ca_map is not compiled with GPU support");
#endif
} else {
AT_ERROR("ca is not implemented on the CPU");
}
}
void ca_map_backward(const Tensor dout, const Tensor weight, const Tensor g,
Tensor dw, Tensor dg) {
if (dout.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(dout);
CHECK_CUDA_INPUT(weight);
CHECK_CUDA_INPUT(g);
CHECK_CUDA_INPUT(dw);
CHECK_CUDA_INPUT(dg);
ca_map_backward_cuda(dout, weight, g, dw, dg);
#else
AT_ERROR("ca_map is not compiled with GPU support");
#endif
} else {
AT_ERROR("ca is not implemented on the CPU");
}
}
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/LikeLy-Journey/SegmenTron/blob/master/segmentron/modules/csrc/criss_cross_attention/ca_cuda.cu
#include <THC/THC.h>
#include <THC/THCDeviceUtils.cuh>
#include "cc_attention_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void CAForwardCUDAKernelLauncher(const Tensor t, const Tensor f,
Tensor weight) {
AT_ASSERTM(t.device().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(f.device().is_cuda(), "input must be a CUDA tensor");
auto n = t.size(0);
auto c = t.size(1);
auto h = t.size(2);
auto w = t.size(3);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Run kernel
dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y;
int d3 = h + w - 1;
dim3 blocks(d1, d2, d3 * n);
AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_forward", [&] {
ca_forward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
t.contiguous().data_ptr<scalar_t>(),
f.contiguous().data_ptr<scalar_t>(),
weight.contiguous().data_ptr<scalar_t>(), n, c, h, w);
});
THCudaCheck(cudaGetLastError());
}
void CABackwardCUDAKernelLauncher(const Tensor dw, const Tensor t,
const Tensor f, Tensor dt, Tensor df) {
AT_ASSERTM(dw.device().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(t.device().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(f.device().is_cuda(), "input must be a CUDA tensor");
auto n = t.size(0);
auto c = t.size(1);
auto h = t.size(2);
auto w = t.size(3);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Run kernel
dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y;
int d3 = c * n;
dim3 blocks(d1, d2, d3);
AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_backward_kernel_t", [&] {
ca_backward_kernel_t<scalar_t><<<blocks, threads, 0, stream>>>(
dw.contiguous().data_ptr<scalar_t>(),
t.contiguous().data_ptr<scalar_t>(),
f.contiguous().data_ptr<scalar_t>(),
dt.contiguous().data_ptr<scalar_t>(), n, c, h, w);
});
AT_DISPATCH_FLOATING_TYPES(f.scalar_type(), "ca_backward_kernel_f", [&] {
ca_backward_kernel_f<scalar_t><<<blocks, threads, 0, stream>>>(
dw.contiguous().data_ptr<scalar_t>(),
t.contiguous().data_ptr<scalar_t>(),
f.contiguous().data_ptr<scalar_t>(),
df.contiguous().data_ptr<scalar_t>(), n, c, h, w);
});
THCudaCheck(cudaGetLastError());
}
void CAMapForwardCUDAKernelLauncher(const Tensor weight, const Tensor g,
Tensor out) {
AT_ASSERTM(weight.device().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(g.device().is_cuda(), "input must be a CUDA tensor");
auto n = g.size(0);
auto c = g.size(1);
auto h = g.size(2);
auto w = g.size(3);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Run kernel
dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y;
int d3 = c * n;
dim3 blocks(d1, d2, d3);
AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_forward", [&] {
ca_map_forward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
weight.contiguous().data_ptr<scalar_t>(),
g.contiguous().data_ptr<scalar_t>(),
out.contiguous().data_ptr<scalar_t>(), n, c, h, w);
});
THCudaCheck(cudaGetLastError());
}
void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight,
const Tensor g, Tensor dw, Tensor dg) {
AT_ASSERTM(dout.device().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(weight.device().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(g.device().is_cuda(), "input must be a CUDA tensor");
auto n = dout.size(0);
auto c = dout.size(1);
auto h = dout.size(2);
auto w = dout.size(3);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Run kernel
dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y;
int d3 = h + w - 1;
dim3 blocks(d1, d2, d3 * n);
AT_DISPATCH_FLOATING_TYPES(
weight.scalar_type(), "ca_map_backward_kernel_w", [&] {
ca_map_backward_kernel_w<scalar_t><<<blocks, threads, 0, stream>>>(
dout.contiguous().data_ptr<scalar_t>(),
weight.contiguous().data_ptr<scalar_t>(),
g.contiguous().data_ptr<scalar_t>(),
dw.contiguous().data_ptr<scalar_t>(), n, c, h, w);
});
d3 = c * n;
blocks = dim3(d1, d2, d3);
AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_backward_kernel_g", [&] {
ca_map_backward_kernel_g<scalar_t><<<blocks, threads, 0, stream>>>(
dout.contiguous().data_ptr<scalar_t>(),
weight.contiguous().data_ptr<scalar_t>(),
g.contiguous().data_ptr<scalar_t>(),
dg.contiguous().data_ptr<scalar_t>(), n, c, h, w);
});
THCudaCheck(cudaGetLastError());
}
...@@ -158,16 +158,6 @@ void sync_bn_backward_data(const Tensor grad_output, const Tensor weight, ...@@ -158,16 +158,6 @@ void sync_bn_backward_data(const Tensor grad_output, const Tensor weight,
const Tensor norm, const Tensor std, const Tensor norm, const Tensor std,
Tensor grad_input); Tensor grad_input);
void ca_forward(const Tensor t, const Tensor f, Tensor weight);
void ca_backward(const Tensor dw, const Tensor t, const Tensor f, Tensor dt,
Tensor df);
void ca_map_forward(const Tensor weight, const Tensor g, Tensor out);
void ca_map_backward(const Tensor dout, const Tensor weight, const Tensor g,
Tensor dw, Tensor dg);
void psamask_forward(const Tensor input, Tensor output, const int psa_type, void psamask_forward(const Tensor input, Tensor output, const int psa_type,
const int num_, const int h_feature, const int w_feature, const int num_, const int h_feature, const int w_feature,
const int h_mask, const int w_mask, const int half_h_mask, const int h_mask, const int w_mask, const int half_h_mask,
...@@ -385,15 +375,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -385,15 +375,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"sync_bn backward_data", py::arg("grad_output"), py::arg("weight"), "sync_bn backward_data", py::arg("grad_output"), py::arg("weight"),
py::arg("grad_weight"), py::arg("grad_bias"), py::arg("norm"), py::arg("grad_weight"), py::arg("grad_bias"), py::arg("norm"),
py::arg("std"), py::arg("grad_input")); py::arg("std"), py::arg("grad_input"));
m.def("ca_forward", &ca_forward, "ccattention forward", py::arg("t"),
py::arg("f"), py::arg("weight"));
m.def("ca_backward", &ca_backward, "ccattention backward", py::arg("dw"),
py::arg("t"), py::arg("f"), py::arg("dt"), py::arg("df"));
m.def("ca_map_forward", &ca_map_forward, "ccattention map forward",
py::arg("weight"), py::arg("g"), py::arg("out"));
m.def("ca_map_backward", &ca_map_backward, "ccattention map backward",
py::arg("dout"), py::arg("weight"), py::arg("g"), py::arg("dw"),
py::arg("dg"));
m.def("psamask_forward", &psamask_forward, "PSAMASK forward (CPU/CUDA)", m.def("psamask_forward", &psamask_forward, "PSAMASK forward (CPU/CUDA)",
py::arg("input"), py::arg("output"), py::arg("psa_type"), py::arg("input"), py::arg("output"), py::arg("psa_type"),
py::arg("num_"), py::arg("h_feature"), py::arg("w_feature"), py::arg("num_"), py::arg("h_feature"), py::arg("w_feature"),
......
...@@ -17,8 +17,7 @@ class Loss(nn.Module): ...@@ -17,8 +17,7 @@ class Loss(nn.Module):
class TestCrissCrossAttention(object): class TestCrissCrossAttention(object):
def test_cc_attention(self): def test_cc_attention(self):
if not torch.cuda.is_available(): device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
return
from mmcv.ops import CrissCrossAttention from mmcv.ops import CrissCrossAttention
loss_func = Loss() loss_func = Loss()
...@@ -42,9 +41,9 @@ class TestCrissCrossAttention(object): ...@@ -42,9 +41,9 @@ class TestCrissCrossAttention(object):
channel = shape[1] channel = shape[1]
cca = CrissCrossAttention(channel) cca = CrissCrossAttention(channel)
cca.cuda() cca.to(device)
input = input.cuda() input = input.to(device)
label = label.cuda() label = label.to(device)
cca.train() cca.train()
test_output = cca(input) test_output = cca(input)
test_loss = loss_func(test_output, label) test_loss = loss_func(test_output, label)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment