"docs/source/zh/index.mdx" did not exist on "6e1c5786dcd770b7fbf23b945e0d3110c0ec316e"
Unverified Commit 59407bbe authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Add Deformable DETR (#17281)



* First draft

* More improvements

* Improve model, add custom CUDA code

* Import torch before

* Add script that imports custom layer

* Add everything in new ops directory

* Import custom layer in modeling file

* Fix ARCHIVE_MAP typo

* Creating the custom kernel on the fly.

* Import custom layer in modeling file

* More improvements

* Fix CUDA loading

* More improvements

* Improve conversion script

* Improve conversion script

* Make it work until encoder_outputs

* Make forward pass work

* More improvements

* Make logits match original implementation

* Make implementation also support single_scale model

* Add support for single_scale and dilation checkpoint

* Add support for with_box_refine model

* Support also two stage model

* Improve tests

* Fix more tests

* Make more tests pass

* Upload all models to the hub

* Clean up some code

* Improve decoder outputs

* Rename intermediate hidden states and reference points

* Improve model outputs

* Move tests to dedicated folder

* Improve model outputs

* Fix retain_grad test

* Improve docs

* Clean up and make test_initialization pass

* Improve variable names

* Add copied from statements

* Improve docs

* Fix style

* Improve docs

* Improve docs, move tests to model folder

* Fix rebase

* Remove DetrForSegmentation from auto mapping

* Apply suggestions from code review

* Improve variable names and docstrings

* Apply some more suggestions from code review

* Apply suggestion from code review

* better docs and variables names

* hint to num_queries and two_stage confusion

* remove asserts and code refactor

* add exception if two_stage is True and with_box_refine is False

* use f-strings

* Improve docs and variable names

* Fix code quality

* Fix rebase

* Add require_torch_gpu decorator

* Add pip install ninja to CI jobs

* Apply suggestion of @sgugger

* Remove DeformableDetrForObjectDetection from auto mapping

* Remove DeformableDetrModel from auto mapping

* Add model to toctree

* Add model back to mappings, skip model in pipeline tests

* Apply @sgugger's suggestion

* Fix imports in the init

* Fix copies

* Add CPU implementation

* Comment out GPU function

* Undo previous change

* Apply more suggestions

* Remove require_torch_gpu annotator

* Fix quality

* Add logger.info

* Fix logger

* Fix variable names

* Fix initializaztion

* Add missing initialization

* Update checkpoint name

* Add model to doc tests

* Add CPU/GPU equivalence test

* Add Deformable DETR to pipeline tests

* Skip model for object detection pipeline
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
Co-authored-by: default avatarNouamane Tazi <nouamane98@gmail.com>
Co-authored-by: default avatarSylvain Gugger <Sylvain.gugger@gmail.com>
parent 5a70a77b
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include <torch/extension.h>
at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step);
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step);
/*!
**************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************
* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
* Copyright (c) 2018 Microsoft
**************************************************************************
*/
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N, const int num_threads)
{
return (N + num_threads - 1) / num_threads;
}
template <typename scalar_t>
__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
const int &height, const int &width, const int &nheads, const int &channels,
const scalar_t &h, const scalar_t &w, const int &m, const int &c)
{
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
{
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
{
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
{
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
{
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
}
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename scalar_t>
__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
const int &height, const int &width, const int &nheads, const int &channels,
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
const scalar_t &top_grad,
const scalar_t &attn_weight,
scalar_t* &grad_value,
scalar_t* grad_sampling_loc,
scalar_t* grad_attn_weight)
{
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t top_grad_value = top_grad * attn_weight;
scalar_t grad_h_weight = 0, grad_w_weight = 0;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
{
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
atomicAdd(grad_value+ptr1, w1*top_grad_value);
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
{
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
atomicAdd(grad_value+ptr2, w2*top_grad_value);
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
{
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
atomicAdd(grad_value+ptr3, w3*top_grad_value);
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
{
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
atomicAdd(grad_value+ptr4, w4*top_grad_value);
}
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
*grad_attn_weight = top_grad * val;
*grad_sampling_loc = width * grad_w_weight * top_grad_value;
*(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
}
template <typename scalar_t>
__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
const int &height, const int &width, const int &nheads, const int &channels,
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
const scalar_t &top_grad,
const scalar_t &attn_weight,
scalar_t* &grad_value,
scalar_t* grad_sampling_loc,
scalar_t* grad_attn_weight)
{
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t top_grad_value = top_grad * attn_weight;
scalar_t grad_h_weight = 0, grad_w_weight = 0;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
{
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
atomicAdd(grad_value+ptr1, w1*top_grad_value);
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
{
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
atomicAdd(grad_value+ptr2, w2*top_grad_value);
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
{
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
atomicAdd(grad_value+ptr3, w3*top_grad_value);
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
{
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
atomicAdd(grad_value+ptr4, w4*top_grad_value);
}
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
atomicAdd(grad_attn_weight, top_grad * val);
atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
}
template <typename scalar_t>
__global__ void ms_deformable_im2col_gpu_kernel(const int n,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *data_col)
{
CUDA_KERNEL_LOOP(index, n)
{
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
scalar_t *data_col_ptr = data_col + index;
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
scalar_t col = 0;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
}
}
*data_col_ptr = col;
}
}
template <typename scalar_t, unsigned int blockSize>
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
__shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
if (tid == 0)
{
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
int sid=2;
for (unsigned int tid = 1; tid < blockSize; ++tid)
{
_grad_w += cache_grad_sampling_loc[sid];
_grad_h += cache_grad_sampling_loc[sid + 1];
_grad_a += cache_grad_attn_weight[tid];
sid += 2;
}
*grad_sampling_loc = _grad_w;
*(grad_sampling_loc + 1) = _grad_h;
*grad_attn_weight = _grad_a;
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t, unsigned int blockSize>
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
__shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
for (unsigned int s=blockSize/2; s>0; s>>=1)
{
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
}
__syncthreads();
}
if (tid == 0)
{
*grad_sampling_loc = cache_grad_sampling_loc[0];
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight = cache_grad_attn_weight[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
extern __shared__ int _s[];
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
if (tid == 0)
{
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
int sid=2;
for (unsigned int tid = 1; tid < blockDim.x; ++tid)
{
_grad_w += cache_grad_sampling_loc[sid];
_grad_h += cache_grad_sampling_loc[sid + 1];
_grad_a += cache_grad_attn_weight[tid];
sid += 2;
}
*grad_sampling_loc = _grad_w;
*(grad_sampling_loc + 1) = _grad_h;
*grad_attn_weight = _grad_a;
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
extern __shared__ int _s[];
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
{
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
if (tid + (s << 1) < spre)
{
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
}
}
__syncthreads();
}
if (tid == 0)
{
*grad_sampling_loc = cache_grad_sampling_loc[0];
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight = cache_grad_attn_weight[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
extern __shared__ int _s[];
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
{
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
if (tid + (s << 1) < spre)
{
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
}
}
__syncthreads();
}
if (tid == 0)
{
atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear_gm(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
grad_sampling_loc, grad_attn_weight);
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
void ms_deformable_im2col_cuda(cudaStream_t stream,
const scalar_t* data_value,
const int64_t* data_spatial_shapes,
const int64_t* data_level_start_index,
const scalar_t* data_sampling_loc,
const scalar_t* data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t* data_col)
{
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
const int num_threads = CUDA_NUM_THREADS;
ms_deformable_im2col_gpu_kernel<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
}
}
template <typename scalar_t>
void ms_deformable_col2im_cuda(cudaStream_t stream,
const scalar_t* grad_col,
const scalar_t* data_value,
const int64_t * data_spatial_shapes,
const int64_t * data_level_start_index,
const scalar_t * data_sampling_loc,
const scalar_t * data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t* grad_value,
scalar_t* grad_sampling_loc,
scalar_t* grad_attn_weight)
{
const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
if (channels > 1024)
{
if ((channels & 1023) == 0)
{
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads*3*sizeof(scalar_t), stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
}
else
{
ms_deformable_col2im_gpu_kernel_gm<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
}
}
else{
switch(channels)
{
case 1:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 2:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 4:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 8:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 16:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 32:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 64:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 128:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 256:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 512:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 1024:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
default:
if (channels < 64)
{
ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads*3*sizeof(scalar_t), stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
}
else
{
ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads*3*sizeof(scalar_t), stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
}
}
}
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
}
}
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include "cpu/ms_deform_attn_cpu.h"
#ifdef WITH_CUDA
#include "cuda/ms_deform_attn_cuda.h"
#endif
at::Tensor
ms_deform_attn_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
if (value.type().is_cuda())
{
#ifdef WITH_CUDA
return ms_deform_attn_cuda_forward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
std::vector<at::Tensor>
ms_deform_attn_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
if (value.type().is_cuda())
{
#ifdef WITH_CUDA
return ms_deform_attn_cuda_backward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include "ms_deform_attn.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
}
\ No newline at end of file
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Loading of Deformable DETR's CUDA kernels"""
import os
def load_cuda_kernels():
from torch.utils.cpp_extension import load
root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_kernel")
src_files = [
os.path.join(root, filename)
for filename in [
"vision.cpp",
os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
os.path.join("cuda", "ms_deform_attn_cuda.cu"),
]
]
load(
"MultiScaleDeformableAttention",
src_files,
# verbose=True,
with_cuda=True,
extra_include_paths=[root],
# build_directory=os.path.dirname(os.path.realpath(__file__)),
extra_cflags=["-DWITH_CUDA=1"],
extra_cuda_cflags=[
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
],
)
import MultiScaleDeformableAttention as MSDA
return MSDA
# coding=utf-8
# Copyright 2022 SenseTime and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Deformable DETR model."""
import copy
import math
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from ...activations import ACT2FN
from ...file_utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_scipy_available,
is_timm_available,
is_torch_cuda_available,
is_vision_available,
replace_return_docstrings,
requires_backends,
)
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_deformable_detr import DeformableDetrConfig
from .load_custom import load_cuda_kernels
logger = logging.get_logger(__name__)
# Move this to not compile only when importing, this needs to happen later, like in __init__.
if is_torch_cuda_available():
logger.info("Loading custom CUDA kernels...")
MultiScaleDeformableAttention = load_cuda_kernels()
else:
MultiScaleDeformableAttention = None
class MultiScaleDeformableAttentionFunction(Function):
@staticmethod
def forward(
context,
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
im2col_step,
):
context.im2col_step = im2col_step
output = MultiScaleDeformableAttention.ms_deform_attn_forward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
context.im2col_step,
)
context.save_for_backward(
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights
)
return output
@staticmethod
@once_differentiable
def backward(context, grad_output):
(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
) = context.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output,
context.im2col_step,
)
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
if is_scipy_available():
from scipy.optimize import linear_sum_assignment
if is_vision_available():
from transformers.models.detr.feature_extraction_detr import center_to_corners_format
if is_timm_available():
from timm import create_model
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "DeformableDetrConfig"
_CHECKPOINT_FOR_DOC = "sensetime/deformable-detr"
DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST = [
"sensetime/deformable-detr",
# See all Deformable DETR models at https://huggingface.co/models?filter=deformable-detr
]
@dataclass
class DeformableDetrDecoderOutput(ModelOutput):
"""
Base class for outputs of the DeformableDetrDecoder. This class adds two attributes to
BaseModelOutputWithCrossAttentions, namely:
- a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer)
- a stacked tensor of intermediate reference points.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
Stacked intermediate hidden states (output of each layer of the decoder).
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Stacked intermediate reference points (reference points of each layer of the decoder).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
the self-attention heads.
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
used to compute the weighted average in the cross-attention heads.
"""
last_hidden_state: torch.FloatTensor = None
intermediate_hidden_states: torch.FloatTensor = None
intermediate_reference_points: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class DeformableDetrModelOutput(ModelOutput):
"""
Base class for outputs of the Deformable DETR encoder-decoder model.
Args:
init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
Initial reference points sent through the Transformer decoder.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the decoder of the model.
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`):
Stacked intermediate hidden states (output of each layer of the decoder).
intermediate_reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 4)`):
Stacked intermediate reference points (reference points of each layer of the decoder).
decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, num_queries, hidden_size)`. Hidden-states of the decoder at the output of each layer
plus the initial embedding outputs.
decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, num_queries,
num_queries)`. Attentions weights of the decoder, after the attention softmax, used to compute the weighted
average in the self-attention heads.
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_queries, num_heads, 4, 4)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each
layer plus the initial embedding outputs.
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_queries, num_heads, 4, 4)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
foreground and background).
enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
Logits of predicted bounding boxes coordinates in the first stage.
"""
init_reference_points: torch.FloatTensor = None
last_hidden_state: torch.FloatTensor = None
intermediate_hidden_states: torch.FloatTensor = None
intermediate_reference_points: torch.FloatTensor = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
enc_outputs_class: Optional[torch.FloatTensor] = None
enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
@dataclass
class DeformableDetrObjectDetectionOutput(ModelOutput):
"""
Output type of [`DeformableDetrForObjectDetection`].
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
scale-invariant IoU loss.
loss_dict (`Dict`, *optional*):
A dictionary containing the individual losses. Useful for logging.
logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
Classification logits (including no-object) for all queries.
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
possible padding). You can use [`~AutoFeatureExtractor.post_process`] to retrieve the unnormalized bounding
boxes.
auxiliary_outputs (`list[Dict]`, *optional*):
Optional, only returned when auxilary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
`pred_boxes`) for each decoder layer.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the decoder of the model.
decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, num_queries, hidden_size)`. Hidden-states of the decoder at the output of each layer
plus the initial embedding outputs.
decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, num_queries,
num_queries)`. Attentions weights of the decoder, after the attention softmax, used to compute the weighted
average in the self-attention heads.
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_queries, num_heads, 4, 4)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each
layer plus the initial embedding outputs.
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_heads, 4,
4)`. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average
in the self-attention heads.
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`):
Stacked intermediate hidden states (output of each layer of the decoder).
intermediate_reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 4)`):
Stacked intermediate reference points (reference points of each layer of the decoder).
init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
Initial reference points sent through the Transformer decoder.
enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
foreground and background).
enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
Logits of predicted bounding boxes coordinates in the first stage.
"""
loss: Optional[torch.FloatTensor] = None
loss_dict: Optional[Dict] = None
logits: torch.FloatTensor = None
pred_boxes: torch.FloatTensor = None
auxiliary_outputs: Optional[List[Dict]] = None
init_reference_points: Optional[torch.FloatTensor] = None
last_hidden_state: Optional[torch.FloatTensor] = None
intermediate_hidden_states: Optional[torch.FloatTensor] = None
intermediate_reference_points: Optional[torch.FloatTensor] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
enc_outputs_class: Optional = None
enc_outputs_coord_logits: Optional = None
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def inverse_sigmoid(x, eps=1e-5):
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)
# Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->DeformableDetr
class DeformableDetrFrozenBatchNorm2d(nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
torchvision.models.resnet[18,34,50,101] produce nans.
"""
def __init__(self, n):
super().__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def forward(self, x):
# move reshapes to the beginning
# to make it user-friendly
weight = self.weight.reshape(1, -1, 1, 1)
bias = self.bias.reshape(1, -1, 1, 1)
running_var = self.running_var.reshape(1, -1, 1, 1)
running_mean = self.running_mean.reshape(1, -1, 1, 1)
epsilon = 1e-5
scale = weight * (running_var + epsilon).rsqrt()
bias = bias - running_mean * scale
return x * scale + bias
# Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->DeformableDetr
def replace_batch_norm(m, name=""):
for attr_str in dir(m):
target_attr = getattr(m, attr_str)
if isinstance(target_attr, nn.BatchNorm2d):
frozen = DeformableDetrFrozenBatchNorm2d(target_attr.num_features)
bn = getattr(m, attr_str)
frozen.weight.data.copy_(bn.weight)
frozen.bias.data.copy_(bn.bias)
frozen.running_mean.data.copy_(bn.running_mean)
frozen.running_var.data.copy_(bn.running_var)
setattr(m, attr_str, frozen)
for n, ch in m.named_children():
replace_batch_norm(ch, n)
class DeformableDetrTimmConvEncoder(nn.Module):
"""
Convolutional encoder (backbone) from the timm library.
nn.BatchNorm2d layers are replaced by DeformableDetrFrozenBatchNorm2d as defined above.
"""
def __init__(self, config):
super().__init__()
kwargs = {}
if config.dilation:
kwargs["output_stride"] = 16
requires_backends(self, ["timm"])
out_indices = (2, 3, 4) if config.num_feature_levels > 1 else (4,)
backbone = create_model(
config.backbone, pretrained=True, features_only=True, out_indices=out_indices, **kwargs
)
# replace batch norm by frozen batch norm
with torch.no_grad():
replace_batch_norm(backbone)
self.model = backbone
self.intermediate_channel_sizes = self.model.feature_info.channels()
self.strides = self.model.feature_info.reduction()
if "resnet" in config.backbone:
for name, parameter in self.model.named_parameters():
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
parameter.requires_grad_(False)
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
"""
Outputs feature maps of latter stages C_3 through C_5 in ResNet if `config.num_feature_levels > 1`, otherwise
outputs feature maps of C_5.
"""
# send pixel_values through the model to get list of feature maps
features = self.model(pixel_values)
out = []
for feature_map in features:
# downsample pixel_mask to match shape of corresponding feature_map
mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
out.append((feature_map, mask))
return out
# Copied from transformers.models.detr.modeling_detr.DetrConvModel with Detr->DeformableDetr
class DeformableDetrConvModel(nn.Module):
"""
This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder.
"""
def __init__(self, conv_encoder, position_embedding):
super().__init__()
self.conv_encoder = conv_encoder
self.position_embedding = position_embedding
def forward(self, pixel_values, pixel_mask):
# send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples
out = self.conv_encoder(pixel_values, pixel_mask)
pos = []
for feature_map, mask in out:
# position encoding
pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype))
return out, pos
# Copied from transformers.models.detr.modeling_detr._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, target_len: Optional[int] = None):
"""
Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, target_seq_len, source_seq_len]`.
"""
batch_size, source_len = mask.size()
target_len = target_len if target_len is not None else source_len
expanded_mask = mask[:, None, None, :].expand(batch_size, 1, target_len, source_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
class DeformableDetrSinePositionEmbedding(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
need paper, generalized to work on images.
"""
def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.embedding_dim = embedding_dim
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, pixel_values, pixel_mask):
if pixel_mask is None:
raise ValueError("No pixel mask provided")
y_embed = pixel_mask.cumsum(1, dtype=torch.float32)
x_embed = pixel_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.embedding_dim)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
# Copied from transformers.models.detr.modeling_detr.DetrLearnedPositionEmbedding
class DeformableDetrLearnedPositionEmbedding(nn.Module):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, embedding_dim=256):
super().__init__()
self.row_embeddings = nn.Embedding(50, embedding_dim)
self.column_embeddings = nn.Embedding(50, embedding_dim)
def forward(self, pixel_values, pixel_mask=None):
height, width = pixel_values.shape[-2:]
width_values = torch.arange(width, device=pixel_values.device)
height_values = torch.arange(height, device=pixel_values.device)
x_emb = self.column_embeddings(width_values)
y_emb = self.row_embeddings(height_values)
pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
pos = pos.permute(2, 0, 1)
pos = pos.unsqueeze(0)
pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
return pos
# Copied from transformers.models.detr.modeling_detr.build_position_encoding with Detr->DeformableDetr
def build_position_encoding(config):
n_steps = config.d_model // 2
if config.position_embedding_type == "sine":
# TODO find a better way of exposing other arguments
position_embedding = DeformableDetrSinePositionEmbedding(n_steps, normalize=True)
elif config.position_embedding_type == "learned":
position_embedding = DeformableDetrLearnedPositionEmbedding(n_steps)
else:
raise ValueError(f"Not supported {config.position_embedding_type}")
return position_embedding
def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
# for debug and test only,
# need to use cuda version instead
N_, S_, M_, D_ = value.shape
_, Lq_, M_, L_, P_, _ = sampling_locations.shape
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_)
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
# N_*M_, D_, Lq_, P_
sampling_value_l_ = F.grid_sample(
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
)
sampling_value_list.append(sampling_value_l_)
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
attention_weights = attention_weights.transpose(1, 2).reshape(N_ * M_, 1, Lq_, L_ * P_)
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_ * D_, Lq_)
return output.transpose(1, 2).contiguous()
class DeformableDetrMultiscaleDeformableAttention(nn.Module):
"""
Multiscale deformable attention as proposed in Deformable DETR.
"""
def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int):
super().__init__()
if embed_dim % num_heads != 0:
raise ValueError(
f"embed_dim (d_model) must be divisible by num_heads, but got {embed_dim} and {num_heads}"
)
dim_per_head = embed_dim // num_heads
# check if dim_per_head is power of 2
if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
warnings.warn(
"You'd better set embed_dim (d_model) in DeformableDetrMultiscaleDeformableAttention to make the"
" dimension of each attention head a power of 2 which is more efficient in the authors' CUDA"
" implementation."
)
self.im2col_step = 64
self.d_model = embed_dim
self.n_levels = n_levels
self.n_heads = num_heads
self.n_points = n_points
self.sampling_offsets = nn.Linear(embed_dim, num_heads * n_levels * n_points * 2)
self.attention_weights = nn.Linear(embed_dim, num_heads * n_levels * n_points)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.output_proj = nn.Linear(embed_dim, embed_dim)
self._reset_parameters()
def _reset_parameters(self):
nn.init.constant_(self.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
.view(self.n_heads, 1, 1, 2)
.repeat(1, self.n_levels, self.n_points, 1)
)
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
nn.init.constant_(self.attention_weights.weight.data, 0.0)
nn.init.constant_(self.attention_weights.bias.data, 0.0)
nn.init.xavier_uniform_(self.value_proj.weight.data)
nn.init.constant_(self.value_proj.bias.data, 0.0)
nn.init.xavier_uniform_(self.output_proj.weight.data)
nn.init.constant_(self.output_proj.bias.data, 0.0)
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
return tensor if position_embeddings is None else tensor + position_embeddings
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None,
encoder_attention_mask=None,
position_embeddings: Optional[torch.Tensor] = None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
output_attentions: bool = False,
):
# add position embeddings to the hidden states before projecting to queries and keys
if position_embeddings is not None:
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
batch_size, num_queries, _ = hidden_states.shape
batch_size, sequence_length, _ = encoder_hidden_states.shape
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
raise ValueError(
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
)
value = self.value_proj(encoder_hidden_states)
if attention_mask is not None:
# we invert the attention_mask
value = value.masked_fill(~attention_mask[..., None], float(0))
value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(hidden_states).view(
batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
)
attention_weights = self.attention_weights(hidden_states).view(
batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
)
attention_weights = F.softmax(attention_weights, -1).view(
batch_size, num_queries, self.n_heads, self.n_levels, self.n_points
)
# batch_size, num_queries, n_heads, n_levels, n_points, 2
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = (
reference_points[:, :, None, :, None, :]
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
)
elif reference_points.shape[-1] == 4:
sampling_locations = (
reference_points[:, :, None, :, None, :2]
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
)
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
try:
# GPU
output = MultiScaleDeformableAttentionFunction.apply(
value,
spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
except Exception:
# CPU
output = ms_deform_attn_core_pytorch(value, spatial_shapes, sampling_locations, attention_weights)
output = self.output_proj(output)
return output, attention_weights
class DeformableDetrMultiheadAttention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper.
Here, we add position embeddings to the queries and keys (as explained in the Deformable DETR paper).
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
if self.head_dim * num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
return tensor if position_embeddings is None else tensor + position_embeddings
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
batch_size, target_len, embed_dim = hidden_states.size()
# add position embeddings to the hidden states before projecting to queries and keys
if position_embeddings is not None:
hidden_states_original = hidden_states
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
# get queries, keys and values
query_states = self.q_proj(hidden_states) * self.scaling
key_states = self._shape(self.k_proj(hidden_states), -1, batch_size)
value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size)
proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
source_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
raise ValueError(
f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
f" {attn_weights.size()}"
)
# expand attention_mask
if attention_mask is not None:
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, target_len, source_len):
raise ValueError(
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
f" {attention_mask.size()}"
)
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped
class DeformableDetrEncoderLayer(nn.Module):
def __init__(self, config: DeformableDetrConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = DeformableDetrMultiscaleDeformableAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
n_levels=config.num_feature_levels,
n_points=config.encoder_n_points,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor = None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
output_attentions: bool = False,
):
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Input to the layer.
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Attention mask.
position_embeddings (`torch.FloatTensor`, *optional*):
Position embeddings, to be added to `hidden_states`.
reference_points (`torch.FloatTensor`, *optional*):
Reference points.
spatial_shapes (`torch.LongTensor`, *optional*):
Spatial shapes of the backbone feature maps.
level_start_index (`torch.LongTensor`, *optional*):
Level start index.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
# Apply Multi-scale Deformable Attention Module on the multi-scale feature maps.
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
if self.training:
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class DeformableDetrDecoderLayer(nn.Module):
def __init__(self, config: DeformableDetrConfig):
super().__init__()
self.embed_dim = config.d_model
# self-attention
self.self_attn = DeformableDetrMultiheadAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
# cross-attention
self.encoder_attn = DeformableDetrMultiscaleDeformableAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
n_levels=config.num_feature_levels,
n_points=config.decoder_n_points,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
# feedforward neural networks
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Optional[torch.Tensor] = None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
):
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(seq_len, batch, embed_dim)`.
position_embeddings (`torch.FloatTensor`, *optional*):
Position embeddings that are added to the queries and keys in the self-attention layer.
reference_points (`torch.FloatTensor`, *optional*):
Reference points.
spatial_shapes (`torch.LongTensor`, *optional*):
Spatial shapes.
level_start_index (`torch.LongTensor`, *optional*):
Level start index.
encoder_hidden_states (`torch.FloatTensor`):
cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
second_residual = hidden_states
# Cross-Attention
cross_attn_weights = None
hidden_states, cross_attn_weights = self.encoder_attn(
hidden_states=hidden_states,
attention_mask=encoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = second_residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# Fully Connected
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
return outputs
# Copied from transformers.models.detr.modeling_detr.DetrClassificationHead
class DeformableDetrClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, input_dim: int, inner_dim: int, num_classes: int, pooler_dropout: float):
super().__init__()
self.dense = nn.Linear(input_dim, inner_dim)
self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, hidden_states: torch.Tensor):
hidden_states = self.dropout(hidden_states)
hidden_states = self.dense(hidden_states)
hidden_states = torch.tanh(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.out_proj(hidden_states)
return hidden_states
class DeformableDetrPreTrainedModel(PreTrainedModel):
config_class = DeformableDetrConfig
base_model_prefix = "model"
main_input_name = "pixel_values"
def _init_weights(self, module):
std = self.config.init_std
if isinstance(module, DeformableDetrLearnedPositionEmbedding):
nn.init.uniform_(module.row_embeddings.weight)
nn.init.uniform_(module.column_embeddings.weight)
elif isinstance(module, DeformableDetrMultiscaleDeformableAttention):
module._reset_parameters()
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if hasattr(module, "reference_points") and not self.config.two_stage:
nn.init.xavier_uniform_(module.reference_points.weight.data, gain=1.0)
nn.init.constant_(module.reference_points.bias.data, 0.0)
if hasattr(module, "level_embed"):
nn.init.normal_(module.level_embed)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, DeformableDetrDecoder):
module.gradient_checkpointing = value
DEFORMABLE_DETR_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`DeformableDetrConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
DEFORMABLE_DETR_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it.
Pixel values can be obtained using [`AutoFeatureExtractor`]. See [`AutoFeatureExtractor.__call__`] for
details.
pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:
- 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**).
[What are attention masks?](../glossary#attention-mask)
decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, num_queries)`, *optional*):
Not used by default. Can be used to mask object queries.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
can choose to directly pass a flattened representation of an image.
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
embedded representation.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""
class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
"""
Transformer encoder consisting of *config.encoder_layers* deformable attention layers. Each layer is a
[`DeformableDetrEncoderLayer`].
The encoder updates the flattened multi-scale feature maps through multiple deformable attention layers.
Args:
config: DeformableDetrConfig
"""
def __init__(self, config: DeformableDetrConfig):
super().__init__(config)
self.dropout = config.dropout
self.layers = nn.ModuleList([DeformableDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
# Initialize weights and apply final processing
self.post_init()
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
"""
Get reference points for each feature map. Used in decoder.
Args:
spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
Spatial shapes of each feature map.
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
Valid ratios of each feature map.
device (`torch.device`):
Device on which to create the tensors.
Returns:
`torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)`
"""
reference_points_list = []
for level, (height, width) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device),
torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device),
)
# TODO: valid_ratios could be useless here. check https://github.com/fundamentalvision/Deformable-DETR/issues/36
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, level, 1] * height)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, level, 0] * width)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def forward(
self,
inputs_embeds=None,
attention_mask=None,
position_embeddings=None,
spatial_shapes=None,
level_start_index=None,
valid_ratios=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
- 1 for pixel features that are real (i.e. **not masked**),
- 0 for pixel features that are padding (i.e. **masked**).
[What are attention masks?](../glossary#attention-mask)
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Position embeddings that are added to the queries and keys in each self-attention layer.
spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
Spatial shapes of each feature map.
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
Starting index of each feature map.
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
Ratio of valid area in each feature level.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = inputs_embeds
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=inputs_embeds.device)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DeformableDetrDecoderLayer`].
The decoder updates the query embeddings through multiple self-attention and cross-attention layers.
Some tweaks for Deformable DETR:
- `position_embeddings`, `reference_points`, `spatial_shapes` and `valid_ratios` are added to the forward pass.
- it also returns a stack of intermediate outputs and reference points from all decoding layers.
Args:
config: DeformableDetrConfig
"""
def __init__(self, config: DeformableDetrConfig):
super().__init__(config)
self.dropout = config.dropout
self.layers = nn.ModuleList([DeformableDetrDecoderLayer(config) for _ in range(config.decoder_layers)])
self.gradient_checkpointing = False
# hack implementation for iterative bounding box refinement and two-stage Deformable DETR
self.bbox_embed = None
self.class_embed = None
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
position_embeddings=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
valid_ratios=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
The query embeddings that are passed into the decoder.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected
in `[0, 1]`:
- 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**).
position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
Position embeddings that are added to the queries and keys in each self-attention layer.
reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*):
Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area.
spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`):
Spatial shapes of the feature maps.
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*):
Indexes for the start of each feature level. In range `[0, sequence_length]`.
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*):
Ratio of valid area in each feature level.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is not None:
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
intermediate = ()
intermediate_reference_points = ()
for idx, decoder_layer in enumerate(self.layers):
if reference_points.shape[-1] == 4:
reference_points_input = (
reference_points[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None]
)
else:
if reference_points.shape[-1] != 2:
raise ValueError("Reference points' last dimension must be of size 2")
reference_points_input = reference_points[:, :, None] * valid_ratios[:, None]
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
encoder_hidden_states,
encoder_attention_mask,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
encoder_hidden_states=encoder_hidden_states,
reference_points=reference_points_input,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
# hack implementation for iterative bounding box refinement
if self.bbox_embed is not None:
tmp = self.bbox_embed[idx](hidden_states)
if reference_points.shape[-1] == 4:
new_reference_points = tmp + inverse_sigmoid(reference_points)
new_reference_points = new_reference_points.sigmoid()
else:
if reference_points.shape[-1] != 2:
raise ValueError(
f"Reference points' last dimension must be of size 2, but is {reference_points.shape[-1]}"
)
new_reference_points = tmp
new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points)
new_reference_points = new_reference_points.sigmoid()
reference_points = new_reference_points.detach()
intermediate += (hidden_states,)
intermediate_reference_points += (reference_points,)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
intermediate = torch.stack(intermediate)
intermediate_reference_points = torch.stack(intermediate_reference_points)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
intermediate,
intermediate_reference_points,
all_hidden_states,
all_self_attns,
all_cross_attentions,
]
if v is not None
)
return DeformableDetrDecoderOutput(
last_hidden_state=hidden_states,
intermediate_hidden_states=intermediate,
intermediate_reference_points=intermediate_reference_points,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
@add_start_docstrings(
"""
The bare Deformable DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw
hidden-states without any specific head on top.
""",
DEFORMABLE_DETR_START_DOCSTRING,
)
class DeformableDetrModel(DeformableDetrPreTrainedModel):
def __init__(self, config: DeformableDetrConfig):
super().__init__(config)
# Create backbone + positional encoding
backbone = DeformableDetrTimmConvEncoder(config)
position_embeddings = build_position_encoding(config)
self.backbone = DeformableDetrConvModel(backbone, position_embeddings)
# Create input projection layers
if config.num_feature_levels > 1:
num_backbone_outs = len(backbone.strides)
input_proj_list = []
for _ in range(num_backbone_outs):
in_channels = backbone.intermediate_channel_sizes[_]
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels, config.d_model, kernel_size=1),
nn.GroupNorm(32, config.d_model),
)
)
for _ in range(config.num_feature_levels - num_backbone_outs):
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(32, config.d_model),
)
)
in_channels = config.d_model
self.input_proj = nn.ModuleList(input_proj_list)
else:
self.input_proj = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1),
nn.GroupNorm(32, config.d_model),
)
]
)
if not config.two_stage:
self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model * 2)
self.encoder = DeformableDetrEncoder(config)
self.decoder = DeformableDetrDecoder(config)
self.level_embed = nn.Parameter(torch.Tensor(config.num_feature_levels, config.d_model))
if config.two_stage:
self.enc_output = nn.Linear(config.d_model, config.d_model)
self.enc_output_norm = nn.LayerNorm(config.d_model)
self.pos_trans = nn.Linear(config.d_model * 2, config.d_model * 2)
self.pos_trans_norm = nn.LayerNorm(config.d_model * 2)
else:
self.reference_points = nn.Linear(config.d_model, 2)
self.post_init()
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
def freeze_backbone(self):
for name, param in self.backbone.conv_encoder.model.named_parameters():
param.requires_grad_(False)
def unfreeze_backbone(self):
for name, param in self.backbone.conv_encoder.model.named_parameters():
param.requires_grad_(True)
def get_valid_ratio(self, mask):
"""Get the valid ratio of all feature maps."""
_, height, width = mask.shape
valid_height = torch.sum(~mask[:, :, 0], 1)
valid_width = torch.sum(~mask[:, 0, :], 1)
valid_ratio_heigth = valid_height.float() / height
valid_ratio_width = valid_width.float() / width
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1)
return valid_ratio
def get_proposal_pos_embed(self, proposals):
"""Get the position embedding of the proposals."""
num_pos_feats = 128
temperature = 10000
scale = 2 * math.pi
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
# batch_size, num_queries, 4
proposals = proposals.sigmoid() * scale
# batch_size, num_queries, 4, 128
pos = proposals[:, :, :, None] / dim_t
# batch_size, num_queries, 4, 64, 2 -> batch_size, num_queries, 512
pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)
return pos
def gen_encoder_output_proposals(self, enc_output, padding_mask, spatial_shapes):
"""Generate the encoder output proposals from encoded enc_output.
Args:
enc_output (Tensor[batch_size, sequence_length, hidden_size]): Output of the encoder.
padding_mask (Tensor[batch_size, sequence_length]): Padding mask for `enc_output`.
spatial_shapes (Tensor[num_feature_levels, 2]): Spatial shapes of the feature maps.
Returns:
`tuple(torch.FloatTensor)`: A tuple of feature map and bbox prediction.
- object_query (Tensor[batch_size, sequence_length, hidden_size]): Object query features. Later used to
directly predict a bounding box. (without the need of a decoder)
- output_proposals (Tensor[batch_size, sequence_length, 4]): Normalized proposals, after an inverse
sigmoid.
"""
batch_size = enc_output.shape[0]
proposals = []
_cur = 0
for level, (height, width) in enumerate(spatial_shapes):
mask_flatten_ = padding_mask[:, _cur : (_cur + height * width)].view(batch_size, height, width, 1)
valid_height = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
grid_y, grid_x = torch.meshgrid(
torch.linspace(0, height - 1, height, dtype=torch.float32, device=enc_output.device),
torch.linspace(0, width - 1, width, dtype=torch.float32, device=enc_output.device),
)
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
scale = torch.cat([valid_width.unsqueeze(-1), valid_height.unsqueeze(-1)], 1).view(batch_size, 1, 1, 2)
grid = (grid.unsqueeze(0).expand(batch_size, -1, -1, -1) + 0.5) / scale
width_heigth = torch.ones_like(grid) * 0.05 * (2.0**level)
proposal = torch.cat((grid, width_heigth), -1).view(batch_size, -1, 4)
proposals.append(proposal)
_cur += height * width
output_proposals = torch.cat(proposals, 1)
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
output_proposals = torch.log(output_proposals / (1 - output_proposals)) # inverse sigmoid
output_proposals = output_proposals.masked_fill(padding_mask.unsqueeze(-1), float("inf"))
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf"))
# assign each pixel as an object query
object_query = enc_output
object_query = object_query.masked_fill(padding_mask.unsqueeze(-1), float(0))
object_query = object_query.masked_fill(~output_proposals_valid, float(0))
object_query = self.enc_output_norm(self.enc_output(object_query))
return object_query, output_proposals
@add_start_docstrings_to_model_forward(DEFORMABLE_DETR_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=DeformableDetrModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values,
pixel_mask=None,
decoder_attention_mask=None,
encoder_outputs=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Returns:
Examples:
```python
>>> from transformers import AutoFeatureExtractor, DeformableDetrModel
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("SenseTime/deformable-detr")
>>> model = DeformableDetrModel.from_pretrained("SenseTime/deformable-detr")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
>>> list(last_hidden_states.shape)
[1, 300, 256]
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size, num_channels, height, width = pixel_values.shape
device = pixel_values.device
if pixel_mask is None:
pixel_mask = torch.ones(((batch_size, height, width)), dtype=torch.long, device=device)
# Extract multi-scale feature maps of same resolution `config.d_model` (cf Figure 4 in paper)
# First, sent pixel_values + pixel_mask through Backbone to obtain the features
# which is a list of tuples
features, position_embeddings_list = self.backbone(pixel_values, pixel_mask)
# Then, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
sources = []
masks = []
for level, (source, mask) in enumerate(features):
sources.append(self.input_proj[level](source))
masks.append(mask)
if mask is None:
raise ValueError("No attention mask was provided")
# Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
if self.config.num_feature_levels > len(sources):
_len_sources = len(sources)
for level in range(_len_sources, self.config.num_feature_levels):
if level == _len_sources:
source = self.input_proj[level](features[-1][0])
else:
source = self.input_proj[level](sources[-1])
mask = nn.functional.interpolate(pixel_mask[None].float(), size=source.shape[-2:]).to(torch.bool)[0]
pos_l = self.backbone.position_embedding(source, mask).to(source.dtype)
sources.append(source)
masks.append(mask)
position_embeddings_list.append(pos_l)
# Create queries
query_embeds = None
if not self.config.two_stage:
query_embeds = self.query_position_embeddings.weight
# Prepare encoder inputs (by flattening)
source_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for level, (source, mask, pos_embed) in enumerate(zip(sources, masks, position_embeddings_list)):
batch_size, num_channels, height, width = source.shape
spatial_shape = (height, width)
spatial_shapes.append(spatial_shape)
source = source.flatten(2).transpose(1, 2)
mask = mask.flatten(1)
pos_embed = pos_embed.flatten(2).transpose(1, 2)
lvl_pos_embed = pos_embed + self.level_embed[level].view(1, 1, -1)
lvl_pos_embed_flatten.append(lvl_pos_embed)
source_flatten.append(source)
mask_flatten.append(mask)
source_flatten = torch.cat(source_flatten, 1)
mask_flatten = torch.cat(mask_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
# revert valid_ratios
valid_ratios = ~valid_ratios.bool()
valid_ratios = valid_ratios.float()
# Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder
# Also provide spatial_shapes, level_start_index and valid_ratios
if encoder_outputs is None:
encoder_outputs = self.encoder(
inputs_embeds=source_flatten,
attention_mask=mask_flatten,
position_embeddings=lvl_pos_embed_flatten,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
# Fifth, prepare decoder inputs
batch_size, _, num_channels = encoder_outputs[0].shape
enc_outputs_class = None
enc_outputs_coord_logits = None
if self.config.two_stage:
object_query_embedding, output_proposals = self.gen_encoder_output_proposals(
encoder_outputs[0], ~mask_flatten, spatial_shapes
)
# hack implementation for two-stage Deformable DETR
# apply a detection head to each pixel (A.4 in paper)
# linear projection for bounding box binary classification (i.e. foreground and background)
enc_outputs_class = self.decoder.class_embed[-1](object_query_embedding)
# 3-layer FFN to predict bounding boxes coordinates (bbox regression branch)
delta_bbox = self.decoder.bbox_embed[-1](object_query_embedding)
enc_outputs_coord_logits = delta_bbox + output_proposals
# only keep top scoring `config.two_stage_num_proposals` proposals
topk = self.config.two_stage_num_proposals
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
topk_coords_logits = torch.gather(
enc_outputs_coord_logits, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
)
topk_coords_logits = topk_coords_logits.detach()
reference_points = topk_coords_logits.sigmoid()
init_reference_points = reference_points
pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_logits)))
query_embed, target = torch.split(pos_trans_out, num_channels, dim=2)
else:
query_embed, target = torch.split(query_embeds, num_channels, dim=1)
query_embed = query_embed.unsqueeze(0).expand(batch_size, -1, -1)
target = target.unsqueeze(0).expand(batch_size, -1, -1)
reference_points = self.reference_points(query_embed).sigmoid()
init_reference_points = reference_points
decoder_outputs = self.decoder(
inputs_embeds=target,
position_embeddings=query_embed,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=mask_flatten,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
enc_outputs = tuple(value for value in [enc_outputs_class, enc_outputs_coord_logits] if value is not None)
tuple_outputs = (init_reference_points,) + decoder_outputs + encoder_outputs + enc_outputs
return tuple_outputs
return DeformableDetrModelOutput(
init_reference_points=init_reference_points,
last_hidden_state=decoder_outputs.last_hidden_state,
intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
intermediate_reference_points=decoder_outputs.intermediate_reference_points,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
enc_outputs_class=enc_outputs_class,
enc_outputs_coord_logits=enc_outputs_coord_logits,
)
@add_start_docstrings(
"""
Deformable DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on
top, for tasks such as COCO detection.
""",
DEFORMABLE_DETR_START_DOCSTRING,
)
class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
def __init__(self, config: DeformableDetrConfig):
super().__init__(config)
# Deformable DETR encoder-decoder model
self.model = DeformableDetrModel(config)
# Detection heads on top
self.class_embed = nn.Linear(config.d_model, config.num_labels)
self.bbox_embed = DeformableDetrMLPPredictionHead(
input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
)
prior_prob = 0.01
bias_value = -math.log((1 - prior_prob) / prior_prob)
self.class_embed.bias.data = torch.ones(config.num_labels) * bias_value
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
# if two-stage, the last class_embed and bbox_embed is for region proposal generation
num_pred = (config.decoder_layers + 1) if config.two_stage else config.decoder_layers
if config.with_box_refine:
self.class_embed = _get_clones(self.class_embed, num_pred)
self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0)
# hack implementation for iterative bounding box refinement
self.model.decoder.bbox_embed = self.bbox_embed
else:
nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0)
self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)])
self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)])
self.model.decoder.bbox_embed = None
if config.two_stage:
# hack implementation for two-stage
self.model.decoder.class_embed = self.class_embed
for box_embed in self.bbox_embed:
nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0)
# Initialize weights and apply final processing
self.post_init()
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
@torch.jit.unused
def _set_aux_loss(self, outputs_class, outputs_coord):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
@add_start_docstrings_to_model_forward(DEFORMABLE_DETR_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=DeformableDetrObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values,
pixel_mask=None,
decoder_attention_mask=None,
encoder_outputs=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (`List[Dict]` of len `(batch_size,)`, *optional*):
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
Returns:
Examples:
```python
>>> from transformers import AutoFeatureExtractor, DeformableDetrForObjectDetection
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("SenseTime/deformable-detr")
>>> model = DeformableDetrForObjectDetection.from_pretrained("SenseTime/deformable-detr")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> # convert outputs (bounding boxes and class logits) to COCO API
>>> target_sizes = torch.tensor([image.size[::-1]])
>>> results = feature_extractor.post_process(outputs, target_sizes=target_sizes)[0]
>>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
... box = [round(i, 2) for i in box.tolist()]
... # let's only keep detections with score > 0.7
... if score > 0.7:
... print(
... f"Detected {model.config.id2label[label.item()]} with confidence "
... f"{round(score.item(), 3)} at location {box}"
... )
Detected cat with confidence 0.856 at location [342.19, 24.3, 640.02, 372.25]
Detected remote with confidence 0.739 at location [40.79, 72.78, 176.76, 117.25]
Detected cat with confidence 0.859 at location [16.5, 52.84, 318.25, 470.78]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# First, sent images through DETR base model to obtain encoder + decoder outputs
outputs = self.model(
pixel_values,
pixel_mask=pixel_mask,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs.intermediate_hidden_states if return_dict else outputs[2]
init_reference = outputs.init_reference_points if return_dict else outputs[0]
inter_references = outputs.intermediate_reference_points if return_dict else outputs[3]
# class logits + predicted bounding boxes
outputs_classes = []
outputs_coords = []
for level in range(hidden_states.shape[0]):
if level == 0:
reference = init_reference
else:
reference = inter_references[level - 1]
reference = inverse_sigmoid(reference)
outputs_class = self.class_embed[level](hidden_states[level])
delta_bbox = self.bbox_embed[level](hidden_states[level])
if reference.shape[-1] == 4:
outputs_coord_logits = delta_bbox + reference
elif reference.shape[-1] == 2:
delta_bbox[..., :2] += reference
outputs_coord_logits = delta_bbox
else:
raise ValueError(f"reference.shape[-1] should be 4 or 2, but got {reference.shape[-1]}")
outputs_coord = outputs_coord_logits.sigmoid()
outputs_classes.append(outputs_class)
outputs_coords.append(outputs_coord)
outputs_class = torch.stack(outputs_classes)
outputs_coord = torch.stack(outputs_coords)
logits = outputs_class[-1]
pred_boxes = outputs_coord[-1]
loss, loss_dict, auxiliary_outputs = None, None, None
if labels is not None:
# First: create the matcher
matcher = DeformableDetrHungarianMatcher(
class_cost=self.config.class_cost, bbox_cost=self.config.bbox_cost, giou_cost=self.config.giou_cost
)
# Second: create the criterion
losses = ["labels", "boxes", "cardinality"]
criterion = DeformableDetrLoss(
matcher=matcher,
num_classes=self.config.num_labels,
eos_coef=self.config.eos_coefficient,
losses=losses,
)
criterion.to(self.device)
# Third: compute the losses, based on outputs and labels
outputs_loss = {}
outputs_loss["logits"] = logits
outputs_loss["pred_boxes"] = pred_boxes
if self.config.auxiliary_loss:
intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4]
outputs_class = self.class_embed(intermediate)
outputs_coord = self.bbox_embed(intermediate).sigmoid()
auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord)
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
if self.config.two_stage:
enc_outputs_coord = outputs.enc_outputs_coord_logits.sigmoid()
outputs["enc_outputs"] = {"pred_logits": outputs.enc_outputs_class, "pred_boxes": enc_outputs_coord}
loss_dict = criterion(outputs_loss, labels)
# Fourth: compute total loss, as a weighted sum of the various losses
weight_dict = {"loss_ce": 1, "loss_bbox": self.config.bbox_loss_coefficient}
weight_dict["loss_giou"] = self.config.giou_loss_coefficient
if self.config.auxiliary_loss:
aux_weight_dict = {}
for i in range(self.config.decoder_layers - 1):
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict)
loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
if not return_dict:
if auxiliary_outputs is not None:
output = (logits, pred_boxes) + auxiliary_outputs + outputs
else:
output = (logits, pred_boxes) + outputs
tuple_outputs = ((loss, loss_dict) + output) if loss is not None else output
return tuple_outputs
dict_outputs = DeformableDetrObjectDetectionOutput(
loss=loss,
loss_dict=loss_dict,
logits=logits,
pred_boxes=pred_boxes,
auxiliary_outputs=auxiliary_outputs,
last_hidden_state=outputs.last_hidden_state,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
intermediate_hidden_states=outputs.intermediate_hidden_states,
init_reference_points=outputs.init_reference_points,
intermediate_reference_points=outputs.intermediate_reference_points,
enc_outputs_class=outputs.enc_outputs_class,
enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
)
return dict_outputs
# Copied from transformers.models.detr.modeling_detr.dice_loss
def dice_loss(inputs, targets, num_boxes):
"""
Compute the DICE loss, similar to generalized IOU for masks
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs (0 for the negative class and 1 for the positive
class).
"""
inputs = inputs.sigmoid()
inputs = inputs.flatten(1)
numerator = 2 * (inputs * targets).sum(1)
denominator = inputs.sum(-1) + targets.sum(-1)
loss = 1 - (numerator + 1) / (denominator + 1)
return loss.sum() / num_boxes
# Copied from transformers.models.detr.modeling_detr.sigmoid_focal_loss
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs (`torch.FloatTensor` of arbitrary shape):
The predictions for each example.
targets (`torch.FloatTensor` with the same shape as `inputs`)
A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class
and 1 for the positive class).
alpha (`float`, *optional*, defaults to `0.25`):
Optional weighting factor in the range (0,1) to balance positive vs. negative examples.
gamma (`int`, *optional*, defaults to `2`):
Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples.
Returns:
Loss tensor
"""
prob = inputs.sigmoid()
ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
# add modulating factor
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
return loss.mean(1).sum() / num_boxes
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
class DeformableDetrLoss(nn.Module):
"""
This class computes the losses for DeformableDetrForObjectDetection. The process happens in two steps: 1) we
compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of
matched ground-truth / prediction (supervise class and box)
"""
def __init__(self, matcher, num_classes, eos_coef, losses, focal_alpha=0.25):
"""
Create the criterion.
A note on the num_classes parameter (copied from original repo in detr.py): "the naming of the `num_classes`
parameter of the criterion is somewhat misleading. it indeed corresponds to `max_obj_id + 1`, where max_obj_id
is the maximum id for a class in your dataset. For example, COCO has a max_obj_id of 90, so we pass
`num_classes` to be 91. As another example, for a dataset that has a single class with id 1, you should pass
`num_classes` to be 2 (max_obj_id + 1). For more details on this, check the following discussion
https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223"
Parameters:
matcher: module able to compute a matching between targets and proposals.
num_classes: number of object categories, omitting the special no-object category.
eos_coef: relative classification weight applied to the no-object category.
losses: list of all the losses to be applied. See get_loss for list of available losses.
focal_alpha: alpha in Focal Loss.
"""
super().__init__()
self.matcher = matcher
self.num_classes = num_classes
self.losses = losses
self.focal_alpha = focal_alpha
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
"""Classification loss (NLL)
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
"""
if "logits" not in outputs:
raise ValueError("No logits were found in the outputs")
source_logits = outputs["logits"]
idx = self._get_source_permutation_idx(indices)
target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full(
source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device
)
target_classes[idx] = target_classes_o
target_classes_onehot = torch.zeros(
[source_logits.shape[0], source_logits.shape[1], source_logits.shape[2] + 1],
dtype=source_logits.dtype,
layout=source_logits.layout,
device=source_logits.device,
)
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
target_classes_onehot = target_classes_onehot[:, :, :-1]
loss_ce = (
sigmoid_focal_loss(source_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2)
* source_logits.shape[1]
)
losses = {"loss_ce": loss_ce}
return losses
@torch.no_grad()
def loss_cardinality(self, outputs, targets, indices, num_boxes):
"""
Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
"""
logits = outputs["logits"]
device = logits.device
target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
# Count the number of predictions that are NOT "no-object" (which is the last class)
card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
losses = {"cardinality_error": card_err}
return losses
def loss_boxes(self, outputs, targets, indices, num_boxes):
"""
Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes
are expected in format (center_x, center_y, w, h), normalized by the image size.
"""
if "pred_boxes" not in outputs:
raise ValueError("No predicted boxes found in outputs")
idx = self._get_source_permutation_idx(indices)
source_boxes = outputs["pred_boxes"][idx]
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
losses = {}
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
loss_giou = 1 - torch.diag(
generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))
)
losses["loss_giou"] = loss_giou.sum() / num_boxes
return losses
def _get_source_permutation_idx(self, indices):
# permute predictions following indices
batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
source_idx = torch.cat([source for (source, _) in indices])
return batch_idx, source_idx
def _get_target_permutation_idx(self, indices):
# permute targets following indices
batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
target_idx = torch.cat([target for (_, target) in indices])
return batch_idx, target_idx
def get_loss(self, loss, outputs, targets, indices, num_boxes):
loss_map = {
"labels": self.loss_labels,
"cardinality": self.loss_cardinality,
"boxes": self.loss_boxes,
}
if loss not in loss_map:
raise ValueError(f"Loss {loss} not supported")
return loss_map[loss](outputs, targets, indices, num_boxes)
def forward(self, outputs, targets):
"""
This performs the loss computation.
Parameters:
outputs: dict of tensors, see the output specification of the model for the format
targets: list of dicts, such that len(targets) == batch_size.
The expected keys in each dict depends on the losses applied, see each loss' doc
"""
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
# Retrieve the matching between the outputs of the last layer and the targets
indices = self.matcher(outputs_without_aux, targets)
# Compute the average number of target boxes accross all nodes, for normalization purposes
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# (Niels): comment out function below, distributed training to be added
# if is_dist_avail_and_initialized():
# torch.distributed.all_reduce(num_boxes)
# (Niels) in original implementation, num_boxes is divided by get_world_size()
num_boxes = torch.clamp(num_boxes, min=1).item()
# Compute all the requested losses
losses = {}
for loss in self.losses:
losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if "auxiliary_outputs" in outputs:
for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]):
indices = self.matcher(auxiliary_outputs, targets)
for loss in self.losses:
l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes)
l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
losses.update(l_dict)
if "enc_outputs" in outputs:
enc_outputs = outputs["enc_outputs"]
bin_targets = copy.deepcopy(targets)
for bt in bin_targets:
bt["labels"] = torch.zeros_like(bt["labels"])
indices = self.matcher(enc_outputs, bin_targets)
for loss in self.losses:
kwargs = {}
if loss == "labels":
# Logging is enabled only for the last layer
kwargs["log"] = False
l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs)
l_dict = {k + "_enc": v for k, v in l_dict.items()}
losses.update(l_dict)
return losses
# Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead
class DeformableDetrMLPPredictionHead(nn.Module):
"""
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
height and width of a bounding box w.r.t. an image.
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
# Copied from transformers.models.detr.modeling_detr.DetrHungarianMatcher
class DeformableDetrHungarianMatcher(nn.Module):
"""
This class computes an assignment between the targets and the predictions of the network.
For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
un-matched (and thus treated as non-objects).
Args:
class_cost:
The relative weight of the classification error in the matching cost.
bbox_cost:
The relative weight of the L1 error of the bounding box coordinates in the matching cost.
giou_cost:
The relative weight of the giou loss of the bounding box in the matching cost.
"""
def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1):
super().__init__()
requires_backends(self, ["scipy"])
self.class_cost = class_cost
self.bbox_cost = bbox_cost
self.giou_cost = giou_cost
if class_cost == 0 and bbox_cost == 0 and giou_cost == 0:
raise ValueError("All costs of the Matcher can't be 0")
@torch.no_grad()
def forward(self, outputs, targets):
"""
Args:
outputs (`dict`):
A dictionary that contains at least these entries:
* "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
* "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
targets (`List[dict]`):
A list of targets (len(targets) = batch_size), where each target is a dict containing:
* "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
ground-truth
objects in the target) containing the class labels
* "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.
Returns:
`List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
- index_i is the indices of the selected predictions (in order)
- index_j is the indices of the corresponding selected targets (in order)
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
"""
batch_size, num_queries = outputs["logits"].shape[:2]
# We flatten to compute the cost matrices in a batch
out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
# Also concat the target labels and boxes
target_ids = torch.cat([v["class_labels"] for v in targets])
target_bbox = torch.cat([v["boxes"] for v in targets])
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
class_cost = -out_prob[:, target_ids]
# Compute the L1 cost between boxes
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
# Compute the giou cost between boxes
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
# Final cost matrix
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
sizes = [len(v["boxes"]) for v in targets]
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
# Copied from transformers.models.detr.modeling_detr._upcast
def _upcast(t: Tensor) -> Tensor:
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
if t.is_floating_point():
return t if t.dtype in (torch.float32, torch.float64) else t.float()
else:
return t if t.dtype in (torch.int32, torch.int64) else t.int()
# Copied from transformers.models.detr.modeling_detr.box_area
def box_area(boxes: Tensor) -> Tensor:
"""
Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
Args:
boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
< x2` and `0 <= y1 < y2`.
Returns:
`torch.FloatTensor`: a tensor containing the area for each box.
"""
boxes = _upcast(boxes)
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
# Copied from transformers.models.detr.modeling_detr.box_iou
def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
union = area1[:, None] + area2 - inter
iou = inter / union
return iou, union
# Copied from transformers.models.detr.modeling_detr.generalized_box_iou
def generalized_box_iou(boxes1, boxes2):
"""
Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
Returns:
`torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
"""
# degenerate boxes gives inf / nan results
# so do an early check
if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
iou, union = box_iou(boxes1, boxes2)
top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2]
area = width_height[:, :, 0] * width_height[:, :, 1]
return iou - (area - union) / area
# Copied from transformers.models.detr.modeling_detr._max_by_axis
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
# Copied from transformers.models.detr.modeling_detr.NestedTensor
class NestedTensor(object):
def __init__(self, tensors, mask: Optional[Tensor]):
self.tensors = tensors
self.mask = mask
def to(self, device):
cast_tensor = self.tensors.to(device)
mask = self.mask
if mask is not None:
cast_mask = mask.to(device)
else:
cast_mask = None
return NestedTensor(cast_tensor, cast_mask)
def decompose(self):
return self.tensors, self.mask
def __repr__(self):
return str(self.tensors)
# Copied from transformers.models.detr.modeling_detr.nested_tensor_from_tensor_list
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
if tensor_list[0].ndim == 3:
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
batch_shape = [len(tensor_list)] + max_size
batch_size, num_channels, height, width = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((batch_size, height, width), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], : img.shape[2]] = False
else:
raise ValueError("Only 3-dimensional tensors are supported")
return NestedTensor(tensor, mask)
......@@ -273,7 +273,7 @@ class DetrFrozenBatchNorm2d(nn.Module):
"""
def __init__(self, n):
super(DetrFrozenBatchNorm2d, self).__init__()
super().__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
......@@ -286,7 +286,7 @@ class DetrFrozenBatchNorm2d(nn.Module):
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super(DetrFrozenBatchNorm2d, self)._load_from_state_dict(
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
......@@ -387,14 +387,14 @@ class DetrConvModel(nn.Module):
return out, pos
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, target_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, target_seq_len, source_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
batch_size, source_len = mask.size()
target_len = target_len if target_len is not None else source_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
expanded_mask = mask[:, None, None, :].expand(batch_size, 1, target_len, source_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
......@@ -449,12 +449,12 @@ class DetrLearnedPositionEmbedding(nn.Module):
self.column_embeddings = nn.Embedding(50, embedding_dim)
def forward(self, pixel_values, pixel_mask=None):
h, w = pixel_values.shape[-2:]
i = torch.arange(w, device=pixel_values.device)
j = torch.arange(h, device=pixel_values.device)
x_emb = self.column_embeddings(i)
y_emb = self.row_embeddings(j)
pos = torch.cat([x_emb.unsqueeze(0).repeat(h, 1, 1), y_emb.unsqueeze(1).repeat(1, w, 1)], dim=-1)
height, width = pixel_values.shape[-2:]
width_values = torch.arange(width, device=pixel_values.device)
height_values = torch.arange(height, device=pixel_values.device)
x_emb = self.column_embeddings(width_values)
y_emb = self.row_embeddings(height_values)
pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
pos = pos.permute(2, 0, 1)
pos = pos.unsqueeze(0)
pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
......@@ -506,8 +506,8 @@ class DetrAttention(nn.Module):
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
return tensor if position_embeddings is None else tensor + position_embeddings
......@@ -526,7 +526,7 @@ class DetrAttention(nn.Module):
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()
batch_size, target_len, embed_dim = hidden_states.size()
# add position embeddings to the hidden states before projecting to queries and keys
if position_embeddings is not None:
......@@ -543,35 +543,36 @@ class DetrAttention(nn.Module):
# get key, value proj
if is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states_original), -1, bsz)
key_states = self._shape(self.k_proj(key_value_states), -1, batch_size)
value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states_original), -1, bsz)
key_states = self._shape(self.k_proj(hidden_states), -1, batch_size)
value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
source_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
if attention_mask.size() != (batch_size, 1, target_len, source_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
f" {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
......@@ -580,8 +581,8 @@ class DetrAttention(nn.Module):
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
else:
attn_weights_reshaped = None
......@@ -589,15 +590,15 @@ class DetrAttention(nn.Module):
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
attn_output = self.out_proj(attn_output)
......@@ -632,7 +633,8 @@ class DetrEncoderLayer(nn.Module):
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
values.
position_embeddings (`torch.FloatTensor`, *optional*): position embeddings, to be added to hidden_states.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
......@@ -714,7 +716,8 @@ class DetrDecoderLayer(nn.Module):
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
values.
position_embeddings (`torch.FloatTensor`, *optional*):
position embeddings that are added to the queries and keys
in the cross-attention layer.
......@@ -724,7 +727,8 @@ class DetrDecoderLayer(nn.Module):
encoder_hidden_states (`torch.FloatTensor`):
cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
......@@ -957,7 +961,7 @@ class DetrEncoder(DetrPreTrainedModel):
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
encoder_states = () if output_hidden_states else None
......@@ -1081,15 +1085,17 @@ class DetrDecoder(DetrPreTrainedModel):
combined_attention_mask = None
if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
combined_attention_mask = combined_attention_mask + _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
attention_mask, inputs_embeds.dtype, target_len=input_shape[-1]
)
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
encoder_attention_mask = _expand_mask(
encoder_attention_mask, inputs_embeds.dtype, target_len=input_shape[-1]
)
# optional intermediate hidden states
intermediate = () if self.config.auxiliary_loss else None
......@@ -1889,21 +1895,22 @@ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: f
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs (0 for the negative class and 1 for the positive
class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
inputs (`torch.FloatTensor` of arbitrary shape):
The predictions for each example.
targets (`torch.FloatTensor` with the same shape as `inputs`)
A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class
and 1 for the positive class).
alpha (`float`, *optional*, defaults to `0.25`):
Optional weighting factor in the range (0,1) to balance positive vs. negative examples.
gamma (`int`, *optional*, defaults to `2`):
Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples.
Returns:
Loss tensor
"""
prob = inputs.sigmoid()
ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
# add modulating factor
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
......@@ -1981,10 +1988,10 @@ class DetrLoss(nn.Module):
"""
logits = outputs["logits"]
device = logits.device
tgt_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
# Count the number of predictions that are NOT "no-object" (which is the last class)
card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
card_err = nn.functional.l1_loss(card_pred.float(), tgt_lengths.float())
card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
losses = {"cardinality_error": card_err}
return losses
......@@ -2191,19 +2198,19 @@ class DetrHungarianMatcher(nn.Module):
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
# Also concat the target labels and boxes
tgt_ids = torch.cat([v["class_labels"] for v in targets])
tgt_bbox = torch.cat([v["boxes"] for v in targets])
target_ids = torch.cat([v["class_labels"] for v in targets])
target_bbox = torch.cat([v["boxes"] for v in targets])
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
class_cost = -out_prob[:, tgt_ids]
class_cost = -out_prob[:, target_ids]
# Compute the L1 cost between boxes
bbox_cost = torch.cdist(out_bbox, tgt_bbox, p=1)
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
# Compute the giou cost between boxes
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(tgt_bbox))
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
# Final cost matrix
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
......@@ -2267,15 +2274,17 @@ def generalized_box_iou(boxes1, boxes2):
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
iou, union = box_iou(boxes1, boxes2)
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
wh = (rb - lt).clamp(min=0) # [N,M,2]
area = wh[:, :, 0] * wh[:, :, 1]
width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2]
area = width_height[:, :, 0] * width_height[:, :, 1]
return iou - (area - union) / area
......@@ -2317,11 +2326,11 @@ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
if tensor_list[0].ndim == 3:
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
batch_size, num_channels, height, width = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
mask = torch.ones((batch_size, height, width), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], : img.shape[2]] = False
......
......@@ -1211,8 +1211,8 @@ class DetrAttention(nn.Module):
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
return tensor if position_embeddings is None else tensor + position_embeddings
......@@ -1231,7 +1231,7 @@ class DetrAttention(nn.Module):
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()
batch_size, target_len, embed_dim = hidden_states.size()
# add position embeddings to the hidden states before projecting to queries and keys
if position_embeddings is not None:
......@@ -1248,35 +1248,36 @@ class DetrAttention(nn.Module):
# get key, value proj
if is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states_original), -1, bsz)
key_states = self._shape(self.k_proj(key_value_states), -1, batch_size)
value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states_original), -1, bsz)
key_states = self._shape(self.k_proj(hidden_states), -1, batch_size)
value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
source_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
if attention_mask.size() != (batch_size, 1, target_len, source_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
f" {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
......@@ -1285,8 +1286,8 @@ class DetrAttention(nn.Module):
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
else:
attn_weights_reshaped = None
......@@ -1294,15 +1295,15 @@ class DetrAttention(nn.Module):
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
attn_output = self.out_proj(attn_output)
......@@ -1351,7 +1352,8 @@ class DetrDecoderLayer(nn.Module):
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
values.
position_embeddings (`torch.FloatTensor`, *optional*):
position embeddings that are added to the queries and keys
in the cross-attention layer.
......@@ -1361,7 +1363,8 @@ class DetrDecoderLayer(nn.Module):
encoder_hidden_states (`torch.FloatTensor`):
cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
......@@ -1416,14 +1419,14 @@ class DetrDecoderLayer(nn.Module):
# Copied from transformers.models.detr.modeling_detr._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, target_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, target_seq_len, source_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
batch_size, source_len = mask.size()
target_len = target_len if target_len is not None else source_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
expanded_mask = mask[:, None, None, :].expand(batch_size, 1, target_len, source_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
......
......@@ -874,21 +874,22 @@ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: f
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs (0 for the negative class and 1 for the positive
class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
inputs (`torch.FloatTensor` of arbitrary shape):
The predictions for each example.
targets (`torch.FloatTensor` with the same shape as `inputs`)
A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class
and 1 for the positive class).
alpha (`float`, *optional*, defaults to `0.25`):
Optional weighting factor in the range (0,1) to balance positive vs. negative examples.
gamma (`int`, *optional*, defaults to `2`):
Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples.
Returns:
Loss tensor
"""
prob = inputs.sigmoid()
ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
# add modulating factor
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
......@@ -966,10 +967,10 @@ class YolosLoss(nn.Module):
"""
logits = outputs["logits"]
device = logits.device
tgt_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
# Count the number of predictions that are NOT "no-object" (which is the last class)
card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
card_err = nn.functional.l1_loss(card_pred.float(), tgt_lengths.float())
card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
losses = {"cardinality_error": card_err}
return losses
......@@ -1176,19 +1177,19 @@ class YolosHungarianMatcher(nn.Module):
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
# Also concat the target labels and boxes
tgt_ids = torch.cat([v["class_labels"] for v in targets])
tgt_bbox = torch.cat([v["boxes"] for v in targets])
target_ids = torch.cat([v["class_labels"] for v in targets])
target_bbox = torch.cat([v["boxes"] for v in targets])
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
class_cost = -out_prob[:, tgt_ids]
class_cost = -out_prob[:, target_ids]
# Compute the L1 cost between boxes
bbox_cost = torch.cdist(out_bbox, tgt_bbox, p=1)
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
# Compute the giou cost between boxes
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(tgt_bbox))
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
# Final cost matrix
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
......@@ -1252,15 +1253,17 @@ def generalized_box_iou(boxes1, boxes2):
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
iou, union = box_iou(boxes1, boxes2)
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
wh = (rb - lt).clamp(min=0) # [N,M,2]
area = wh[:, :, 0] * wh[:, :, 1]
width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2]
area = width_height[:, :, 0] * width_height[:, :, 1]
return iou - (area - union) / area
......@@ -1302,11 +1305,11 @@ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
if tensor_list[0].ndim == 3:
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
batch_size, num_channels, height, width = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
mask = torch.ones((batch_size, height, width), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], : img.shape[2]] = False
......
......@@ -3,6 +3,30 @@
from ..utils import DummyObject, requires_backends
DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST = None
class DeformableDetrForObjectDetection(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class DeformableDetrModel(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class DeformableDetrPreTrainedModel(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
DETR_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..utils import requires_backends
DETR_PRETRAINED_MODEL_ARCHIVE_LIST = None
class DetrForObjectDetection:
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["timm"])
class DetrForSegmentation:
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["timm"])
class DetrModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["timm"])
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch Deformable DETR model. """
import inspect
import math
import unittest
from typing import Dict, List, Tuple
from transformers import DeformableDetrConfig, is_timm_available, is_vision_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_timm, require_torch_gpu, require_vision, slow, torch_device
from ...generation.test_generation_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
if is_timm_available():
import torch
from transformers import DeformableDetrForObjectDetection, DeformableDetrModel
if is_vision_available():
from PIL import Image
from transformers import AutoFeatureExtractor
class DeformableDetrModelTester:
def __init__(
self,
parent,
batch_size=8,
is_training=True,
use_labels=True,
hidden_size=256,
num_hidden_layers=2,
num_attention_heads=8,
intermediate_size=4,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
num_queries=12,
num_channels=3,
image_size=196,
n_targets=8,
num_labels=91,
num_feature_levels=4,
encoder_n_points=2,
decoder_n_points=6,
):
self.parent = parent
self.batch_size = batch_size
self.is_training = is_training
self.use_labels = use_labels
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.num_queries = num_queries
self.num_channels = num_channels
self.image_size = image_size
self.n_targets = n_targets
self.num_labels = num_labels
self.num_feature_levels = num_feature_levels
self.encoder_n_points = encoder_n_points
self.decoder_n_points = decoder_n_points
# we also set the expected seq length for both encoder and decoder
self.encoder_seq_length = (
math.ceil(self.image_size / 8) ** 2
+ math.ceil(self.image_size / 16) ** 2
+ math.ceil(self.image_size / 32) ** 2
+ math.ceil(self.image_size / 64) ** 2
)
self.decoder_seq_length = self.num_queries
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device)
labels = None
if self.use_labels:
# labels is a list of Dict (each Dict being the labels for a given example in the batch)
labels = []
for i in range(self.batch_size):
target = {}
target["class_labels"] = torch.randint(
high=self.num_labels, size=(self.n_targets,), device=torch_device
)
target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device)
target["masks"] = torch.rand(self.n_targets, self.image_size, self.image_size, device=torch_device)
labels.append(target)
config = self.get_config()
return config, pixel_values, pixel_mask, labels
def get_config(self):
return DeformableDetrConfig(
d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers,
decoder_layers=self.num_hidden_layers,
encoder_attention_heads=self.num_attention_heads,
decoder_attention_heads=self.num_attention_heads,
encoder_ffn_dim=self.intermediate_size,
decoder_ffn_dim=self.intermediate_size,
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
num_queries=self.num_queries,
num_labels=self.num_labels,
num_feature_levels=self.num_feature_levels,
encoder_n_points=self.encoder_n_points,
decoder_n_points=self.decoder_n_points,
)
def prepare_config_and_inputs_for_common(self):
config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs()
inputs_dict = {"pixel_values": pixel_values, "pixel_mask": pixel_mask}
return config, inputs_dict
def create_and_check_deformable_detr_model(self, config, pixel_values, pixel_mask, labels):
model = DeformableDetrModel(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
result = model(pixel_values)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.hidden_size))
def create_and_check_deformable_detr_object_detection_head_model(self, config, pixel_values, pixel_mask, labels):
model = DeformableDetrForObjectDetection(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
@require_timm
class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (DeformableDetrModel, DeformableDetrForObjectDetection) if is_timm_available() else ()
is_encoder_decoder = True
test_torchscript = False
test_pruning = False
test_head_masking = False
test_missing_keys = False
# special case for head models
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
if model_class.__name__ == "DeformableDetrForObjectDetection":
labels = []
for i in range(self.model_tester.batch_size):
target = {}
target["class_labels"] = torch.ones(
size=(self.model_tester.n_targets,), device=torch_device, dtype=torch.long
)
target["boxes"] = torch.ones(
self.model_tester.n_targets, 4, device=torch_device, dtype=torch.float
)
target["masks"] = torch.ones(
self.model_tester.n_targets,
self.model_tester.image_size,
self.model_tester.image_size,
device=torch_device,
dtype=torch.float,
)
labels.append(target)
inputs_dict["labels"] = labels
return inputs_dict
def setUp(self):
self.model_tester = DeformableDetrModelTester(self)
self.config_tester = ConfigTester(self, config_class=DeformableDetrConfig, has_text_modality=False)
def test_config(self):
# we don't test common_properties and arguments_init as these don't apply for Deformable DETR
self.config_tester.create_and_test_config_to_json_string()
self.config_tester.create_and_test_config_to_json_file()
self.config_tester.create_and_test_config_from_and_save_pretrained()
self.config_tester.create_and_test_config_with_num_labels()
self.config_tester.check_config_can_be_init_without_params()
def test_deformable_detr_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deformable_detr_model(*config_and_inputs)
def test_deformable_detr_object_detection_head_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deformable_detr_object_detection_head_model(*config_and_inputs)
@unittest.skip(reason="Deformable DETR does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="Deformable DETR does not have a get_input_embeddings method")
def test_model_common_attributes(self):
pass
@unittest.skip(reason="Deformable DETR is not a generative model")
def test_generate_without_input_ids(self):
pass
@unittest.skip(reason="Deformable DETR does not use token embeddings")
def test_resize_tokens_embeddings(self):
pass
@unittest.skip(reason="Feed forward chunking is not implemented")
def test_feed_forward_chunking(self):
pass
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[
self.model_tester.num_attention_heads,
self.model_tester.num_feature_levels,
self.model_tester.encoder_n_points,
],
)
out_len = len(outputs)
correct_outlen = 8
# loss is at first position
if "labels" in inputs_dict:
correct_outlen += 1 # loss is added to beginning
# Object Detection model returns pred_logits and pred_boxes
if model_class.__name__ == "DeformableDetrForObjectDetection":
correct_outlen += 2
self.assertEqual(out_len, correct_outlen)
# decoder attentions
decoder_attentions = outputs.decoder_attentions
self.assertIsInstance(decoder_attentions, (list, tuple))
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, self.model_tester.num_queries, self.model_tester.num_queries],
)
# cross attentions
cross_attentions = outputs.cross_attentions
self.assertIsInstance(cross_attentions, (list, tuple))
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(cross_attentions[0].shape[-3:]),
[
self.model_tester.num_attention_heads,
self.model_tester.num_feature_levels,
self.model_tester.decoder_n_points,
],
)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else:
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[
self.model_tester.num_attention_heads,
self.model_tester.num_feature_levels,
self.model_tester.encoder_n_points,
],
)
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def set_nan_tensor_to_zero(t):
t[t != t] = 0
return t
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
with torch.no_grad():
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(
tuple_object.values(), dict_object.values()
):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
),
)
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
print("Model class:", model_class)
model = model_class(config)
model.to(torch_device)
model.eval()
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(
model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
)
def test_retain_grad_hidden_states_attentions(self):
# removed retain_grad and grad on decoder_hidden_states, as queries don't require grad
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True
config.output_attentions = True
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
model.to(torch_device)
inputs = self._prepare_for_class(inputs_dict, model_class)
outputs = model(**inputs)
# we take the second output since last_hidden_state is the second item
output = outputs[1]
encoder_hidden_states = outputs.encoder_hidden_states[0]
encoder_attentions = outputs.encoder_attentions[0]
encoder_hidden_states.retain_grad()
encoder_attentions.retain_grad()
decoder_attentions = outputs.decoder_attentions[0]
decoder_attentions.retain_grad()
cross_attentions = outputs.cross_attentions[0]
cross_attentions.retain_grad()
output.flatten()[0].backward(retain_graph=True)
self.assertIsNotNone(encoder_hidden_states.grad)
self.assertIsNotNone(encoder_attentions.grad)
self.assertIsNotNone(decoder_attentions.grad)
self.assertIsNotNone(cross_attentions.grad)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
if model.config.is_encoder_decoder:
expected_arg_names = ["pixel_values", "pixel_mask"]
expected_arg_names.extend(
["head_mask", "decoder_head_mask", "encoder_outputs"]
if "head_mask" and "decoder_head_mask" in arg_names
else []
)
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
else:
expected_arg_names = ["pixel_values", "pixel_mask"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_different_timm_backbone(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# let's pick a random timm backbone
config.backbone = "tf_mobilenetv3_small_075"
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if model_class.__name__ == "DeformableDetrForObjectDetection":
expected_shape = (
self.model_tester.batch_size,
self.model_tester.num_queries,
self.model_tester.num_labels,
)
self.assertEqual(outputs.logits.shape, expected_shape)
self.assertTrue(outputs)
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
print("Model class:", model_class)
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if param.requires_grad:
if param.requires_grad:
if (
"level_embed" in name
or "sampling_offsets.bias" in name
or "value_proj" in name
or "output_proj" in name
or "reference_points" in name
):
continue
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
TOLERANCE = 1e-4
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_timm
@require_vision
@slow
class DeformableDetrModelIntegrationTests(unittest.TestCase):
@cached_property
def default_feature_extractor(self):
return AutoFeatureExtractor.from_pretrained("SenseTime/deformable-detr") if is_vision_available() else None
def test_inference_object_detection_head(self):
model = DeformableDetrForObjectDetection.from_pretrained("SenseTime/deformable-detr").to(torch_device)
feature_extractor = self.default_feature_extractor
image = prepare_img()
encoding = feature_extractor(images=image, return_tensors="pt").to(torch_device)
pixel_values = encoding["pixel_values"].to(torch_device)
pixel_mask = encoding["pixel_mask"].to(torch_device)
with torch.no_grad():
outputs = model(pixel_values, pixel_mask)
expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.num_labels))
self.assertEqual(outputs.logits.shape, expected_shape_logits)
expected_logits = torch.tensor(
[[-9.6645, -4.3449, -5.8705], [-9.7035, -3.8504, -5.0724], [-10.5634, -5.3379, -7.5116]]
).to(torch_device)
expected_boxes = torch.tensor(
[[0.8693, 0.2289, 0.2492], [0.3150, 0.5489, 0.5845], [0.5563, 0.7580, 0.8518]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_logits, atol=1e-4))
expected_shape_boxes = torch.Size((1, model.config.num_queries, 4))
self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes)
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_boxes, atol=1e-4))
def test_inference_object_detection_head_with_box_refine_two_stage(self):
model = DeformableDetrForObjectDetection.from_pretrained(
"SenseTime/deformable-detr-with-box-refine-two-stage"
).to(torch_device)
feature_extractor = self.default_feature_extractor
image = prepare_img()
encoding = feature_extractor(images=image, return_tensors="pt").to(torch_device)
pixel_values = encoding["pixel_values"].to(torch_device)
pixel_mask = encoding["pixel_mask"].to(torch_device)
with torch.no_grad():
outputs = model(pixel_values, pixel_mask)
expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.num_labels))
self.assertEqual(outputs.logits.shape, expected_shape_logits)
expected_logits = torch.tensor(
[[-6.7108, -4.3213, -6.3777], [-8.9014, -6.1799, -6.7240], [-6.9315, -4.4735, -6.2298]]
).to(torch_device)
expected_boxes = torch.tensor(
[[0.2583, 0.5499, 0.4683], [0.7652, 0.9068, 0.4882], [0.5490, 0.2763, 0.0564]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_logits, atol=1e-4))
expected_shape_boxes = torch.Size((1, model.config.num_queries, 4))
self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes)
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_boxes, atol=1e-4))
@require_torch_gpu
def test_inference_object_detection_head_equivalence_cpu_gpu(self):
feature_extractor = self.default_feature_extractor
image = prepare_img()
encoding = feature_extractor(images=image, return_tensors="pt")
pixel_values = encoding["pixel_values"]
pixel_mask = encoding["pixel_mask"]
# 1. run model on CPU
model = DeformableDetrForObjectDetection.from_pretrained("SenseTime/deformable-detr-single-scale")
with torch.no_grad():
cpu_outputs = model(pixel_values, pixel_mask)
# 2. run model on GPU
model.to("cuda")
with torch.no_grad():
gpu_outputs = model(pixel_values.to("cuda"), pixel_mask.to("cuda"))
# 3. assert equivalence
for key in cpu_outputs.keys():
assert torch.allclose(cpu_outputs[key], gpu_outputs[key].cpu(), atol=1e-4)
expected_logits = torch.tensor(
[[-9.9051, -4.2541, -6.4852], [-9.6947, -4.0854, -6.8033], [-10.0665, -5.8470, -7.7003]]
)
assert torch.allclose(cpu_outputs.logits[0, :3, :3], expected_logits, atol=1e-4)
......@@ -53,6 +53,12 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
def get_test_pipeline(self, model, tokenizer, feature_extractor):
if model.__class__.__name__ == "DeformableDetrForObjectDetection":
self.skipTest(
"""Deformable DETR requires a custom CUDA kernel.
"""
)
object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor)
return object_detector, ["./tests/fixtures/tests_samples/COCO/000000039769.png"]
......
......@@ -46,6 +46,8 @@ PRIVATE_MODELS = [
# Being in this list is an exception and should **not** be the rule.
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
# models to ignore for not tested
"DeformableDetrEncoder", # Building part of bigger (tested) model.
"DeformableDetrDecoder", # Building part of bigger (tested) model.
"OPTDecoder", # Building part of bigger (tested) model.
"DecisionTransformerGPT2Model", # Building part of bigger (tested) model.
"SegformerDecodeHead", # Building part of bigger (tested) model.
......
......@@ -28,6 +28,7 @@ src/transformers/models/data2vec/modeling_data2vec_audio.py
src/transformers/models/data2vec/modeling_data2vec_vision.py
src/transformers/models/deberta/modeling_deberta.py
src/transformers/models/deberta_v2/modeling_deberta_v2.py
src/transformers/models/deformable_detr/modeling_deformable_detr.py
src/transformers/models/deit/modeling_deit.py
src/transformers/models/deit/modeling_tf_deit.py
src/transformers/models/detr/modeling_detr.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment