Unverified Commit 732ff509 authored by pc's avatar pc Committed by GitHub
Browse files

Add ms_deformable_attn in parrots (#1042)

parent a6377240
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
Tensor ms_deform_attn_cuda_forward(const Tensor &value,
const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const int im2col_step);
void ms_deform_attn_cuda_backward(
const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index, const Tensor &sampling_loc,
const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value,
Tensor &grad_sampling_loc, Tensor &grad_attn_weight, const int im2col_step);
#endif
Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const int im2col_step) {
if (value.type().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(value)
CHECK_CUDA_INPUT(spatial_shapes)
CHECK_CUDA_INPUT(level_start_index)
CHECK_CUDA_INPUT(sampling_loc)
CHECK_CUDA_INPUT(attn_weight)
return ms_deform_attn_cuda_forward(value, spatial_shapes, level_start_index,
sampling_loc, attn_weight, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
void ms_deform_attn_backward(const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const Tensor &grad_output, Tensor &grad_value,
Tensor &grad_sampling_loc,
Tensor &grad_attn_weight, const int im2col_step) {
if (value.type().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(value)
CHECK_CUDA_INPUT(spatial_shapes)
CHECK_CUDA_INPUT(level_start_index)
CHECK_CUDA_INPUT(sampling_loc)
CHECK_CUDA_INPUT(attn_weight)
CHECK_CUDA_INPUT(grad_output)
CHECK_CUDA_INPUT(grad_value)
CHECK_CUDA_INPUT(grad_sampling_loc)
CHECK_CUDA_INPUT(grad_attn_weight)
ms_deform_attn_cuda_backward(value, spatial_shapes, level_start_index,
sampling_loc, attn_weight, grad_output,
grad_value, grad_sampling_loc,
grad_attn_weight, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
} else {
AT_ERROR("Not implemented on the CPU");
}
}
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <THC/THCAtomics.cuh>
#include <ms_deform_attn_cuda_kernel.cuh>
#include <vector>
template <typename scalar_t>
void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size, const int spatial_size,
const int num_heads, const int channels,
const int num_levels, const int num_query,
const int num_point, scalar_t *data_col) {
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
const int num_threads = CUDA_NUM_THREADS;
ms_deformable_im2col_gpu_kernel<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0, stream>>>(
num_kernels, data_value, data_spatial_shapes, data_level_start_index,
data_sampling_loc, data_attn_weight, batch_size, spatial_size,
num_heads, channels, num_levels, num_query, num_point, data_col);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
}
}
template <typename scalar_t>
void ms_deformable_col2im_cuda(
cudaStream_t stream, const scalar_t *grad_col, const scalar_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
const int batch_size, const int spatial_size, const int num_heads,
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
const int num_threads =
(channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels;
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
if (channels > 1024) {
if ((channels & 1023) == 0) {
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads * 3 * sizeof(scalar_t), stream>>>(
num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc, data_attn_weight,
batch_size, spatial_size, num_heads, channels, num_levels,
num_query, num_point, grad_value, grad_sampling_loc,
grad_attn_weight);
} else {
ms_deformable_col2im_gpu_kernel_gm<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
}
} else {
switch (channels) {
case 1:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
1>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 2:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
2>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 4:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
4>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 8:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
8>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 16:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
16>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 32:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
32>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 64:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
64>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 128:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
128>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 256:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
256>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 512:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
512>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 1024:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
1024>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
default:
if (channels < 64) {
ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads * 3 * sizeof(scalar_t), stream>>>(
num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc, data_attn_weight,
batch_size, spatial_size, num_heads, channels, num_levels,
num_query, num_point, grad_value, grad_sampling_loc,
grad_attn_weight);
} else {
ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads * 3 * sizeof(scalar_t), stream>>>(
num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc, data_attn_weight,
batch_size, spatial_size, num_heads, channels, num_levels,
num_query, num_point, grad_value, grad_sampling_loc,
grad_attn_weight);
}
}
}
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
}
}
at::Tensor ms_deform_attn_cuda_forward(const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step) {
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(),
"spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(),
"level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(),
"sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(),
"attn_weight tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(),
"spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(),
"level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(),
"sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)",
batch, im2col_step_);
auto output =
at::zeros({batch, num_query, num_heads, channels}, value.options());
const int batch_n = im2col_step_;
auto output_n = output.view(
{batch / im2col_step_, batch_n, num_query, num_heads, channels});
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
for (int n = 0; n < batch / im2col_step_; ++n) {
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(
value.type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(
at::cuda::getCurrentCUDAStream(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(), level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() +
n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() +
n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query,
num_point, columns.data<scalar_t>());
}));
}
output = output.view({batch, num_query, num_heads * channels});
return output;
}
void ms_deform_attn_cuda_backward(
const at::Tensor &value, const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index, const at::Tensor &sampling_loc,
const at::Tensor &attn_weight, const at::Tensor &grad_output,
at::Tensor &grad_value, at::Tensor &grad_sampling_loc,
at::Tensor &grad_attn_weight, const int im2col_step) {
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(),
"spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(),
"level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(),
"sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(),
"attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(),
"grad_output tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(),
"spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(),
"level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(),
"sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)",
batch, im2col_step_);
const int batch_n = im2col_step_;
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
auto grad_output_n = grad_output.view(
{batch / im2col_step_, batch_n, num_query, num_heads, channels});
for (int n = 0; n < batch / im2col_step_; ++n) {
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(
value.type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(
at::cuda::getCurrentCUDAStream(), grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(), level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() +
n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() +
n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query,
num_point,
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
grad_sampling_loc.data<scalar_t>() +
n * im2col_step_ * per_sample_loc_size,
grad_attn_weight.data<scalar_t>() +
n * im2col_step_ * per_attn_weight_size);
}));
}
}
#include <torch/extension.h>
#include <parrots/compute/aten.hpp>
#include <parrots/extension.hpp>
#include <parrots/foundation/ssattrs.hpp>
using namespace at;
using namespace parrots;
Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight, const int im2col_step);
void ms_deform_attn_backward(const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const Tensor &grad_output, Tensor &grad_value,
Tensor &grad_sampling_loc,
Tensor &grad_attn_weight, const int im2col_step);
void ms_deform_attn_forward_parrots(CudaContext &ctx, const SSElement &attr,
const OperatorBase::in_list_t &ins,
OperatorBase::out_list_t &outs) {
int im2col_step;
SSAttrs(attr).get<int>("im2col_step", im2col_step).done();
const auto &value = buildATensor(ctx, ins[0]);
const auto &spatial_shapes = buildATensor(ctx, ins[1]);
const auto &level_start_index = buildATensor(ctx, ins[2]);
const auto &sampling_loc = buildATensor(ctx, ins[3]);
const auto &attn_weight = buildATensor(ctx, ins[4]);
auto out = ms_deform_attn_forward(value, spatial_shapes, level_start_index,
sampling_loc, attn_weight, im2col_step);
updateDArray(ctx, out, outs[0]);
}
void ms_deform_attn_backward_parrots(CudaContext &ctx, const SSElement &attr,
const OperatorBase::in_list_t &ins,
OperatorBase::out_list_t &outs) {
int im2col_step;
SSAttrs(attr).get<int>("im2col_step", im2col_step).done();
const auto &value = buildATensor(ctx, ins[0]);
const auto &spatial_shapes = buildATensor(ctx, ins[1]);
const auto &level_start_index = buildATensor(ctx, ins[2]);
const auto &sampling_loc = buildATensor(ctx, ins[3]);
const auto &attn_weight = buildATensor(ctx, ins[4]);
const auto &grad_output = buildATensor(ctx, ins[5]);
auto grad_value = buildATensor(ctx, outs[0]);
auto grad_sampling_loc = buildATensor(ctx, outs[1]);
auto grad_attn_weight = buildATensor(ctx, outs[2]);
ms_deform_attn_backward(value, spatial_shapes, level_start_index,
sampling_loc, attn_weight, grad_output, grad_value,
grad_sampling_loc, grad_attn_weight, im2col_step);
}
PARROTS_EXTENSION_REGISTER(ms_deform_attn_forward)
.attr("im2col_step")
.input(5)
.output(1)
.apply(ms_deform_attn_forward_parrots)
.done();
PARROTS_EXTENSION_REGISTER(ms_deform_attn_backward)
.attr("im2col_step")
.input(6)
.output(3)
.apply(ms_deform_attn_backward_parrots)
.done();
......@@ -19,11 +19,11 @@ Tensor ms_deform_attn_cuda_forward(const Tensor &value,
const Tensor &attn_weight,
const int im2col_step);
std::vector<Tensor> ms_deform_attn_cuda_backward(
void ms_deform_attn_cuda_backward(
const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index, const Tensor &sampling_loc,
const Tensor &attn_weight, const Tensor &grad_output,
const int im2col_step);
const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value,
Tensor &grad_sampling_loc, Tensor &grad_attn_weight, const int im2col_step);
#endif
......@@ -48,13 +48,13 @@ Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes,
AT_ERROR("Not implemented on the CPU");
}
std::vector<Tensor> ms_deform_attn_backward(const Tensor &value,
const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const Tensor &grad_output,
const int im2col_step) {
void ms_deform_attn_backward(const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const Tensor &grad_output, Tensor &grad_value,
Tensor &grad_sampling_loc,
Tensor &grad_attn_weight, const int im2col_step) {
if (value.type().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(value)
......@@ -63,12 +63,17 @@ std::vector<Tensor> ms_deform_attn_backward(const Tensor &value,
CHECK_CUDA_INPUT(sampling_loc)
CHECK_CUDA_INPUT(attn_weight)
CHECK_CUDA_INPUT(grad_output)
return ms_deform_attn_cuda_backward(value, spatial_shapes,
level_start_index, sampling_loc,
attn_weight, grad_output, im2col_step);
CHECK_CUDA_INPUT(grad_value)
CHECK_CUDA_INPUT(grad_sampling_loc)
CHECK_CUDA_INPUT(grad_attn_weight)
ms_deform_attn_cuda_backward(value, spatial_shapes, level_start_index,
sampling_loc, attn_weight, grad_output,
grad_value, grad_sampling_loc,
grad_attn_weight, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
} else {
AT_ERROR("Not implemented on the CPU");
}
AT_ERROR("Not implemented on the CPU");
}
......@@ -286,11 +286,12 @@ at::Tensor ms_deform_attn_cuda_forward(const at::Tensor &value,
return output;
}
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
void ms_deform_attn_cuda_backward(
const at::Tensor &value, const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index, const at::Tensor &sampling_loc,
const at::Tensor &attn_weight, const at::Tensor &grad_output,
const int im2col_step) {
at::Tensor &grad_value, at::Tensor &grad_sampling_loc,
at::Tensor &grad_attn_weight, const int im2col_step) {
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(),
"spatial_shapes tensor has to be contiguous");
......@@ -328,10 +329,6 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)",
batch, im2col_step_);
auto grad_value = at::zeros_like(value);
auto grad_sampling_loc = at::zeros_like(sampling_loc);
auto grad_attn_weight = at::zeros_like(attn_weight);
const int batch_n = im2col_step_;
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
......@@ -360,6 +357,4 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
n * im2col_step_ * per_attn_weight_size);
}));
}
return {grad_value, grad_sampling_loc, grad_attn_weight};
}
......@@ -97,13 +97,13 @@ Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes,
const Tensor &sampling_loc,
const Tensor &attn_weight, const int im2col_step);
std::vector<Tensor> ms_deform_attn_backward(const Tensor &value,
const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const Tensor &grad_output,
const int im2col_step);
void ms_deform_attn_backward(const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const Tensor &grad_output, Tensor &grad_value,
Tensor &grad_sampling_loc,
Tensor &grad_attn_weight, const int im2col_step);
Tensor nms(Tensor boxes, Tensor scores, float iou_threshold, int offset);
......@@ -445,5 +445,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("value"), py::arg("value_spatial_shapes"),
py::arg("value_level_start_index"), py::arg("sampling_locations"),
py::arg("attention_weights"), py::arg("grad_output"),
py::arg("im2col_step"));
py::arg("grad_value"), py::arg("grad_sampling_loc"),
py::arg("grad_attn_weight"), py::arg("im2col_step"));
}
......@@ -35,11 +35,13 @@ class MultiScaleDeformableAttnFunction(Function):
"""
ctx.im2col_step = im2col_step
output = ext_module.ms_deform_attn_forward(value, value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
ctx.im2col_step)
output = ext_module.ms_deform_attn_forward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
im2col_step=ctx.im2col_step)
ctx.save_for_backward(value, value_spatial_shapes,
value_level_start_index, sampling_locations,
attention_weights)
......@@ -60,15 +62,21 @@ class MultiScaleDeformableAttnFunction(Function):
"""
value, value_spatial_shapes, value_level_start_index,\
sampling_locations, attention_weights = ctx.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = \
ext_module.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output,
ctx.im2col_step)
grad_value = torch.zeros_like(value)
grad_sampling_loc = torch.zeros_like(sampling_locations)
grad_attn_weight = torch.zeros_like(attention_weights)
ext_module.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output.contiguous(),
grad_value,
grad_sampling_loc,
grad_attn_weight,
im2col_step=ctx.im2col_step)
return grad_value, None, None, \
grad_sampling_loc, grad_attn_weight, None
......
......@@ -16,22 +16,46 @@ else:
from parrots import extension
has_return_value_ops = [
'nms', 'softnms', 'nms_match', 'nms_rotated', 'top_pool_forward',
'top_pool_backward', 'bottom_pool_forward', 'bottom_pool_backward',
'left_pool_forward', 'left_pool_backward', 'right_pool_forward',
'right_pool_backward', 'fused_bias_leakyrelu', 'upfirdn2d'
'nms',
'softnms',
'nms_match',
'nms_rotated',
'top_pool_forward',
'top_pool_backward',
'bottom_pool_forward',
'bottom_pool_backward',
'left_pool_forward',
'left_pool_backward',
'right_pool_forward',
'right_pool_backward',
'fused_bias_leakyrelu',
'upfirdn2d',
'ms_deform_attn_forward',
]
def get_fake_func(name):
def fake_func(*args, **kwargs):
raise RuntimeError(
'{} is not supported in parrots now'.format(name))
return fake_func
def load_ext(name, funcs):
ExtModule = namedtuple('ExtModule', funcs)
ext_list = []
lib_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
for fun in funcs:
if fun in has_return_value_ops:
ext_list.append(extension.load(fun, name, lib_dir=lib_root).op)
try:
ext_fun = extension.load(fun, name, lib_dir=lib_root)
except Exception:
ext_fun = get_fake_func(fun)
ext_list.append(ext_fun)
else:
ext_list.append(
extension.load(fun, name, lib_dir=lib_root).op_)
if fun in has_return_value_ops:
ext_list.append(ext_fun.op)
else:
ext_list.append(ext_fun.op_)
return ExtModule(*ext_list)
......
import pytest
import torch
from torch.autograd import gradcheck
from mmcv.ops.multi_scale_deform_attn import (
MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch)
_USING_PARROTS = True
try:
from parrots.autograd import gradcheck
except ImportError:
from torch.autograd import gradcheck
_USING_PARROTS = False
def test_forward_multi_scale_deformable_attn_pytorch():
N, M, D = 1, 2, 2
......@@ -118,8 +124,13 @@ def test_gradient_numerical(channels,
value.requires_grad = grad_value
sampling_locations.requires_grad = grad_sampling_loc
attention_weights.requires_grad = grad_attn_weight
assert gradcheck(
func,
(value.double(), shapes, level_start_index,
sampling_locations.double(), attention_weights.double(), im2col_step))
if _USING_PARROTS:
assert gradcheck(
func, (value.double(), shapes, level_start_index,
sampling_locations.double(), attention_weights.double(),
im2col_step),
no_grads=[shapes, level_start_index])
else:
assert gradcheck(func, (value.double(), shapes, level_start_index,
sampling_locations.double(),
attention_weights.double(), im2col_step))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment