Commit cce6e1bf authored by chenych's avatar chenych
Browse files

First commit.

parents
Pipeline #640 failed with stages
in 0 seconds
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include <torch/extension.h>
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step);
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step);
/*!
**************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************
* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
* Copyright (c) 2018 Microsoft
**************************************************************************
*/
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N, const int num_threads)
{
return (N + num_threads - 1) / num_threads;
}
template <typename scalar_t>
__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
const int &height, const int &width, const int &nheads, const int &channels,
const scalar_t &h, const scalar_t &w, const int &m, const int &c)
{
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
{
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
{
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
{
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
{
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
}
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename scalar_t>
__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
const int &height, const int &width, const int &nheads, const int &channels,
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
const scalar_t &top_grad,
const scalar_t &attn_weight,
scalar_t* &grad_value,
scalar_t* grad_sampling_loc,
scalar_t* grad_attn_weight)
{
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t top_grad_value = top_grad * attn_weight;
scalar_t grad_h_weight = 0, grad_w_weight = 0;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
{
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
atomicAdd(grad_value+ptr1, w1*top_grad_value);
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
{
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
atomicAdd(grad_value+ptr2, w2*top_grad_value);
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
{
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
atomicAdd(grad_value+ptr3, w3*top_grad_value);
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
{
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
atomicAdd(grad_value+ptr4, w4*top_grad_value);
}
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
*grad_attn_weight = top_grad * val;
*grad_sampling_loc = width * grad_w_weight * top_grad_value;
*(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
}
template <typename scalar_t>
__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
const int &height, const int &width, const int &nheads, const int &channels,
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
const scalar_t &top_grad,
const scalar_t &attn_weight,
scalar_t* &grad_value,
scalar_t* grad_sampling_loc,
scalar_t* grad_attn_weight)
{
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t top_grad_value = top_grad * attn_weight;
scalar_t grad_h_weight = 0, grad_w_weight = 0;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
{
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
atomicAdd(grad_value+ptr1, w1*top_grad_value);
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
{
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
atomicAdd(grad_value+ptr2, w2*top_grad_value);
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
{
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
atomicAdd(grad_value+ptr3, w3*top_grad_value);
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
{
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
atomicAdd(grad_value+ptr4, w4*top_grad_value);
}
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
atomicAdd(grad_attn_weight, top_grad * val);
atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
}
template <typename scalar_t>
__global__ void ms_deformable_im2col_gpu_kernel(const int n,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *data_col)
{
CUDA_KERNEL_LOOP(index, n)
{
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
scalar_t *data_col_ptr = data_col + index;
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
scalar_t col = 0;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
}
}
*data_col_ptr = col;
}
}
template <typename scalar_t, unsigned int blockSize>
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
__shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
if (tid == 0)
{
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
int sid=2;
for (unsigned int tid = 1; tid < blockSize; ++tid)
{
_grad_w += cache_grad_sampling_loc[sid];
_grad_h += cache_grad_sampling_loc[sid + 1];
_grad_a += cache_grad_attn_weight[tid];
sid += 2;
}
*grad_sampling_loc = _grad_w;
*(grad_sampling_loc + 1) = _grad_h;
*grad_attn_weight = _grad_a;
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t, unsigned int blockSize>
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
__shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
for (unsigned int s=blockSize/2; s>0; s>>=1)
{
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
}
__syncthreads();
}
if (tid == 0)
{
*grad_sampling_loc = cache_grad_sampling_loc[0];
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight = cache_grad_attn_weight[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
extern __shared__ int _s[];
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
if (tid == 0)
{
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
int sid=2;
for (unsigned int tid = 1; tid < blockDim.x; ++tid)
{
_grad_w += cache_grad_sampling_loc[sid];
_grad_h += cache_grad_sampling_loc[sid + 1];
_grad_a += cache_grad_attn_weight[tid];
sid += 2;
}
*grad_sampling_loc = _grad_w;
*(grad_sampling_loc + 1) = _grad_h;
*grad_attn_weight = _grad_a;
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
extern __shared__ int _s[];
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
{
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
if (tid + (s << 1) < spre)
{
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
}
}
__syncthreads();
}
if (tid == 0)
{
*grad_sampling_loc = cache_grad_sampling_loc[0];
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight = cache_grad_attn_weight[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
extern __shared__ int _s[];
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
{
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
if (tid + (s << 1) < spre)
{
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
}
}
__syncthreads();
}
if (tid == 0)
{
atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear_gm(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
grad_sampling_loc, grad_attn_weight);
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
void ms_deformable_im2col_cuda(cudaStream_t stream,
const scalar_t* data_value,
const int64_t* data_spatial_shapes,
const int64_t* data_level_start_index,
const scalar_t* data_sampling_loc,
const scalar_t* data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t* data_col)
{
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
const int num_threads = CUDA_NUM_THREADS;
ms_deformable_im2col_gpu_kernel<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
}
}
template <typename scalar_t>
void ms_deformable_col2im_cuda(cudaStream_t stream,
const scalar_t* grad_col,
const scalar_t* data_value,
const int64_t * data_spatial_shapes,
const int64_t * data_level_start_index,
const scalar_t * data_sampling_loc,
const scalar_t * data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t* grad_value,
scalar_t* grad_sampling_loc,
scalar_t* grad_attn_weight)
{
const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
if (channels > 1024)
{
if ((channels & 1023) == 0)
{
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads*3*sizeof(scalar_t), stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
}
else
{
ms_deformable_col2im_gpu_kernel_gm<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
}
}
else{
switch(channels)
{
case 1:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 2:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 4:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 8:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 16:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 32:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 64:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 128:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 256:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 512:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 1024:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
default:
if (channels < 64)
{
ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads*3*sizeof(scalar_t), stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
}
else
{
ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads*3*sizeof(scalar_t), stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
}
}
}
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
}
}
\ No newline at end of file
#include <cuda_runtime_api.h>
#define CUDART_VERSION 11080
namespace adet {
int get_cudart_version() {
#if !defined(__HIP_PLATFORM_HCC__)
return CUDART_VERSION;
#else
int version = 0;
#if HIP_VERSION_MAJOR != 0
// Create a convention similar to that of CUDA, as assumed by other
// parts of the code.
version = HIP_VERSION_MINOR;
version += (HIP_VERSION_MAJOR * 100);
#else
hipRuntimeGetVersion(&version);
#endif
return version;
#endif
}
} // namespace adet
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#include "DeformAttn/ms_deform_attn.h"
namespace adet {
#ifdef WITH_CUDA
extern int get_cudart_version();
#endif
std::string get_cuda_version() {
#ifdef WITH_CUDA
std::ostringstream oss;
// copied from
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231
auto printCudaStyleVersion = [&](int v) {
oss << (v / 1000) << "." << (v / 10 % 100);
if (v % 10 != 0) {
oss << "." << (v % 10);
}
};
printCudaStyleVersion(get_cudart_version());
return oss.str();
#else
return std::string("not available");
#endif
}
// similar to
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Version.cpp
std::string get_compiler_version() {
std::ostringstream ss;
#if defined(__GNUC__)
#ifndef __clang__
{ ss << "GCC " << __GNUC__ << "." << __GNUC_MINOR__; }
#endif
#endif
#if defined(__clang_major__)
{
ss << "clang " << __clang_major__ << "." << __clang_minor__ << "."
<< __clang_patchlevel__;
}
#endif
#if defined(_MSC_VER)
{ ss << "MSVC " << _MSC_FULL_VER; }
#endif
return ss.str();
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
}
} // namespace adet
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
import copy
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
from adet.utils.misc import inverse_sigmoid
from .ms_deform_attn import MSDeformAttn
from scipy.special import comb as n_over_k
from adet.utils.curve_utils import upcast
from adet.modeling.model.utils import MLP, gen_point_pos_embed
class DeformableTransformer(nn.Module):
def __init__(
self,
temp=10000,
d_model=256,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=1024,
dropout=0.1,
activation="relu",
return_intermediate_dec=False,
num_feature_levels=4,
dec_n_points=4,
enc_n_points=4,
num_proposals=300,
num_points=25
):
super().__init__()
self.d_model = d_model
self.nhead = nhead
self.num_proposals = num_proposals
encoder_layer = DeformableTransformerEncoderLayer(
d_model,
dim_feedforward,
dropout,
activation,
num_feature_levels,
nhead,
enc_n_points
)
self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers)
decoder_layer = DeformableCompositeTransformerDecoderLayer(
d_model,
dim_feedforward,
dropout,
activation,
num_feature_levels,
nhead,
dec_n_points
)
self.decoder = DeformableCompositeTransformerDecoder(
temp,
decoder_layer,
num_decoder_layers,
return_intermediate_dec,
d_model
)
self.level_embed = nn.Parameter(
torch.Tensor(num_feature_levels, d_model)
)
self.bezier_coord_embed = None
self.bezier_class_embed = None
self.enc_output = nn.Linear(d_model, d_model)
self.enc_output_norm = nn.LayerNorm(d_model)
self.num_points = num_points
Mtk = lambda n, t, k: t ** k * (1 - t) ** (n - k) * n_over_k(n, k)
BezierCoeff = lambda ts: [[Mtk(3, t, k) for k in range(4)] for t in ts]
curve_token = torch.linspace(0, 1, num_points)
self.bernstein_matrix = torch.tensor(BezierCoeff(curve_token), requires_grad=False)
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MSDeformAttn):
m._reset_parameters()
normal_(self.level_embed)
def init_points_from_bezier_proposals(self, reference_bezier):
bz = reference_bezier.shape[0]
initial_reference_points = reference_bezier.view(bz, self.num_proposals, 4, 2)
initial_reference_points = torch.matmul(
upcast(self.bernstein_matrix.to(initial_reference_points.device)),
upcast(initial_reference_points)
)
return initial_reference_points
def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
N_, S_, C_ = memory.shape
base_scale = 4.0
proposals = []
_cur = 0
for lvl, (H_, W_) in enumerate(spatial_shapes):
mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1)
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device))
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
proposal = grid.repeat(1, 1, 1, 4)
assert proposal.shape[-1] == 8
proposal = proposal.view(N_, -1, 8)
proposals.append(proposal)
_cur += (H_ * W_)
output_proposals = torch.cat(proposals, 1)
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
output_proposals = torch.log(output_proposals / (1 - output_proposals))
output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))
output_memory = memory
output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
output_memory = self.enc_output_norm(self.enc_output(output_memory))
return output_memory, output_proposals
def get_valid_ratio(self, mask):
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def forward(self, srcs, masks, pos_embeds, query_embed):
src_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
bs, c, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
src = src.flatten(2).transpose(1, 2)
mask = mask.flatten(1)
pos_embed = pos_embed.flatten(2).transpose(1, 2)
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
lvl_pos_embed_flatten.append(lvl_pos_embed)
src_flatten.append(src)
mask_flatten.append(mask)
src_flatten = torch.cat(src_flatten, 1)
mask_flatten = torch.cat(mask_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
memory = self.encoder(
src_flatten,
spatial_shapes,
level_start_index,
valid_ratios,
lvl_pos_embed_flatten,
mask_flatten
)
bs, _, c = memory.shape
output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)
enc_outputs_class = self.bezier_class_embed(output_memory)
enc_outputs_coord_unact = self.bezier_coord_embed(output_memory) + output_proposals
# select top-k curves
topk = self.num_proposals
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
topk_coords_unact = torch.gather(
enc_outputs_coord_unact,
1,
topk_proposals.unsqueeze(-1).repeat(1, 1, 8)
)
topk_coords_unact = topk_coords_unact.detach()
reference_points = topk_coords_unact.sigmoid() # bs, nq, 8
reference_points = self.init_points_from_bezier_proposals(reference_points) # bs, nq, num_points, 2
init_reference_out = reference_points
query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1, -1)
hs, inter_references = self.decoder(
query_embed,
reference_points,
memory,
spatial_shapes,
level_start_index,
valid_ratios,
query_pos=None,
src_padding_mask=mask_flatten
)
inter_references_out = inter_references
return (hs, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact)
class DeformableTransformerEncoderLayer(nn.Module):
def __init__(
self,
d_model=256,
d_ffn=1024,
dropout=0.1,
activation="relu",
n_levels=4,
n_heads=8,
n_points=4
):
super().__init__()
# self attention
self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation)
self.dropout2 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout3 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, src):
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
src = src + self.dropout3(src2)
src = self.norm2(src)
return src
def forward(
self,
src,
pos,
reference_points,
spatial_shapes,
level_start_index,
padding_mask=None
):
# self attention
src2 = self.self_attn(
self.with_pos_embed(src, pos),
reference_points,
src,
spatial_shapes,
level_start_index,
padding_mask
)
src = src + self.dropout1(src2)
src = self.norm1(src)
# ffn
src = self.forward_ffn(src)
return src
class DeformableTransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def forward(
self,
src,
spatial_shapes,
level_start_index,
valid_ratios,
pos=None,
padding_mask=None
):
output = src
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
for _, layer in enumerate(self.layers):
output = layer(
output,
pos,
reference_points,
spatial_shapes,
level_start_index,
padding_mask
)
return output
class DeformableCompositeTransformerDecoderLayer(nn.Module):
def __init__(
self,
d_model=256,
d_ffn=1024,
dropout=0.1,
activation="relu",
n_levels=4,
n_heads=8,
n_points=4
):
super().__init__()
# self attention (intra)
self.attn_intra = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
self.norm_intra = nn.LayerNorm(d_model)
self.dropout_intra = nn.Dropout(dropout)
# self attention (inter)
self.attn_inter = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
self.dropout_inter = nn.Dropout(dropout)
self.norm_inter = nn.LayerNorm(d_model)
# cross attention
self.attn_cross = MSDeformAttn(d_model, n_levels, n_heads, n_points)
self.dropout_cross = nn.Dropout(dropout)
self.norm_cross = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation)
self.dropout3 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout4 = nn.Dropout(dropout)
self.norm3 = nn.LayerNorm(d_model)
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, tgt):
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout4(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward(
self,
tgt,
query_pos,
reference_points,
src,
src_spatial_shapes,
level_start_index,
src_padding_mask=None
):
## input size
# tgt: bs, n_q, n_pts, embed_dim
# query_pos: bs, n_q, n_pts, embed_dim
q = k = self.with_pos_embed(tgt, query_pos)
# q.flatten(0, 1).transpose(0, 1): (n_pts, bs*n_q, dim)
tgt2 = self.attn_intra(
q.flatten(0, 1).transpose(0, 1),
k.flatten(0, 1).transpose(0, 1),
tgt.flatten(0, 1).transpose(0, 1),
)[0].transpose(0, 1).reshape(q.shape)
tgt = tgt + self.dropout_intra(tgt2)
tgt = self.norm_intra(tgt)
q_inter = k_inter = tgt_inter = torch.swapdims(tgt, 1, 2) # (bs, n_pts, n_q, dim)
# q_inter.flatten(0, 1).transpose(0, 1): n_q, bs*n_pts, dim
tgt2_inter = self.attn_inter(
q_inter.flatten(0, 1).transpose(0, 1),
k_inter.flatten(0, 1).transpose(0, 1),
tgt_inter.flatten(0, 1).transpose(0, 1)
)[0].transpose(0, 1).reshape(q_inter.shape)
tgt_inter = tgt_inter + self.dropout_inter(tgt2_inter)
tgt_inter = torch.swapdims(self.norm_inter(tgt_inter), 1, 2)
# cross attention
if len(reference_points.shape) == 4:
reference_points_loc = reference_points[:, :, None, :, :].repeat(1, 1, tgt_inter.shape[2], 1, 1)
else:
assert reference_points.shape[2] == tgt_inter.shape[2]
reference_points_loc = reference_points
tgt2 = self.attn_cross(
self.with_pos_embed(tgt_inter, query_pos).flatten(1, 2),
reference_points_loc.flatten(1, 2),
src,
src_spatial_shapes,
level_start_index,
src_padding_mask
).reshape(tgt_inter.shape)
tgt_inter = tgt_inter + self.dropout_cross(tgt2)
tgt = self.norm_cross(tgt_inter)
# ffn
tgt = self.forward_ffn(tgt)
return tgt
class DeformableCompositeTransformerDecoder(nn.Module):
def __init__(
self,
temp,
decoder_layer,
num_layers,
return_intermediate=False,
d_model=256
):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.return_intermediate = return_intermediate
self.ctrl_point_coord = None
self.ref_point_head = MLP(d_model, d_model, d_model, 2)
self.temp = temp
self.d_model = d_model
def forward(
self,
tgt,
reference_points,
src,
src_spatial_shapes,
src_level_start_index,
src_valid_ratios,
query_pos=None,
src_padding_mask=None
):
# reference_points: bs, 100, 4
# src_valid_ratios: bs, 4, 2
# query_pos: bs, 100, n_pts, d_model
output = tgt # bs, 100, n_pts, d_model
assert query_pos is None and reference_points.shape[-1] == 2
intermediate = []
intermediate_reference_points = []
for lid, layer in enumerate(self.layers):
if reference_points.shape[-1] == 4:
reference_points_input = reference_points[:, :, None] \
* torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None]
else:
assert reference_points.shape[-1] == 2
# reference_points: (bs, nq, n_pts, 2), reference_points_input: (bs, nq, n_pts, 4, 2)
reference_points_input = reference_points[:, :, :, None] * src_valid_ratios[:, None, None]
query_pos = gen_point_pos_embed(reference_points_input[:, :, :, 0, :], self.d_model, self.temp)
query_pos = self.ref_point_head(query_pos)
output = layer(output, query_pos, reference_points_input, src,
src_spatial_shapes, src_level_start_index, src_padding_mask)
if self.ctrl_point_coord is not None:
tmp = self.ctrl_point_coord[lid](output)
new_reference_points = tmp + inverse_sigmoid(reference_points)
new_reference_points = new_reference_points.sigmoid()
reference_points = new_reference_points.detach()
if self.return_intermediate:
intermediate.append(output)
intermediate_reference_points.append(reference_points)
if self.return_intermediate:
return torch.stack(intermediate), torch.stack(intermediate_reference_points)
return output, reference_points
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
# ------------------------------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
import warnings
import math
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, constant_
from torch.autograd.function import once_differentiable
from adet import _C
import sys
class _MSDeformAttnFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
ctx.im2col_step = im2col_step
output = _C.ms_deform_attn_forward(
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = \
_C.ms_deform_attn_backward(
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
# for debug and test only,
# need to use cuda version instead
N_, S_, M_, D_ = value.shape
_, Lq_, M_, L_, P_, _ = sampling_locations.shape
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
# N_*M_, D_, Lq_, P_
sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
mode='bilinear', padding_mode='zeros', align_corners=False)
sampling_value_list.append(sampling_value_l_)
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
return output.transpose(1, 2).contiguous()
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
return (n & (n-1) == 0) and n != 0
class MSDeformAttn(nn.Module):
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
"""
Multi-Scale Deformable Attention Module
:param d_model hidden dimension
:param n_levels number of feature levels
:param n_heads number of attention heads
:param n_points number of sampling points per attention head per feature level
"""
super().__init__()
if d_model % n_heads != 0:
raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
_d_per_head = d_model // n_heads
# you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
if not _is_power_of_2(_d_per_head):
warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation.")
self.im2col_step = 64
self.d_model = d_model
self.n_levels = n_levels
self.n_heads = n_heads
self.n_points = n_points
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
self.value_proj = nn.Linear(d_model, d_model)
self.output_proj = nn.Linear(d_model, d_model)
self._reset_parameters()
def _reset_parameters(self):
constant_(self.sampling_offsets.weight.data, 0.)
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
constant_(self.attention_weights.weight.data, 0.)
constant_(self.attention_weights.bias.data, 0.)
xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.)
def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
"""
:param query (N, Length_{query}, C)
:param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
:param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
:param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
:param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
:param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
:return output (N, Length_{query}, C)
"""
N, Len_q, _ = query.shape
N, Len_in, _ = input_flatten.shape
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
value = self.value_proj(input_flatten)
if input_padding_mask is not None:
value = value.masked_fill(input_padding_mask[..., None], float(0))
value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
# N, Len_q, n_heads, n_levels, n_points, 2
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
else:
raise ValueError(
'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
output = _MSDeformAttnFunction.apply(
value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
output = self.output_proj(output)
# if decoder:
# return output, sampling_locations, attention_weights
return output
import torch
import numpy as np
import torch.nn as nn
class PositionalEncoding1D(nn.Module):
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
"""
:param channels: The last dimension of the tensor you want to apply pos emb to.
"""
super().__init__()
self.channels = num_pos_feats
dim_t = torch.arange(0, self.channels, 2).float()
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * np.pi
self.scale = scale
self.normalize = normalize
inv_freq = 1. / (temperature ** (dim_t / self.channels))
self.register_buffer('inv_freq', inv_freq)
def forward(self, tensor):
"""
:param tensor: A 2d tensor of size (len, c)
:return: Positional Encoding Matrix of size (len, c)
"""
if tensor.ndim != 2:
raise RuntimeError("The input tensor has to be 2D!")
x, orig_ch = tensor.shape
pos_x = torch.arange(
1, x + 1, device=tensor.device).type(self.inv_freq.type())
if self.normalize:
eps = 1e-6
pos_x = pos_x / (pos_x[-1:] + eps) * self.scale
sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
emb_x = torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1)
emb = torch.zeros((x, self.channels),
device=tensor.device).type(tensor.type())
emb[:, :self.channels] = emb_x
return emb[:, :orig_ch]
class PositionalEncoding2D(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * np.pi
self.scale = scale
def forward(self, tensors):
x = tensors.tensors
mask = tensors.mask
assert mask is not None
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='trunc') / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .text_spotter import TransformerPureDetector
_EXCLUDE = {"torch", "ShapeSpec"}
__all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")]
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from adet.layers.deformable_transformer import DeformableTransformer
from adet.utils.misc import (
NestedTensor,
inverse_sigmoid_offset,
nested_tensor_from_tensor_list,
sigmoid_offset
)
from adet.modeling.model.utils import MLP
class DETECTION_TRANSFORMER(nn.Module):
def __init__(self, cfg, backbone):
super().__init__()
self.device = torch.device(cfg.MODEL.DEVICE)
self.backbone = backbone
self.d_model = cfg.MODEL.TRANSFORMER.HIDDEN_DIM
self.nhead = cfg.MODEL.TRANSFORMER.NHEADS
self.num_encoder_layers = cfg.MODEL.TRANSFORMER.ENC_LAYERS
self.num_decoder_layers = cfg.MODEL.TRANSFORMER.DEC_LAYERS
self.dim_feedforward = cfg.MODEL.TRANSFORMER.DIM_FEEDFORWARD
self.dropout = cfg.MODEL.TRANSFORMER.DROPOUT
self.activation = "relu"
self.return_intermediate_dec = True
self.num_feature_levels = cfg.MODEL.TRANSFORMER.NUM_FEATURE_LEVELS
self.dec_n_points = cfg.MODEL.TRANSFORMER.ENC_N_POINTS
self.enc_n_points = cfg.MODEL.TRANSFORMER.DEC_N_POINTS
self.pos_embed_scale = cfg.MODEL.TRANSFORMER.POSITION_EMBEDDING_SCALE
self.num_classes = 1 # text or not text
self.voc_size = cfg.MODEL.TRANSFORMER.VOC_SIZE
self.sigmoid_offset = False
self.num_proposals = cfg.MODEL.TRANSFORMER.NUM_QUERIES
self.num_points = cfg.MODEL.TRANSFORMER.NUM_POINTS
self.point_embed = nn.Embedding(self.num_proposals * self.num_points, self.d_model)
self.transformer = DeformableTransformer(
temp=cfg.MODEL.TRANSFORMER.TEMPERATURE,
d_model=self.d_model,
nhead=self.nhead,
num_encoder_layers=self.num_encoder_layers,
num_decoder_layers=self.num_decoder_layers,
dim_feedforward=self.dim_feedforward,
dropout=self.dropout,
activation=self.activation,
return_intermediate_dec=self.return_intermediate_dec,
num_feature_levels=self.num_feature_levels,
dec_n_points=self.dec_n_points,
enc_n_points=self.enc_n_points,
num_proposals=self.num_proposals,
num_points=self.num_points
)
if self.num_feature_levels > 1:
strides = [8, 16, 32]
if cfg.MODEL.BACKBONE.NAME == 'build_swin_backbone':
if cfg.MODEL.SWIN.TYPE == 'tiny' or 'small':
num_channels = [192, 384, 768]
else:
raise NotImplementedError
elif cfg.MODEL.BACKBONE.NAME == 'build_vitaev2_backbone':
if cfg.MODEL.ViTAEv2.TYPE == 'vitaev2_s':
num_channels = [128, 256, 512]
else:
raise NotImplementedError
else:
num_channels = [512, 1024, 2048]
num_backbone_outs = len(strides)
input_proj_list = []
for _ in range(num_backbone_outs):
in_channels = num_channels[_]
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels, self.d_model, kernel_size=1),
nn.GroupNorm(32, self.d_model),
)
)
for _ in range(self.num_feature_levels - num_backbone_outs):
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels, self.d_model, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(32, self.d_model),
)
)
in_channels = self.d_model
self.input_proj = nn.ModuleList(input_proj_list)
else:
strides = [32]
num_channels = [2048]
self.input_proj = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(
num_channels[0], self.d_model, kernel_size=1),
nn.GroupNorm(32, self.d_model),
)
]
)
for proj in self.input_proj:
nn.init.xavier_uniform_(proj[0].weight, gain=1)
nn.init.constant_(proj[0].bias, 0)
self.aux_loss = cfg.MODEL.TRANSFORMER.AUX_LOSS
# bezier center line proposal after the encoder
# x_0, y_0, ... , x_3, y_3
self.bezier_proposal_coord = MLP(self.d_model, self.d_model, 8, 3)
self.bezier_proposal_class = nn.Linear(self.d_model, self.num_classes) # text or non-text
# task specific heads after the decoder
self.ctrl_point_coord = MLP(self.d_model, self.d_model, 2, 3)
self.ctrl_point_class = nn.Linear(self.d_model, self.num_classes) # text or non-text
self.ctrl_point_text = nn.Linear(self.d_model, self.voc_size + 1) # specific character class for each point
self.boundary_head_on = cfg.MODEL.TRANSFORMER.BOUNDARY_HEAD
if self.boundary_head_on:
self.boundary_offset = MLP(self.d_model, self.d_model, 4, 3) # to rebuild the text boundary from queries
prior_prob = 0.01
bias_value = -np.log((1 - prior_prob) / prior_prob)
self.bezier_proposal_class.bias.data = torch.ones(self.num_classes) * bias_value
self.ctrl_point_class.bias.data = torch.ones(self.num_classes) * bias_value
self.ctrl_point_text.bias.data = torch.ones(self.voc_size + 1) * bias_value
nn.init.constant_(self.bezier_proposal_coord.layers[-1].weight.data, 0)
nn.init.constant_(self.bezier_proposal_coord.layers[-1].bias.data, 0)
self.transformer.bezier_coord_embed = self.bezier_proposal_coord
self.transformer.bezier_class_embed = self.bezier_proposal_class
nn.init.constant_(self.ctrl_point_coord.layers[-1].weight.data, 0)
nn.init.constant_(self.ctrl_point_coord.layers[-1].bias.data, 0)
if self.boundary_head_on:
nn.init.constant_(self.boundary_offset.layers[-1].weight.data, 0)
nn.init.constant_(self.boundary_offset.layers[-1].bias.data, 0)
######################################################################
# shared prediction heads
######################################################################
num_pred = self.num_decoder_layers
self.ctrl_point_coord = nn.ModuleList(
[self.ctrl_point_coord for _ in range(num_pred)]
)
self.ctrl_point_class = nn.ModuleList(
[self.ctrl_point_class for _ in range(num_pred)]
)
self.ctrl_point_text = nn.ModuleList(
[self.ctrl_point_text for _ in range(num_pred)]
)
if self.boundary_head_on:
self.boundary_offset = nn.ModuleList(
[self.boundary_offset for _ in range(num_pred)]
)
self.transformer.decoder.ctrl_point_coord = self.ctrl_point_coord
self.to(self.device)
def forward(self, samples: NestedTensor):
""" The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
"""
if isinstance(samples, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(samples)
features, pos = self.backbone(samples)
srcs = []
masks = []
for l, feat in enumerate(features):
src, mask = feat.decompose()
srcs.append(self.input_proj[l](src))
masks.append(mask)
assert mask is not None
if self.num_feature_levels > len(srcs):
_len_srcs = len(srcs)
for l in range(_len_srcs, self.num_feature_levels):
if l == _len_srcs:
src = self.input_proj[l](features[-1].tensors)
else:
src = self.input_proj[l](srcs[-1])
m = masks[0]
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
srcs.append(src)
masks.append(mask)
pos.append(pos_l)
# (n_proposal x n_pts, d_model) -> (n_proposal, n_pts, d_model)
point_embed = self.point_embed.weight.reshape((self.num_proposals, self.num_points, self.d_model)) # not shared
(
hs,
init_reference,
inter_references,
enc_outputs_class,
enc_outputs_coord_unact
) = self.transformer(srcs, masks, pos, point_embed)
outputs_texts = []
outputs_coords = []
outputs_classes = []
if self.boundary_head_on:
outputs_bd_coords = []
for lvl in range(hs.shape[0]):
if lvl == 0:
reference = init_reference
else:
reference = inter_references[lvl - 1]
# hs shape: (bs, n_proposal, n_pts, d_model)
reference = inverse_sigmoid_offset(reference, offset=self.sigmoid_offset)
outputs_class = self.ctrl_point_class[lvl](hs[lvl])
outputs_text = self.ctrl_point_text[lvl](hs[lvl]) # bs, n_proposal, n_pts, voc_size
tmp = self.ctrl_point_coord[lvl](hs[lvl])
if self.boundary_head_on:
tmp_bd = self.boundary_offset[lvl](hs[lvl])
if reference.shape[-1] == 2:
tmp += reference
if self.boundary_head_on:
tmp_bd += reference.repeat(1, 1, 1, 2)
else:
raise NotImplementedError
outputs_coord = sigmoid_offset(tmp, offset=self.sigmoid_offset)
if self.boundary_head_on:
outputs_bd_coord = sigmoid_offset(tmp_bd, offset=self.sigmoid_offset)
outputs_bd_coords.append(outputs_bd_coord)
outputs_classes.append(outputs_class)
outputs_texts.append(outputs_text)
outputs_coords.append(outputs_coord)
outputs_class = torch.stack(outputs_classes)
outputs_text = torch.stack(outputs_texts)
outputs_coord = torch.stack(outputs_coords)
if self.boundary_head_on:
outputs_bd_coord = torch.stack(outputs_bd_coords)
out = {
'pred_logits': outputs_class[-1],
'pred_text_logits': outputs_text[-1],
'pred_ctrl_points': outputs_coord[-1],
'pred_bd_points': outputs_bd_coord[-1] if self.boundary_head_on else None
}
if self.aux_loss:
out['aux_outputs'] = self._set_aux_loss(
outputs_class,
outputs_text,
outputs_coord,
outputs_bd_coord if self.boundary_head_on else None
)
enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
out['enc_outputs'] = {
'pred_logits': enc_outputs_class,
'pred_beziers': enc_outputs_coord
}
return out
@torch.jit.unused
def _set_aux_loss(self, outputs_class, outputs_text, outputs_coord, outputs_bd_coord):
if outputs_bd_coord is not None:
return [
{'pred_logits': a, 'pred_text_logits': b, 'pred_ctrl_points': c, 'pred_bd_points': d}
for a, b, c, d in zip(outputs_class[:-1], outputs_text[:-1], outputs_coord[:-1], outputs_bd_coord[:-1])
]
else:
return [
{'pred_logits': a, 'pred_text_logits': b, 'pred_ctrl_points': c}
for a, b, c in zip(outputs_class[:-1], outputs_text[:-1], outputs_coord[:-1])
]
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from adet.utils.misc import accuracy, is_dist_avail_and_initialized
from detectron2.utils.comm import get_world_size
from adet.utils.curve_utils import BezierSampler
def sigmoid_focal_loss(inputs, targets, num_inst, alpha: float = 0.25, gamma: float = 2):
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
Returns:
Loss tensor
"""
prob = inputs.sigmoid()
ce_loss = F.binary_cross_entropy_with_logits(
inputs, targets, reduction="none")
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
if loss.ndim == 4:
return loss.mean((1, 2)).sum() / num_inst
elif loss.ndim == 3:
return loss.mean(1).sum() / num_inst
else:
raise NotImplementedError(f"Unsupported dim {loss.ndim}")
class SetCriterion(nn.Module):
"""
The process happens in two steps:
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
"""
def __init__(
self,
num_classes,
enc_matcher,
dec_matcher,
weight_dict,
enc_losses,
num_sample_points,
dec_losses,
voc_size,
num_ctrl_points,
focal_alpha=0.25,
focal_gamma=2.0
):
""" Create the criterion.
Parameters:
num_classes: number of object categories, omitting the special no-object category
matcher: module able to compute a matching between targets and proposals
weight_dict: dict containing as key the names of the losses and as values their relative weight.
losses: list of all the losses to be applied. See get_loss for list of available losses.
focal_alpha: alpha in Focal Loss
"""
super().__init__()
self.num_classes = num_classes
self.enc_matcher = enc_matcher
self.dec_matcher = dec_matcher
self.weight_dict = weight_dict
self.enc_losses = enc_losses
self.num_sample_points = num_sample_points
self.bezier_sampler = BezierSampler(num_sample_points=num_sample_points)
self.dec_losses = dec_losses
self.voc_size = voc_size
self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma
self.num_ctrl_points = num_ctrl_points
def loss_labels(self, outputs, targets, indices, num_inst, log=False):
"""Classification loss (NLL)
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
"""
assert 'pred_logits' in outputs
src_logits = outputs['pred_logits']
idx = self._get_src_permutation_idx(indices)
target_classes = torch.full(src_logits.shape[:-1], self.num_classes,
dtype=torch.int64, device=src_logits.device)
target_classes_o = torch.cat([t["labels"][J]
for t, (_, J) in zip(targets, indices)])
if len(target_classes_o.shape) < len(target_classes[idx].shape):
target_classes_o = target_classes_o[..., None]
target_classes[idx] = target_classes_o
shape = list(src_logits.shape)
shape[-1] += 1
target_classes_onehot = torch.zeros(shape,
dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)
target_classes_onehot.scatter_(-1, target_classes.unsqueeze(-1), 1)
target_classes_onehot = target_classes_onehot[..., :-1]
# src_logits, target_classes_onehot: (bs, nq, n_ctrl_pts, 1)
loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_inst,
alpha=self.focal_alpha, gamma=self.focal_gamma) * src_logits.shape[1]
losses = {'loss_ce': loss_ce}
if log:
# this should probably be a separate loss, not hacked in this one here
losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
return losses
def loss_beziers(self, outputs, targets, indices, num_inst):
# may FIX: (1) seg valid points
assert 'pred_beziers' in outputs
idx = self._get_src_permutation_idx(indices)
src_beziers = outputs['pred_beziers'][idx]
src_beziers = self.bezier_sampler.get_sample_points(src_beziers.view(-1, 4, 2))
target_beziers = torch.cat(
[t['beziers'][i] for t, (_, i) in zip(targets, indices)],
dim=0
)
target_beziers = self.bezier_sampler.get_sample_points(target_beziers)
if target_beziers.numel() == 0:
target_beziers = src_beziers.clone().detach()
loss_bezier = F.l1_loss(src_beziers, target_beziers, reduction='none')
losses = {}
losses['loss_bezier'] = loss_bezier.sum() / num_inst
return losses
@torch.no_grad()
def loss_cardinality(self, outputs, targets, indices, num_inst):
""" Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
"""
pred_logits = outputs['pred_logits']
device = pred_logits.device
tgt_lengths = torch.as_tensor(
[len(v["labels"]) for v in targets], device=device)
card_pred = (pred_logits.mean(-2).argmax(-1) == 0).sum(1)
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
losses = {'cardinality_error': card_err}
return losses
def loss_texts(self, outputs, targets, indices, num_inst):
# CTC loss for classification of points
assert 'pred_text_logits' in outputs
idx = self._get_src_permutation_idx(indices)
src_texts = outputs['pred_text_logits'][idx] # shape: (n, length, voc_size+1)
src_texts = src_texts.permute(1, 0, 2)
src = F.log_softmax(src_texts, dim=-1) # shape: (length, n, voc_size+1)
target_texts = torch.cat([t['texts'][i] for t, (_, i) in zip(targets, indices)]) # n, length
input_lengths = torch.full((src.size(1),), src.size(0), dtype=torch.long)
target_lengths = (target_texts != self.voc_size).long().sum(dim=-1)
target_texts = torch.cat([t[:l] for t, l in zip(target_texts, target_lengths)])
return {
'loss_texts': F.ctc_loss(
src,
target_texts,
input_lengths,
target_lengths,
blank=self.voc_size,
zero_infinity=True
)
}
def loss_ctrl_points(self, outputs, targets, indices, num_inst):
"""Compute the L1 regression loss
"""
assert 'pred_ctrl_points' in outputs
idx = self._get_src_permutation_idx(indices)
src_ctrl_points = outputs['pred_ctrl_points'][idx]
target_ctrl_points = torch.cat([t['ctrl_points'][i] for t, (_, i) in zip(targets, indices)], dim=0)
loss_ctrl_points = F.l1_loss(src_ctrl_points, target_ctrl_points, reduction='sum')
losses = {'loss_ctrl_points': loss_ctrl_points / num_inst}
return losses
def loss_bd_points(self, outputs, targets, indices, num_inst):
assert 'pred_bd_points' in outputs
idx = self._get_src_permutation_idx(indices)
src_bd_points = outputs['pred_bd_points'][idx]
target_bd_points = torch.cat([t['bd_points'][i] for t, (_, i) in zip(targets, indices)], dim=0)
loss_bd_points = F.l1_loss(src_bd_points, target_bd_points, reduction='sum')
losses = {'loss_bd_points': loss_bd_points / num_inst}
return losses
@staticmethod
def _get_src_permutation_idx(indices):
# permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i)
for i, (src, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices])
return batch_idx, src_idx
@staticmethod
def _get_tgt_permutation_idx(indices):
# permute targets following indices
batch_idx = torch.cat([torch.full_like(tgt, i)
for i, (_, tgt) in enumerate(indices)])
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
return batch_idx, tgt_idx
def get_loss(self, loss, outputs, targets, indices, num_inst, **kwargs):
loss_map = {
'labels': self.loss_labels,
'cardinality': self.loss_cardinality,
'ctrl_points': self.loss_ctrl_points,
'beziers': self.loss_beziers,
'texts': self.loss_texts,
'bd_points': self.loss_bd_points
}
assert loss in loss_map, f'do you really want to compute {loss} loss?'
return loss_map[loss](outputs, targets, indices, num_inst, **kwargs)
def forward(self, outputs, targets):
outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs' and k != 'enc_outputs'}
# Retrieve the matching between the outputs of the last layer and the targets
indices = self.dec_matcher(outputs_without_aux, targets)
# Compute the average number of target boxes accross all nodes, for normalization purposes
num_inst = sum(len(t['ctrl_points']) for t in targets)
num_inst = torch.as_tensor(
[num_inst], dtype=torch.float, device=next(iter(outputs.values())).device)
if is_dist_avail_and_initialized():
torch.distributed.all_reduce(num_inst)
num_inst = torch.clamp(num_inst / get_world_size(), min=1).item()
# Compute all the requested losses
losses = {}
for loss in self.dec_losses:
kwargs = {}
losses.update(self.get_loss(loss, outputs, targets, indices, num_inst, **kwargs))
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if 'aux_outputs' in outputs:
for i, aux_outputs in enumerate(outputs['aux_outputs']):
indices = self.dec_matcher(aux_outputs, targets)
for loss in self.dec_losses:
kwargs = {}
if loss == 'labels':
# Logging is enabled only for the last layer
kwargs['log'] = False
l_dict = self.get_loss(
loss, aux_outputs, targets, indices, num_inst, **kwargs)
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
losses.update(l_dict)
if 'enc_outputs' in outputs:
enc_outputs = outputs['enc_outputs']
indices = self.enc_matcher(enc_outputs, targets)
for loss in self.enc_losses:
kwargs = {}
if loss == 'labels':
kwargs['log'] = False
l_dict = self.get_loss(
loss, enc_outputs, targets, indices, num_inst, **kwargs)
l_dict = {k + f'_enc': v for k, v in l_dict.items()}
losses.update(l_dict)
return losses
\ No newline at end of file
"""
Modules to compute the matching cost and solve the corresponding LSAP.
"""
import torch
from scipy.optimize import linear_sum_assignment
from torch import nn
import torch.nn.functional as F
from adet.utils.curve_utils import BezierSampler
class CtrlPointHungarianMatcher(nn.Module):
def __init__(
self,
class_weight: float = 1,
coord_weight: float = 1,
text_weight: float = 1,
focal_alpha: float = 0.25,
focal_gamma: float = 2.0
):
super().__init__()
self.class_weight = class_weight
self.coord_weight = coord_weight
self.text_weight = text_weight
self.alpha = focal_alpha
self.gamma = focal_gamma
assert class_weight != 0 or coord_weight != 0, "all costs cant be 0"
def forward(self, outputs, targets):
with torch.no_grad():
sizes = [len(v["ctrl_points"]) for v in targets]
bs, num_queries = outputs["pred_logits"].shape[:2]
out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid()
out_texts = F.log_softmax(outputs['pred_text_logits'], dim=-1) # (bs, n_q, n_pts, voc+1)
n_pts, voc = out_texts.shape[2], out_texts.shape[-1] - 1
target_texts = torch.cat([v["texts"] for v in targets])
target_lengths = (target_texts != voc).long().sum(dim=-1)
target_texts = torch.split(target_texts, sizes, dim=0)
target_lengths = torch.split(target_lengths, sizes)
texts_cost_list = []
for out_texts_batch, targe_texts_batch, target_len_batch in zip(out_texts, target_texts, target_lengths):
out_texts_batch_temp = out_texts_batch.repeat(targe_texts_batch.shape[0], 1, 1).permute(1, 0, 2)
input_len = torch.full((out_texts_batch_temp.size(1),), out_texts_batch_temp.size(0), dtype=torch.long)
targe_texts_batch_temp = torch.cat([
t[:target_len_batch[t_idx]].repeat(num_queries) for t_idx, t in enumerate(targe_texts_batch)
])
target_len_batch_temp = target_len_batch.reshape((-1, 1)).repeat(1, num_queries).reshape(-1)
text_cost = F.ctc_loss(
out_texts_batch_temp,
targe_texts_batch_temp,
input_len,
target_len_batch_temp,
blank=voc,
zero_infinity=True,
reduction='none'
)
text_cost.div_(target_len_batch_temp)
text_cost_cpu = text_cost.reshape((-1, num_queries)).transpose(1, 0).cpu()
texts_cost_list.append(text_cost_cpu)
# ctrl points of the text center line: (bz, n_q, n_pts, 2) --> (bz * n_q, n_pts * 2)
out_pts = outputs["pred_ctrl_points"].flatten(0, 1).flatten(-2)
tgt_pts = torch.cat([v["ctrl_points"] for v in targets]).flatten(-2)
neg_cost_class = (1 - self.alpha) * (out_prob ** self.gamma) * \
(-(1 - out_prob + 1e-8).log())
pos_cost_class = self.alpha * \
((1 - out_prob) ** self.gamma) * (-(out_prob + 1e-8).log())
cost_class = (pos_cost_class[..., 0] - neg_cost_class[..., 0]).mean(-1, keepdims=True)
cost_kpts = torch.cdist(out_pts, tgt_pts, p=1) # (bz * n_q, num_gt)
C = self.class_weight * cost_class + self.coord_weight * cost_kpts
C = C.view(bs, num_queries, -1).cpu()
indices = [linear_sum_assignment(
c[i] + self.text_weight * texts_cost_list[i]
) for i, c in enumerate(C.split(sizes, -1))]
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
class BezierHungarianMatcher(nn.Module):
def __init__(
self,
class_weight: float = 1,
coord_weight: float = 1,
num_sample_points: int = 100,
focal_alpha: float = 0.25,
focal_gamma: float = 2.0
):
"""Creates the matcher
Params:
class_weight: This is the relative weight of the classification error in the matching cost
coord_weight: not the control points of bezier curve but the sampled points on curve,
refer to "https://github.com/voldemortX/pytorch-auto-drive"
"""
super().__init__()
self.class_weight = class_weight
self.coord_weight = coord_weight
self.num_sample_points = num_sample_points
self.alpha = focal_alpha
self.gamma = focal_gamma
assert class_weight != 0 or coord_weight != 0, "all costs cant be 0"
self.bezier_sampler = BezierSampler(num_sample_points=num_sample_points)
def forward(self, outputs, targets):
with torch.no_grad():
bs, num_queries = outputs["pred_logits"].shape[:2]
out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid()
out_beziers = outputs["pred_beziers"].flatten(0, 1).view(-1, 4, 2) # (batch_size * num_queries, 4, 2)
tgt_ids = torch.cat([v["labels"] for v in targets])
tgt_beziers = torch.cat([v["beziers"] for v in targets]) # (g, 4, 2)
# Compute the classification cost.
neg_cost_class = (1 - self.alpha) * (out_prob ** self.gamma) * \
(-(1 - out_prob + 1e-8).log())
pos_cost_class = self.alpha * \
((1 - out_prob) ** self.gamma) * (-(out_prob + 1e-8).log())
cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
# Compute the L1 cost betweeen sampled points on Bezier curve
cost_coord = torch.cdist(
(self.bezier_sampler.get_sample_points(out_beziers)).flatten(start_dim=-2),
(self.bezier_sampler.get_sample_points(tgt_beziers)).flatten(start_dim=-2),
p=1
)
C = self.class_weight * cost_class + self.coord_weight * cost_coord
C = C.view(bs, num_queries, -1).cpu()
sizes = [len(v["beziers"]) for v in targets]
indices = [
linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))
]
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
def build_matcher(cfg):
cfg = cfg.MODEL.TRANSFORMER.LOSS
return BezierHungarianMatcher(class_weight=cfg.BEZIER_CLASS_WEIGHT,
coord_weight=cfg.BEZIER_COORD_WEIGHT,
num_sample_points=cfg.BEZIER_SAMPLE_POINTS,
focal_alpha=cfg.FOCAL_ALPHA,
focal_gamma=cfg.FOCAL_GAMMA), \
CtrlPointHungarianMatcher(class_weight=cfg.POINT_CLASS_WEIGHT,
coord_weight=cfg.POINT_COORD_WEIGHT,
text_weight=cfg.POINT_TEXT_WEIGHT,
focal_alpha=cfg.FOCAL_ALPHA,
focal_gamma=cfg.FOCAL_GAMMA)
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MLP(nn.Module):
""" Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k)
for n, k in zip([input_dim] + h, h + [output_dim])
)
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
def gen_point_pos_embed(pts_tensor, d_model, temp):
# pts_tensor: bs, nq, n_pts, 2
scale = 2 * math.pi
dim = d_model // 2
dim_t = torch.arange(dim, dtype=torch.float32, device=pts_tensor.device)
dim_t = temp ** (2 * torch.div(dim_t, 2, rounding_mode='trunc') / dim)
x_embed = pts_tensor[:, :, :, 0] * scale
y_embed = pts_tensor[:, :, :, 1] * scale
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_x, pos_y), dim=-1)
return pos
\ No newline at end of file
from .swin_transformer import build_swin_backbone
__all__ = [k for k in globals().keys() if not k.startswith("_")]
\ No newline at end of file
# --------------------------------------------------------
# Swin Transformer
# modified from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py
# --------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from detectron2.modeling.backbone import Backbone
from detectron2.modeling.backbone.build import BACKBONE_REGISTRY
from detectron2.layers import ShapeSpec
class Mlp(nn.Module):
""" Multilayer perceptron."""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(
self,
dim,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
""" Forward function.
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module):
""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(
self,
dim,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop
)
self.H = None
self.W = None
def forward(self, x, mask_matrix):
""" Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
mask_matrix: Attention mask for cyclic shift.
"""
B, L, C = x.shape
H, W = self.H, self.W
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# pad feature maps to multiples of window size
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
attn_mask = mask_matrix
else:
shifted_x = x
attn_mask = None
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchMerging(nn.Module):
""" Patch Merging Layer
Args:
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x, H, W):
""" Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
# padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of feature channels
depth (int): Depths of this stage.
num_heads (int): Number of attention head.
window_size (int): Local window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(
self,
dim,
depth,
num_heads,
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False
):
super().__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList(
[
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer
)
for i in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, H, W):
""" Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
h_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None)
)
w_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None)
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
attn_mask == 0, float(0.0)
)
for blk in self.blocks:
blk.H, blk.W = H, W
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, attn_mask)
else:
x = blk(x, attn_mask)
if self.downsample is not None:
x_down = self.downsample(x, H, W)
Wh, Ww = (H + 1) // 2, (W + 1) // 2
return x, H, W, x_down, Wh, Ww
else:
return x, H, W, x, H, W
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
Args:
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, H, W = x.size()
if W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
if H % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
x = self.proj(x) # B C Wh Ww
if self.norm is not None:
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
return x
class SwinTransformer(Backbone):
""" Swin Transformer backbone.
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
pretrain_img_size (int): Input image size for training the pretrained model,
used in absolute postion embedding. Default 224.
patch_size (int | tuple(int)): Patch size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
depths (tuple[int]): Depths of each Swin Transformer stage.
num_heads (tuple[int]): Number of attention head of each stage.
window_size (int): Window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
drop_rate (float): Dropout rate.
attn_drop_rate (float): Attention dropout rate. Default: 0.
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
out_indices (Sequence[int]): Output from which stages.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(
self,
pretrain_img_size=224,
patch_size=4,
in_chans=3,
embed_dim=96,
depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.2,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
frozen_stages=-1,
use_checkpoint=False,
out_features=None
):
super(SwinTransformer, self).__init__()
self.pretrain_img_size = pretrain_img_size
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.frozen_stages = frozen_stages
self.out_features = out_features
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None
)
# absolute position embedding
if self.ape:
pretrain_img_size = to_2tuple(pretrain_img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [
pretrain_img_size[0] // patch_size[0],
pretrain_img_size[1] // patch_size[1]
]
self.absolute_pos_embed = nn.Parameter(
torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
)
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
self._out_feature_strides = {}
self._out_feature_channels = {}
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint
)
self.layers.append(layer)
stage = f'stage{i_layer + 2}'
if stage in self.out_features:
self._out_feature_channels[stage] = embed_dim * 2 ** i_layer
self._out_feature_strides[stage] = 4 * 2 ** i_layer
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
self.num_features = num_features
# add a norm layer for each output
for i_layer in range(self.num_layers):
stage = f'stage{i_layer + 2}'
if stage in self.out_features:
layer = norm_layer(num_features[i_layer])
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
self._freeze_stages()
self._size_devisibility = 32
self.apply(self._init_weights)
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
if self.frozen_stages >= 1 and self.ape:
self.absolute_pos_embed.requires_grad = False
if self.frozen_stages >= 2:
self.pos_drop.eval()
for i in range(0, self.frozen_stages - 1):
m = self.layers[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@property
def size_divisibility(self):
return self._size_devisibility
def forward(self, x):
"""Forward function."""
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
if self.ape:
# interpolate the position embedding to the corresponding size
absolute_pos_embed = F.interpolate(
self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic'
)
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
else:
x = x.flatten(2).transpose(1, 2)
x = self.pos_drop(x)
outs = {}
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
name = f'stage{i + 2}'
if name in self.out_features:
norm_layer = getattr(self, f'norm{i}')
x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
outs[name] = out
return outs # {"stage%d" % (i+2,): out for i, out in enumerate(outs)} #tuple(outs)
def output_shape(self):
return {
name: ShapeSpec(
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
)
for name in self.out_features
}
@BACKBONE_REGISTRY.register()
def build_swin_backbone(cfg, input_shape):
swin_type = cfg.MODEL.SWIN.TYPE
if swin_type == 'tiny':
backbone = SwinTransformer(
in_chans=input_shape.channels,
embed_dim=96,
depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
window_size=7,
drop_path_rate=cfg.MODEL.SWIN.DROP_PATH_RATE,
ape=False,
patch_norm=True,
frozen_stages=cfg.MODEL.BACKBONE.FREEZE_AT,
out_features=["stage3", "stage4", "stage5"]
)
elif swin_type == 'small':
backbone = SwinTransformer(
in_chans=input_shape.channels,
embed_dim=96,
depths=(2, 2, 18, 2),
num_heads=(3, 6, 12, 24),
window_size=7,
drop_path_rate=cfg.MODEL.SWIN.DROP_PATH_RATE,
ape=False,
patch_norm=True,
frozen_stages=cfg.MODEL.BACKBONE.FREEZE_AT,
out_features=["stage3", "stage4", "stage5"]
)
else:
raise NotImplementedError
return backbone
\ No newline at end of file
from typing import List
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
from detectron2.modeling import build_backbone
from detectron2.structures import ImageList, Instances
from adet.layers.pos_encoding import PositionalEncoding2D
from adet.modeling.model.losses import SetCriterion
from adet.modeling.model.matcher import build_matcher
from adet.modeling.model.detection_transformer import DETECTION_TRANSFORMER
from adet.utils.misc import NestedTensor
class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)
def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for _, x in xs.items():
out.append(x)
# position encoding
pos.append(self[1](x).to(x.tensors.dtype))
return out, pos
class MaskedBackbone(nn.Module):
""" This is a thin wrapper around D2's backbone to provide padding masking"""
def __init__(self, cfg):
super().__init__()
self.backbone = build_backbone(cfg)
backbone_shape = self.backbone.output_shape()
self.feature_strides = [backbone_shape[f].stride for f in backbone_shape.keys()]
self.num_channels = backbone_shape[list(backbone_shape.keys())[-1]].channels
def forward(self, images):
features = self.backbone(images.tensor)
masks = self.mask_out_padding(
[features_per_level.shape for features_per_level in features.values()],
images.image_sizes,
images.tensor.device,
)
assert len(features) == len(masks)
for i, k in enumerate(features.keys()):
features[k] = NestedTensor(features[k], masks[i])
return features
def mask_out_padding(self, feature_shapes, image_sizes, device):
masks = []
assert len(feature_shapes) == len(self.feature_strides)
for idx, shape in enumerate(feature_shapes):
N, _, H, W = shape
masks_per_feature_level = torch.ones((N, H, W), dtype=torch.bool, device=device)
for img_idx, (h, w) in enumerate(image_sizes):
masks_per_feature_level[
img_idx,
: int(np.ceil(float(h) / self.feature_strides[idx])),
: int(np.ceil(float(w) / self.feature_strides[idx])),
] = 0
masks.append(masks_per_feature_level)
return masks
def detector_postprocess(results, output_height, output_width, min_size=None, max_size=None):
"""
scale align
"""
if min_size and max_size:
# to eliminate the padding influence for ViTAE backbone results
size = min_size * 1.0
scale_img_size = min_size / min(output_width, output_height)
if output_height < output_width:
newh, neww = size, scale_img_size * output_width
else:
newh, neww = scale_img_size * output_height, size
if max(newh, neww) > max_size:
scale = max_size * 1.0 / max(newh, neww)
newh = newh * scale
neww = neww * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
scale_x, scale_y = (output_width / neww, output_height / newh)
else:
scale_x, scale_y = (output_width / results.image_size[1], output_height / results.image_size[0])
# scale points
if results.has("ctrl_points"):
ctrl_points = results.ctrl_points
ctrl_points[:, 0::2] *= scale_x
ctrl_points[:, 1::2] *= scale_y
if results.has("bd") and not isinstance(results.bd, list):
bd = results.bd
bd[..., 0::2] *= scale_x
bd[..., 1::2] *= scale_y
return results
@META_ARCH_REGISTRY.register()
class TransformerPureDetector(nn.Module):
"""
Same as :class:`detectron2.modeling.ProposalNetwork`.
Use one stage detector and a second stage for instance-wise prediction.
"""
def __init__(self, cfg):
super().__init__()
self.device = torch.device(cfg.MODEL.DEVICE)
N_steps = cfg.MODEL.TRANSFORMER.HIDDEN_DIM // 2
self.test_score_threshold = cfg.MODEL.TRANSFORMER.INFERENCE_TH_TEST
self.min_size_test = None
self.max_size_test = None
if cfg.MODEL.BACKBONE.NAME == "build_vitaev2_backbone":
self.min_size_test = cfg.INPUT.MIN_SIZE_TEST
self.max_size_test = cfg.INPUT.MAX_SIZE_TEST
d2_backbone = MaskedBackbone(cfg)
backbone = Joiner(
d2_backbone,
PositionalEncoding2D(N_steps, cfg.MODEL.TRANSFORMER.TEMPERATURE, normalize=True)
)
backbone.num_channels = d2_backbone.num_channels
self.detection_transformer = DETECTION_TRANSFORMER(cfg, backbone)
bezier_matcher, point_matcher = build_matcher(cfg)
loss_cfg = cfg.MODEL.TRANSFORMER.LOSS
weight_dict = {
"loss_ce": loss_cfg.POINT_CLASS_WEIGHT,
"loss_texts": loss_cfg.POINT_TEXT_WEIGHT,
"loss_ctrl_points": loss_cfg.POINT_COORD_WEIGHT,
"loss_bd_points": loss_cfg.BOUNDARY_WEIGHT,
}
enc_weight_dict = {
"loss_bezier": loss_cfg.BEZIER_COORD_WEIGHT,
"loss_ce": loss_cfg.BEZIER_CLASS_WEIGHT
}
if loss_cfg.AUX_LOSS:
aux_weight_dict = {}
# decoder aux loss
for i in range(cfg.MODEL.TRANSFORMER.DEC_LAYERS - 1):
aux_weight_dict.update(
{k + f'_{i}': v for k, v in weight_dict.items()}
)
# encoder aux loss
aux_weight_dict.update(
{k + f'_enc': v for k, v in enc_weight_dict.items()}
)
weight_dict.update(aux_weight_dict)
enc_losses = ["labels", "beziers"]
if cfg.MODEL.TRANSFORMER.BOUNDARY_HEAD:
dec_losses = ["labels", "texts", "ctrl_points", "bd_points"]
else:
dec_losses = ["labels", "texts", "ctrl_points"]
self.criterion = SetCriterion(
self.detection_transformer.num_classes,
bezier_matcher,
point_matcher,
weight_dict,
enc_losses,
cfg.MODEL.TRANSFORMER.LOSS.BEZIER_SAMPLE_POINTS,
dec_losses,
cfg.MODEL.TRANSFORMER.VOC_SIZE,
self.detection_transformer.num_points,
focal_alpha=loss_cfg.FOCAL_ALPHA,
focal_gamma=loss_cfg.FOCAL_GAMMA
)
pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1)
pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1)
self.normalizer = lambda x: (x - pixel_mean) / pixel_std
self.to(self.device)
def preprocess_image(self, batched_inputs):
"""
Normalize, pad and batch the input images.
"""
images = [self.normalizer(x["image"].to(self.device)) for x in batched_inputs]
images = ImageList.from_tensors(images)
return images
def forward(self, batched_inputs):
"""
Args:
batched_inputs: a list, batched outputs of :class:`DatasetMapper` .
Each item in the list contains the inputs for one image.
For now, each item in the list is a dict that contains:
* image: Tensor, image in (C, H, W) format.
* instances (optional): groundtruth :class:`Instances`
* proposals (optional): :class:`Instances`, precomputed proposals.
Other information that's included in the original dicts, such as:
* "height", "width" (int): the output resolution of the model, used in inference.
See :meth:`postprocess` for details.
"""
images = self.preprocess_image(batched_inputs)
if self.training:
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
targets = self.prepare_targets(gt_instances)
output = self.detection_transformer(images)
loss_dict = self.criterion(output, targets)
weight_dict = self.criterion.weight_dict
for k in loss_dict.keys():
if k in weight_dict:
loss_dict[k] *= weight_dict[k]
return loss_dict
else:
output = self.detection_transformer(images)
ctrl_point_cls = output["pred_logits"]
ctrl_point_coord = output["pred_ctrl_points"]
ctrl_point_text = output["pred_text_logits"]
bd_points = output["pred_bd_points"]
results = self.inference(
ctrl_point_cls,
ctrl_point_coord,
ctrl_point_text,
bd_points,
images.image_sizes
)
processed_results = []
for results_per_image, input_per_image, image_size in zip(results, batched_inputs, images.image_sizes):
height = input_per_image.get("height", image_size[0])
width = input_per_image.get("width", image_size[1])
r = detector_postprocess(results_per_image, height, width, self.min_size_test, self.max_size_test)
processed_results.append({"instances": r})
return processed_results
def prepare_targets(self, targets):
new_targets = []
for targets_per_image in targets:
h, w = targets_per_image.image_size
gt_classes = targets_per_image.gt_classes
raw_beziers = targets_per_image.beziers
raw_ctrl_points = targets_per_image.polyline
raw_boundary = targets_per_image.boundary
gt_texts = targets_per_image.texts
gt_beziers = raw_beziers.reshape(-1, 4, 2) / \
torch.as_tensor([w, h], dtype=torch.float, device=self.device)[None, None, :]
gt_ctrl_points = raw_ctrl_points.reshape(-1, self.detection_transformer.num_points, 2) / \
torch.as_tensor([w, h], dtype=torch.float, device=self.device)[None, None, :]
gt_boundary = raw_boundary.reshape(-1, self.detection_transformer.num_points, 4) / \
torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device)[None, None, :]
new_targets.append(
{
"labels": gt_classes,
"beziers": gt_beziers,
"ctrl_points": gt_ctrl_points,
"texts": gt_texts,
"bd_points": gt_boundary
}
)
return new_targets
def inference(
self,
ctrl_point_cls,
ctrl_point_coord,
ctrl_point_text,
bd_points,
image_sizes
):
assert len(ctrl_point_cls) == len(image_sizes)
results = []
# cls shape: (b, nq, n_pts, voc_size)
ctrl_point_text = torch.softmax(ctrl_point_text, dim=-1)
prob = ctrl_point_cls.mean(-2).sigmoid()
scores, labels = prob.max(-1)
if bd_points is not None:
for scores_per_image, labels_per_image, ctrl_point_per_image, ctrl_point_text_per_image, bd, image_size in zip(
scores, labels, ctrl_point_coord, ctrl_point_text, bd_points, image_sizes
):
selector = scores_per_image >= self.test_score_threshold
scores_per_image = scores_per_image[selector]
labels_per_image = labels_per_image[selector]
ctrl_point_per_image = ctrl_point_per_image[selector]
ctrl_point_text_per_image = ctrl_point_text_per_image[selector]
bd = bd[selector]
result = Instances(image_size)
result.scores = scores_per_image
result.pred_classes = labels_per_image
result.rec_scores = ctrl_point_text_per_image
ctrl_point_per_image[..., 0] *= image_size[1]
ctrl_point_per_image[..., 1] *= image_size[0]
result.ctrl_points = ctrl_point_per_image.flatten(1)
_, text_pred = ctrl_point_text_per_image.topk(1)
result.recs = text_pred.squeeze(-1)
bd[..., 0::2] *= image_size[1]
bd[..., 1::2] *= image_size[0]
result.bd = bd
results.append(result)
return results
else:
for scores_per_image, labels_per_image, ctrl_point_per_image, ctrl_point_text_per_image, image_size in zip(
scores, labels, ctrl_point_coord, ctrl_point_text, image_sizes
):
selector = scores_per_image >= self.test_score_threshold
scores_per_image = scores_per_image[selector]
labels_per_image = labels_per_image[selector]
ctrl_point_per_image = ctrl_point_per_image[selector]
ctrl_point_text_per_image = ctrl_point_text_per_image[selector]
result = Instances(image_size)
result.scores = scores_per_image
result.pred_classes = labels_per_image
result.rec_scores = ctrl_point_text_per_image
ctrl_point_per_image[..., 0] *= image_size[1]
ctrl_point_per_image[..., 1] *= image_size[0]
result.ctrl_points = ctrl_point_per_image.flatten(1)
_, text_pred = ctrl_point_text_per_image.topk(1)
result.recs = text_pred.squeeze(-1)
result.bd = [None] * len(scores_per_image)
results.append(result)
return results
\ No newline at end of file
"""
Borrow from timm(https://github.com/rwightman/pytorch-image-models)
"""
import torch
import torch.nn as nn
import numpy as np
from .window import WindowAttention, window_partition, window_reverse
import math
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.hidden_features = hidden_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AttentionPerformer(nn.Module):
def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., kernel_ratio=0.5):
super().__init__()
self.head_dim = dim // num_heads
self.emb = dim
self.kqv = nn.Linear(dim, 3 * self.emb)
self.dp = nn.Dropout(proj_drop)
self.proj = nn.Linear(self.emb, self.emb)
self.head_cnt = num_heads
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
self.epsilon = 1e-8 # for stable in division
self.drop_path = nn.Identity()
self.m = int(self.head_dim * kernel_ratio)
self.w = torch.randn(self.head_cnt, self.m, self.head_dim)
for i in range(self.head_cnt):
self.w[i] = nn.Parameter(nn.init.orthogonal_(self.w[i]) * math.sqrt(self.m), requires_grad=False)
self.w.requires_grad_(False)
def prm_exp(self, x):
# part of the function is borrow from https://github.com/lucidrains/performer-pytorch
# and Simo Ryu (https://github.com/cloneofsimo)
# ==== positive random features for gaussian kernels ====
# x = (B, T, hs)
# w = (m, hs)
# return : x : B, T, m
# SM(x, y) = E_w[exp(w^T x - |x|/2) exp(w^T y - |y|/2)]
# therefore return exp(w^Tx - |x|/2)/sqrt(m)
xd = ((x * x).sum(dim=-1, keepdim=True)).repeat(1, 1, 1, self.m) / 2
wtx = torch.einsum('bhti,hmi->bhtm', x.float(), self.w.to(x.device))
return torch.exp(wtx - xd) / math.sqrt(self.m)
def attn(self, x):
B, N, C = x.shape
kqv = self.kqv(x).reshape(B, N, 3, self.head_cnt, self.head_dim).permute(2, 0, 3, 1, 4)
k, q, v = kqv[0], kqv[1], kqv[2] # (B, H, T, hs)
kp, qp = self.prm_exp(k), self.prm_exp(q) # (B, H, T, m), (B, H, T, m)
D = torch.einsum('bhti,bhi->bht', qp, kp.sum(dim=2)).unsqueeze(
dim=-1) # (B, H, T, m) * (B, H, m) -> (B, H, T, 1)
kptv = torch.einsum('bhin,bhim->bhnm', v.float(), kp) # (B, H, emb, m)
y = torch.einsum('bhti,bhni->bhtn', qp, kptv) / (
D.repeat(1, 1, 1, self.head_dim) + self.epsilon) # (B, H, T, emb)/Diag
# skip connection
y = y.permute(0, 2, 1, 3).reshape(B, N, self.emb)
y = self.dp(self.proj(y)) # same as token_transformer in T2T layer, use v as skip connection
return y
def forward(self, x):
x = self.attn(x)
return x
class NormalCell(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, class_token=False, group=64,
tokens_type='transformer',
shift_size=0, window_size=0, gamma=False, init_values=1e-4, SE=False, img_size=224):
super().__init__()
self.H = None
self.W = None
self.norm1 = norm_layer(dim)
self.class_token = class_token
self.img_size = img_size
self.window_size = window_size
self.shift_size = shift_size
self.tokens_type = tokens_type
if tokens_type == 'transformer':
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
elif tokens_type == 'performer':
self.attn = AttentionPerformer(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
elif tokens_type == 'window':
self.attn = WindowAttention(
in_dim=dim, out_dim=dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.PCM = nn.Sequential(
nn.Conv2d(dim, mlp_hidden_dim, 3, 1, 1, 1, group),
nn.BatchNorm2d(mlp_hidden_dim),
nn.SiLU(inplace=True),
nn.Conv2d(mlp_hidden_dim, dim, 3, 1, 1, 1, group),
nn.BatchNorm2d(dim),
nn.SiLU(inplace=True),
nn.Conv2d(dim, dim, 3, 1, 1, 1, group),
)
if gamma:
self.gamma1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma3 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
else:
self.gamma1 = 1
self.gamma2 = 1
self.gamma3 = 1
def forward(self, x):
b, n, c = x.shape
H, W = self.H, self.W
assert n == H * W
shortcut = x
if self.tokens_type == 'window':
padding_td = (self.window_size - H % self.window_size) % self.window_size
padding_top = padding_td // 2
padding_down = padding_td - padding_top
padding_lr = (self.window_size - W % self.window_size) % self.window_size
padding_left = padding_lr // 2
padding_right = padding_lr - padding_left
if self.shift_size > 0 and min(H, W) > self.window_size:
shift_size = self.shift_size
else:
shift_size = 0
if shift_size > 0:
img_mask = torch.zeros((1, H + padding_td, W + padding_lr, 1)).cuda() # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -shift_size),
slice(-shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -shift_size),
slice(-shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
x = self.norm1(x)
x = x.view(b, H, W, c).permute(0, 3, 1, 2)
x = nn.functional.pad(x, (padding_left, padding_right, padding_top, padding_down))
x = x.permute(0, 2, 3, 1)
# cyclic shift
if shift_size > 0:
shifted_x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
shifted_x = window_reverse(attn_windows, self.window_size, H + padding_td, W + padding_lr) # B H' W' C
# reverse cyclic shift
if shift_size > 0:
x = torch.roll(shifted_x, shifts=(shift_size, shift_size), dims=(1, 2))
else:
x = shifted_x
x = x[:, padding_top:padding_top + H, padding_left:padding_left + W, :]
x = x.reshape(b, H * W, c)
else:
x = self.gamma1 * self.attn(self.norm1(x))
if self.class_token:
n = n - 1
wh = int(math.sqrt(n))
convX = self.drop_path(
self.gamma2 * self.PCM(shortcut[:, 1:, :].view(b, wh, wh, c).permute(0, 3, 1, 2).contiguous()).permute(
0, 2, 3, 1).contiguous().view(b, n, c))
x = shortcut + self.drop_path(self.gamma1 * x)
x[:, 1:] = x[:, 1:] + convX
else:
# wh = int(math.sqrt(n))
convX = self.drop_path(
self.gamma2 * self.PCM(shortcut.view(b, H, W, c).permute(0, 3, 1, 2).contiguous()).permute(0, 2, 3,
1).contiguous().view(
b, n, c))
x = shortcut + self.drop_path(self.gamma1 * x) + convX
# x = x + convX
x = x + self.drop_path(self.gamma3 * self.mlp(self.norm2(x)))
return x
def get_sinusoid_encoding(n_position, d_hid):
''' Sinusoid position encoding table '''
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
\ No newline at end of file
import math
from numpy.core.fromnumeric import resize, shape
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
import numpy as np
from .token_transformer import Token_transformer
from .token_performer import Token_performer
from .window import WindowTransformerBlock, window_partition, window_reverse
class PRM(nn.Module):
def __init__(self, img_size=224, kernel_size=4, downsample_ratio=4, dilations=[1,6,12], in_chans=3, embed_dim=64, share_weights=False, op='cat'):
super().__init__()
self.dilations = dilations
self.embed_dim = embed_dim
self.downsample_ratio = downsample_ratio
self.op = op
self.kernel_size = kernel_size
self.stride = downsample_ratio
self.share_weights = share_weights
self.outSize = img_size // downsample_ratio
if share_weights:
self.convolution = nn.Conv2d(in_channels=in_chans, out_channels=embed_dim, kernel_size=self.kernel_size, \
stride=self.stride, padding=3*dilations[0]//2, dilation=dilations[0])
else:
self.convs = nn.ModuleList()
for dilation in self.dilations:
padding = math.ceil(((self.kernel_size-1)*dilation + 1 - self.stride) / 2)
self.convs.append(nn.Sequential(*[nn.Conv2d(in_channels=in_chans, out_channels=embed_dim, kernel_size=self.kernel_size, \
stride=self.stride, padding=padding, dilation=dilation),
nn.GELU()]))
if self.op == 'sum':
self.out_chans = embed_dim
elif op == 'cat':
self.out_chans = embed_dim * len(self.dilations)
def forward(self, x):
B, C, W, H = x.shape
if self.share_weights:
padding = math.ceil(((self.kernel_size-1)*self.dilations[0] + 1 - self.stride) / 2)
y = nn.functional.conv2d(x, weight=self.convolution.weight, bias=self.convolution.bias, \
stride=self.downsample_ratio, padding=padding, dilation=self.dilations[0]).unsqueeze(dim=-1)
for i in range(1, len(self.dilations)):
padding = math.ceil(((self.kernel_size-1)*self.dilations[i] + 1 - self.stride) / 2)
_y = nn.functional.conv2d(x, weight=self.convolution.weight, bias=self.convolution.bias, \
stride=self.downsample_ratio, padding=padding, dilation=self.dilations[i]).unsqueeze(dim=-1)
y = torch.cat((y, _y), dim=-1)
else:
y = self.convs[0](x).unsqueeze(dim=-1)
for i in range(1, len(self.dilations)):
_y = self.convs[i](x).unsqueeze(dim=-1)
y = torch.cat((y, _y), dim=-1)
B, C, W, H, N = y.shape
if self.op == 'sum':
y = y.sum(dim=-1).flatten(2).permute(0,2,1).contiguous()
elif self.op == 'cat':
y = y.permute(0,4,1,2,3).flatten(3).reshape(B, N*C, W*H).permute(0,2,1).contiguous()
else:
raise NotImplementedError('no such operation: {} for multi-levels!'.format(self.op))
return y, (W, H)
class ReductionCell(nn.Module):
def __init__(self, img_size=224, in_chans=3, embed_dims=64, token_dims=64, downsample_ratios=4, kernel_size=7,
num_heads=1, dilations=[1,2,3,4], share_weights=False, op='cat', tokens_type='performer', group=1,
drop=0., attn_drop=0., drop_path=0., mlp_ratio=1.0, gamma=False, init_values=1e-4, SE=False, window_size=7):
super().__init__()
self.img_size = img_size
self.window_size = window_size
self.op = op
self.dilations = dilations
self.num_heads = num_heads
self.embed_dims = embed_dims
self.token_dims = token_dims
self.in_chans = in_chans
self.downsample_ratios = downsample_ratios
self.kernel_size = kernel_size
self.outSize = img_size
PCMStride = []
residual = downsample_ratios // 2
for _ in range(3):
PCMStride.append((residual > 0) + 1)
residual = residual // 2
assert residual == 0
self.pool = None
self.tokens_type = tokens_type
if tokens_type == 'pooling':
PCMStride = [1, 1, 1]
self.pool = nn.MaxPool2d(downsample_ratios, stride=downsample_ratios, padding=0)
tokens_type = 'transformer'
self.outSize = self.outSize // downsample_ratios
downsample_ratios = 1
self.PCM = nn.Sequential(
nn.Conv2d(in_chans, embed_dims, kernel_size=(3, 3), stride=PCMStride[0], padding=(1, 1), groups=group), # the 1st convolution
nn.BatchNorm2d(embed_dims),
nn.SiLU(inplace=True),
nn.Conv2d(embed_dims, embed_dims, kernel_size=(3, 3), stride=PCMStride[1], padding=(1, 1), groups=group), # the 1st convolution
nn.BatchNorm2d(embed_dims),
nn.SiLU(inplace=True),
nn.Conv2d(embed_dims, token_dims, kernel_size=(3, 3), stride=PCMStride[2], padding=(1, 1), groups=group), # the 1st convolution
)
self.PRM = PRM(img_size=img_size, kernel_size=kernel_size, downsample_ratio=downsample_ratios, dilations=self.dilations,
in_chans=in_chans, embed_dim=embed_dims, share_weights=share_weights, op=op)
self.outSize = self.outSize // downsample_ratios
in_chans = self.PRM.out_chans
if tokens_type == 'performer':
# assert num_heads == 1
self.attn = Token_performer(dim=in_chans, in_dim=token_dims, head_cnt=num_heads, kernel_ratio=0.5, gamma=gamma, init_values=init_values)
elif tokens_type == 'performer_less':
self.attn = None
self.PCM = None
elif tokens_type == 'transformer':
self.attn = Token_transformer(dim=in_chans, in_dim=token_dims, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop,
attn_drop=attn_drop, drop_path=drop_path, gamma=gamma, init_values=init_values)
elif tokens_type == 'window':
self.attn = WindowTransformerBlock(in_dim=in_chans, out_dim=token_dims, input_resolution=(self.img_size//self.downsample_ratios, self.img_size//self.downsample_ratios),
num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop,
attn_drop=attn_drop, drop_path=drop_path, window_size=window_size, shift_size=0)
if gamma:
self.gamma2 = nn.Parameter(init_values * torch.ones((token_dims)),requires_grad=True)
self.gamma3 = nn.Parameter(init_values * torch.ones((token_dims)),requires_grad=True)
else:
self.gamma2 = 1
self.gamma3 = 1
self.num_patches = (img_size // 2) * (img_size // 2) # there are 3 sfot split, stride are 4,2,2 seperately
def forward(self, x, size):
H, W = size
if len(x.shape) < 4:
B, N, C = x.shape
# n = int(np.sqrt(N))
x = x.view(B, H, W, C).contiguous()
x = x.permute(0, 3, 1, 2)
if self.pool is not None:
x = self.pool(x)
shortcut = x
PRM_x, _ = self.PRM(x)
H, W = H // self.downsample_ratios, W // self.downsample_ratios
B, N, C = PRM_x.shape
assert N == H * W
if self.tokens_type == 'window':
# H, W = self.img_size // self.downsample_ratios, self.img_size // self.downsample_ratios
# b, _, c = PRM_x.shape
x = self.attn.norm1(PRM_x)
padding_td = (self.window_size - H % self.window_size) % self.window_size
padding_top = padding_td // 2
padding_down = padding_td - padding_top
padding_lr = (self.window_size - W % self.window_size) % self.window_size
padding_left = padding_lr // 2
padding_right = padding_lr - padding_left
x = x.view(B, H, W, C).permute(0, 3, 1, 2)
x = nn.functional.pad(x, (padding_left, padding_right, padding_top, padding_down))
x = x.permute(0, 2, 3, 1)
x_windows = window_partition(x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
attn_windows = self.attn.attn(x_windows, mask=self.attn.attn_mask) # nW*B, window_size*window_size, C
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.token_dims)
shifted_x = window_reverse(attn_windows, self.window_size, H+padding_td, W+padding_lr) # B H' W' C
x = shifted_x
x = x[:, padding_top:padding_top+H, padding_left:padding_left+W, :]
x = x.reshape(B, H * W, self.token_dims)
convX = self.PCM(shortcut)
convX = convX.permute(0, 2, 3, 1).view(*x.shape).contiguous()
x = x + self.attn.drop_path(convX * self.gamma2)
# x = shortcut + self.attn.drop_path(x)
# x = x + self.attn.drop_path(self.attn.mlp(self.attn.norm2(x)))
x = x + self.attn.drop_path(self.gamma3 * self.attn.mlp(self.attn.norm2(x)))
else:
if self.attn is None:
return PRM_x
convX = self.PCM(shortcut)
x = self.attn.attn(self.attn.norm1(PRM_x))
convX = convX.permute(0, 2, 3, 1).view(*x.shape).contiguous()
x = x + self.attn.drop_path(convX * self.gamma2)
x = x + self.attn.drop_path(self.gamma3 * self.attn.mlp(self.attn.norm2(x)))
return x, (H, W)
\ No newline at end of file
from .vitae_v2 import build_vitaev2_backbone
__all__ = [k for k in globals().keys() if not k.startswith("_")]
\ No newline at end of file
"""
Take Performer as T2T Transformer
"""
import math
import torch
import torch.nn as nn
import numpy as np
class Token_performer(nn.Module):
def __init__(self, dim, in_dim, head_cnt=1, kernel_ratio=0.5, dp1=0.1, dp2 = 0.1, gamma=False, init_values=1e-4):
super().__init__()
self.head_dim = in_dim // head_cnt
self.emb = in_dim
self.kqv = nn.Linear(dim, 3 * self.emb)
self.dp = nn.Dropout(dp1)
self.proj = nn.Linear(self.emb, self.emb)
self.head_cnt = head_cnt
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(self.emb)
self.epsilon = 1e-8 # for stable in division
self.drop_path = nn.Identity()
self.mlp = nn.Sequential(
nn.Linear(self.emb, 1 * self.emb),
nn.GELU(),
nn.Linear(1 * self.emb, self.emb),
nn.Dropout(dp2),
)
self.m = int(self.head_dim * kernel_ratio)
self.w = torch.randn(head_cnt, self.m, self.head_dim)
for i in range(self.head_cnt):
self.w[i] = nn.Parameter(nn.init.orthogonal_(self.w[i]) * math.sqrt(self.m), requires_grad=False)
self.w.requires_grad_(False)
if gamma:
self.gamma1 = nn.Parameter(init_values * torch.ones((self.emb)),requires_grad=True)
else:
self.gamma1 = 1
def prm_exp(self, x):
# part of the function is borrow from https://github.com/lucidrains/performer-pytorch
# and Simo Ryu (https://github.com/cloneofsimo)
# ==== positive random features for gaussian kernels ====
# x = (B, H, N, hs)
# w = (H, m, hs)
# return : x : B, T, m
# SM(x, y) = E_w[exp(w^T x - |x|/2) exp(w^T y - |y|/2)]
# therefore return exp(w^Tx - |x|/2)/sqrt(m)
xd = ((x * x).sum(dim=-1, keepdim=True)).repeat(1, 1, 1, self.m) / 2
wtx = torch.einsum('bhti,hmi->bhtm', x.float(), self.w.to(x.device))
return torch.exp(wtx - xd) / math.sqrt(self.m)
def attn(self, x):
B, N, C = x.shape
kqv = self.kqv(x).reshape(B, N, 3, self.head_cnt, self.head_dim).permute(2, 0, 3, 1, 4)
k, q, v = kqv[0], kqv[1], kqv[2] # (B, H, T, hs)
kp, qp = self.prm_exp(k), self.prm_exp(q) # (B, H, T, m), (B, H, T, m)
D = torch.einsum('bhti,bhi->bht', qp, kp.sum(dim=2)).unsqueeze(dim=-1) # (B, H, T, m) * (B, H, m) -> (B, H, T, 1)
kptv = torch.einsum('bhin,bhim->bhnm', v.float(), kp) # (B, H, emb, m)
y = torch.einsum('bhti,bhni->bhtn', qp, kptv) / (D.repeat(1, 1, 1, self.head_dim) + self.epsilon) # (B, H, T, emb)/Diag
# skip connection
y = y.permute(0, 2, 1, 3).reshape(B, N, self.emb)
v = v.permute(0, 2, 1, 3).reshape(B, N, self.emb)
y = v + self.dp(self.gamma1 * self.proj(y)) # same as token_transformer, use v as skip connection
return y
def forward(self, x):
x = self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment