Unverified Commit 54a7ebb4 authored by ZhangShilong's avatar ZhangShilong Committed by GitHub
Browse files

[Feature]: support Multi-Scale-DeformAttention in deformable-detr (#878)

* add c++ ms_deform_atten

* fix cpp lint

* fix cpp lint

* clang format

* remove cmakefile

* google style

* clang-format precommit

* use clang-format-lint-action

* add transformer base class

* add merge

* add docstr

* add pyargs

* fix according to commments

* resiger module

* change to use basemodule

* add _ between build function

* split the name

* fix according to comments

* fix lint and fix unitest

* fix cpp lint

* fix bug of deformdetr_atten

* fix drop out

* fix residual

* use CUDA_1D_KERNEL_LOOP
parent 0dd0c49a
import copy import copy
import math
import warnings import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv import ConfigDict from mmcv import ConfigDict
from mmcv.cnn import Linear, build_activation_layer, build_norm_layer from mmcv.cnn import (Linear, build_activation_layer, build_norm_layer,
constant_init, xavier_init)
from mmcv.ops.multi_scale_deform_attn import (
MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch)
from mmcv.runner.base_module import BaseModule from mmcv.runner.base_module import BaseModule
from mmcv.utils import build_from_cfg from mmcv.utils import build_from_cfg
from .registry import (ATTENTION, POSITIONAL_ENCODING, TRANSFORMER_LAYER, from .registry import (ATTENTION, POSITIONAL_ENCODING, TRANSFORMER_LAYER,
...@@ -135,6 +139,201 @@ class MultiheadAttention(BaseModule): ...@@ -135,6 +139,201 @@ class MultiheadAttention(BaseModule):
return residual + self.dropout(out) return residual + self.dropout(out)
@ATTENTION.register_module()
class MultiScaleDeformableAttention(BaseModule):
"""An attention module used in Deformable-Detr. `Deformable DETR:
Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_residual`.
Default: 0..
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims=256,
num_heads=8,
num_levels=4,
num_points=4,
im2col_step=64,
dropout=0.1,
norm_cfg=None,
init_cfg=None):
super().__init__(init_cfg)
if embed_dims % num_heads != 0:
raise ValueError(f'embed_dims must be divisible by num_heads, '
f'but got {embed_dims} and {num_heads}')
dim_per_head = embed_dims // num_heads
self.norm_cfg = norm_cfg
self.init_cfg = init_cfg
self.dropout = nn.Dropout(dropout)
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError(
'invalid input for _is_power_of_2: {} (type: {})'.format(
n, type(n)))
return (n & (n - 1) == 0) and n != 0
if not _is_power_of_2(dim_per_head):
warnings.warn(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.')
self.im2col_step = im2col_step
self.embed_dims = embed_dims
self.num_levels = num_levels
self.num_heads = num_heads
self.num_points = num_points
self.sampling_offsets = nn.Linear(
embed_dims, num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dims,
num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dims, embed_dims)
self.output_proj = nn.Linear(embed_dims, embed_dims)
self.init_weight()
def init_weight(self):
"""Default initialization for Parameters of Module."""
constant_init(self.sampling_offsets, 0.)
thetas = torch.arange(
self.num_heads,
dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init /
grid_init.abs().max(-1, keepdim=True)[0]).view(
self.num_heads, 1, 1,
2).repeat(1, self.num_levels, self.num_points, 1)
for i in range(self.num_points):
grid_init[:, :, i, :] *= i + 1
self.sampling_offsets.bias.data = grid_init.view(-1)
constant_init(self.attention_weights, val=0., bias=0.)
xavier_init(self.value_proj, distribution='uniform', bias=0.)
xavier_init(self.output_proj, distribution='uniform', bias=0.)
def forward(self,
query,
key,
value,
residual=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
**kwargs):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
residual (Tensor): The tensor used for addition, with the
same shape as `x`. Default None. If None, `x` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different level. With shape (num_levels, 2),
last dimension represent (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape (num_levels) and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if key is None:
key = query
if value is None:
value = key
if residual is None:
inp_residual = query
if query_pos is not None:
query = query + query_pos
# change to (bs, num_query ,embed_dims)
query = query.permute(1, 0, 2)
value = value.permute(1, 0, 2)
bs, num_query, _ = query.shape
bs, num_key, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_key
value = self.value_proj(value)
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], 0.0)
value = value.view(bs, num_key, self.num_heads, -1)
sampling_offsets = self.sampling_offsets(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_levels * self.num_points)
attention_weights = attention_weights.softmax(-1)
attention_weights = attention_weights.view(bs, num_query,
self.num_heads,
self.num_levels,
self.num_points)
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack(
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets \
/ offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.num_points \
* reference_points[:, :, None, :, None, 2:] \
* 0.5
else:
raise ValueError(
f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if torch.cuda.is_available():
output = MultiScaleDeformableAttnFunction.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
output = self.output_proj(output).permute(1, 0, 2)
# (num_query, bs ,embed_dims)
return self.dropout(output) + inp_residual
class FFN(BaseModule): class FFN(BaseModule):
"""Implements feed-forward networks (FFNs) with residual connection. """Implements feed-forward networks (FFNs) with residual connection.
......
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#ifndef DEFORM_ATTN_CUDA_KERNEL
#define DEFORM_ATTN_CUDA_KERNEL
#include "common_cuda_helper.hpp"
#include "pytorch_cuda_helper.hpp"
const int CUDA_NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N, const int num_threads) {
return (N + num_threads - 1) / num_threads;
}
template <typename scalar_t>
__device__ scalar_t ms_deform_attn_im2col_bilinear(
const scalar_t *&bottom_data, const int &height, const int &width,
const int &nheads, const int &channels, const scalar_t &h,
const scalar_t &w, const int &m, const int &c) {
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
}
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename scalar_t>
__device__ void ms_deform_attn_col2im_bilinear(
const scalar_t *&bottom_data, const int &height, const int &width,
const int &nheads, const int &channels, const scalar_t &h,
const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad,
const scalar_t &attn_weight, scalar_t *&grad_value,
scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) {
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t top_grad_value = top_grad * attn_weight;
scalar_t grad_h_weight = 0, grad_w_weight = 0;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
atomicAdd(grad_value + ptr1, w1 * top_grad_value);
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
atomicAdd(grad_value + ptr2, w2 * top_grad_value);
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
atomicAdd(grad_value + ptr3, w3 * top_grad_value);
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
atomicAdd(grad_value + ptr4, w4 * top_grad_value);
}
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
*grad_attn_weight = top_grad * val;
*grad_sampling_loc = width * grad_w_weight * top_grad_value;
*(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
}
template <typename scalar_t>
__device__ void ms_deform_attn_col2im_bilinear_gm(
const scalar_t *&bottom_data, const int &height, const int &width,
const int &nheads, const int &channels, const scalar_t &h,
const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad,
const scalar_t &attn_weight, scalar_t *&grad_value,
scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) {
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t top_grad_value = top_grad * attn_weight;
scalar_t grad_h_weight = 0, grad_w_weight = 0;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
atomicAdd(grad_value + ptr1, w1 * top_grad_value);
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
atomicAdd(grad_value + ptr2, w2 * top_grad_value);
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
atomicAdd(grad_value + ptr3, w3 * top_grad_value);
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
atomicAdd(grad_value + ptr4, w4 * top_grad_value);
}
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
atomicAdd(grad_attn_weight, top_grad * val);
atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
}
template <typename scalar_t>
__global__ void ms_deformable_im2col_gpu_kernel(
const int n, const scalar_t *data_value, const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index, const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight, const int batch_size,
const int spatial_size, const int num_heads, const int channels,
const int num_levels, const int num_query, const int num_point,
scalar_t *data_col) {
CUDA_1D_KERNEL_LOOP(index, n) {
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
scalar_t *data_col_ptr = data_col + index;
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
scalar_t col = 0;
for (int l_col = 0; l_col < num_levels; ++l_col) {
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const scalar_t *data_value_ptr =
data_value +
(data_value_ptr_init_offset + level_start_id * qid_stride);
for (int p_col = 0; p_col < num_point; ++p_col) {
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h,
spatial_w, num_heads, channels,
h_im, w_im, m_col, c_col) *
weight;
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
}
}
*data_col_ptr = col;
}
}
template <typename scalar_t, unsigned int blockSize>
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
const int n, const scalar_t *grad_col, const scalar_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
const int batch_size, const int spatial_size, const int num_heads,
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
CUDA_1D_KERNEL_LOOP(index, n) {
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
__shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col = 0; l_col < num_levels; ++l_col) {
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset =
data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col = 0; p_col < num_point; ++p_col) {
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight + threadIdx.x) = 0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc + (threadIdx.x << 1),
cache_grad_attn_weight + threadIdx.x);
}
__syncthreads();
if (tid == 0) {
scalar_t _grad_w = cache_grad_sampling_loc[0],
_grad_h = cache_grad_sampling_loc[1],
_grad_a = cache_grad_attn_weight[0];
int sid = 2;
for (unsigned int tid = 1; tid < blockSize; ++tid) {
_grad_w += cache_grad_sampling_loc[sid];
_grad_h += cache_grad_sampling_loc[sid + 1];
_grad_a += cache_grad_attn_weight[tid];
sid += 2;
}
*grad_sampling_loc = _grad_w;
*(grad_sampling_loc + 1) = _grad_h;
*grad_attn_weight = _grad_a;
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t, unsigned int blockSize>
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
const int n, const scalar_t *grad_col, const scalar_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
const int batch_size, const int spatial_size, const int num_heads,
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
CUDA_1D_KERNEL_LOOP(index, n) {
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
__shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col = 0; l_col < num_levels; ++l_col) {
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset =
data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col = 0; p_col < num_point; ++p_col) {
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight + threadIdx.x) = 0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc + (threadIdx.x << 1),
cache_grad_attn_weight + threadIdx.x);
}
__syncthreads();
for (unsigned int s = blockSize / 2; s > 0; s >>= 1) {
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] +=
cache_grad_sampling_loc[xid2 + 1];
}
__syncthreads();
}
if (tid == 0) {
*grad_sampling_loc = cache_grad_sampling_loc[0];
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight = cache_grad_attn_weight[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(
const int n, const scalar_t *grad_col, const scalar_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
const int batch_size, const int spatial_size, const int num_heads,
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
CUDA_1D_KERNEL_LOOP(index, n) {
extern __shared__ int _s[];
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col = 0; l_col < num_levels; ++l_col) {
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset =
data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col = 0; p_col < num_point; ++p_col) {
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight + threadIdx.x) = 0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc + (threadIdx.x << 1),
cache_grad_attn_weight + threadIdx.x);
}
__syncthreads();
if (tid == 0) {
scalar_t _grad_w = cache_grad_sampling_loc[0],
_grad_h = cache_grad_sampling_loc[1],
_grad_a = cache_grad_attn_weight[0];
int sid = 2;
for (unsigned int tid = 1; tid < blockDim.x; ++tid) {
_grad_w += cache_grad_sampling_loc[sid];
_grad_h += cache_grad_sampling_loc[sid + 1];
_grad_a += cache_grad_attn_weight[tid];
sid += 2;
}
*grad_sampling_loc = _grad_w;
*(grad_sampling_loc + 1) = _grad_h;
*grad_attn_weight = _grad_a;
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(
const int n, const scalar_t *grad_col, const scalar_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
const int batch_size, const int spatial_size, const int num_heads,
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
CUDA_1D_KERNEL_LOOP(index, n) {
extern __shared__ int _s[];
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col = 0; l_col < num_levels; ++l_col) {
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset =
data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col = 0; p_col < num_point; ++p_col) {
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight + threadIdx.x) = 0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc + (threadIdx.x << 1),
cache_grad_attn_weight + threadIdx.x);
}
__syncthreads();
for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0;
s >>= 1, spre >>= 1) {
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] +=
cache_grad_sampling_loc[xid2 + 1];
if (tid + (s << 1) < spre) {
cache_grad_attn_weight[tid] +=
cache_grad_attn_weight[tid + (s << 1)];
cache_grad_sampling_loc[xid1] +=
cache_grad_sampling_loc[xid2 + (s << 1)];
cache_grad_sampling_loc[xid1 + 1] +=
cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
}
}
__syncthreads();
}
if (tid == 0) {
*grad_sampling_loc = cache_grad_sampling_loc[0];
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight = cache_grad_attn_weight[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
const int n, const scalar_t *grad_col, const scalar_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
const int batch_size, const int spatial_size, const int num_heads,
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
CUDA_1D_KERNEL_LOOP(index, n) {
extern __shared__ int _s[];
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col = 0; l_col < num_levels; ++l_col) {
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset =
data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col = 0; p_col < num_point; ++p_col) {
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight + threadIdx.x) = 0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc + (threadIdx.x << 1),
cache_grad_attn_weight + threadIdx.x);
}
__syncthreads();
for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0;
s >>= 1, spre >>= 1) {
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] +=
cache_grad_sampling_loc[xid2 + 1];
if (tid + (s << 1) < spre) {
cache_grad_attn_weight[tid] +=
cache_grad_attn_weight[tid + (s << 1)];
cache_grad_sampling_loc[xid1] +=
cache_grad_sampling_loc[xid2 + (s << 1)];
cache_grad_sampling_loc[xid1 + 1] +=
cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
}
}
__syncthreads();
}
if (tid == 0) {
atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_gm(
const int n, const scalar_t *grad_col, const scalar_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
const int batch_size, const int spatial_size, const int num_heads,
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
CUDA_1D_KERNEL_LOOP(index, n) {
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col = 0; l_col < num_levels; ++l_col) {
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset =
data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col = 0; p_col < num_point; ++p_col) {
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
ms_deform_attn_col2im_bilinear_gm(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
grad_sampling_loc, grad_attn_weight);
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
#endif // DEFORM_ATTN_CUDA_KERNEL
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
Tensor ms_deform_attn_cuda_forward(const Tensor &value,
const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const int im2col_step);
std::vector<Tensor> ms_deform_attn_cuda_backward(
const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index, const Tensor &sampling_loc,
const Tensor &attn_weight, const Tensor &grad_output,
const int im2col_step);
#endif
Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const int im2col_step) {
if (value.type().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(value)
CHECK_CUDA_INPUT(spatial_shapes)
CHECK_CUDA_INPUT(level_start_index)
CHECK_CUDA_INPUT(sampling_loc)
CHECK_CUDA_INPUT(attn_weight)
return ms_deform_attn_cuda_forward(value, spatial_shapes, level_start_index,
sampling_loc, attn_weight, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
std::vector<Tensor> ms_deform_attn_backward(const Tensor &value,
const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const Tensor &grad_output,
const int im2col_step) {
if (value.type().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(value)
CHECK_CUDA_INPUT(spatial_shapes)
CHECK_CUDA_INPUT(level_start_index)
CHECK_CUDA_INPUT(sampling_loc)
CHECK_CUDA_INPUT(attn_weight)
CHECK_CUDA_INPUT(grad_output)
return ms_deform_attn_cuda_backward(value, spatial_shapes,
level_start_index, sampling_loc,
attn_weight, grad_output, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <THC/THCAtomics.cuh>
#include <ms_deform_attn_cuda_kernel.cuh>
#include <vector>
template <typename scalar_t>
void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size, const int spatial_size,
const int num_heads, const int channels,
const int num_levels, const int num_query,
const int num_point, scalar_t *data_col) {
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
const int num_threads = CUDA_NUM_THREADS;
ms_deformable_im2col_gpu_kernel<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0, stream>>>(
num_kernels, data_value, data_spatial_shapes, data_level_start_index,
data_sampling_loc, data_attn_weight, batch_size, spatial_size,
num_heads, channels, num_levels, num_query, num_point, data_col);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
}
}
template <typename scalar_t>
void ms_deformable_col2im_cuda(
cudaStream_t stream, const scalar_t *grad_col, const scalar_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
const int batch_size, const int spatial_size, const int num_heads,
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
const int num_threads =
(channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels;
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
if (channels > 1024) {
if ((channels & 1023) == 0) {
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads * 3 * sizeof(scalar_t), stream>>>(
num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc, data_attn_weight,
batch_size, spatial_size, num_heads, channels, num_levels,
num_query, num_point, grad_value, grad_sampling_loc,
grad_attn_weight);
} else {
ms_deformable_col2im_gpu_kernel_gm<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
}
} else {
switch (channels) {
case 1:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
1>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 2:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
2>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 4:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
4>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 8:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
8>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 16:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
16>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 32:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
32>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 64:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
64>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 128:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
128>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 256:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
256>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 512:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
512>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 1024:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
1024>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
default:
if (channels < 64) {
ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads * 3 * sizeof(scalar_t), stream>>>(
num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc, data_attn_weight,
batch_size, spatial_size, num_heads, channels, num_levels,
num_query, num_point, grad_value, grad_sampling_loc,
grad_attn_weight);
} else {
ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads * 3 * sizeof(scalar_t), stream>>>(
num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc, data_attn_weight,
batch_size, spatial_size, num_heads, channels, num_levels,
num_query, num_point, grad_value, grad_sampling_loc,
grad_attn_weight);
}
}
}
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
}
}
at::Tensor ms_deform_attn_cuda_forward(const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step) {
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(),
"spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(),
"level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(),
"sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(),
"attn_weight tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(),
"spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(),
"level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(),
"sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)",
batch, im2col_step_);
auto output =
at::zeros({batch, num_query, num_heads, channels}, value.options());
const int batch_n = im2col_step_;
auto output_n = output.view(
{batch / im2col_step_, batch_n, num_query, num_heads, channels});
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
for (int n = 0; n < batch / im2col_step_; ++n) {
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(
value.type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(
at::cuda::getCurrentCUDAStream(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(), level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() +
n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() +
n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query,
num_point, columns.data<scalar_t>());
}));
}
output = output.view({batch, num_query, num_heads * channels});
return output;
}
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value, const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index, const at::Tensor &sampling_loc,
const at::Tensor &attn_weight, const at::Tensor &grad_output,
const int im2col_step) {
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(),
"spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(),
"level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(),
"sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(),
"attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(),
"grad_output tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(),
"spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(),
"level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(),
"sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)",
batch, im2col_step_);
auto grad_value = at::zeros_like(value);
auto grad_sampling_loc = at::zeros_like(sampling_loc);
auto grad_attn_weight = at::zeros_like(attn_weight);
const int batch_n = im2col_step_;
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
auto grad_output_n = grad_output.view(
{batch / im2col_step_, batch_n, num_query, num_heads, channels});
for (int n = 0; n < batch / im2col_step_; ++n) {
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(
value.type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(
at::cuda::getCurrentCUDAStream(), grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(), level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() +
n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() +
n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query,
num_point,
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
grad_sampling_loc.data<scalar_t>() +
n * im2col_step_ * per_sample_loc_size,
grad_attn_weight.data<scalar_t>() +
n * im2col_step_ * per_attn_weight_size);
}));
}
return {grad_value, grad_sampling_loc, grad_attn_weight};
}
...@@ -92,6 +92,19 @@ void modulated_deform_conv_backward( ...@@ -92,6 +92,19 @@ void modulated_deform_conv_backward(
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
const bool with_bias); const bool with_bias);
Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight, const int im2col_step);
std::vector<Tensor> ms_deform_attn_backward(const Tensor &value,
const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const Tensor &grad_output,
const int im2col_step);
Tensor nms(Tensor boxes, Tensor scores, float iou_threshold, int offset); Tensor nms(Tensor boxes, Tensor scores, float iou_threshold, int offset);
Tensor softnms(Tensor boxes, Tensor scores, Tensor dets, float iou_threshold, Tensor softnms(Tensor boxes, Tensor scores, Tensor dets, float iou_threshold,
...@@ -182,12 +195,12 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order, ...@@ -182,12 +195,12 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
const Tensor dets_sorted, const float iou_threshold, const Tensor dets_sorted, const float iou_threshold,
const int multi_label); const int multi_label);
Tensor upfirdn2d(const Tensor& input, const Tensor& kernel, int up_x, int up_y, Tensor upfirdn2d(const Tensor &input, const Tensor &kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0, int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0,
int pad_y1); int pad_y1);
Tensor fused_bias_leakyrelu(const Tensor& input, const Tensor& bias, Tensor fused_bias_leakyrelu(const Tensor &input, const Tensor &bias,
const Tensor& refer, int act, int grad, float alpha, const Tensor &refer, int act, int grad, float alpha,
float scale); float scale);
void roi_align_rotated_forward(Tensor input, Tensor rois, Tensor output, void roi_align_rotated_forward(Tensor input, Tensor rois, Tensor output,
...@@ -401,4 +414,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -401,4 +414,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("grad_input"), py::arg("pooled_height"), py::arg("grad_input"), py::arg("pooled_height"),
py::arg("pooled_width"), py::arg("spatial_scale"), py::arg("pooled_width"), py::arg("spatial_scale"),
py::arg("sample_num"), py::arg("aligned"), py::arg("clockwise")); py::arg("sample_num"), py::arg("aligned"), py::arg("clockwise"));
m.def("ms_deform_attn_forward", &ms_deform_attn_forward,
"forward function of multi-scale deformable attention",
py::arg("value"), py::arg("value_spatial_shapes"),
py::arg("value_level_start_index"), py::arg("sampling_locations"),
py::arg("attention_weights"), py::arg("im2col_step"));
m.def("ms_deform_attn_backward", &ms_deform_attn_backward,
"backward function of multi-scale deformable attention",
py::arg("value"), py::arg("value_spatial_shapes"),
py::arg("value_level_start_index"), py::arg("sampling_locations"),
py::arg("attention_weights"), py::arg("grad_output"),
py::arg("im2col_step"));
} }
import torch
import torch.nn.functional as F
from torch.autograd.function import Function, once_differentiable
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
class MultiScaleDeformableAttnFunction(Function):
@staticmethod
def forward(ctx, value, value_spatial_shapes, value_level_start_index,
sampling_locations, attention_weights, im2col_step):
"""GPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
im2col_step (Tensor): The step used in image to column.
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
"""
ctx.im2col_step = im2col_step
output = ext_module.ms_deform_attn_forward(value, value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
ctx.im2col_step)
ctx.save_for_backward(value, value_spatial_shapes,
value_level_start_index, sampling_locations,
attention_weights)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
"""GPU version of backward function.
Args:
grad_output (Tensor): Gradient
of output tensor of forward.
Returns:
Tuple[Tensor]: Gradient
of input tensors in forward.
"""
value, value_spatial_shapes, value_level_start_index,\
sampling_locations, attention_weights = ctx.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = \
ext_module.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output,
ctx.im2col_step)
return grad_value, None, None, \
grad_sampling_loc, grad_attn_weight, None
def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes,
sampling_locations, attention_weights):
"""CPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
(bs, num_keys, mum_heads, embed_dims//num_heads)
value_spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_heads, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
(bs ,num_queries, num_heads, num_levels, num_points),
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
"""
bs, _, num_heads, embed_dims = value.shape
_, num_queries, num_heads, num_levels, num_points, _ =\
sampling_locations.shape
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes],
dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level, (H_, W_) in enumerate(value_spatial_shapes):
# bs, H_*W_, num_heads, embed_dims ->
# bs, H_*W_, num_heads*embed_dims ->
# bs, num_heads*embed_dims, H_*W_ ->
# bs*num_heads, embed_dims, H_, W_
value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(
bs * num_heads, embed_dims, H_, W_)
# bs, num_queries, num_heads, num_points, 2 ->
# bs, num_heads, num_queries, num_points, 2 ->
# bs*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :,
level].transpose(1, 2).flatten(0, 1)
# bs*num_heads, embed_dims, num_queries, num_points
sampling_value_l_ = F.grid_sample(
value_l_,
sampling_grid_l_,
mode='bilinear',
padding_mode='zeros',
align_corners=False)
sampling_value_list.append(sampling_value_l_)
# (bs, num_queries, num_heads, num_levels, num_points) ->
# (bs, num_heads, num_queries, num_levels, num_points) ->
# (bs, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(
bs * num_heads, 1, num_queries, num_levels * num_points)
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) *
attention_weights).sum(-1).view(bs, num_heads * embed_dims,
num_queries)
return output.transpose(1, 2).contiguous()
import pytest
import torch
from torch.autograd import gradcheck
from mmcv.ops.multi_scale_deform_attn import (
MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch)
def test_forward_multi_scale_deformable_attn_pytorch():
N, M, D = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
S = sum([(H * W).item() for H, W in shapes])
torch.manual_seed(3)
value = torch.rand(N, S, M, D) * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2)
attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
multi_scale_deformable_attn_pytorch(value.double(), shapes,
sampling_locations.double(),
attention_weights.double()).detach()
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_forward_equal_with_pytorch_double():
N, M, D = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum([(H * W).item() for H, W in shapes])
torch.manual_seed(3)
value = torch.rand(N, S, M, D).cuda() * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
im2col_step = 2
output_pytorch = multi_scale_deformable_attn_pytorch(
value.double(), shapes, sampling_locations.double(),
attention_weights.double()).detach().cpu()
output_cuda = MultiScaleDeformableAttnFunction.apply(
value.double(), shapes, level_start_index, sampling_locations.double(),
attention_weights.double(), im2col_step).detach().cpu()
assert torch.allclose(output_cuda, output_pytorch)
max_abs_err = (output_cuda - output_pytorch).abs().max()
max_rel_err = ((output_cuda - output_pytorch).abs() /
output_pytorch.abs()).max()
assert max_abs_err < 1e-18
assert max_rel_err < 1e-15
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_forward_equal_with_pytorch_float():
N, M, D = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum([(H * W).item() for H, W in shapes])
torch.manual_seed(3)
value = torch.rand(N, S, M, D).cuda() * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
im2col_step = 2
output_pytorch = multi_scale_deformable_attn_pytorch(
value, shapes, sampling_locations, attention_weights).detach().cpu()
output_cuda = MultiScaleDeformableAttnFunction.apply(
value, shapes, level_start_index, sampling_locations,
attention_weights, im2col_step).detach().cpu()
assert torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
max_abs_err = (output_cuda - output_pytorch).abs().max()
max_rel_err = ((output_cuda - output_pytorch).abs() /
output_pytorch.abs()).max()
assert max_abs_err < 1e-9
assert max_rel_err < 1e-6
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
@pytest.mark.parametrize('channels', [4, 30, 32, 64, 71, 1025, 2048, 3096])
def test_gradient_numerical(channels,
grad_value=True,
grad_sampling_loc=True,
grad_attn_weight=True):
N, M, _ = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum([(H * W).item() for H, W in shapes])
value = torch.rand(N, S, M, channels).cuda() * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
im2col_step = 2
func = MultiScaleDeformableAttnFunction.apply
value.requires_grad = grad_value
sampling_locations.requires_grad = grad_sampling_loc
attention_weights.requires_grad = grad_attn_weight
assert gradcheck(
func,
(value.double(), shapes, level_start_index,
sampling_locations.double(), attention_weights.double(), im2col_step))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment