Commit 34e4011b authored by zk's avatar zk
Browse files

首次提交

parents
Pipeline #3503 failed with stages
in 0 seconds
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include <torch/extension.h>
namespace groundingdino {
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step);
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step);
} // namespace groundingdino
\ No newline at end of file
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include <vector>
#include "ms_deform_im2col_cuda_hip.cuh"
#include <ATen/ATen.h>
#include <ATen/hip/HIPContext.h>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
namespace groundingdino {
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = ::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
const int batch_n = im2col_step_;
auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(at::hip::getCurrentHIPStreamMasqueradingAsCUDA(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
columns.data<scalar_t>());
}));
}
output = output.view({batch, num_query, num_heads*channels});
return output;
}
// at::Tensor ms_deform_attn_forward_wrapper(
// const at::Tensor &value,
// const at::Tensor &spatial_shapes,
// const at::Tensor &level_start_index,
// const at::Tensor &sampling_loc,
// const at::Tensor &attn_weight,
// int64_t im2col_step // ✅ 注意这里
// )
// {
// return groundingdino::ms_deform_attn_cuda_forward(
// value,
// spatial_shapes,
// level_start_index,
// sampling_loc,
// attn_weight,
// im2col_step
// );
// }
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = ::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto grad_value = at::zeros_like(value);
auto grad_sampling_loc = at::zeros_like(sampling_loc);
auto grad_attn_weight = at::zeros_like(attn_weight);
const int batch_n = im2col_step_;
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(at::hip::getCurrentHIPStreamMasqueradingAsCUDA(),
grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
}));
}
return {
grad_value, grad_sampling_loc, grad_attn_weight
};
}
} // namespace groundingdino
// #include <torch/library.h>
// // 注册 schema
// TORCH_LIBRARY(my_ops, m) {
// m.def("ms_deform_attn(Tensor value, Tensor spatial_shapes, Tensor level_start_index, Tensor sampling_loc, Tensor attn_weight, int im2col_step) -> Tensor");
// }
// // CUDA实现
// TORCH_LIBRARY_IMPL(my_ops, CUDA, m) {
// m.impl("ms_deform_attn", groundingdino::ms_deform_attn_forward_wrapper);
// }
\ No newline at end of file
/*!
**************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************
* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
* Copyright (c) 2018 Microsoft
**************************************************************************
*/
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N, const int num_threads)
{
return (N + num_threads - 1) / num_threads;
}
template <typename scalar_t>
__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
const int &height, const int &width, const int &nheads, const int &channels,
const scalar_t &h, const scalar_t &w, const int &m, const int &c)
{
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
{
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
{
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
{
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
{
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
}
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename scalar_t>
__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
const int &height, const int &width, const int &nheads, const int &channels,
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
const scalar_t &top_grad,
const scalar_t &attn_weight,
scalar_t* &grad_value,
scalar_t* grad_sampling_loc,
scalar_t* grad_attn_weight)
{
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t top_grad_value = top_grad * attn_weight;
scalar_t grad_h_weight = 0, grad_w_weight = 0;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
{
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
atomicAdd(grad_value+ptr1, w1*top_grad_value);
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
{
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
atomicAdd(grad_value+ptr2, w2*top_grad_value);
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
{
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
atomicAdd(grad_value+ptr3, w3*top_grad_value);
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
{
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
atomicAdd(grad_value+ptr4, w4*top_grad_value);
}
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
*grad_attn_weight = top_grad * val;
*grad_sampling_loc = width * grad_w_weight * top_grad_value;
*(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
}
template <typename scalar_t>
__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
const int &height, const int &width, const int &nheads, const int &channels,
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
const scalar_t &top_grad,
const scalar_t &attn_weight,
scalar_t* &grad_value,
scalar_t* grad_sampling_loc,
scalar_t* grad_attn_weight)
{
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t top_grad_value = top_grad * attn_weight;
scalar_t grad_h_weight = 0, grad_w_weight = 0;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
{
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
atomicAdd(grad_value+ptr1, w1*top_grad_value);
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
{
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
atomicAdd(grad_value+ptr2, w2*top_grad_value);
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
{
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
atomicAdd(grad_value+ptr3, w3*top_grad_value);
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
{
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
atomicAdd(grad_value+ptr4, w4*top_grad_value);
}
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
atomicAdd(grad_attn_weight, top_grad * val);
atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
}
template <typename scalar_t>
__global__ void ms_deformable_im2col_gpu_kernel(const int n,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *data_col)
{
CUDA_KERNEL_LOOP(index, n)
{
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
scalar_t *data_col_ptr = data_col + index;
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
scalar_t col = 0;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
}
}
*data_col_ptr = col;
}
}
template <typename scalar_t, unsigned int blockSize>
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
__shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
if (tid == 0)
{
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
int sid=2;
for (unsigned int tid = 1; tid < blockSize; ++tid)
{
_grad_w += cache_grad_sampling_loc[sid];
_grad_h += cache_grad_sampling_loc[sid + 1];
_grad_a += cache_grad_attn_weight[tid];
sid += 2;
}
*grad_sampling_loc = _grad_w;
*(grad_sampling_loc + 1) = _grad_h;
*grad_attn_weight = _grad_a;
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t, unsigned int blockSize>
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
__shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
for (unsigned int s=blockSize/2; s>0; s>>=1)
{
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
}
__syncthreads();
}
if (tid == 0)
{
*grad_sampling_loc = cache_grad_sampling_loc[0];
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight = cache_grad_attn_weight[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
extern __shared__ int _s[];
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
if (tid == 0)
{
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
int sid=2;
for (unsigned int tid = 1; tid < blockDim.x; ++tid)
{
_grad_w += cache_grad_sampling_loc[sid];
_grad_h += cache_grad_sampling_loc[sid + 1];
_grad_a += cache_grad_attn_weight[tid];
sid += 2;
}
*grad_sampling_loc = _grad_w;
*(grad_sampling_loc + 1) = _grad_h;
*grad_attn_weight = _grad_a;
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
extern __shared__ int _s[];
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
{
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
if (tid + (s << 1) < spre)
{
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
}
}
__syncthreads();
}
if (tid == 0)
{
*grad_sampling_loc = cache_grad_sampling_loc[0];
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight = cache_grad_attn_weight[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
extern __shared__ int _s[];
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
{
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
if (tid + (s << 1) < spre)
{
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
}
}
__syncthreads();
}
if (tid == 0)
{
atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear_gm(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
grad_sampling_loc, grad_attn_weight);
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
void ms_deformable_im2col_cuda(cudaStream_t stream,
const scalar_t* data_value,
const int64_t* data_spatial_shapes,
const int64_t* data_level_start_index,
const scalar_t* data_sampling_loc,
const scalar_t* data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t* data_col)
{
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
const int num_threads = CUDA_NUM_THREADS;
ms_deformable_im2col_gpu_kernel<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
}
}
template <typename scalar_t>
void ms_deformable_col2im_cuda(cudaStream_t stream,
const scalar_t* grad_col,
const scalar_t* data_value,
const int64_t * data_spatial_shapes,
const int64_t * data_level_start_index,
const scalar_t * data_sampling_loc,
const scalar_t * data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t* grad_value,
scalar_t* grad_sampling_loc,
scalar_t* grad_attn_weight)
{
const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
if (channels > 1024)
{
if ((channels & 1023) == 0)
{
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads*3*sizeof(scalar_t), stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
}
else
{
ms_deformable_col2im_gpu_kernel_gm<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
}
}
else{
switch(channels)
{
case 1:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 2:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 4:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 8:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 16:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 32:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 64:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 128:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 256:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 512:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 1024:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
default:
if (channels < 64)
{
ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads*3*sizeof(scalar_t), stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
}
else
{
ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads*3*sizeof(scalar_t), stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
}
}
}
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
}
}
\ No newline at end of file
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "hip/hip_runtime.h"
/*!
**************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************
* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
* Copyright (c) 2018 Microsoft
**************************************************************************
*/
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <ATen/ATen.h>
#include <ATen/hip/HIPContext.h>
#include <THH/THHAtomics.cuh>
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N, const int num_threads)
{
return (N + num_threads - 1) / num_threads;
}
template <typename scalar_t>
__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
const int &height, const int &width, const int &nheads, const int &channels,
const scalar_t &h, const scalar_t &w, const int &m, const int &c)
{
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
{
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
{
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
{
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
{
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
}
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename scalar_t>
__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
const int &height, const int &width, const int &nheads, const int &channels,
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
const scalar_t &top_grad,
const scalar_t &attn_weight,
scalar_t* &grad_value,
scalar_t* grad_sampling_loc,
scalar_t* grad_attn_weight)
{
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t top_grad_value = top_grad * attn_weight;
scalar_t grad_h_weight = 0, grad_w_weight = 0;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
{
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
atomicAdd(grad_value+ptr1, w1*top_grad_value);
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
{
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
atomicAdd(grad_value+ptr2, w2*top_grad_value);
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
{
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
atomicAdd(grad_value+ptr3, w3*top_grad_value);
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
{
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
atomicAdd(grad_value+ptr4, w4*top_grad_value);
}
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
*grad_attn_weight = top_grad * val;
*grad_sampling_loc = width * grad_w_weight * top_grad_value;
*(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
}
template <typename scalar_t>
__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
const int &height, const int &width, const int &nheads, const int &channels,
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
const scalar_t &top_grad,
const scalar_t &attn_weight,
scalar_t* &grad_value,
scalar_t* grad_sampling_loc,
scalar_t* grad_attn_weight)
{
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t top_grad_value = top_grad * attn_weight;
scalar_t grad_h_weight = 0, grad_w_weight = 0;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
{
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
atomicAdd(grad_value+ptr1, w1*top_grad_value);
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
{
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
atomicAdd(grad_value+ptr2, w2*top_grad_value);
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
{
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
atomicAdd(grad_value+ptr3, w3*top_grad_value);
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
{
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
atomicAdd(grad_value+ptr4, w4*top_grad_value);
}
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
atomicAdd(grad_attn_weight, top_grad * val);
atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
}
template <typename scalar_t>
__global__ void ms_deformable_im2col_gpu_kernel(const int n,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *data_col)
{
CUDA_KERNEL_LOOP(index, n)
{
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
scalar_t *data_col_ptr = data_col + index;
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
scalar_t col = 0;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
}
}
*data_col_ptr = col;
}
}
template <typename scalar_t, unsigned int blockSize>
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
__shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
if (tid == 0)
{
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
int sid=2;
for (unsigned int tid = 1; tid < blockSize; ++tid)
{
_grad_w += cache_grad_sampling_loc[sid];
_grad_h += cache_grad_sampling_loc[sid + 1];
_grad_a += cache_grad_attn_weight[tid];
sid += 2;
}
*grad_sampling_loc = _grad_w;
*(grad_sampling_loc + 1) = _grad_h;
*grad_attn_weight = _grad_a;
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t, unsigned int blockSize>
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
__shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
for (unsigned int s=blockSize/2; s>0; s>>=1)
{
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
}
__syncthreads();
}
if (tid == 0)
{
*grad_sampling_loc = cache_grad_sampling_loc[0];
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight = cache_grad_attn_weight[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
extern __shared__ int _s[];
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
if (tid == 0)
{
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
int sid=2;
for (unsigned int tid = 1; tid < blockDim.x; ++tid)
{
_grad_w += cache_grad_sampling_loc[sid];
_grad_h += cache_grad_sampling_loc[sid + 1];
_grad_a += cache_grad_attn_weight[tid];
sid += 2;
}
*grad_sampling_loc = _grad_w;
*(grad_sampling_loc + 1) = _grad_h;
*grad_attn_weight = _grad_a;
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
extern __shared__ int _s[];
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
{
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
if (tid + (s << 1) < spre)
{
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
}
}
__syncthreads();
}
if (tid == 0)
{
*grad_sampling_loc = cache_grad_sampling_loc[0];
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight = cache_grad_attn_weight[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
extern __shared__ int _s[];
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
{
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
if (tid + (s << 1) < spre)
{
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
}
}
__syncthreads();
}
if (tid == 0)
{
atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear_gm(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
grad_sampling_loc, grad_attn_weight);
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
void ms_deformable_im2col_cuda(hipStream_t stream,
const scalar_t* data_value,
const int64_t* data_spatial_shapes,
const int64_t* data_level_start_index,
const scalar_t* data_sampling_loc,
const scalar_t* data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t* data_col)
{
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
const int num_threads = CUDA_NUM_THREADS;
hipLaunchKernelGGL(( ms_deformable_im2col_gpu_kernel<scalar_t>)
, dim3(GET_BLOCKS(num_actual_kernels, num_threads)), dim3(num_threads),
0, stream,
num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
hipError_t err = hipGetLastError();
if (err != hipSuccess)
{
printf("error in ms_deformable_im2col_cuda: %s\n", hipGetErrorString(err));
}
}
template <typename scalar_t>
void ms_deformable_col2im_cuda(hipStream_t stream,
const scalar_t* grad_col,
const scalar_t* data_value,
const int64_t * data_spatial_shapes,
const int64_t * data_level_start_index,
const scalar_t * data_sampling_loc,
const scalar_t * data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t* grad_value,
scalar_t* grad_sampling_loc,
scalar_t* grad_attn_weight)
{
const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
if (channels > 1024)
{
if ((channels & 1023) == 0)
{
hipLaunchKernelGGL(( ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>)
, dim3(GET_BLOCKS(num_actual_kernels, num_threads)), dim3(num_threads),
num_threads*3*sizeof(scalar_t), stream,
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
}
else
{
hipLaunchKernelGGL(( ms_deformable_col2im_gpu_kernel_gm<scalar_t>)
, dim3(GET_BLOCKS(num_actual_kernels, num_threads)), dim3(num_threads),
0, stream,
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
}
}
else{
switch(channels)
{
case 1:
hipLaunchKernelGGL(( ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>)
, dim3(GET_BLOCKS(num_actual_kernels, num_threads)), dim3(num_threads),
0, stream,
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 2:
hipLaunchKernelGGL(( ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>)
, dim3(GET_BLOCKS(num_actual_kernels, num_threads)), dim3(num_threads),
0, stream,
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 4:
hipLaunchKernelGGL(( ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>)
, dim3(GET_BLOCKS(num_actual_kernels, num_threads)), dim3(num_threads),
0, stream,
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 8:
hipLaunchKernelGGL(( ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>)
, dim3(GET_BLOCKS(num_actual_kernels, num_threads)), dim3(num_threads),
0, stream,
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 16:
hipLaunchKernelGGL(( ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>)
, dim3(GET_BLOCKS(num_actual_kernels, num_threads)), dim3(num_threads),
0, stream,
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 32:
hipLaunchKernelGGL(( ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>)
, dim3(GET_BLOCKS(num_actual_kernels, num_threads)), dim3(num_threads),
0, stream,
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 64:
hipLaunchKernelGGL(( ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>)
, dim3(GET_BLOCKS(num_actual_kernels, num_threads)), dim3(num_threads),
0, stream,
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 128:
hipLaunchKernelGGL(( ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>)
, dim3(GET_BLOCKS(num_actual_kernels, num_threads)), dim3(num_threads),
0, stream,
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 256:
hipLaunchKernelGGL(( ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>)
, dim3(GET_BLOCKS(num_actual_kernels, num_threads)), dim3(num_threads),
0, stream,
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 512:
hipLaunchKernelGGL(( ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>)
, dim3(GET_BLOCKS(num_actual_kernels, num_threads)), dim3(num_threads),
0, stream,
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 1024:
hipLaunchKernelGGL(( ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>)
, dim3(GET_BLOCKS(num_actual_kernels, num_threads)), dim3(num_threads),
0, stream,
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
default:
if (channels < 64)
{
hipLaunchKernelGGL(( ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>)
, dim3(GET_BLOCKS(num_actual_kernels, num_threads)), dim3(num_threads),
num_threads*3*sizeof(scalar_t), stream,
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
}
else
{
hipLaunchKernelGGL(( ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>)
, dim3(GET_BLOCKS(num_actual_kernels, num_threads)), dim3(num_threads),
num_threads*3*sizeof(scalar_t), stream,
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
}
}
}
hipError_t err = hipGetLastError();
if (err != hipSuccess)
{
printf("error in ms_deformable_col2im_cuda: %s\n", hipGetErrorString(err));
}
}
\ No newline at end of file
#include <cuda_runtime_api.h>
namespace groundingdino {
int get_cudart_version() {
return CUDART_VERSION;
}
} // namespace groundingdino
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include <hip/hip_runtime_api.h>
namespace groundingdino {
int get_cudart_version() {
return DTKRT_VERSION;
}
} // namespace groundingdino
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#include "MsDeformAttn/ms_deform_attn.h"
namespace groundingdino {
#ifdef WITH_CUDA
extern int get_cudart_version();
#endif
std::string get_cuda_version() {
#ifdef WITH_CUDA
std::ostringstream oss;
// copied from
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231
auto printCudaStyleVersion = [&](int v) {
oss << (v / 1000) << "." << (v / 10 % 100);
if (v % 10 != 0) {
oss << "." << (v % 10);
}
};
printCudaStyleVersion(get_cudart_version());
return oss.str();
#else
return std::string("not available");
#endif
}
// similar to
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Version.cpp
std::string get_compiler_version() {
std::ostringstream ss;
#if defined(__GNUC__)
#ifndef __clang__
{ ss << "GCC " << __GNUC__ << "." << __GNUC_MINOR__; }
#endif
#endif
#if defined(__clang_major__)
{
ss << "clang " << __clang_major__ << "." << __clang_minor__ << "."
<< __clang_patchlevel__;
}
#endif
#if defined(_MSC_VER)
{ ss << "MSVC " << _MSC_FULL_VER; }
#endif
return ss.str();
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
}
} // namespace groundingdino
\ No newline at end of file
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath
class FeatureResizer(nn.Module):
"""
This class takes as input a set of embeddings of dimension C1 and outputs a set of
embedding of dimension C2, after a linear transformation, dropout and normalization (LN).
"""
def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True):
super().__init__()
self.do_ln = do_ln
# Object feature encoding
self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True)
self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12)
self.dropout = nn.Dropout(dropout)
def forward(self, encoder_features):
x = self.fc(encoder_features)
if self.do_ln:
x = self.layer_norm(x)
output = self.dropout(x)
return output
def l1norm(X, dim, eps=1e-8):
"""L1-normalize columns of X"""
norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
X = torch.div(X, norm)
return X
def l2norm(X, dim, eps=1e-8):
"""L2-normalize columns of X"""
norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
X = torch.div(X, norm)
return X
def func_attention(query, context, smooth=1, raw_feature_norm="softmax", eps=1e-8):
"""
query: (n_context, queryL, d)
context: (n_context, sourceL, d)
"""
batch_size_q, queryL = query.size(0), query.size(1)
batch_size, sourceL = context.size(0), context.size(1)
# Get attention
# --> (batch, d, queryL)
queryT = torch.transpose(query, 1, 2)
# (batch, sourceL, d)(batch, d, queryL)
# --> (batch, sourceL, queryL)
attn = torch.bmm(context, queryT)
if raw_feature_norm == "softmax":
# --> (batch*sourceL, queryL)
attn = attn.view(batch_size * sourceL, queryL)
attn = nn.Softmax()(attn)
# --> (batch, sourceL, queryL)
attn = attn.view(batch_size, sourceL, queryL)
elif raw_feature_norm == "l2norm":
attn = l2norm(attn, 2)
elif raw_feature_norm == "clipped_l2norm":
attn = nn.LeakyReLU(0.1)(attn)
attn = l2norm(attn, 2)
else:
raise ValueError("unknown first norm type:", raw_feature_norm)
# --> (batch, queryL, sourceL)
attn = torch.transpose(attn, 1, 2).contiguous()
# --> (batch*queryL, sourceL)
attn = attn.view(batch_size * queryL, sourceL)
attn = nn.Softmax()(attn * smooth)
# --> (batch, queryL, sourceL)
attn = attn.view(batch_size, queryL, sourceL)
# --> (batch, sourceL, queryL)
attnT = torch.transpose(attn, 1, 2).contiguous()
# --> (batch, d, sourceL)
contextT = torch.transpose(context, 1, 2)
# (batch x d x sourceL)(batch x sourceL x queryL)
# --> (batch, d, queryL)
weightedContext = torch.bmm(contextT, attnT)
# --> (batch, queryL, d)
weightedContext = torch.transpose(weightedContext, 1, 2)
return weightedContext, attnT
class BiMultiHeadAttention(nn.Module):
def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1, cfg=None):
super(BiMultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.v_dim = v_dim
self.l_dim = l_dim
assert (
self.head_dim * self.num_heads == self.embed_dim
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
self.scale = self.head_dim ** (-0.5)
self.dropout = dropout
self.v_proj = nn.Linear(self.v_dim, self.embed_dim)
self.l_proj = nn.Linear(self.l_dim, self.embed_dim)
self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim)
self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim)
self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim)
self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim)
self.stable_softmax_2d = True
self.clamp_min_for_underflow = True
self.clamp_max_for_overflow = True
self._reset_parameters()
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def _reset_parameters(self):
nn.init.xavier_uniform_(self.v_proj.weight)
self.v_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.l_proj.weight)
self.l_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.values_v_proj.weight)
self.values_v_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.values_l_proj.weight)
self.values_l_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.out_v_proj.weight)
self.out_v_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.out_l_proj.weight)
self.out_l_proj.bias.data.fill_(0)
def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
"""_summary_
Args:
v (_type_): bs, n_img, dim
l (_type_): bs, n_text, dim
attention_mask_v (_type_, optional): _description_. bs, n_img
attention_mask_l (_type_, optional): _description_. bs, n_text
Returns:
_type_: _description_
"""
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
# import ipdb; ipdb.set_trace()
bsz, tgt_len, _ = v.size()
query_states = self.v_proj(v) * self.scale
key_states = self._shape(self.l_proj(l), -1, bsz)
value_v_states = self._shape(self.values_v_proj(v), -1, bsz)
value_l_states = self._shape(self.values_l_proj(l), -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_v_states = value_v_states.view(*proj_shape)
value_l_states = value_l_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) # bs*nhead, nimg, ntxt
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
)
if self.stable_softmax_2d:
attn_weights = attn_weights - attn_weights.max()
if self.clamp_min_for_underflow:
attn_weights = torch.clamp(
attn_weights, min=-50000
) # Do not increase -50000, data type half has quite limited range
if self.clamp_max_for_overflow:
attn_weights = torch.clamp(
attn_weights, max=50000
) # Do not increase 50000, data type half has quite limited range
attn_weights_T = attn_weights.transpose(1, 2)
attn_weights_l = attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[0]
if self.clamp_min_for_underflow:
attn_weights_l = torch.clamp(
attn_weights_l, min=-50000
) # Do not increase -50000, data type half has quite limited range
if self.clamp_max_for_overflow:
attn_weights_l = torch.clamp(
attn_weights_l, max=50000
) # Do not increase 50000, data type half has quite limited range
# mask vison for language
if attention_mask_v is not None:
attention_mask_v = (
attention_mask_v[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
)
attn_weights_l.masked_fill_(attention_mask_v, float("-inf"))
attn_weights_l = attn_weights_l.softmax(dim=-1)
# mask language for vision
if attention_mask_l is not None:
attention_mask_l = (
attention_mask_l[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
)
attn_weights.masked_fill_(attention_mask_l, float("-inf"))
attn_weights_v = attn_weights.softmax(dim=-1)
attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training)
attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training)
attn_output_v = torch.bmm(attn_probs_v, value_l_states)
attn_output_l = torch.bmm(attn_probs_l, value_v_states)
if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}"
)
if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim):
raise ValueError(
f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}"
)
attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output_v = attn_output_v.transpose(1, 2)
attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim)
attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim)
attn_output_l = attn_output_l.transpose(1, 2)
attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim)
attn_output_v = self.out_v_proj(attn_output_v)
attn_output_l = self.out_l_proj(attn_output_l)
return attn_output_v, attn_output_l
# Bi-Direction MHA (text->image, image->text)
class BiAttentionBlock(nn.Module):
def __init__(
self,
v_dim,
l_dim,
embed_dim,
num_heads,
dropout=0.1,
drop_path=0.0,
init_values=1e-4,
cfg=None,
):
"""
Inputs:
embed_dim - Dimensionality of input and attention feature vectors
hidden_dim - Dimensionality of hidden layer in feed-forward network
(usually 2-4x larger than embed_dim)
num_heads - Number of heads to use in the Multi-Head Attention block
dropout - Amount of dropout to apply in the feed-forward network
"""
super(BiAttentionBlock, self).__init__()
# pre layer norm
self.layer_norm_v = nn.LayerNorm(v_dim)
self.layer_norm_l = nn.LayerNorm(l_dim)
self.attn = BiMultiHeadAttention(
v_dim=v_dim, l_dim=l_dim, embed_dim=embed_dim, num_heads=num_heads, dropout=dropout
)
# add layer scale for training stability
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=True)
self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True)
def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
v = self.layer_norm_v(v)
l = self.layer_norm_l(l)
delta_v, delta_l = self.attn(
v, l, attention_mask_v=attention_mask_v, attention_mask_l=attention_mask_l
)
# v, l = v + delta_v, l + delta_l
v = v + self.drop_path(self.gamma_v * delta_v)
l = l + self.drop_path(self.gamma_l * delta_l)
return v, l
# def forward(self, v:List[torch.Tensor], l, attention_mask_v=None, attention_mask_l=None)
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR model and criterion classes.
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------
import copy
from typing import List, Optional
import torch
import torch.nn.functional as F
from torch import nn
from torchvision.ops.boxes import nms
from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
from groundingdino.util import box_ops, get_tokenlizer
from groundingdino.util.misc import (
NestedTensor,
accuracy,
get_world_size,
interpolate,
inverse_sigmoid,
is_dist_avail_and_initialized,
nested_tensor_from_tensor_list,
)
from groundingdino.util.utils import get_phrases_from_posmap
from groundingdino.util.visualizer import COCOVisualizer
from groundingdino.util.vl_utils import create_positive_map_from_span
from ..registry import MODULE_BUILD_FUNCS
from .backbone import build_backbone
from .bertwarper import (
BertModelWarper,
generate_masks_with_special_tokens,
generate_masks_with_special_tokens_and_transfer_map,
)
from .transformer import build_transformer
from .utils import MLP, ContrastiveEmbed, sigmoid_focal_loss
class GroundingDINO(nn.Module):
"""This is the Cross-Attention Detector module that performs object detection"""
def __init__(
self,
backbone,
transformer,
num_queries,
aux_loss=False,
iter_update=False,
query_dim=2,
num_feature_levels=1,
nheads=8,
# two stage
two_stage_type="no", # ['no', 'standard']
dec_pred_bbox_embed_share=True,
two_stage_class_embed_share=True,
two_stage_bbox_embed_share=True,
num_patterns=0,
dn_number=100,
dn_box_noise_scale=0.4,
dn_label_noise_ratio=0.5,
dn_labelbook_size=100,
text_encoder_type="bert-base-uncased",
sub_sentence_present=True,
max_text_len=256,
):
"""Initializes the model.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.num_queries = num_queries
self.transformer = transformer
self.hidden_dim = hidden_dim = transformer.d_model
self.num_feature_levels = num_feature_levels
self.nheads = nheads
self.max_text_len = 256
self.sub_sentence_present = sub_sentence_present
# setting query dim
self.query_dim = query_dim
assert query_dim == 4
# for dn training
self.num_patterns = num_patterns
self.dn_number = dn_number
self.dn_box_noise_scale = dn_box_noise_scale
self.dn_label_noise_ratio = dn_label_noise_ratio
self.dn_labelbook_size = dn_labelbook_size
# bert
self.tokenizer = get_tokenlizer.get_tokenlizer(text_encoder_type)
self.bert = get_tokenlizer.get_pretrained_language_model(text_encoder_type)
self.bert.pooler.dense.weight.requires_grad_(False)
self.bert.pooler.dense.bias.requires_grad_(False)
self.bert = BertModelWarper(bert_model=self.bert)
self.feat_map = nn.Linear(self.bert.config.hidden_size, self.hidden_dim, bias=True)
nn.init.constant_(self.feat_map.bias.data, 0)
nn.init.xavier_uniform_(self.feat_map.weight.data)
# freeze
# special tokens
self.specical_tokens = self.tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]", ".", "?"])
# prepare input projection layers
if num_feature_levels > 1:
num_backbone_outs = len(backbone.num_channels)
input_proj_list = []
for _ in range(num_backbone_outs):
in_channels = backbone.num_channels[_]
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim),
)
)
for _ in range(num_feature_levels - num_backbone_outs):
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(32, hidden_dim),
)
)
in_channels = hidden_dim
self.input_proj = nn.ModuleList(input_proj_list)
else:
assert two_stage_type == "no", "two_stage_type should be no if num_feature_levels=1 !!!"
self.input_proj = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim),
)
]
)
self.backbone = backbone
self.aux_loss = aux_loss
self.box_pred_damping = box_pred_damping = None
self.iter_update = iter_update
assert iter_update, "Why not iter_update?"
# prepare pred layers
self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share
# prepare class & box embed
_class_embed = ContrastiveEmbed()
_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
if dec_pred_bbox_embed_share:
box_embed_layerlist = [_bbox_embed for i in range(transformer.num_decoder_layers)]
else:
box_embed_layerlist = [
copy.deepcopy(_bbox_embed) for i in range(transformer.num_decoder_layers)
]
class_embed_layerlist = [_class_embed for i in range(transformer.num_decoder_layers)]
self.bbox_embed = nn.ModuleList(box_embed_layerlist)
self.class_embed = nn.ModuleList(class_embed_layerlist)
self.transformer.decoder.bbox_embed = self.bbox_embed
self.transformer.decoder.class_embed = self.class_embed
# two stage
self.two_stage_type = two_stage_type
assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format(
two_stage_type
)
if two_stage_type != "no":
if two_stage_bbox_embed_share:
assert dec_pred_bbox_embed_share
self.transformer.enc_out_bbox_embed = _bbox_embed
else:
self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed)
if two_stage_class_embed_share:
assert dec_pred_bbox_embed_share
self.transformer.enc_out_class_embed = _class_embed
else:
self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed)
self.refpoint_embed = None
self._reset_parameters()
def _reset_parameters(self):
# init input_proj
for proj in self.input_proj:
nn.init.xavier_uniform_(proj[0].weight, gain=1)
nn.init.constant_(proj[0].bias, 0)
def set_image_tensor(self, samples: NestedTensor):
if isinstance(samples, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(samples)
self.features, self.poss = self.backbone(samples)
def unset_image_tensor(self):
if hasattr(self, 'features'):
del self.features
if hasattr(self,'poss'):
del self.poss
def set_image_features(self, features , poss):
self.features = features
self.poss = poss
def init_ref_points(self, use_num_queries):
self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim)
# def forward(self, samples: NestedTensor, targets: List = None, **kw):
def forward(self, samples: NestedTensor,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
text_self_attention_masks: Optional[torch.Tensor] = None):
"""The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
It returns a dict with the following elements:
- "pred_logits": the classification logits (including no-object) for all queries.
Shape= [batch_size x num_queries x num_classes]
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
(center_x, center_y, width, height). These values are normalized in [0, 1],
relative to the size of each individual image (disregarding possible padding).
See PostProcess for information on how to retrieve the unnormalized bounding box.
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer.
"""
# if targets is None:
# captions = kw["captions"]
# else:
# captions = [t["caption"] for t in targets]
# # encoder texts
# tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to(
# samples.device
# )
# (
# text_self_attention_masks,
# position_ids,
# cate_to_token_mask_list,
# ) = generate_masks_with_special_tokens_and_transfer_map(
# tokenized, self.specical_tokens, self.tokenizer
# )
# if text_self_attention_masks.shape[1] > self.max_text_len:
# text_self_attention_masks = text_self_attention_masks[
# :, : self.max_text_len, : self.max_text_len
# ]
# position_ids = position_ids[:, : self.max_text_len]
# tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len]
# tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len]
# tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len]
# # extract text embeddings
# if self.sub_sentence_present:
# tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
# tokenized_for_encoder["attention_mask"] = text_self_attention_masks
# tokenized_for_encoder["position_ids"] = position_ids
# else:
# # import ipdb; ipdb.set_trace()
# tokenized_for_encoder = tokenized
# ---------------------------输入不同--------------------------
tokenized_for_encoder ={}
tokenized_for_encoder["input_ids"] = input_ids
tokenized_for_encoder["attention_mask"] = text_self_attention_masks
tokenized_for_encoder["position_ids"] = position_ids
tokenized_for_encoder["token_type_ids"] = token_type_ids
# --------------------------输入不同结束------------------------
bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768
encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model
#------------------------------2---------------------------------
# text_token_mask = tokenized.attention_mask.bool() # bs, 195
#-----------------------------------------------------------------
text_token_mask = attention_mask.bool() # bs, 195
# text_token_mask: True for nomask, False for mask
# text_self_attention_masks: True for nomask, False for mask
if encoded_text.shape[1] > self.max_text_len:
encoded_text = encoded_text[:, : self.max_text_len, :]
text_token_mask = text_token_mask[:, : self.max_text_len]
position_ids = position_ids[:, : self.max_text_len]
text_self_attention_masks = text_self_attention_masks[
:, : self.max_text_len, : self.max_text_len
]
text_dict = {
"encoded_text": encoded_text, # bs, 195, d_model
"text_token_mask": text_token_mask, # bs, 195
"position_ids": position_ids, # bs, 195
"text_self_attention_masks": text_self_attention_masks, # bs, 195,195
}
# import ipdb; ipdb.set_trace()
if isinstance(samples, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(samples)
if not hasattr(self, 'features') or not hasattr(self, 'poss'):
self.set_image_tensor(samples)
srcs = []
masks = []
for l, feat in enumerate(self.features):
src, mask = feat.decompose()
srcs.append(self.input_proj[l](src))
masks.append(mask)
assert mask is not None
if self.num_feature_levels > len(srcs):
_len_srcs = len(srcs)
for l in range(_len_srcs, self.num_feature_levels):
if l == _len_srcs:
src = self.input_proj[l](self.features[-1].tensors)
else:
src = self.input_proj[l](srcs[-1])
m = samples.mask
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
srcs.append(src)
masks.append(mask)
self.poss.append(pos_l)
input_query_bbox = input_query_label = attn_mask = dn_meta = None
hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(
srcs, masks, input_query_bbox, self.poss, input_query_label, attn_mask, text_dict
)
# deformable-detr-like anchor update
outputs_coord_list = []
for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(
zip(reference[:-1], self.bbox_embed, hs)
):
layer_delta_unsig = layer_bbox_embed(layer_hs)
layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
layer_outputs_unsig = layer_outputs_unsig.sigmoid()
outputs_coord_list.append(layer_outputs_unsig)
outputs_coord_list = torch.stack(outputs_coord_list)
# output
outputs_class = torch.stack(
[
layer_cls_embed(layer_hs, text_dict)
for layer_cls_embed, layer_hs in zip(self.class_embed, hs)
]
)
out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord_list[-1]}
# # for intermediate outputs
# if self.aux_loss:
# out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord_list)
# # for encoder output
# if hs_enc is not None:
# # prepare intermediate outputs
# interm_coord = ref_enc[-1]
# interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict)
# out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord}
# out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal}
#-------------------------2---------------------------------
# unset_image_tensor = kw.get('unset_image_tensor', True)
# if unset_image_tensor:
# self.unset_image_tensor() ## If necessary
#-----------------------------------------------------------
return out
@torch.jit.unused
def _set_aux_loss(self, outputs_class, outputs_coord):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return [
{"pred_logits": a, "pred_boxes": b}
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
]
@MODULE_BUILD_FUNCS.registe_with_name(module_name="groundingdino")
def build_groundingdino(args):
backbone = build_backbone(args)
transformer = build_transformer(args)
dn_labelbook_size = args.dn_labelbook_size
dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share
sub_sentence_present = args.sub_sentence_present
model = GroundingDINO(
backbone,
transformer,
num_queries=args.num_queries,
aux_loss=True,
iter_update=True,
query_dim=4,
num_feature_levels=args.num_feature_levels,
nheads=args.nheads,
dec_pred_bbox_embed_share=dec_pred_bbox_embed_share,
two_stage_type=args.two_stage_type,
two_stage_bbox_embed_share=args.two_stage_bbox_embed_share,
two_stage_class_embed_share=args.two_stage_class_embed_share,
num_patterns=args.num_patterns,
dn_number=0,
dn_box_noise_scale=args.dn_box_noise_scale,
dn_label_noise_ratio=args.dn_label_noise_ratio,
dn_labelbook_size=dn_labelbook_size,
text_encoder_type=args.text_encoder_type,
sub_sentence_present=sub_sentence_present,
max_text_len=args.max_text_len,
)
return model
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR model and criterion classes.
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------
import copy
from typing import List, Optional
import torch
import torch.nn.functional as F
from torch import nn
from torchvision.ops.boxes import nms
from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
from groundingdino.util import box_ops, get_tokenlizer
from groundingdino.util.misc import (
NestedTensor,
accuracy,
get_world_size,
interpolate,
inverse_sigmoid,
is_dist_avail_and_initialized,
nested_tensor_from_tensor_list,
)
from groundingdino.util.utils import get_phrases_from_posmap
from groundingdino.util.visualizer import COCOVisualizer
from groundingdino.util.vl_utils import create_positive_map_from_span
from ..registry import MODULE_BUILD_FUNCS
from .backbone import build_backbone
from .bertwarper import (
BertModelWarper,
generate_masks_with_special_tokens,
generate_masks_with_special_tokens_and_transfer_map,
)
from .transformer import build_transformer
from .utils import MLP, ContrastiveEmbed, sigmoid_focal_loss
class GroundingDINO(nn.Module):
"""This is the Cross-Attention Detector module that performs object detection"""
def __init__(
self,
backbone,
transformer,
num_queries,
aux_loss=False,
iter_update=False,
query_dim=2,
num_feature_levels=1,
nheads=8,
# two stage
two_stage_type="no", # ['no', 'standard']
dec_pred_bbox_embed_share=True,
two_stage_class_embed_share=True,
two_stage_bbox_embed_share=True,
num_patterns=0,
dn_number=100,
dn_box_noise_scale=0.4,
dn_label_noise_ratio=0.5,
dn_labelbook_size=100,
text_encoder_type="bert-base-uncased",
sub_sentence_present=True,
max_text_len=256,
):
"""Initializes the model.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.num_queries = num_queries
self.transformer = transformer
self.hidden_dim = hidden_dim = transformer.d_model
self.num_feature_levels = num_feature_levels
self.nheads = nheads
self.max_text_len = 256
self.sub_sentence_present = sub_sentence_present
# setting query dim
self.query_dim = query_dim
assert query_dim == 4
# for dn training
self.num_patterns = num_patterns
self.dn_number = dn_number
self.dn_box_noise_scale = dn_box_noise_scale
self.dn_label_noise_ratio = dn_label_noise_ratio
self.dn_labelbook_size = dn_labelbook_size
# bert
self.tokenizer = get_tokenlizer.get_tokenlizer(text_encoder_type)
self.bert = get_tokenlizer.get_pretrained_language_model(text_encoder_type)
self.bert.pooler.dense.weight.requires_grad_(False)
self.bert.pooler.dense.bias.requires_grad_(False)
self.bert = BertModelWarper(bert_model=self.bert)
self.feat_map = nn.Linear(self.bert.config.hidden_size, self.hidden_dim, bias=True)
nn.init.constant_(self.feat_map.bias.data, 0)
nn.init.xavier_uniform_(self.feat_map.weight.data)
# freeze
# special tokens
self.specical_tokens = self.tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]", ".", "?"])
# prepare input projection layers
if num_feature_levels > 1:
num_backbone_outs = len(backbone.num_channels)
input_proj_list = []
for _ in range(num_backbone_outs):
in_channels = backbone.num_channels[_]
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim),
)
)
for _ in range(num_feature_levels - num_backbone_outs):
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(32, hidden_dim),
)
)
in_channels = hidden_dim
self.input_proj = nn.ModuleList(input_proj_list)
else:
assert two_stage_type == "no", "two_stage_type should be no if num_feature_levels=1 !!!"
self.input_proj = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim),
)
]
)
self.backbone = backbone
self.aux_loss = aux_loss
self.box_pred_damping = box_pred_damping = None
self.iter_update = iter_update
assert iter_update, "Why not iter_update?"
# prepare pred layers
self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share
# prepare class & box embed
_class_embed = ContrastiveEmbed()
_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
if dec_pred_bbox_embed_share:
box_embed_layerlist = [_bbox_embed for i in range(transformer.num_decoder_layers)]
else:
box_embed_layerlist = [
copy.deepcopy(_bbox_embed) for i in range(transformer.num_decoder_layers)
]
class_embed_layerlist = [_class_embed for i in range(transformer.num_decoder_layers)]
self.bbox_embed = nn.ModuleList(box_embed_layerlist)
self.class_embed = nn.ModuleList(class_embed_layerlist)
self.transformer.decoder.bbox_embed = self.bbox_embed
self.transformer.decoder.class_embed = self.class_embed
# two stage
self.two_stage_type = two_stage_type
assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format(
two_stage_type
)
if two_stage_type != "no":
if two_stage_bbox_embed_share:
assert dec_pred_bbox_embed_share
self.transformer.enc_out_bbox_embed = _bbox_embed
else:
self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed)
if two_stage_class_embed_share:
assert dec_pred_bbox_embed_share
self.transformer.enc_out_class_embed = _class_embed
else:
self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed)
self.refpoint_embed = None
self._reset_parameters()
def _reset_parameters(self):
# init input_proj
for proj in self.input_proj:
nn.init.xavier_uniform_(proj[0].weight, gain=1)
nn.init.constant_(proj[0].bias, 0)
def set_image_tensor(self, samples: NestedTensor):
if isinstance(samples, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(samples)
self.features, self.poss = self.backbone(samples)
def unset_image_tensor(self):
if hasattr(self, 'features'):
del self.features
if hasattr(self,'poss'):
del self.poss
def set_image_features(self, features , poss):
self.features = features
self.poss = poss
def init_ref_points(self, use_num_queries):
self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim)
# def forward(self, samples: NestedTensor, targets: List = None, **kw):
def forward(self, samples: NestedTensor,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
text_self_attention_masks: Optional[torch.Tensor] = None):
"""The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
It returns a dict with the following elements:
- "pred_logits": the classification logits (including no-object) for all queries.
Shape= [batch_size x num_queries x num_classes]
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
(center_x, center_y, width, height). These values are normalized in [0, 1],
relative to the size of each individual image (disregarding possible padding).
See PostProcess for information on how to retrieve the unnormalized bounding box.
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer.
"""
# if targets is None:
# captions = kw["captions"]
# else:
# captions = [t["caption"] for t in targets]
# # encoder texts
# tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to(
# samples.device
# )
# (
# text_self_attention_masks,
# position_ids,
# cate_to_token_mask_list,
# ) = generate_masks_with_special_tokens_and_transfer_map(
# tokenized, self.specical_tokens, self.tokenizer
# )
# if text_self_attention_masks.shape[1] > self.max_text_len:
# text_self_attention_masks = text_self_attention_masks[
# :, : self.max_text_len, : self.max_text_len
# ]
# position_ids = position_ids[:, : self.max_text_len]
# tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len]
# tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len]
# tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len]
# # extract text embeddings
# if self.sub_sentence_present:
# tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
# tokenized_for_encoder["attention_mask"] = text_self_attention_masks
# tokenized_for_encoder["position_ids"] = position_ids
# else:
# # import ipdb; ipdb.set_trace()
# tokenized_for_encoder = tokenized
# ---------------------------输入不同--------------------------
tokenized_for_encoder ={}
tokenized_for_encoder["input_ids"] = input_ids
tokenized_for_encoder["attention_mask"] = text_self_attention_masks
tokenized_for_encoder["position_ids"] = position_ids
tokenized_for_encoder["token_type_ids"] = token_type_ids
# --------------------------输入不同结束------------------------
bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768
encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model
#------------------------------2---------------------------------
# text_token_mask = tokenized.attention_mask.bool() # bs, 195
#-----------------------------------------------------------------
text_token_mask = attention_mask.bool() # bs, 195
# text_token_mask: True for nomask, False for mask
# text_self_attention_masks: True for nomask, False for mask
if encoded_text.shape[1] > self.max_text_len:
encoded_text = encoded_text[:, : self.max_text_len, :]
text_token_mask = text_token_mask[:, : self.max_text_len]
position_ids = position_ids[:, : self.max_text_len]
text_self_attention_masks = text_self_attention_masks[
:, : self.max_text_len, : self.max_text_len
]
text_dict = {
"encoded_text": encoded_text, # bs, 195, d_model
"text_token_mask": text_token_mask, # bs, 195
"position_ids": position_ids, # bs, 195
"text_self_attention_masks": text_self_attention_masks, # bs, 195,195
}
# import ipdb; ipdb.set_trace()
if isinstance(samples, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(samples)
if not hasattr(self, 'features') or not hasattr(self, 'poss'):
self.set_image_tensor(samples)
srcs = []
masks = []
for l, feat in enumerate(self.features):
src, mask = feat.decompose()
srcs.append(self.input_proj[l](src))
masks.append(mask)
assert mask is not None
if self.num_feature_levels > len(srcs):
_len_srcs = len(srcs)
for l in range(_len_srcs, self.num_feature_levels):
if l == _len_srcs:
src = self.input_proj[l](self.features[-1].tensors)
else:
src = self.input_proj[l](srcs[-1])
m = samples.mask
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
srcs.append(src)
masks.append(mask)
self.poss.append(pos_l)
input_query_bbox = input_query_label = attn_mask = dn_meta = None
hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(
srcs, masks, input_query_bbox, self.poss, input_query_label, attn_mask, text_dict
)
# deformable-detr-like anchor update
outputs_coord_list = []
for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(
zip(reference[:-1], self.bbox_embed, hs)
):
layer_delta_unsig = layer_bbox_embed(layer_hs)
layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
layer_outputs_unsig = layer_outputs_unsig.sigmoid()
outputs_coord_list.append(layer_outputs_unsig)
outputs_coord_list = torch.stack(outputs_coord_list)
# output
outputs_class = torch.stack(
[
layer_cls_embed(layer_hs, text_dict)
for layer_cls_embed, layer_hs in zip(self.class_embed, hs)
]
)
out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord_list[-1]}
# # for intermediate outputs
# if self.aux_loss:
# out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord_list)
# # for encoder output
# if hs_enc is not None:
# # prepare intermediate outputs
# interm_coord = ref_enc[-1]
# interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict)
# out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord}
# out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal}
#-------------------------2---------------------------------
# unset_image_tensor = kw.get('unset_image_tensor', True)
# if unset_image_tensor:
# self.unset_image_tensor() ## If necessary
#-----------------------------------------------------------
return out
@torch.jit.unused
def _set_aux_loss(self, outputs_class, outputs_coord):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return [
{"pred_logits": a, "pred_boxes": b}
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
]
@MODULE_BUILD_FUNCS.registe_with_name(module_name="groundingdino")
def build_groundingdino(args):
backbone = build_backbone(args)
transformer = build_transformer(args)
dn_labelbook_size = args.dn_labelbook_size
dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share
sub_sentence_present = args.sub_sentence_present
model = GroundingDINO(
backbone,
transformer,
num_queries=args.num_queries,
aux_loss=True,
iter_update=True,
query_dim=4,
num_feature_levels=args.num_feature_levels,
nheads=args.nheads,
dec_pred_bbox_embed_share=dec_pred_bbox_embed_share,
two_stage_type=args.two_stage_type,
two_stage_bbox_embed_share=args.two_stage_bbox_embed_share,
two_stage_class_embed_share=args.two_stage_class_embed_share,
num_patterns=args.num_patterns,
dn_number=0,
dn_box_noise_scale=args.dn_box_noise_scale,
dn_label_noise_ratio=args.dn_label_noise_ratio,
dn_labelbook_size=dn_labelbook_size,
text_encoder_type=args.text_encoder_type,
sub_sentence_present=sub_sentence_present,
max_text_len=args.max_text_len,
)
return model
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR model and criterion classes.
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------
import copy
from typing import List, Optional
import torch
import torch.nn.functional as F
from torch import nn
from torchvision.ops.boxes import nms
from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
from groundingdino.util import box_ops, get_tokenlizer
from groundingdino.util.misc import (
NestedTensor,
accuracy,
get_world_size,
interpolate,
inverse_sigmoid,
is_dist_avail_and_initialized,
nested_tensor_from_tensor_list,
)
from groundingdino.util.utils import get_phrases_from_posmap
from groundingdino.util.visualizer import COCOVisualizer
from groundingdino.util.vl_utils import create_positive_map_from_span
from ..registry import MODULE_BUILD_FUNCS
from .backbone import build_backbone
from .bertwarper import (
BertModelWarper,
generate_masks_with_special_tokens,
generate_masks_with_special_tokens_and_transfer_map,
)
from .transformer import build_transformer
from .utils import MLP, ContrastiveEmbed, sigmoid_focal_loss
class GroundingDINO(nn.Module):
"""This is the Cross-Attention Detector module that performs object detection"""
def __init__(
self,
backbone,
transformer,
num_queries,
aux_loss=False,
iter_update=False,
query_dim=2,
num_feature_levels=1,
nheads=8,
# two stage
two_stage_type="no", # ['no', 'standard']
dec_pred_bbox_embed_share=True,
two_stage_class_embed_share=True,
two_stage_bbox_embed_share=True,
num_patterns=0,
dn_number=100,
dn_box_noise_scale=0.4,
dn_label_noise_ratio=0.5,
dn_labelbook_size=100,
text_encoder_type="bert-base-uncased",
sub_sentence_present=True,
max_text_len=256,
):
"""Initializes the model.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.num_queries = num_queries
self.transformer = transformer
self.hidden_dim = hidden_dim = transformer.d_model
self.num_feature_levels = num_feature_levels
self.nheads = nheads
self.max_text_len = 256
self.sub_sentence_present = sub_sentence_present
# setting query dim
self.query_dim = query_dim
assert query_dim == 4
# for dn training
self.num_patterns = num_patterns
self.dn_number = dn_number
self.dn_box_noise_scale = dn_box_noise_scale
self.dn_label_noise_ratio = dn_label_noise_ratio
self.dn_labelbook_size = dn_labelbook_size
# bert
self.tokenizer = get_tokenlizer.get_tokenlizer(text_encoder_type)
self.bert = get_tokenlizer.get_pretrained_language_model(text_encoder_type)
self.bert.pooler.dense.weight.requires_grad_(False)
self.bert.pooler.dense.bias.requires_grad_(False)
self.bert = BertModelWarper(bert_model=self.bert)
self.feat_map = nn.Linear(self.bert.config.hidden_size, self.hidden_dim, bias=True)
nn.init.constant_(self.feat_map.bias.data, 0)
nn.init.xavier_uniform_(self.feat_map.weight.data)
# freeze
# special tokens
self.specical_tokens = self.tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]", ".", "?"])
# prepare input projection layers
if num_feature_levels > 1:
num_backbone_outs = len(backbone.num_channels)
input_proj_list = []
for _ in range(num_backbone_outs):
in_channels = backbone.num_channels[_]
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim),
)
)
for _ in range(num_feature_levels - num_backbone_outs):
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(32, hidden_dim),
)
)
in_channels = hidden_dim
self.input_proj = nn.ModuleList(input_proj_list)
else:
assert two_stage_type == "no", "two_stage_type should be no if num_feature_levels=1 !!!"
self.input_proj = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim),
)
]
)
self.backbone = backbone
self.aux_loss = aux_loss
self.box_pred_damping = box_pred_damping = None
self.iter_update = iter_update
assert iter_update, "Why not iter_update?"
# prepare pred layers
self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share
# prepare class & box embed
_class_embed = ContrastiveEmbed()
_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
if dec_pred_bbox_embed_share:
box_embed_layerlist = [_bbox_embed for i in range(transformer.num_decoder_layers)]
else:
box_embed_layerlist = [
copy.deepcopy(_bbox_embed) for i in range(transformer.num_decoder_layers)
]
class_embed_layerlist = [_class_embed for i in range(transformer.num_decoder_layers)]
self.bbox_embed = nn.ModuleList(box_embed_layerlist)
self.class_embed = nn.ModuleList(class_embed_layerlist)
self.transformer.decoder.bbox_embed = self.bbox_embed
self.transformer.decoder.class_embed = self.class_embed
# two stage
self.two_stage_type = two_stage_type
assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format(
two_stage_type
)
if two_stage_type != "no":
if two_stage_bbox_embed_share:
assert dec_pred_bbox_embed_share
self.transformer.enc_out_bbox_embed = _bbox_embed
else:
self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed)
if two_stage_class_embed_share:
assert dec_pred_bbox_embed_share
self.transformer.enc_out_class_embed = _class_embed
else:
self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed)
self.refpoint_embed = None
self._reset_parameters()
def _reset_parameters(self):
# init input_proj
for proj in self.input_proj:
nn.init.xavier_uniform_(proj[0].weight, gain=1)
nn.init.constant_(proj[0].bias, 0)
def set_image_tensor(self, samples: NestedTensor):
if isinstance(samples, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(samples)
self.features, self.poss = self.backbone(samples)
def unset_image_tensor(self):
if hasattr(self, 'features'):
del self.features
if hasattr(self,'poss'):
del self.poss
def set_image_features(self, features , poss):
self.features = features
self.poss = poss
def init_ref_points(self, use_num_queries):
self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim)
def forward(self, samples: NestedTensor, targets: List = None, **kw):
# def forward(self, samples: NestedTensor,
# input_ids: Optional[torch.Tensor] = None,
# attention_mask: Optional[torch.Tensor] = None,
# position_ids: Optional[torch.Tensor] = None,
# token_type_ids: Optional[torch.Tensor] = None,
# text_self_attention_masks: Optional[torch.Tensor] = None):
"""The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
It returns a dict with the following elements:
- "pred_logits": the classification logits (including no-object) for all queries.
Shape= [batch_size x num_queries x num_classes]
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
(center_x, center_y, width, height). These values are normalized in [0, 1],
relative to the size of each individual image (disregarding possible padding).
See PostProcess for information on how to retrieve the unnormalized bounding box.
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer.
"""
if targets is None:
captions = kw["captions"]
else:
captions = [t["caption"] for t in targets]
# encoder texts
tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to(
samples.device
)
(
text_self_attention_masks,
position_ids,
cate_to_token_mask_list,
) = generate_masks_with_special_tokens_and_transfer_map(
tokenized, self.specical_tokens, self.tokenizer
)
if text_self_attention_masks.shape[1] > self.max_text_len:
text_self_attention_masks = text_self_attention_masks[
:, : self.max_text_len, : self.max_text_len
]
position_ids = position_ids[:, : self.max_text_len]
tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len]
tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len]
tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len]
# extract text embeddings
if self.sub_sentence_present:
tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
tokenized_for_encoder["attention_mask"] = text_self_attention_masks
tokenized_for_encoder["position_ids"] = position_ids
else:
# import ipdb; ipdb.set_trace()
tokenized_for_encoder = tokenized
# ---------------------------输入不同--------------------------
# tokenized_for_encoder ={}
# tokenized_for_encoder["input_ids"] = input_ids
# tokenized_for_encoder["attention_mask"] = text_self_attention_masks
# tokenized_for_encoder["position_ids"] = position_ids
# tokenized_for_encoder["token_type_ids"] = token_type_ids
# --------------------------输入不同结束------------------------
bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768
encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model
#------------------------------2---------------------------------
text_token_mask = tokenized.attention_mask.bool() # bs, 195
#-----------------------------------------------------------------
# text_token_mask = attention_mask.bool() # bs, 195
# text_token_mask: True for nomask, False for mask
# text_self_attention_masks: True for nomask, False for mask
if encoded_text.shape[1] > self.max_text_len:
encoded_text = encoded_text[:, : self.max_text_len, :]
text_token_mask = text_token_mask[:, : self.max_text_len]
position_ids = position_ids[:, : self.max_text_len]
text_self_attention_masks = text_self_attention_masks[
:, : self.max_text_len, : self.max_text_len
]
text_dict = {
"encoded_text": encoded_text, # bs, 195, d_model
"text_token_mask": text_token_mask, # bs, 195
"position_ids": position_ids, # bs, 195
"text_self_attention_masks": text_self_attention_masks, # bs, 195,195
}
# import ipdb; ipdb.set_trace()
if isinstance(samples, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(samples)
if not hasattr(self, 'features') or not hasattr(self, 'poss'):
self.set_image_tensor(samples)
srcs = []
masks = []
for l, feat in enumerate(self.features):
src, mask = feat.decompose()
srcs.append(self.input_proj[l](src))
masks.append(mask)
assert mask is not None
if self.num_feature_levels > len(srcs):
_len_srcs = len(srcs)
for l in range(_len_srcs, self.num_feature_levels):
if l == _len_srcs:
src = self.input_proj[l](self.features[-1].tensors)
else:
src = self.input_proj[l](srcs[-1])
m = samples.mask
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
srcs.append(src)
masks.append(mask)
self.poss.append(pos_l)
input_query_bbox = input_query_label = attn_mask = dn_meta = None
hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(
srcs, masks, input_query_bbox, self.poss, input_query_label, attn_mask, text_dict
)
# deformable-detr-like anchor update
outputs_coord_list = []
for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(
zip(reference[:-1], self.bbox_embed, hs)
):
layer_delta_unsig = layer_bbox_embed(layer_hs)
layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
layer_outputs_unsig = layer_outputs_unsig.sigmoid()
outputs_coord_list.append(layer_outputs_unsig)
outputs_coord_list = torch.stack(outputs_coord_list)
# output
outputs_class = torch.stack(
[
layer_cls_embed(layer_hs, text_dict)
for layer_cls_embed, layer_hs in zip(self.class_embed, hs)
]
)
out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord_list[-1]}
# # for intermediate outputs
# if self.aux_loss:
# out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord_list)
# # for encoder output
# if hs_enc is not None:
# # prepare intermediate outputs
# interm_coord = ref_enc[-1]
# interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict)
# out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord}
# out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal}
#-------------------------2---------------------------------
unset_image_tensor = kw.get('unset_image_tensor', True)
if unset_image_tensor:
self.unset_image_tensor() ## If necessary
#-----------------------------------------------------------
return out
@torch.jit.unused
def _set_aux_loss(self, outputs_class, outputs_coord):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return [
{"pred_logits": a, "pred_boxes": b}
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
]
@MODULE_BUILD_FUNCS.registe_with_name(module_name="groundingdino")
def build_groundingdino(args):
backbone = build_backbone(args)
transformer = build_transformer(args)
dn_labelbook_size = args.dn_labelbook_size
dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share
sub_sentence_present = args.sub_sentence_present
model = GroundingDINO(
backbone,
transformer,
num_queries=args.num_queries,
aux_loss=True,
iter_update=True,
query_dim=4,
num_feature_levels=args.num_feature_levels,
nheads=args.nheads,
dec_pred_bbox_embed_share=dec_pred_bbox_embed_share,
two_stage_type=args.two_stage_type,
two_stage_bbox_embed_share=args.two_stage_bbox_embed_share,
two_stage_class_embed_share=args.two_stage_class_embed_share,
num_patterns=args.num_patterns,
dn_number=0,
dn_box_noise_scale=args.dn_box_noise_scale,
dn_label_noise_ratio=args.dn_label_noise_ratio,
dn_labelbook_size=dn_labelbook_size,
text_encoder_type=args.text_encoder_type,
sub_sentence_present=sub_sentence_present,
max_text_len=args.max_text_len,
)
return model
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from:
# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/functions/ms_deform_attn_func.py
# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py
# https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/multi_scale_deform_attn.py
# ------------------------------------------------------------------------------------------------
import math
import warnings
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.init import constant_, xavier_uniform_
try:
from groundingdino import _C
except:
warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only!")
# helpers
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
return (n & (n - 1) == 0) and n != 0
class MultiScaleDeformableAttnFunction(Function):
@staticmethod
def forward(
ctx,
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
im2col_step,
):
ctx.im2col_step = im2col_step
output = _C.ms_deform_attn_forward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
ctx.im2col_step,
)
ctx.save_for_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
) = ctx.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = _C.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output,
ctx.im2col_step,
)
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
def multi_scale_deformable_attn_pytorch(
value: torch.Tensor,
value_spatial_shapes: torch.Tensor,
sampling_locations: torch.Tensor,
attention_weights: torch.Tensor,
) -> torch.Tensor:
bs, _, num_heads, embed_dims = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level, (H_, W_) in enumerate(value_spatial_shapes):
# bs, H_*W_, num_heads, embed_dims ->
# bs, H_*W_, num_heads*embed_dims ->
# bs, num_heads*embed_dims, H_*W_ ->
# bs*num_heads, embed_dims, H_, W_
value_l_ = (
value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)
)
# bs, num_queries, num_heads, num_points, 2 ->
# bs, num_heads, num_queries, num_points, 2 ->
# bs*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
# bs*num_heads, embed_dims, num_queries, num_points
sampling_value_l_ = F.grid_sample(
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
)
sampling_value_list.append(sampling_value_l_)
# (bs, num_queries, num_heads, num_levels, num_points) ->
# (bs, num_heads, num_queries, num_levels, num_points) ->
# (bs, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(
bs * num_heads, 1, num_queries, num_levels * num_points
)
output = (
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
.sum(-1)
.view(bs, num_heads * embed_dims, num_queries)
)
return output.transpose(1, 2).contiguous()
class MultiScaleDeformableAttention(nn.Module):
"""Multi-Scale Deformable Attention Module used in Deformable-DETR
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dim (int): The embedding dimension of Attention. Default: 256.
num_heads (int): The number of attention heads. Default: 8.
num_levels (int): The number of feature map used in Attention. Default: 4.
num_points (int): The number of sampling points for each query
in each head. Default: 4.
img2col_steps (int): The step used in image_to_column. Defualt: 64.
dropout (float): Dropout layer used in output. Default: 0.1.
batch_first (bool): if ``True``, then the input and output tensor will be
provided as `(bs, n, embed_dim)`. Default: False. `(n, bs, embed_dim)`
"""
def __init__(
self,
embed_dim: int = 256,
num_heads: int = 8,
num_levels: int = 4,
num_points: int = 4,
img2col_step: int = 64,
batch_first: bool = False,
):
super().__init__()
if embed_dim % num_heads != 0:
raise ValueError(
"embed_dim must be divisible by num_heads, but got {} and {}".format(
embed_dim, num_heads
)
)
head_dim = embed_dim // num_heads
self.batch_first = batch_first
if not _is_power_of_2(head_dim):
warnings.warn(
"""
You'd better set d_model in MSDeformAttn to make sure that
each dim of the attention head a power of 2, which is more efficient.
"""
)
self.im2col_step = img2col_step
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_levels = num_levels
self.num_points = num_points
self.sampling_offsets = nn.Linear(embed_dim, num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dim, num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.output_proj = nn.Linear(embed_dim, embed_dim)
self.init_weights()
def _reset_parameters(self):
return self.init_weights()
def init_weights(self):
"""
Default initialization for Parameters of Module.
"""
constant_(self.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(self.num_heads, dtype=torch.float32) * (
2.0 * math.pi / self.num_heads
)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
.view(self.num_heads, 1, 1, 2)
.repeat(1, self.num_levels, self.num_points, 1)
)
for i in range(self.num_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
constant_(self.attention_weights.weight.data, 0.0)
constant_(self.attention_weights.bias.data, 0.0)
xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.0)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.0)
def freeze_sampling_offsets(self):
print("Freeze sampling offsets")
self.sampling_offsets.weight.requires_grad = False
self.sampling_offsets.bias.requires_grad = False
def freeze_attention_weights(self):
print("Freeze attention weights")
self.attention_weights.weight.requires_grad = False
self.attention_weights.bias.requires_grad = False
def forward(
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
query_pos: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
reference_points: Optional[torch.Tensor] = None,
spatial_shapes: Optional[torch.Tensor] = None,
level_start_index: Optional[torch.Tensor] = None,
**kwargs
) -> torch.Tensor:
"""Forward Function of MultiScaleDeformableAttention
Args:
query (torch.Tensor): Query embeddings with shape
`(num_query, bs, embed_dim)`
key (torch.Tensor): Key embeddings with shape
`(num_key, bs, embed_dim)`
value (torch.Tensor): Value embeddings with shape
`(num_key, bs, embed_dim)`
query_pos (torch.Tensor): The position embedding for `query`. Default: None.
key_padding_mask (torch.Tensor): ByteTensor for `query`, with shape `(bs, num_key)`,
indicating which elements within `key` to be ignored in attention.
reference_points (torch.Tensor): The normalized reference points
with shape `(bs, num_query, num_levels, 2)`,
all elements is range in [0, 1], top-left (0, 0),
bottom-right (1, 1), including padding are.
or `(N, Length_{query}, num_levels, 4)`, add additional
two dimensions `(h, w)` to form reference boxes.
spatial_shapes (torch.Tensor): Spatial shape of features in different levels.
With shape `(num_levels, 2)`, last dimension represents `(h, w)`.
level_start_index (torch.Tensor): The start index of each level. A tensor with
shape `(num_levels, )` which can be represented as
`[0, h_0 * w_0, h_0 * w_0 + h_1 * w_1, ...]`.
Returns:
torch.Tensor: forward results with shape `(num_query, bs, embed_dim)`
"""
if value is None:
value = query
if query_pos is not None:
query = query + query_pos
if not self.batch_first:
# change to (bs, num_query ,embed_dims)
query = query.permute(1, 0, 2)
value = value.permute(1, 0, 2)
bs, num_query, _ = query.shape
bs, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
value = self.value_proj(value)
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], float(0))
value = value.view(bs, num_value, self.num_heads, -1)
sampling_offsets = self.sampling_offsets(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2
)
attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_levels * self.num_points
)
attention_weights = attention_weights.softmax(-1)
attention_weights = attention_weights.view(
bs,
num_query,
self.num_heads,
self.num_levels,
self.num_points,
)
# bs, num_query, num_heads, num_levels, num_points, 2
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = (
reference_points[:, :, None, :, None, :]
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
)
elif reference_points.shape[-1] == 4:
sampling_locations = (
reference_points[:, :, None, :, None, :2]
+ sampling_offsets
/ self.num_points
* reference_points[:, :, None, :, None, 2:]
* 0.5
)
else:
raise ValueError(
"Last dim of reference_points must be 2 or 4, but get {} instead.".format(
reference_points.shape[-1]
)
)
if torch.cuda.is_available() and value.is_cuda:
halffloat = False
if value.dtype == torch.float16:
halffloat = True
value = value.float()
sampling_locations = sampling_locations.float()
attention_weights = attention_weights.float()
output = MultiScaleDeformableAttnFunction.apply(
value,
spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
if halffloat:
output = output.half()
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights
)
output = self.output_proj(output)
if not self.batch_first:
output = output.permute(1, 0, 2)
return output
def create_dummy_class(klass, dependency, message=""):
"""
When a dependency of a class is not available, create a dummy class which throws ImportError
when used.
Args:
klass (str): name of the class.
dependency (str): name of the dependency.
message: extra message to print
Returns:
class: a class object
"""
err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, klass)
if message:
err = err + " " + message
class _DummyMetaClass(type):
# throw error on class attribute access
def __getattr__(_, __): # noqa: B902
raise ImportError(err)
class _Dummy(object, metaclass=_DummyMetaClass):
# throw error on constructor
def __init__(self, *args, **kwargs):
raise ImportError(err)
return _Dummy
def create_dummy_func(func, dependency, message=""):
"""
When a dependency of a function is not available, create a dummy function which throws
ImportError when used.
Args:
func (str): name of the function.
dependency (str or list[str]): name(s) of the dependency.
message: extra message to print
Returns:
function: a function object
"""
err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, func)
if message:
err = err + " " + message
if isinstance(dependency, (list, tuple)):
dependency = ",".join(dependency)
def _dummy(*args, **kwargs):
raise ImportError(err)
return _dummy
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from:
# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/functions/ms_deform_attn_func.py
# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py
# https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/multi_scale_deform_attn.py
# ------------------------------------------------------------------------------------------------
import math
import warnings
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.init import constant_, xavier_uniform_
try:
from groundingdino import _C
except:
warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only!")
# helpers
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
return (n & (n - 1) == 0) and n != 0
class MultiScaleDeformableAttnFunction(Function):
@staticmethod
def forward(
ctx,
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
im2col_step,
):
ctx.im2col_step = im2col_step
output = _C.ms_deform_attn_forward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
ctx.im2col_step,
)
ctx.save_for_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
) = ctx.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = _C.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output,
ctx.im2col_step,
)
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
# =================================================================
# 新增这个 symbolic 方法,用于 ONNX 导出
# =================================================================
@staticmethod
def symbolic(
g,
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
im2col_step,
):
"""
当调用 torch.onnx.export 时,遇到此 Function 会执行这个逻辑。
"""
# g.op 表示在 ONNX 图中插入一个节点
# "custom::ms_deform_attn" 是我们给它起的唯一名字 (Domain::OpName)
# im2col_step 是一个标量整数,在 ONNX 中作为 attribute 传入,后缀 _i 表示 int
return g.op(
"custom::ms_deform_attn",
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
im2col_step_i=im2col_step,
)
def multi_scale_deformable_attn_pytorch(
value: torch.Tensor,
value_spatial_shapes: torch.Tensor,
sampling_locations: torch.Tensor,
attention_weights: torch.Tensor,
) -> torch.Tensor:
bs, _, num_heads, embed_dims = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level, (H_, W_) in enumerate(value_spatial_shapes):
# bs, H_*W_, num_heads, embed_dims ->
# bs, H_*W_, num_heads*embed_dims ->
# bs, num_heads*embed_dims, H_*W_ ->
# bs*num_heads, embed_dims, H_, W_
value_l_ = (
value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)
)
# bs, num_queries, num_heads, num_points, 2 ->
# bs, num_heads, num_queries, num_points, 2 ->
# bs*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
# bs*num_heads, embed_dims, num_queries, num_points
sampling_value_l_ = F.grid_sample(
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
)
sampling_value_list.append(sampling_value_l_)
# (bs, num_queries, num_heads, num_levels, num_points) ->
# (bs, num_heads, num_queries, num_levels, num_points) ->
# (bs, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(
bs * num_heads, 1, num_queries, num_levels * num_points
)
output = (
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
.sum(-1)
.view(bs, num_heads * embed_dims, num_queries)
)
return output.transpose(1, 2).contiguous()
class MultiScaleDeformableAttention(nn.Module):
"""Multi-Scale Deformable Attention Module used in Deformable-DETR
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dim (int): The embedding dimension of Attention. Default: 256.
num_heads (int): The number of attention heads. Default: 8.
num_levels (int): The number of feature map used in Attention. Default: 4.
num_points (int): The number of sampling points for each query
in each head. Default: 4.
img2col_steps (int): The step used in image_to_column. Defualt: 64.
dropout (float): Dropout layer used in output. Default: 0.1.
batch_first (bool): if ``True``, then the input and output tensor will be
provided as `(bs, n, embed_dim)`. Default: False. `(n, bs, embed_dim)`
"""
def __init__(
self,
embed_dim: int = 256,
num_heads: int = 8,
num_levels: int = 4,
num_points: int = 4,
img2col_step: int = 64,
batch_first: bool = False,
):
super().__init__()
if embed_dim % num_heads != 0:
raise ValueError(
"embed_dim must be divisible by num_heads, but got {} and {}".format(
embed_dim, num_heads
)
)
head_dim = embed_dim // num_heads
self.batch_first = batch_first
if not _is_power_of_2(head_dim):
warnings.warn(
"""
You'd better set d_model in MSDeformAttn to make sure that
each dim of the attention head a power of 2, which is more efficient.
"""
)
self.im2col_step = img2col_step
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_levels = num_levels
self.num_points = num_points
self.sampling_offsets = nn.Linear(embed_dim, num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dim, num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.output_proj = nn.Linear(embed_dim, embed_dim)
self.init_weights()
def _reset_parameters(self):
return self.init_weights()
def init_weights(self):
"""
Default initialization for Parameters of Module.
"""
constant_(self.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(self.num_heads, dtype=torch.float32) * (
2.0 * math.pi / self.num_heads
)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
.view(self.num_heads, 1, 1, 2)
.repeat(1, self.num_levels, self.num_points, 1)
)
for i in range(self.num_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
constant_(self.attention_weights.weight.data, 0.0)
constant_(self.attention_weights.bias.data, 0.0)
xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.0)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.0)
def freeze_sampling_offsets(self):
print("Freeze sampling offsets")
self.sampling_offsets.weight.requires_grad = False
self.sampling_offsets.bias.requires_grad = False
def freeze_attention_weights(self):
print("Freeze attention weights")
self.attention_weights.weight.requires_grad = False
self.attention_weights.bias.requires_grad = False
def forward(
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
query_pos: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
reference_points: Optional[torch.Tensor] = None,
spatial_shapes: Optional[torch.Tensor] = None,
level_start_index: Optional[torch.Tensor] = None,
**kwargs
) -> torch.Tensor:
"""Forward Function of MultiScaleDeformableAttention
Args:
query (torch.Tensor): Query embeddings with shape
`(num_query, bs, embed_dim)`
key (torch.Tensor): Key embeddings with shape
`(num_key, bs, embed_dim)`
value (torch.Tensor): Value embeddings with shape
`(num_key, bs, embed_dim)`
query_pos (torch.Tensor): The position embedding for `query`. Default: None.
key_padding_mask (torch.Tensor): ByteTensor for `query`, with shape `(bs, num_key)`,
indicating which elements within `key` to be ignored in attention.
reference_points (torch.Tensor): The normalized reference points
with shape `(bs, num_query, num_levels, 2)`,
all elements is range in [0, 1], top-left (0, 0),
bottom-right (1, 1), including padding are.
or `(N, Length_{query}, num_levels, 4)`, add additional
two dimensions `(h, w)` to form reference boxes.
spatial_shapes (torch.Tensor): Spatial shape of features in different levels.
With shape `(num_levels, 2)`, last dimension represents `(h, w)`.
level_start_index (torch.Tensor): The start index of each level. A tensor with
shape `(num_levels, )` which can be represented as
`[0, h_0 * w_0, h_0 * w_0 + h_1 * w_1, ...]`.
Returns:
torch.Tensor: forward results with shape `(num_query, bs, embed_dim)`
"""
if value is None:
value = query
if query_pos is not None:
query = query + query_pos
if not self.batch_first:
# change to (bs, num_query ,embed_dims)
query = query.permute(1, 0, 2)
value = value.permute(1, 0, 2)
bs, num_query, _ = query.shape
bs, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
value = self.value_proj(value)
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], float(0))
value = value.view(bs, num_value, self.num_heads, -1)
sampling_offsets = self.sampling_offsets(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2
)
attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_levels * self.num_points
)
attention_weights = attention_weights.softmax(-1)
attention_weights = attention_weights.view(
bs,
num_query,
self.num_heads,
self.num_levels,
self.num_points,
)
# bs, num_query, num_heads, num_levels, num_points, 2
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = (
reference_points[:, :, None, :, None, :]
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
)
elif reference_points.shape[-1] == 4:
sampling_locations = (
reference_points[:, :, None, :, None, :2]
+ sampling_offsets
/ self.num_points
* reference_points[:, :, None, :, None, 2:]
* 0.5
)
else:
raise ValueError(
"Last dim of reference_points must be 2 or 4, but get {} instead.".format(
reference_points.shape[-1]
)
)
if torch.cuda.is_available() and value.is_cuda:
halffloat = False
if value.dtype == torch.float16:
halffloat = True
value = value.float()
sampling_locations = sampling_locations.float()
attention_weights = attention_weights.float()
output = MultiScaleDeformableAttnFunction.apply(
value,
spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
if halffloat:
output = output.half()
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights
)
output = self.output_proj(output)
if not self.batch_first:
output = output.permute(1, 0, 2)
return output
def create_dummy_class(klass, dependency, message=""):
"""
When a dependency of a class is not available, create a dummy class which throws ImportError
when used.
Args:
klass (str): name of the class.
dependency (str): name of the dependency.
message: extra message to print
Returns:
class: a class object
"""
err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, klass)
if message:
err = err + " " + message
class _DummyMetaClass(type):
# throw error on class attribute access
def __getattr__(_, __): # noqa: B902
raise ImportError(err)
class _Dummy(object, metaclass=_DummyMetaClass):
# throw error on constructor
def __init__(self, *args, **kwargs):
raise ImportError(err)
return _Dummy
def create_dummy_func(func, dependency, message=""):
"""
When a dependency of a function is not available, create a dummy function which throws
ImportError when used.
Args:
func (str): name of the function.
dependency (str or list[str]): name(s) of the dependency.
message: extra message to print
Returns:
function: a function object
"""
err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, func)
if message:
err = err + " " + message
if isinstance(dependency, (list, tuple)):
dependency = ",".join(dependency)
def _dummy(*args, **kwargs):
raise ImportError(err)
return _dummy
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# DINO
# Copyright (c) 2022 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR Transformer class.
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
from typing import Optional
import torch
import torch.utils.checkpoint as checkpoint
from torch import Tensor, nn
from groundingdino.util.misc import inverse_sigmoid
from .fuse_modules import BiAttentionBlock
from .ms_deform_attn import MultiScaleDeformableAttention as MSDeformAttn
from .transformer_vanilla import TransformerEncoderLayer
from .utils import (
MLP,
_get_activation_fn,
_get_clones,
gen_encoder_output_proposals,
gen_sineembed_for_position,
get_sine_pos_embed,
)
class Transformer(nn.Module):
def __init__(
self,
d_model=256,
nhead=8,
num_queries=300,
num_encoder_layers=6,
num_unicoder_layers=0,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.0,
activation="relu",
normalize_before=False,
return_intermediate_dec=False,
query_dim=4,
num_patterns=0,
# for deformable encoder
num_feature_levels=1,
enc_n_points=4,
dec_n_points=4,
# init query
learnable_tgt_init=False,
# two stage
two_stage_type="no", # ['no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1']
embed_init_tgt=False,
# for text
use_text_enhancer=False,
use_fusion_layer=False,
use_checkpoint=False,
use_transformer_ckpt=False,
use_text_cross_attention=False,
text_dropout=0.1,
fusion_dropout=0.1,
fusion_droppath=0.0,
):
super().__init__()
self.num_feature_levels = num_feature_levels
self.num_encoder_layers = num_encoder_layers
self.num_unicoder_layers = num_unicoder_layers
self.num_decoder_layers = num_decoder_layers
self.num_queries = num_queries
assert query_dim == 4
# choose encoder layer type
encoder_layer = DeformableTransformerEncoderLayer(
d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points
)
if use_text_enhancer:
text_enhance_layer = TransformerEncoderLayer(
d_model=d_model,
nhead=nhead // 2,
dim_feedforward=dim_feedforward // 2,
dropout=text_dropout,
)
else:
text_enhance_layer = None
if use_fusion_layer:
feature_fusion_layer = BiAttentionBlock(
v_dim=d_model,
l_dim=d_model,
embed_dim=dim_feedforward // 2,
num_heads=nhead // 2,
dropout=fusion_dropout,
drop_path=fusion_droppath,
)
else:
feature_fusion_layer = None
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
assert encoder_norm is None
self.encoder = TransformerEncoder(
encoder_layer,
num_encoder_layers,
d_model=d_model,
num_queries=num_queries,
text_enhance_layer=text_enhance_layer,
feature_fusion_layer=feature_fusion_layer,
use_checkpoint=use_checkpoint,
use_transformer_ckpt=use_transformer_ckpt,
)
# choose decoder layer type
decoder_layer = DeformableTransformerDecoderLayer(
d_model,
dim_feedforward,
dropout,
activation,
num_feature_levels,
nhead,
dec_n_points,
use_text_cross_attention=use_text_cross_attention,
)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoder(
decoder_layer,
num_decoder_layers,
decoder_norm,
return_intermediate=return_intermediate_dec,
d_model=d_model,
query_dim=query_dim,
num_feature_levels=num_feature_levels,
)
self.d_model = d_model
self.nhead = nhead
self.dec_layers = num_decoder_layers
self.num_queries = num_queries # useful for single stage model only
self.num_patterns = num_patterns
if not isinstance(num_patterns, int):
Warning("num_patterns should be int but {}".format(type(num_patterns)))
self.num_patterns = 0
if num_feature_levels > 1:
if self.num_encoder_layers > 0:
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
else:
self.level_embed = None
self.learnable_tgt_init = learnable_tgt_init
assert learnable_tgt_init, "why not learnable_tgt_init"
self.embed_init_tgt = embed_init_tgt
if (two_stage_type != "no" and embed_init_tgt) or (two_stage_type == "no"):
self.tgt_embed = nn.Embedding(self.num_queries, d_model)
nn.init.normal_(self.tgt_embed.weight.data)
else:
self.tgt_embed = None
# for two stage
self.two_stage_type = two_stage_type
assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format(
two_stage_type
)
if two_stage_type == "standard":
# anchor selection at the output of encoder
self.enc_output = nn.Linear(d_model, d_model)
self.enc_output_norm = nn.LayerNorm(d_model)
self.two_stage_wh_embedding = None
if two_stage_type == "no":
self.init_ref_points(num_queries) # init self.refpoint_embed
self.enc_out_class_embed = None
self.enc_out_bbox_embed = None
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MSDeformAttn):
m._reset_parameters()
if self.num_feature_levels > 1 and self.level_embed is not None:
nn.init.normal_(self.level_embed)
def get_valid_ratio(self, mask):
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def init_ref_points(self, use_num_queries):
self.refpoint_embed = nn.Embedding(use_num_queries, 4)
def forward(self, srcs, masks, refpoint_embed, pos_embeds, tgt, attn_mask=None, text_dict=None):
"""
Input:
- srcs: List of multi features [bs, ci, hi, wi]
- masks: List of multi masks [bs, hi, wi]
- refpoint_embed: [bs, num_dn, 4]. None in infer
- pos_embeds: List of multi pos embeds [bs, ci, hi, wi]
- tgt: [bs, num_dn, d_model]. None in infer
"""
# prepare input for encoder
src_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
bs, c, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
src = src.flatten(2).transpose(1, 2) # bs, hw, c
mask = mask.flatten(1) # bs, hw
pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
if self.num_feature_levels > 1 and self.level_embed is not None:
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
else:
lvl_pos_embed = pos_embed
lvl_pos_embed_flatten.append(lvl_pos_embed)
src_flatten.append(src)
mask_flatten.append(mask)
src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=src_flatten.device
)
level_start_index = torch.cat(
(spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])
)
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
# two stage
enc_topk_proposals = enc_refpoint_embed = None
#########################################################
# Begin Encoder
#########################################################
memory, memory_text = self.encoder(
src_flatten,
pos=lvl_pos_embed_flatten,
level_start_index=level_start_index,
spatial_shapes=spatial_shapes,
valid_ratios=valid_ratios,
key_padding_mask=mask_flatten,
memory_text=text_dict["encoded_text"],
text_attention_mask=~text_dict["text_token_mask"],
# we ~ the mask . False means use the token; True means pad the token
position_ids=text_dict["position_ids"],
text_self_attention_masks=text_dict["text_self_attention_masks"],
)
#########################################################
# End Encoder
# - memory: bs, \sum{hw}, c
# - mask_flatten: bs, \sum{hw}
# - lvl_pos_embed_flatten: bs, \sum{hw}, c
# - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
# - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
#########################################################
text_dict["encoded_text"] = memory_text
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
# if memory.isnan().any() | memory.isinf().any():
# import ipdb; ipdb.set_trace()
if self.two_stage_type == "standard":
output_memory, output_proposals = gen_encoder_output_proposals(
memory, mask_flatten, spatial_shapes
)
output_memory = self.enc_output_norm(self.enc_output(output_memory))
if text_dict is not None:
enc_outputs_class_unselected = self.enc_out_class_embed(output_memory, text_dict)
else:
enc_outputs_class_unselected = self.enc_out_class_embed(output_memory)
topk_logits = enc_outputs_class_unselected.max(-1)[0]
enc_outputs_coord_unselected = (
self.enc_out_bbox_embed(output_memory) + output_proposals
) # (bs, \sum{hw}, 4) unsigmoid
topk = self.num_queries
topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq
# gather boxes
refpoint_embed_undetach = torch.gather(
enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
) # unsigmoid
refpoint_embed_ = refpoint_embed_undetach.detach()
init_box_proposal = torch.gather(
output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
).sigmoid() # sigmoid
# gather tgt
tgt_undetach = torch.gather(
output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model)
)
if self.embed_init_tgt:
tgt_ = (
self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
) # nq, bs, d_model
else:
tgt_ = tgt_undetach.detach()
if refpoint_embed is not None:
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1)
tgt = torch.cat([tgt, tgt_], dim=1)
else:
refpoint_embed, tgt = refpoint_embed_, tgt_
elif self.two_stage_type == "no":
tgt_ = (
self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
) # nq, bs, d_model
refpoint_embed_ = (
self.refpoint_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
) # nq, bs, 4
if refpoint_embed is not None:
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1)
tgt = torch.cat([tgt, tgt_], dim=1)
else:
refpoint_embed, tgt = refpoint_embed_, tgt_
if self.num_patterns > 0:
tgt_embed = tgt.repeat(1, self.num_patterns, 1)
refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1)
tgt_pat = self.patterns.weight[None, :, :].repeat_interleave(
self.num_queries, 1
) # 1, n_q*n_pat, d_model
tgt = tgt_embed + tgt_pat
init_box_proposal = refpoint_embed_.sigmoid()
else:
raise NotImplementedError("unknown two_stage_type {}".format(self.two_stage_type))
#########################################################
# End preparing tgt
# - tgt: bs, NQ, d_model
# - refpoint_embed(unsigmoid): bs, NQ, d_model
#########################################################
#########################################################
# Begin Decoder
#########################################################
hs, references = self.decoder(
tgt=tgt.transpose(0, 1),
memory=memory.transpose(0, 1),
memory_key_padding_mask=mask_flatten,
pos=lvl_pos_embed_flatten.transpose(0, 1),
refpoints_unsigmoid=refpoint_embed.transpose(0, 1),
level_start_index=level_start_index,
spatial_shapes=spatial_shapes,
valid_ratios=valid_ratios,
tgt_mask=attn_mask,
memory_text=text_dict["encoded_text"],
text_attention_mask=~text_dict["text_token_mask"],
# we ~ the mask . False means use the token; True means pad the token
)
#########################################################
# End Decoder
# hs: n_dec, bs, nq, d_model
# references: n_dec+1, bs, nq, query_dim
#########################################################
#########################################################
# Begin postprocess
#########################################################
if self.two_stage_type == "standard":
hs_enc = tgt_undetach.unsqueeze(0)
ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0)
else:
hs_enc = ref_enc = None
#########################################################
# End postprocess
# hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or (n_enc, bs, nq, d_model) or None
# ref_enc: (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or (n_enc, bs, nq, d_model) or None
#########################################################
return hs, references, hs_enc, ref_enc, init_box_proposal
# hs: (n_dec, bs, nq, d_model)
# references: sigmoid coordinates. (n_dec+1, bs, bq, 4)
# hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or None
# ref_enc: sigmoid coordinates. \
# (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or None
class TransformerEncoder(nn.Module):
def __init__(
self,
encoder_layer,
num_layers,
d_model=256,
num_queries=300,
enc_layer_share=False,
text_enhance_layer=None,
feature_fusion_layer=None,
use_checkpoint=False,
use_transformer_ckpt=False,
):
"""_summary_
Args:
encoder_layer (_type_): _description_
num_layers (_type_): _description_
norm (_type_, optional): _description_. Defaults to None.
d_model (int, optional): _description_. Defaults to 256.
num_queries (int, optional): _description_. Defaults to 300.
enc_layer_share (bool, optional): _description_. Defaults to False.
"""
super().__init__()
# prepare layers
self.layers = []
self.text_layers = []
self.fusion_layers = []
if num_layers > 0:
self.layers = _get_clones(encoder_layer, num_layers, layer_share=enc_layer_share)
if text_enhance_layer is not None:
self.text_layers = _get_clones(
text_enhance_layer, num_layers, layer_share=enc_layer_share
)
if feature_fusion_layer is not None:
self.fusion_layers = _get_clones(
feature_fusion_layer, num_layers, layer_share=enc_layer_share
)
else:
self.layers = []
del encoder_layer
if text_enhance_layer is not None:
self.text_layers = []
del text_enhance_layer
if feature_fusion_layer is not None:
self.fusion_layers = []
del feature_fusion_layer
self.query_scale = None
self.num_queries = num_queries
self.num_layers = num_layers
self.d_model = d_model
self.use_checkpoint = use_checkpoint
self.use_transformer_ckpt = use_transformer_ckpt
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
)
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def forward(
self,
# for images
src: Tensor,
pos: Tensor,
spatial_shapes: Tensor,
level_start_index: Tensor,
valid_ratios: Tensor,
key_padding_mask: Tensor,
# for texts
memory_text: Tensor = None,
text_attention_mask: Tensor = None,
pos_text: Tensor = None,
text_self_attention_masks: Tensor = None,
position_ids: Tensor = None,
):
"""
Input:
- src: [bs, sum(hi*wi), 256]
- pos: pos embed for src. [bs, sum(hi*wi), 256]
- spatial_shapes: h,w of each level [num_level, 2]
- level_start_index: [num_level] start point of level in sum(hi*wi).
- valid_ratios: [bs, num_level, 2]
- key_padding_mask: [bs, sum(hi*wi)]
- memory_text: bs, n_text, 256
- text_attention_mask: bs, n_text
False for no padding; True for padding
- pos_text: bs, n_text, 256
- position_ids: bs, n_text
Intermedia:
- reference_points: [bs, sum(hi*wi), num_level, 2]
Outpus:
- output: [bs, sum(hi*wi), 256]
"""
output = src
# preparation and reshape
if self.num_layers > 0:
reference_points = self.get_reference_points(
spatial_shapes, valid_ratios, device=src.device
)
if self.text_layers:
# generate pos_text
bs, n_text, text_dim = memory_text.shape
if pos_text is None and position_ids is None:
pos_text = (
torch.arange(n_text, device=memory_text.device)
.float()
.unsqueeze(0)
.unsqueeze(-1)
.repeat(bs, 1, 1)
)
pos_text = get_sine_pos_embed(pos_text, num_pos_feats=256, exchange_xy=False)
if position_ids is not None:
pos_text = get_sine_pos_embed(
position_ids[..., None], num_pos_feats=256, exchange_xy=False
)
# main process
for layer_id, layer in enumerate(self.layers):
# if output.isnan().any() or memory_text.isnan().any():
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
# import ipdb; ipdb.set_trace()
if self.fusion_layers:
if self.use_checkpoint:
output, memory_text = checkpoint.checkpoint(
self.fusion_layers[layer_id],
output,
memory_text,
key_padding_mask,
text_attention_mask,
)
else:
output, memory_text = self.fusion_layers[layer_id](
v=output,
l=memory_text,
attention_mask_v=key_padding_mask,
attention_mask_l=text_attention_mask,
)
if self.text_layers:
memory_text = self.text_layers[layer_id](
src=memory_text.transpose(0, 1),
src_mask=~text_self_attention_masks, # note we use ~ for mask here
src_key_padding_mask=text_attention_mask,
pos=(pos_text.transpose(0, 1) if pos_text is not None else None),
).transpose(0, 1)
# main process
if self.use_transformer_ckpt:
output = checkpoint.checkpoint(
layer,
output,
pos,
reference_points,
spatial_shapes,
level_start_index,
key_padding_mask,
)
else:
output = layer(
src=output,
pos=pos,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
key_padding_mask=key_padding_mask,
)
return output, memory_text
class TransformerDecoder(nn.Module):
def __init__(
self,
decoder_layer,
num_layers,
norm=None,
return_intermediate=False,
d_model=256,
query_dim=4,
num_feature_levels=1,
):
super().__init__()
if num_layers > 0:
self.layers = _get_clones(decoder_layer, num_layers)
else:
self.layers = []
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
assert return_intermediate, "support return_intermediate only"
self.query_dim = query_dim
assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim)
self.num_feature_levels = num_feature_levels
self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)
self.query_pos_sine_scale = None
self.query_scale = None
self.bbox_embed = None
self.class_embed = None
self.d_model = d_model
self.ref_anchor_head = None
def forward(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2
# for memory
level_start_index: Optional[Tensor] = None, # num_levels
spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
valid_ratios: Optional[Tensor] = None,
# for text
memory_text: Optional[Tensor] = None,
text_attention_mask: Optional[Tensor] = None,
):
"""
Input:
- tgt: nq, bs, d_model
- memory: hw, bs, d_model
- pos: hw, bs, d_model
- refpoints_unsigmoid: nq, bs, 2/4
- valid_ratios/spatial_shapes: bs, nlevel, 2
"""
output = tgt
intermediate = []
reference_points = refpoints_unsigmoid.sigmoid()
ref_points = [reference_points]
for layer_id, layer in enumerate(self.layers):
if reference_points.shape[-1] == 4:
reference_points_input = (
reference_points[:, :, None]
* torch.cat([valid_ratios, valid_ratios], -1)[None, :]
) # nq, bs, nlevel, 4
else:
assert reference_points.shape[-1] == 2
reference_points_input = reference_points[:, :, None] * valid_ratios[None, :]
query_sine_embed = gen_sineembed_for_position(
reference_points_input[:, :, 0, :]
) # nq, bs, 256*2
# conditional query
raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256
pos_scale = self.query_scale(output) if self.query_scale is not None else 1
query_pos = pos_scale * raw_query_pos
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
# if query_pos.isnan().any() | query_pos.isinf().any():
# import ipdb; ipdb.set_trace()
# main process
output = layer(
tgt=output,
tgt_query_pos=query_pos,
tgt_query_sine_embed=query_sine_embed,
tgt_key_padding_mask=tgt_key_padding_mask,
tgt_reference_points=reference_points_input,
memory_text=memory_text,
text_attention_mask=text_attention_mask,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
memory_level_start_index=level_start_index,
memory_spatial_shapes=spatial_shapes,
memory_pos=pos,
self_attn_mask=tgt_mask,
cross_attn_mask=memory_mask,
)
if output.isnan().any() | output.isinf().any():
print(f"output layer_id {layer_id} is nan")
try:
num_nan = output.isnan().sum().item()
num_inf = output.isinf().sum().item()
print(f"num_nan {num_nan}, num_inf {num_inf}")
except Exception as e:
print(e)
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
# import ipdb; ipdb.set_trace()
# iter update
if self.bbox_embed is not None:
# box_holder = self.bbox_embed(output)
# box_holder[..., :self.query_dim] += inverse_sigmoid(reference_points)
# new_reference_points = box_holder[..., :self.query_dim].sigmoid()
reference_before_sigmoid = inverse_sigmoid(reference_points)
delta_unsig = self.bbox_embed[layer_id](output)
outputs_unsig = delta_unsig + reference_before_sigmoid
new_reference_points = outputs_unsig.sigmoid()
reference_points = new_reference_points.detach()
# if layer_id != self.num_layers - 1:
ref_points.append(new_reference_points)
intermediate.append(self.norm(output))
return [
[itm_out.transpose(0, 1) for itm_out in intermediate],
[itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points],
]
class DeformableTransformerEncoderLayer(nn.Module):
def __init__(
self,
d_model=256,
d_ffn=1024,
dropout=0.1,
activation="relu",
n_levels=4,
n_heads=8,
n_points=4,
):
super().__init__()
# self attention
self.self_attn = MSDeformAttn(
embed_dim=d_model,
num_levels=n_levels,
num_heads=n_heads,
num_points=n_points,
batch_first=True,
)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation, d_model=d_ffn)
self.dropout2 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout3 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, src):
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
src = src + self.dropout3(src2)
src = self.norm2(src)
return src
def forward(
self, src, pos, reference_points, spatial_shapes, level_start_index, key_padding_mask=None
):
# self attention
# import ipdb; ipdb.set_trace()
src2 = self.self_attn(
query=self.with_pos_embed(src, pos),
reference_points=reference_points,
value=src,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
key_padding_mask=key_padding_mask,
)
src = src + self.dropout1(src2)
src = self.norm1(src)
# ffn
src = self.forward_ffn(src)
return src
class DeformableTransformerDecoderLayer(nn.Module):
def __init__(
self,
d_model=256,
d_ffn=1024,
dropout=0.1,
activation="relu",
n_levels=4,
n_heads=8,
n_points=4,
use_text_feat_guide=False,
use_text_cross_attention=False,
):
super().__init__()
# cross attention
self.cross_attn = MSDeformAttn(
embed_dim=d_model,
num_levels=n_levels,
num_heads=n_heads,
num_points=n_points,
batch_first=True,
)
self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.norm1 = nn.LayerNorm(d_model)
# cross attention text
if use_text_cross_attention:
self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.catext_norm = nn.LayerNorm(d_model)
# self attention
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.norm2 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation, d_model=d_ffn, batch_dim=1)
self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.norm3 = nn.LayerNorm(d_model)
self.key_aware_proj = None
self.use_text_feat_guide = use_text_feat_guide
assert not use_text_feat_guide
self.use_text_cross_attention = use_text_cross_attention
def rm_self_attn_modules(self):
self.self_attn = None
self.dropout2 = None
self.norm2 = None
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, tgt):
with torch.cuda.amp.autocast(enabled=False):
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout4(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward(
self,
# for tgt
tgt: Optional[Tensor], # nq, bs, d_model
tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
tgt_key_padding_mask: Optional[Tensor] = None,
tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
memory_text: Optional[Tensor] = None, # bs, num_token, d_model
text_attention_mask: Optional[Tensor] = None, # bs, num_token
# for memory
memory: Optional[Tensor] = None, # hw, bs, d_model
memory_key_padding_mask: Optional[Tensor] = None,
memory_level_start_index: Optional[Tensor] = None, # num_levels
memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
memory_pos: Optional[Tensor] = None, # pos for memory
# sa
self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
):
"""
Input:
- tgt/tgt_query_pos: nq, bs, d_model
-
"""
assert cross_attn_mask is None
# self attention
if self.self_attn is not None:
# import ipdb; ipdb.set_trace()
q = k = self.with_pos_embed(tgt, tgt_query_pos)
tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
if self.use_text_cross_attention:
tgt2 = self.ca_text(
self.with_pos_embed(tgt, tgt_query_pos),
memory_text.transpose(0, 1),
memory_text.transpose(0, 1),
key_padding_mask=text_attention_mask,
)[0]
tgt = tgt + self.catext_dropout(tgt2)
tgt = self.catext_norm(tgt)
tgt2 = self.cross_attn(
query=self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
reference_points=tgt_reference_points.transpose(0, 1).contiguous(),
value=memory.transpose(0, 1),
spatial_shapes=memory_spatial_shapes,
level_start_index=memory_level_start_index,
key_padding_mask=memory_key_padding_mask,
).transpose(0, 1)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# ffn
tgt = self.forward_ffn(tgt)
return tgt
def build_transformer(args):
return Transformer(
d_model=args.hidden_dim,
dropout=args.dropout,
nhead=args.nheads,
num_queries=args.num_queries,
dim_feedforward=args.dim_feedforward,
num_encoder_layers=args.enc_layers,
num_decoder_layers=args.dec_layers,
normalize_before=args.pre_norm,
return_intermediate_dec=True,
query_dim=args.query_dim,
activation=args.transformer_activation,
num_patterns=args.num_patterns,
num_feature_levels=args.num_feature_levels,
enc_n_points=args.enc_n_points,
dec_n_points=args.dec_n_points,
learnable_tgt_init=True,
# two stage
two_stage_type=args.two_stage_type, # ['no', 'standard', 'early']
embed_init_tgt=args.embed_init_tgt,
use_text_enhancer=args.use_text_enhancer,
use_fusion_layer=args.use_fusion_layer,
use_checkpoint=args.use_checkpoint,
use_transformer_ckpt=args.use_transformer_ckpt,
use_text_cross_attention=args.use_text_cross_attention,
text_dropout=args.text_dropout,
fusion_dropout=args.fusion_dropout,
fusion_droppath=args.fusion_droppath,
)
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# DINO
# Copyright (c) 2022 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR Transformer class.
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
from typing import Optional
import torch
import torch.utils.checkpoint as checkpoint
from torch import Tensor, nn
from groundingdino.util.misc import inverse_sigmoid
from .fuse_modules import BiAttentionBlock
from .ms_deform_attn import MultiScaleDeformableAttention as MSDeformAttn
from .transformer_vanilla import TransformerEncoderLayer
from .utils import (
MLP,
_get_activation_fn,
_get_clones,
gen_encoder_output_proposals,
gen_sineembed_for_position,
get_sine_pos_embed,
)
class Transformer(nn.Module):
def __init__(
self,
d_model=256,
nhead=8,
num_queries=300,
num_encoder_layers=6,
num_unicoder_layers=0,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.0,
activation="relu",
normalize_before=False,
return_intermediate_dec=False,
query_dim=4,
num_patterns=0,
# for deformable encoder
num_feature_levels=1,
enc_n_points=4,
dec_n_points=4,
# init query
learnable_tgt_init=False,
# two stage
two_stage_type="no", # ['no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1']
embed_init_tgt=False,
# for text
use_text_enhancer=False,
use_fusion_layer=False,
use_checkpoint=False,
use_transformer_ckpt=False,
use_text_cross_attention=False,
text_dropout=0.1,
fusion_dropout=0.1,
fusion_droppath=0.0,
):
super().__init__()
self.num_feature_levels = num_feature_levels
self.num_encoder_layers = num_encoder_layers
self.num_unicoder_layers = num_unicoder_layers
self.num_decoder_layers = num_decoder_layers
self.num_queries = num_queries
assert query_dim == 4
# choose encoder layer type
encoder_layer = DeformableTransformerEncoderLayer(
d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points
)
if use_text_enhancer:
text_enhance_layer = TransformerEncoderLayer(
d_model=d_model,
nhead=nhead // 2,
dim_feedforward=dim_feedforward // 2,
dropout=text_dropout,
)
else:
text_enhance_layer = None
if use_fusion_layer:
feature_fusion_layer = BiAttentionBlock(
v_dim=d_model,
l_dim=d_model,
embed_dim=dim_feedforward // 2,
num_heads=nhead // 2,
dropout=fusion_dropout,
drop_path=fusion_droppath,
)
else:
feature_fusion_layer = None
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
assert encoder_norm is None
self.encoder = TransformerEncoder(
encoder_layer,
num_encoder_layers,
d_model=d_model,
num_queries=num_queries,
text_enhance_layer=text_enhance_layer,
feature_fusion_layer=feature_fusion_layer,
use_checkpoint=use_checkpoint,
use_transformer_ckpt=use_transformer_ckpt,
)
# choose decoder layer type
decoder_layer = DeformableTransformerDecoderLayer(
d_model,
dim_feedforward,
dropout,
activation,
num_feature_levels,
nhead,
dec_n_points,
use_text_cross_attention=use_text_cross_attention,
)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoder(
decoder_layer,
num_decoder_layers,
decoder_norm,
return_intermediate=return_intermediate_dec,
d_model=d_model,
query_dim=query_dim,
num_feature_levels=num_feature_levels,
)
self.d_model = d_model
self.nhead = nhead
self.dec_layers = num_decoder_layers
self.num_queries = num_queries # useful for single stage model only
self.num_patterns = num_patterns
if not isinstance(num_patterns, int):
Warning("num_patterns should be int but {}".format(type(num_patterns)))
self.num_patterns = 0
if num_feature_levels > 1:
if self.num_encoder_layers > 0:
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
else:
self.level_embed = None
self.learnable_tgt_init = learnable_tgt_init
assert learnable_tgt_init, "why not learnable_tgt_init"
self.embed_init_tgt = embed_init_tgt
if (two_stage_type != "no" and embed_init_tgt) or (two_stage_type == "no"):
self.tgt_embed = nn.Embedding(self.num_queries, d_model)
nn.init.normal_(self.tgt_embed.weight.data)
else:
self.tgt_embed = None
# for two stage
self.two_stage_type = two_stage_type
assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format(
two_stage_type
)
if two_stage_type == "standard":
# anchor selection at the output of encoder
self.enc_output = nn.Linear(d_model, d_model)
self.enc_output_norm = nn.LayerNorm(d_model)
self.two_stage_wh_embedding = None
if two_stage_type == "no":
self.init_ref_points(num_queries) # init self.refpoint_embed
self.enc_out_class_embed = None
self.enc_out_bbox_embed = None
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MSDeformAttn):
m._reset_parameters()
if self.num_feature_levels > 1 and self.level_embed is not None:
nn.init.normal_(self.level_embed)
def get_valid_ratio(self, mask):
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def init_ref_points(self, use_num_queries):
self.refpoint_embed = nn.Embedding(use_num_queries, 4)
def forward(self, srcs, masks, refpoint_embed, pos_embeds, tgt, attn_mask=None, text_dict=None):
"""
Input:
- srcs: List of multi features [bs, ci, hi, wi]
- masks: List of multi masks [bs, hi, wi]
- refpoint_embed: [bs, num_dn, 4]. None in infer
- pos_embeds: List of multi pos embeds [bs, ci, hi, wi]
- tgt: [bs, num_dn, d_model]. None in infer
"""
# prepare input for encoder
src_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
bs, c, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
src = src.flatten(2).transpose(1, 2) # bs, hw, c
mask = mask.flatten(1) # bs, hw
pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
if self.num_feature_levels > 1 and self.level_embed is not None:
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
else:
lvl_pos_embed = pos_embed
lvl_pos_embed_flatten.append(lvl_pos_embed)
src_flatten.append(src)
mask_flatten.append(mask)
src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=src_flatten.device
)
level_start_index = torch.cat(
(spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])
)
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
# two stage
enc_topk_proposals = enc_refpoint_embed = None
#########################################################
# Begin Encoder
#########################################################
memory, memory_text = self.encoder(
src_flatten,
pos=lvl_pos_embed_flatten,
level_start_index=level_start_index,
spatial_shapes=spatial_shapes,
valid_ratios=valid_ratios,
key_padding_mask=mask_flatten,
memory_text=text_dict["encoded_text"],
text_attention_mask=~text_dict["text_token_mask"],
# we ~ the mask . False means use the token; True means pad the token
position_ids=text_dict["position_ids"],
text_self_attention_masks=text_dict["text_self_attention_masks"],
)
#########################################################
# End Encoder
# - memory: bs, \sum{hw}, c
# - mask_flatten: bs, \sum{hw}
# - lvl_pos_embed_flatten: bs, \sum{hw}, c
# - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
# - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
#########################################################
text_dict["encoded_text"] = memory_text
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
# if memory.isnan().any() | memory.isinf().any():
# import ipdb; ipdb.set_trace()
if self.two_stage_type == "standard":
output_memory, output_proposals = gen_encoder_output_proposals(
memory, mask_flatten, spatial_shapes
)
output_memory = self.enc_output_norm(self.enc_output(output_memory))
if text_dict is not None:
enc_outputs_class_unselected = self.enc_out_class_embed(output_memory, text_dict)
else:
enc_outputs_class_unselected = self.enc_out_class_embed(output_memory)
topk_logits = enc_outputs_class_unselected.max(-1)[0]
enc_outputs_coord_unselected = (
self.enc_out_bbox_embed(output_memory) + output_proposals
) # (bs, \sum{hw}, 4) unsigmoid
topk = self.num_queries
topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq
# gather boxes
refpoint_embed_undetach = torch.gather(
enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
) # unsigmoid
refpoint_embed_ = refpoint_embed_undetach.detach()
init_box_proposal = torch.gather(
output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
).sigmoid() # sigmoid
# gather tgt
tgt_undetach = torch.gather(
output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model)
)
if self.embed_init_tgt:
tgt_ = (
self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
) # nq, bs, d_model
else:
tgt_ = tgt_undetach.detach()
if refpoint_embed is not None:
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1)
tgt = torch.cat([tgt, tgt_], dim=1)
else:
refpoint_embed, tgt = refpoint_embed_, tgt_
elif self.two_stage_type == "no":
tgt_ = (
self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
) # nq, bs, d_model
refpoint_embed_ = (
self.refpoint_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
) # nq, bs, 4
if refpoint_embed is not None:
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1)
tgt = torch.cat([tgt, tgt_], dim=1)
else:
refpoint_embed, tgt = refpoint_embed_, tgt_
if self.num_patterns > 0:
tgt_embed = tgt.repeat(1, self.num_patterns, 1)
refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1)
tgt_pat = self.patterns.weight[None, :, :].repeat_interleave(
self.num_queries, 1
) # 1, n_q*n_pat, d_model
tgt = tgt_embed + tgt_pat
init_box_proposal = refpoint_embed_.sigmoid()
else:
raise NotImplementedError("unknown two_stage_type {}".format(self.two_stage_type))
#########################################################
# End preparing tgt
# - tgt: bs, NQ, d_model
# - refpoint_embed(unsigmoid): bs, NQ, d_model
#########################################################
#########################################################
# Begin Decoder
#########################################################
hs, references = self.decoder(
tgt=tgt.transpose(0, 1),
memory=memory.transpose(0, 1),
memory_key_padding_mask=mask_flatten,
pos=lvl_pos_embed_flatten.transpose(0, 1),
refpoints_unsigmoid=refpoint_embed.transpose(0, 1),
level_start_index=level_start_index,
spatial_shapes=spatial_shapes,
valid_ratios=valid_ratios,
tgt_mask=attn_mask,
memory_text=text_dict["encoded_text"],
text_attention_mask=~text_dict["text_token_mask"],
# we ~ the mask . False means use the token; True means pad the token
)
#########################################################
# End Decoder
# hs: n_dec, bs, nq, d_model
# references: n_dec+1, bs, nq, query_dim
#########################################################
#########################################################
# Begin postprocess
#########################################################
if self.two_stage_type == "standard":
hs_enc = tgt_undetach.unsqueeze(0)
ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0)
else:
hs_enc = ref_enc = None
#########################################################
# End postprocess
# hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or (n_enc, bs, nq, d_model) or None
# ref_enc: (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or (n_enc, bs, nq, d_model) or None
#########################################################
return hs, references, hs_enc, ref_enc, init_box_proposal
# hs: (n_dec, bs, nq, d_model)
# references: sigmoid coordinates. (n_dec+1, bs, bq, 4)
# hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or None
# ref_enc: sigmoid coordinates. \
# (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or None
class TransformerEncoder(nn.Module):
def __init__(
self,
encoder_layer,
num_layers,
d_model=256,
num_queries=300,
enc_layer_share=False,
text_enhance_layer=None,
feature_fusion_layer=None,
use_checkpoint=False,
use_transformer_ckpt=False,
):
"""_summary_
Args:
encoder_layer (_type_): _description_
num_layers (_type_): _description_
norm (_type_, optional): _description_. Defaults to None.
d_model (int, optional): _description_. Defaults to 256.
num_queries (int, optional): _description_. Defaults to 300.
enc_layer_share (bool, optional): _description_. Defaults to False.
"""
super().__init__()
# prepare layers
self.layers = []
self.text_layers = []
self.fusion_layers = []
if num_layers > 0:
self.layers = _get_clones(encoder_layer, num_layers, layer_share=enc_layer_share)
if text_enhance_layer is not None:
self.text_layers = _get_clones(
text_enhance_layer, num_layers, layer_share=enc_layer_share
)
if feature_fusion_layer is not None:
self.fusion_layers = _get_clones(
feature_fusion_layer, num_layers, layer_share=enc_layer_share
)
else:
self.layers = []
del encoder_layer
if text_enhance_layer is not None:
self.text_layers = []
del text_enhance_layer
if feature_fusion_layer is not None:
self.fusion_layers = []
del feature_fusion_layer
self.query_scale = None
self.num_queries = num_queries
self.num_layers = num_layers
self.d_model = d_model
self.use_checkpoint = use_checkpoint
self.use_transformer_ckpt = use_transformer_ckpt
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
)
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def forward(
self,
# for images
src: Tensor,
pos: Tensor,
spatial_shapes: Tensor,
level_start_index: Tensor,
valid_ratios: Tensor,
key_padding_mask: Tensor,
# for texts
memory_text: Tensor = None,
text_attention_mask: Tensor = None,
pos_text: Tensor = None,
text_self_attention_masks: Tensor = None,
position_ids: Tensor = None,
):
"""
Input:
- src: [bs, sum(hi*wi), 256]
- pos: pos embed for src. [bs, sum(hi*wi), 256]
- spatial_shapes: h,w of each level [num_level, 2]
- level_start_index: [num_level] start point of level in sum(hi*wi).
- valid_ratios: [bs, num_level, 2]
- key_padding_mask: [bs, sum(hi*wi)]
- memory_text: bs, n_text, 256
- text_attention_mask: bs, n_text
False for no padding; True for padding
- pos_text: bs, n_text, 256
- position_ids: bs, n_text
Intermedia:
- reference_points: [bs, sum(hi*wi), num_level, 2]
Outpus:
- output: [bs, sum(hi*wi), 256]
"""
output = src
# preparation and reshape
if self.num_layers > 0:
reference_points = self.get_reference_points(
spatial_shapes, valid_ratios, device=src.device
)
if self.text_layers:
# generate pos_text
bs, n_text, text_dim = memory_text.shape
if pos_text is None and position_ids is None:
pos_text = (
torch.arange(n_text, device=memory_text.device)
.float()
.unsqueeze(0)
.unsqueeze(-1)
.repeat(bs, 1, 1)
)
pos_text = get_sine_pos_embed(pos_text, num_pos_feats=256, exchange_xy=False)
if position_ids is not None:
pos_text = get_sine_pos_embed(
position_ids[..., None], num_pos_feats=256, exchange_xy=False
)
# main process
for layer_id, layer in enumerate(self.layers):
# if output.isnan().any() or memory_text.isnan().any():
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
# import ipdb; ipdb.set_trace()
if self.fusion_layers:
if self.use_checkpoint:
output, memory_text = checkpoint.checkpoint(
self.fusion_layers[layer_id],
output,
memory_text,
key_padding_mask,
text_attention_mask,
)
else:
output, memory_text = self.fusion_layers[layer_id](
v=output,
l=memory_text,
attention_mask_v=key_padding_mask,
attention_mask_l=text_attention_mask,
)
if self.text_layers:
memory_text = self.text_layers[layer_id](
src=memory_text.transpose(0, 1),
src_mask=~text_self_attention_masks, # note we use ~ for mask here
src_key_padding_mask=text_attention_mask,
pos=(pos_text.transpose(0, 1) if pos_text is not None else None),
).transpose(0, 1)
# main process
if self.use_transformer_ckpt:
output = checkpoint.checkpoint(
layer,
output,
pos,
reference_points,
spatial_shapes,
level_start_index,
key_padding_mask,
)
else:
output = layer(
src=output,
pos=pos,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
key_padding_mask=key_padding_mask,
)
return output, memory_text
class TransformerDecoder(nn.Module):
def __init__(
self,
decoder_layer,
num_layers,
norm=None,
return_intermediate=False,
d_model=256,
query_dim=4,
num_feature_levels=1,
):
super().__init__()
if num_layers > 0:
self.layers = _get_clones(decoder_layer, num_layers)
else:
self.layers = []
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
assert return_intermediate, "support return_intermediate only"
self.query_dim = query_dim
assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim)
self.num_feature_levels = num_feature_levels
self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)
self.query_pos_sine_scale = None
self.query_scale = None
self.bbox_embed = None
self.class_embed = None
self.d_model = d_model
self.ref_anchor_head = None
def forward(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2
# for memory
level_start_index: Optional[Tensor] = None, # num_levels
spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
valid_ratios: Optional[Tensor] = None,
# for text
memory_text: Optional[Tensor] = None,
text_attention_mask: Optional[Tensor] = None,
):
"""
Input:
- tgt: nq, bs, d_model
- memory: hw, bs, d_model
- pos: hw, bs, d_model
- refpoints_unsigmoid: nq, bs, 2/4
- valid_ratios/spatial_shapes: bs, nlevel, 2
"""
output = tgt
intermediate = []
reference_points = refpoints_unsigmoid.sigmoid()
ref_points = [reference_points]
for layer_id, layer in enumerate(self.layers):
if reference_points.shape[-1] == 4:
reference_points_input = (
reference_points[:, :, None]
* torch.cat([valid_ratios, valid_ratios], -1)[None, :]
) # nq, bs, nlevel, 4
else:
assert reference_points.shape[-1] == 2
reference_points_input = reference_points[:, :, None] * valid_ratios[None, :]
query_sine_embed = gen_sineembed_for_position(
reference_points_input[:, :, 0, :]
) # nq, bs, 256*2
# conditional query
raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256
pos_scale = self.query_scale(output) if self.query_scale is not None else 1
query_pos = pos_scale * raw_query_pos
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
# if query_pos.isnan().any() | query_pos.isinf().any():
# import ipdb; ipdb.set_trace()
# main process
output = layer(
tgt=output,
tgt_query_pos=query_pos,
tgt_query_sine_embed=query_sine_embed,
tgt_key_padding_mask=tgt_key_padding_mask,
tgt_reference_points=reference_points_input,
memory_text=memory_text,
text_attention_mask=text_attention_mask,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
memory_level_start_index=level_start_index,
memory_spatial_shapes=spatial_shapes,
memory_pos=pos,
self_attn_mask=tgt_mask,
cross_attn_mask=memory_mask,
)
if output.isnan().any() | output.isinf().any():
print(f"output layer_id {layer_id} is nan")
try:
num_nan = output.isnan().sum().item()
num_inf = output.isinf().sum().item()
print(f"num_nan {num_nan}, num_inf {num_inf}")
except Exception as e:
print(e)
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
# import ipdb; ipdb.set_trace()
# iter update
if self.bbox_embed is not None:
# box_holder = self.bbox_embed(output)
# box_holder[..., :self.query_dim] += inverse_sigmoid(reference_points)
# new_reference_points = box_holder[..., :self.query_dim].sigmoid()
reference_before_sigmoid = inverse_sigmoid(reference_points)
delta_unsig = self.bbox_embed[layer_id](output)
outputs_unsig = delta_unsig + reference_before_sigmoid
new_reference_points = outputs_unsig.sigmoid()
reference_points = new_reference_points.detach()
# if layer_id != self.num_layers - 1:
ref_points.append(new_reference_points)
intermediate.append(self.norm(output))
return [
[itm_out.transpose(0, 1) for itm_out in intermediate],
[itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points],
]
class DeformableTransformerEncoderLayer(nn.Module):
def __init__(
self,
d_model=256,
d_ffn=1024,
dropout=0.1,
activation="relu",
n_levels=4,
n_heads=8,
n_points=4,
):
super().__init__()
# self attention
self.self_attn = MSDeformAttn(
embed_dim=d_model,
num_levels=n_levels,
num_heads=n_heads,
num_points=n_points,
batch_first=True,
)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation, d_model=d_ffn)
self.dropout2 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout3 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, src):
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
src = src + self.dropout3(src2)
src = self.norm2(src)
return src
def forward(
self, src, pos, reference_points, spatial_shapes, level_start_index, key_padding_mask=None
):
# self attention
# import ipdb; ipdb.set_trace()
src2 = self.self_attn(
query=self.with_pos_embed(src, pos),
reference_points=reference_points,
value=src,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
key_padding_mask=key_padding_mask,
)
src = src + self.dropout1(src2)
src = self.norm1(src)
# ffn
src = self.forward_ffn(src)
return src
class DeformableTransformerDecoderLayer(nn.Module):
def __init__(
self,
d_model=256,
d_ffn=1024,
dropout=0.1,
activation="relu",
n_levels=4,
n_heads=8,
n_points=4,
use_text_feat_guide=False,
use_text_cross_attention=False,
):
super().__init__()
# cross attention
self.cross_attn = MSDeformAttn(
embed_dim=d_model,
num_levels=n_levels,
num_heads=n_heads,
num_points=n_points,
batch_first=True,
)
self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.norm1 = nn.LayerNorm(d_model)
# cross attention text
if use_text_cross_attention:
self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.catext_norm = nn.LayerNorm(d_model)
# self attention
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.norm2 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation, d_model=d_ffn, batch_dim=1)
self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.norm3 = nn.LayerNorm(d_model)
self.key_aware_proj = None
self.use_text_feat_guide = use_text_feat_guide
assert not use_text_feat_guide
self.use_text_cross_attention = use_text_cross_attention
def rm_self_attn_modules(self):
self.self_attn = None
self.dropout2 = None
self.norm2 = None
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, tgt):
with torch.cuda.amp.autocast(enabled=False):
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout4(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward(
self,
# for tgt
tgt: Optional[Tensor], # nq, bs, d_model
tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
tgt_key_padding_mask: Optional[Tensor] = None,
tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
memory_text: Optional[Tensor] = None, # bs, num_token, d_model
text_attention_mask: Optional[Tensor] = None, # bs, num_token
# for memory
memory: Optional[Tensor] = None, # hw, bs, d_model
memory_key_padding_mask: Optional[Tensor] = None,
memory_level_start_index: Optional[Tensor] = None, # num_levels
memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
memory_pos: Optional[Tensor] = None, # pos for memory
# sa
self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
):
"""
Input:
- tgt/tgt_query_pos: nq, bs, d_model
-
"""
assert cross_attn_mask is None
# self attention
if self.self_attn is not None:
# import ipdb; ipdb.set_trace()
q = k = self.with_pos_embed(tgt, tgt_query_pos)
tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
if self.use_text_cross_attention:
tgt2 = self.ca_text(
self.with_pos_embed(tgt, tgt_query_pos),
memory_text.transpose(0, 1),
memory_text.transpose(0, 1),
key_padding_mask=text_attention_mask,
)[0]
tgt = tgt + self.catext_dropout(tgt2)
tgt = self.catext_norm(tgt)
tgt2 = self.cross_attn(
query=self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
reference_points=tgt_reference_points.transpose(0, 1).contiguous(),
value=memory.transpose(0, 1),
spatial_shapes=memory_spatial_shapes,
level_start_index=memory_level_start_index,
key_padding_mask=memory_key_padding_mask,
).transpose(0, 1)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# ffn
tgt = self.forward_ffn(tgt)
return tgt
def build_transformer(args):
return Transformer(
d_model=args.hidden_dim,
dropout=args.dropout,
nhead=args.nheads,
num_queries=args.num_queries,
dim_feedforward=args.dim_feedforward,
num_encoder_layers=args.enc_layers,
num_decoder_layers=args.dec_layers,
normalize_before=args.pre_norm,
return_intermediate_dec=True,
query_dim=args.query_dim,
activation=args.transformer_activation,
num_patterns=args.num_patterns,
num_feature_levels=args.num_feature_levels,
enc_n_points=args.enc_n_points,
dec_n_points=args.dec_n_points,
learnable_tgt_init=True,
# two stage
two_stage_type=args.two_stage_type, # ['no', 'standard', 'early']
embed_init_tgt=args.embed_init_tgt,
use_text_enhancer=args.use_text_enhancer,
use_fusion_layer=args.use_fusion_layer,
use_checkpoint=args.use_checkpoint,
use_transformer_ckpt=args.use_transformer_ckpt,
use_text_cross_attention=args.use_text_cross_attention,
text_dropout=args.text_dropout,
fusion_dropout=args.fusion_dropout,
fusion_droppath=args.fusion_droppath,
)
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DETR Transformer class.
Copy-paste from torch.nn.Transformer with modifications:
* positional encodings are passed in MHattention
* extra LN at the end of encoder is removed
* decoder returns a stack of activations from all decoding layers
"""
from typing import Optional
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from .utils import (
MLP,
_get_activation_fn,
_get_clones,
gen_encoder_output_proposals,
gen_sineembed_for_position,
sigmoid_focal_loss,
)
class TextTransformer(nn.Module):
def __init__(self, num_layers, d_model=256, nheads=8, dim_feedforward=2048, dropout=0.1):
super().__init__()
self.num_layers = num_layers
self.d_model = d_model
self.nheads = nheads
self.dim_feedforward = dim_feedforward
self.norm = None
single_encoder_layer = TransformerEncoderLayer(
d_model=d_model, nhead=nheads, dim_feedforward=dim_feedforward, dropout=dropout
)
self.layers = _get_clones(single_encoder_layer, num_layers)
def forward(self, memory_text: torch.Tensor, text_attention_mask: torch.Tensor):
"""
Args:
text_attention_mask: bs, num_token
memory_text: bs, num_token, d_model
Raises:
RuntimeError: _description_
Returns:
output: bs, num_token, d_model
"""
output = memory_text.transpose(0, 1)
for layer in self.layers:
output = layer(output, src_key_padding_mask=text_attention_mask)
if self.norm is not None:
output = self.norm(output)
return output.transpose(0, 1)
class TransformerEncoderLayer(nn.Module):
def __init__(
self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self.nhead = nhead
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
# repeat attn mask
if src_mask.dim() == 3 and src_mask.shape[0] == src.shape[1]:
# bs, num_q, num_k
src_mask = src_mask.repeat(self.nhead, 1, 1)
q = k = self.with_pos_embed(src, pos)
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)[0]
# src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
import copy
import math
import torch
import torch.nn.functional as F
from torch import Tensor, nn
def _get_clones(module, N, layer_share=False):
# import ipdb; ipdb.set_trace()
if layer_share:
return nn.ModuleList([module for i in range(N)])
else:
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def get_sine_pos_embed(
pos_tensor: torch.Tensor,
num_pos_feats: int = 128,
temperature: int = 10000,
exchange_xy: bool = True,
):
"""generate sine position embedding from a position tensor
Args:
pos_tensor (torch.Tensor): shape: [..., n].
num_pos_feats (int): projected shape for each float in the tensor.
temperature (int): temperature in the sine/cosine function.
exchange_xy (bool, optional): exchange pos x and pos y. \
For example, input tensor is [x,y], the results will be [pos(y), pos(x)]. Defaults to True.
Returns:
pos_embed (torch.Tensor): shape: [..., n*num_pos_feats].
"""
scale = 2 * math.pi
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device)
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
def sine_func(x: torch.Tensor):
sin_x = x * scale / dim_t
sin_x = torch.stack((sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), dim=3).flatten(2)
return sin_x
pos_res = [sine_func(x) for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1)]
if exchange_xy:
pos_res[0], pos_res[1] = pos_res[1], pos_res[0]
pos_res = torch.cat(pos_res, dim=-1)
return pos_res
def gen_encoder_output_proposals(
memory: Tensor, memory_padding_mask: Tensor, spatial_shapes: Tensor, learnedwh=None
):
"""
Input:
- memory: bs, \sum{hw}, d_model
- memory_padding_mask: bs, \sum{hw}
- spatial_shapes: nlevel, 2
- learnedwh: 2
Output:
- output_memory: bs, \sum{hw}, d_model
- output_proposals: bs, \sum{hw}, 4
"""
N_, S_, C_ = memory.shape
proposals = []
_cur = 0
for lvl, (H_, W_) in enumerate(spatial_shapes):
mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view(N_, H_, W_, 1)
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
# import ipdb; ipdb.set_trace()
grid_y, grid_x = torch.meshgrid(
torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device),
)
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
if learnedwh is not None:
# import ipdb; ipdb.set_trace()
wh = torch.ones_like(grid) * learnedwh.sigmoid() * (2.0**lvl)
else:
wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
# scale = torch.cat([W_[None].unsqueeze(-1), H_[None].unsqueeze(-1)], 1).view(1, 1, 1, 2).repeat(N_, 1, 1, 1)
# grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
# wh = torch.ones_like(grid) / scale
proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
proposals.append(proposal)
_cur += H_ * W_
# import ipdb; ipdb.set_trace()
output_proposals = torch.cat(proposals, 1)
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(
-1, keepdim=True
)
output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid
output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float("inf"))
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf"))
output_memory = memory
output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
# output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
# output_memory = output_memory.masked_fill(~output_proposals_valid, float('inf'))
return output_memory, output_proposals
class RandomBoxPerturber:
def __init__(
self, x_noise_scale=0.2, y_noise_scale=0.2, w_noise_scale=0.2, h_noise_scale=0.2
) -> None:
self.noise_scale = torch.Tensor(
[x_noise_scale, y_noise_scale, w_noise_scale, h_noise_scale]
)
def __call__(self, refanchors: Tensor) -> Tensor:
nq, bs, query_dim = refanchors.shape
device = refanchors.device
noise_raw = torch.rand_like(refanchors)
noise_scale = self.noise_scale.to(device)[:query_dim]
new_refanchors = refanchors * (1 + (noise_raw - 0.5) * noise_scale)
return new_refanchors.clamp_(0, 1)
def sigmoid_focal_loss(
inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2, no_reduction=False
):
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
Returns:
Loss tensor
"""
prob = inputs.sigmoid()
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
if no_reduction:
return loss
return loss.mean(1).sum() / num_boxes
class MLP(nn.Module):
"""Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
def _get_activation_fn(activation, d_model=256, batch_dim=0):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
if activation == "prelu":
return nn.PReLU()
if activation == "selu":
return F.selu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
def gen_sineembed_for_position(pos_tensor):
# n_query, bs, _ = pos_tensor.size()
# sineembed_tensor = torch.zeros(n_query, bs, 256)
scale = 2 * math.pi
dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode='floor')) / 128)
x_embed = pos_tensor[:, :, 0] * scale
y_embed = pos_tensor[:, :, 1] * scale
pos_x = x_embed[:, :, None] / dim_t
pos_y = y_embed[:, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
if pos_tensor.size(-1) == 2:
pos = torch.cat((pos_y, pos_x), dim=2)
elif pos_tensor.size(-1) == 4:
w_embed = pos_tensor[:, :, 2] * scale
pos_w = w_embed[:, :, None] / dim_t
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
h_embed = pos_tensor[:, :, 3] * scale
pos_h = h_embed[:, :, None] / dim_t
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
else:
raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
return pos
class ContrastiveEmbed(nn.Module):
def __init__(self, max_text_len=256):
"""
Args:
max_text_len: max length of text.
"""
super().__init__()
self.max_text_len = max_text_len
def forward(self, x, text_dict):
"""_summary_
Args:
x (_type_): _description_
text_dict (_type_): _description_
{
'encoded_text': encoded_text, # bs, 195, d_model
'text_token_mask': text_token_mask, # bs, 195
# True for used tokens. False for padding tokens
}
Returns:
_type_: _description_
"""
assert isinstance(text_dict, dict)
y = text_dict["encoded_text"]
text_token_mask = text_dict["text_token_mask"]
res = x @ y.transpose(-1, -2)
res.masked_fill_(~text_token_mask[:, None, :], float("-inf"))
# padding to max_text_len
new_res = torch.full((*res.shape[:-1], self.max_text_len), float("-inf"), device=res.device)
new_res[..., : res.shape[-1]] = res
return new_res
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .GroundingDINO import build_groundingdino
def build_model(args):
# we use register to maintain models from catdet6 on.
from .registry import MODULE_BUILD_FUNCS
assert args.modelname in MODULE_BUILD_FUNCS._module_dict
build_func = MODULE_BUILD_FUNCS.get(args.modelname)
model = build_func(args)
return model
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# -*- coding: utf-8 -*-
# @Author: Yihao Chen
# @Date: 2021-08-16 16:03:17
# @Last Modified by: Shilong Liu
# @Last Modified time: 2022-01-23 15:26
# modified from mmcv
import inspect
from functools import partial
class Registry(object):
def __init__(self, name):
self._name = name
self._module_dict = dict()
def __repr__(self):
format_str = self.__class__.__name__ + "(name={}, items={})".format(
self._name, list(self._module_dict.keys())
)
return format_str
def __len__(self):
return len(self._module_dict)
@property
def name(self):
return self._name
@property
def module_dict(self):
return self._module_dict
def get(self, key):
return self._module_dict.get(key, None)
def registe_with_name(self, module_name=None, force=False):
return partial(self.register, module_name=module_name, force=force)
def register(self, module_build_function, module_name=None, force=False):
"""Register a module build function.
Args:
module (:obj:`nn.Module`): Module to be registered.
"""
if not inspect.isfunction(module_build_function):
raise TypeError(
"module_build_function must be a function, but got {}".format(
type(module_build_function)
)
)
if module_name is None:
module_name = module_build_function.__name__
if not force and module_name in self._module_dict:
raise KeyError("{} is already registered in {}".format(module_name, self.name))
self._module_dict[module_name] = module_build_function
return module_build_function
MODULE_BUILD_FUNCS = Registry("model build functions")
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment