Commit 1345fab2 authored by luopl's avatar luopl
Browse files

Initial commit

parents
Pipeline #1263 canceled with stages
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N, const int num_threads)
{
return (N + num_threads - 1) / num_threads;
}
template <typename scalar_t>
__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
const int &height, const int &width, const int &nheads, const int &channels,
const scalar_t &h, const scalar_t &w, const int &m, const int &c)
{
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
{
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
{
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
{
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
{
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
}
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename scalar_t>
__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
const int &height, const int &width, const int &nheads, const int &channels,
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
const scalar_t &top_grad,
const scalar_t &attn_weight,
scalar_t* &grad_value,
scalar_t* grad_sampling_loc,
scalar_t* grad_attn_weight)
{
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t top_grad_value = top_grad * attn_weight;
scalar_t grad_h_weight = 0, grad_w_weight = 0;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
{
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
atomicAdd(grad_value+ptr1, w1*top_grad_value);
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
{
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
atomicAdd(grad_value+ptr2, w2*top_grad_value);
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
{
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
atomicAdd(grad_value+ptr3, w3*top_grad_value);
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
{
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
atomicAdd(grad_value+ptr4, w4*top_grad_value);
}
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
*grad_attn_weight = top_grad * val;
*grad_sampling_loc = width * grad_w_weight * top_grad_value;
*(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
}
template <typename scalar_t>
__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
const int &height, const int &width, const int &nheads, const int &channels,
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
const scalar_t &top_grad,
const scalar_t &attn_weight,
scalar_t* &grad_value,
scalar_t* grad_sampling_loc,
scalar_t* grad_attn_weight)
{
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * channels + c;
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t top_grad_value = top_grad * attn_weight;
scalar_t grad_h_weight = 0, grad_w_weight = 0;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
{
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
atomicAdd(grad_value+ptr1, w1*top_grad_value);
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
{
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
atomicAdd(grad_value+ptr2, w2*top_grad_value);
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
{
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
atomicAdd(grad_value+ptr3, w3*top_grad_value);
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
{
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
atomicAdd(grad_value+ptr4, w4*top_grad_value);
}
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
atomicAdd(grad_attn_weight, top_grad * val);
atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
}
template <typename scalar_t>
__global__ void ms_deformable_im2col_gpu_kernel(const int n,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *data_col)
{
CUDA_KERNEL_LOOP(index, n)
{
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
scalar_t *data_col_ptr = data_col + index;
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
scalar_t col = 0;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
}
}
*data_col_ptr = col;
}
}
template <typename scalar_t, unsigned int blockSize>
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
__shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
if (tid == 0)
{
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
int sid=2;
for (unsigned int tid = 1; tid < blockSize; ++tid)
{
_grad_w += cache_grad_sampling_loc[sid];
_grad_h += cache_grad_sampling_loc[sid + 1];
_grad_a += cache_grad_attn_weight[tid];
sid += 2;
}
*grad_sampling_loc = _grad_w;
*(grad_sampling_loc + 1) = _grad_h;
*grad_attn_weight = _grad_a;
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t, unsigned int blockSize>
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
__shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
for (unsigned int s=blockSize/2; s>0; s>>=1)
{
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
}
__syncthreads();
}
if (tid == 0)
{
*grad_sampling_loc = cache_grad_sampling_loc[0];
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight = cache_grad_attn_weight[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
extern __shared__ int _s[];
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
if (tid == 0)
{
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
int sid=2;
for (unsigned int tid = 1; tid < blockDim.x; ++tid)
{
_grad_w += cache_grad_sampling_loc[sid];
_grad_h += cache_grad_sampling_loc[sid + 1];
_grad_a += cache_grad_attn_weight[tid];
sid += 2;
}
*grad_sampling_loc = _grad_w;
*(grad_sampling_loc + 1) = _grad_h;
*grad_attn_weight = _grad_a;
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
extern __shared__ int _s[];
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
{
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
if (tid + (s << 1) < spre)
{
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
}
}
__syncthreads();
}
if (tid == 0)
{
*grad_sampling_loc = cache_grad_sampling_loc[0];
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
*grad_attn_weight = cache_grad_attn_weight[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
extern __shared__ int _s[];
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_attn_weight+threadIdx.x)=0;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
}
__syncthreads();
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
{
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
if (tid + (s << 1) < spre)
{
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
}
}
__syncthreads();
}
if (tid == 0)
{
atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const scalar_t h_im = loc_h * spatial_h - 0.5;
const scalar_t w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
{
ms_deform_attn_col2im_bilinear_gm(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
top_grad, weight, grad_value_ptr,
grad_sampling_loc, grad_attn_weight);
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
grad_sampling_loc += grad_loc_stride;
}
}
}
}
template <typename scalar_t>
void ms_deformable_im2col_cuda(cudaStream_t stream,
const scalar_t* data_value,
const int64_t* data_spatial_shapes,
const int64_t* data_level_start_index,
const scalar_t* data_sampling_loc,
const scalar_t* data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t* data_col)
{
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
const int num_threads = CUDA_NUM_THREADS;
ms_deformable_im2col_gpu_kernel<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
}
}
template <typename scalar_t>
void ms_deformable_col2im_cuda(cudaStream_t stream,
const scalar_t* grad_col,
const scalar_t* data_value,
const int64_t * data_spatial_shapes,
const int64_t * data_level_start_index,
const scalar_t * data_sampling_loc,
const scalar_t * data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t* grad_value,
scalar_t* grad_sampling_loc,
scalar_t* grad_attn_weight)
{
const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
if (channels > 1024)
{
if ((channels & 1023) == 0)
{
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads*3*sizeof(scalar_t), stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
}
else
{
ms_deformable_col2im_gpu_kernel_gm<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
}
}
else{
switch(channels)
{
case 1:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 2:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 4:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 8:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 16:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 32:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 64:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 128:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 256:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 512:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
case 1024:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
break;
default:
if (channels < 64)
{
ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads*3*sizeof(scalar_t), stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
}
else
{
ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads*3*sizeof(scalar_t), stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_sampling_loc,
grad_attn_weight);
}
}
}
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
}
}
\ No newline at end of file
import math
import functools
from typing import Tuple, Union
import torch
from torch import Tensor, nn
class PositionEmbeddingSine(nn.Module):
"""Sinusoidal position embedding used in DETR model. See `End-to-End Object Detection
with Transformers <https://arxiv.org/pdf/2005.12872>`_ for more details.
:param num_pos_feats: The feature dimension for each position along x-axis or y-axis.
The final returned dimension for each position is 2 times of the input value,
defaults to 64
:param temperature: The temperature used for scaling the position embedding, defaults to 10000
:param normalize: Whether to normalize the position embedding, defaults to False
:param scale: A scale factor that scales the position embedding, which is used only when
`normalize` is True, defaults to 2*math.pi
:param eps: A value added to the denominator for numerical stability, defaults to 1e-6
:param offset: An offset added to embed, defaults to 0.0
"""
def __init__(
self,
num_pos_feats=64,
temperature: Union[int, Tuple[int, int]] = 10000,
normalize=False,
scale=2 * math.pi,
eps=1e-6,
offset=0.0,
):
super().__init__()
dim_t = 2 * torch.arange(num_pos_feats).div(2, rounding_mode="floor") / num_pos_feats
if isinstance(temperature, int):
dim_tx = dim_ty = temperature**dim_t
else:
assert len(temperature) == 2, "Only support two elements as (t_x, t_y) in temperature"
dim_tx, dim_ty = [t**dim_t for t in temperature]
self.register_buffer("dim_tx", dim_tx)
self.register_buffer("dim_ty", dim_ty)
self.normalize = normalize
self.scale = scale
self.eps = eps
self.offset = offset
def forward(self, mask: Tensor):
mask = mask.to(torch.int)
not_mask = 1 - mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
y_embed = (y_embed + self.offset) / (y_embed[:, -1:, :] + self.eps) * self.scale
x_embed = (x_embed + self.offset) / (x_embed[:, :, -1:] + self.eps) * self.scale
else:
# RT-DETR uses unnormalized encoding with index from 0
y_embed = y_embed + self.offset
x_embed = x_embed + self.offset
pos_x = x_embed[:, :, :, None] / self.dim_tx
pos_y = y_embed[:, :, :, None] / self.dim_ty
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
class PositionEmbeddingLearned(nn.Module):
"""Absolute pos embedding, learned."""
def __init__(self, num_embeddings: int = 50, num_pos_feats: int = 256):
super().__init__()
self.row_embed = nn.Embedding(num_embeddings, num_pos_feats)
self.col_embed = nn.Embedding(num_embeddings, num_pos_feats)
self.reset_parameters()
def reset_parameters(self):
nn.init.uniform_(self.row_embed.weight)
nn.init.uniform_(self.col_embed.weight)
def forward(self, mask: Tensor):
h, w = mask.shape[-2:]
i = torch.arange(w, device=mask.device)
j = torch.arange(h, device=mask.device)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = (
torch.cat(
[
x_emb.unsqueeze(0).repeat(h, 1, 1),
y_emb.unsqueeze(1).repeat(1, w, 1),
],
dim=-1,
).permute(2, 0, 1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1)
)
return pos
@functools.lru_cache # use lru_cache to avoid redundant calculation for dim_t
def get_dim_t(num_pos_feats: int, temperature: int, device: torch.device):
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=device)
dim_t = temperature**(2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
return dim_t
def get_sine_pos_embed(
pos_tensor: Tensor,
num_pos_feats: int = 128,
temperature: int = 10000,
scale: float = 2 * math.pi,
exchange_xy: bool = True,
) -> Tensor:
"""Generate sine position embedding for a position tensor
:param pos_tensor: shape as (..., 2*n).
:param num_pos_feats: projected shape for each float in the tensor, defaults to 128
:param temperature: the temperature used for scaling the position embedding, defaults to 10000
:param exchange_xy: exchange pos x and pos. For example,
input tensor is [x, y], the results will be [pos(y), pos(x)], defaults to True
:return: position embedding with shape (None, n * num_pos_feats)
"""
dim_t = get_dim_t(num_pos_feats, temperature, pos_tensor.device)
pos_res = pos_tensor.unsqueeze(-1) * scale / dim_t
pos_res = torch.stack((pos_res[..., 0::2].sin(), pos_res[..., 1::2].cos()), dim=-1).flatten(-2)
if exchange_xy:
index = torch.cat([
torch.arange(1, -1, -1, device=pos_res.device),
torch.arange(2, pos_res.shape[-2], device=pos_res.device),
])
pos_res = torch.index_select(pos_res, -2, index)
pos_res = pos_res.view(*pos_tensor.shape[:-1], -1)
return pos_res
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torchvision.ops import boxes as box_ops
class PostProcess(nn.Module):
"""This module converts the model's output into the format expected by the coco api"""
def __init__(
self,
select_box_nums_for_evaluation=100,
nms_iou_threshold=-1,
confidence_score=-1,
):
super().__init__()
self.select_box_nums_for_evaluation = select_box_nums_for_evaluation
self.nms_iou_threshold = nms_iou_threshold
self.confidence_score = confidence_score
@torch.no_grad()
def forward(self, outputs, target_sizes):
out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"]
assert len(out_logits) == len(target_sizes)
assert target_sizes.shape[1] == 2
prob = out_logits.sigmoid()
topk_values, topk_indexes = torch.topk(
prob.view(out_logits.shape[0], -1),
self.select_box_nums_for_evaluation,
dim=1,
)
scores = topk_values
topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="trunc")
labels = topk_indexes % out_logits.shape[2]
boxes = box_ops._box_cxcywh_to_xyxy(out_bbox)
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
# and from relative [0, 1] to absolute [0, height] coordinates
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes = boxes * scale_fct[:, None, :]
item_indice = None
# filter low-confidence predictions
if self.confidence_score > 0:
item_indice = [score > self.confidence_score for score in scores]
# filter overlap predictions
if self.nms_iou_threshold > 0:
nms_indice = [
box_ops.nms(box, score, iou_threshold=self.nms_iou_threshold)
for box, score in zip(boxes, scores)
]
nms_binary_indice = [torch.zeros_like(item_index, dtype=torch.bool) for item_index in item_indice]
for nms_binary_index, nms_index in zip(nms_binary_indice, nms_indice):
nms_binary_index[nms_index] = True
item_indice = [
item_index & nms_binary_index
for item_index, nms_binary_index in zip(item_indice, nms_binary_indice)
]
if item_indice is not None:
scores = [score[item_index] for score, item_index in zip(scores, item_indice)]
boxes = [box[item_index] for box, item_index in zip(boxes, item_indice)]
labels = [label[item_index] for label, item_index in zip(labels, item_indice)]
if torchvision._is_tracing():
# avoid interation warning during ONNX export
scores, labels, boxes = map(lambda x: x.unbind(0), (scores, labels, boxes))
results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
return results
class SegmentationPostProcess(nn.Module):
@torch.no_grad()
def forward(self, outputs, target_sizes, input_sizes, batched_input_size):
out_logits, out_bbox, out_mask = (
outputs["pred_logits"],
outputs["pred_boxes"],
outputs["pred_masks"],
)
assert len(out_logits) == len(target_sizes)
assert len(batched_input_size) == 2
# we average queries of the same class to get onehot segmentation image
out_class = out_logits.argmax(-1)
num_class = out_logits.shape[-1]
result_masks = []
for image_id in range(len(out_logits)):
result_masks_per_image = []
for cur_class in range(num_class):
class_index = out_class[image_id] == cur_class
mask_per_class = out_mask[image_id][class_index].sigmoid()
if mask_per_class.numel() == 0:
mask_per_class = mask_per_class.new_zeros((1, *mask_per_class.shape[-2:]))
mask_per_class = mask_per_class.mean(0)
result_masks_per_image.append(mask_per_class)
result_masks_per_image = torch.stack(result_masks_per_image, 0)
result_masks.append(result_masks_per_image)
result_masks = torch.stack(result_masks, 0)
# upsample masks with 1/4 resolution to input image shapes
result_masks = F.interpolate(
result_masks,
size=batched_input_size,
mode="bilinear",
align_corners=False,
)
# resize masks to original shapes and transform onehot into class
mask_results = []
for mask, (height, width), (out_height, out_width) in zip(
result_masks,
input_sizes,
target_sizes,
):
mask = F.interpolate(
mask[None, :, :height, :width],
size=(out_height, out_width),
mode="bilinear",
align_corners=False,
)[0]
mask_results.append({"masks": mask.argmax(0)})
return mask_results
import copy
import math
from typing import Tuple
import torch
import torchvision
from torch import nn
from models.bricks.base_transformer import TwostageTransformer
from models.bricks.basic import MLP
from models.bricks.ms_deform_attn import MultiScaleDeformableAttention
from models.bricks.position_encoding import PositionEmbeddingLearned, get_sine_pos_embed
from util.misc import inverse_sigmoid
class MaskPredictor(nn.Module):
def __init__(self, in_dim, h_dim):
super().__init__()
self.h_dim = h_dim
self.layer1 = nn.Sequential(
nn.LayerNorm(in_dim),
nn.Linear(in_dim, h_dim),
nn.GELU(),
)
self.layer2 = nn.Sequential(
nn.Linear(h_dim, h_dim // 2),
nn.GELU(),
nn.Linear(h_dim // 2, h_dim // 4),
nn.GELU(),
nn.Linear(h_dim // 4, 1),
)
self.apply(self.init_weights)
@staticmethod
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0)
def forward(self, x):
z = self.layer1(x)
z_local, z_global = torch.split(z, self.h_dim // 2, dim=-1)
z_global = z_global.mean(dim=1, keepdim=True).expand(-1, z_local.shape[1], -1)
z = torch.cat([z_local, z_global], dim=-1)
out = self.layer2(z)
return out
class SalienceTransformer(TwostageTransformer):
def __init__(
self,
encoder: nn.Module,
neck: nn.Module,
decoder: nn.Module,
num_classes: int,
num_feature_levels: int = 4,
two_stage_num_proposals: int = 900,
level_filter_ratio: Tuple = (0.25, 0.5, 1.0, 1.0),
layer_filter_ratio: Tuple = (1.0, 0.8, 0.6, 0.6, 0.4, 0.2),
):
super().__init__(num_feature_levels, encoder.embed_dim)
# model parameters
self.two_stage_num_proposals = two_stage_num_proposals
self.num_classes = num_classes
# salience parameters
self.register_buffer("level_filter_ratio", torch.Tensor(level_filter_ratio))
self.register_buffer("layer_filter_ratio", torch.Tensor(layer_filter_ratio))
self.alpha = nn.Parameter(torch.Tensor(3), requires_grad=True)
# model structure
self.encoder = encoder
self.neck = neck
self.decoder = decoder
self.tgt_embed = nn.Embedding(two_stage_num_proposals, self.embed_dim)
self.encoder_class_head = nn.Linear(self.embed_dim, num_classes)
self.encoder_bbox_head = MLP(self.embed_dim, self.embed_dim, 4, 3)
self.encoder.enhance_mcsp = self.encoder_class_head
self.enc_mask_predictor = MaskPredictor(self.embed_dim, self.embed_dim)
self.init_weights()
def init_weights(self):
# initialize embedding layers
nn.init.normal_(self.tgt_embed.weight)
# initialize encoder classification layers
prior_prob = 0.01
bias_value = -math.log((1 - prior_prob) / prior_prob)
nn.init.constant_(self.encoder_class_head.bias, bias_value)
# initiailize encoder regression layers
nn.init.constant_(self.encoder_bbox_head.layers[-1].weight, 0.0)
nn.init.constant_(self.encoder_bbox_head.layers[-1].bias, 0.0)
# initialize alpha
self.alpha.data.uniform_(-0.3, 0.3)
def forward(
self,
multi_level_feats,
multi_level_masks,
multi_level_pos_embeds,
noised_label_query,
noised_box_query,
attn_mask,
):
# get input for encoder
feat_flatten = self.flatten_multi_level(multi_level_feats)
mask_flatten = self.flatten_multi_level(multi_level_masks)
lvl_pos_embed_flatten = self.get_lvl_pos_embed(multi_level_pos_embeds)
spatial_shapes, level_start_index, valid_ratios = self.multi_level_misc(multi_level_masks)
backbone_output_memory = self.gen_encoder_output_proposals(
feat_flatten + lvl_pos_embed_flatten, mask_flatten, spatial_shapes
)[0]
# calculate filtered tokens numbers for each feature map
reverse_multi_level_masks = [~m for m in multi_level_masks]
valid_token_nums = torch.stack([m.sum((1, 2)) for m in reverse_multi_level_masks], -1)
focus_token_nums = (valid_token_nums * self.level_filter_ratio).int()
level_token_nums = focus_token_nums.max(0)[0]
focus_token_nums = focus_token_nums.sum(-1)
# from high level to low level
batch_size = feat_flatten.shape[0]
selected_score = []
selected_inds = []
salience_score = []
for level_idx in range(spatial_shapes.shape[0] - 1, -1, -1):
start_index = level_start_index[level_idx]
end_index = level_start_index[level_idx + 1] if level_idx < spatial_shapes.shape[0] - 1 else None
level_memory = backbone_output_memory[:, start_index:end_index, :]
mask = mask_flatten[:, start_index:end_index]
# update the memory using the higher-level score_prediction
if level_idx != spatial_shapes.shape[0] - 1:
upsample_score = torch.nn.functional.interpolate(
score,
size=spatial_shapes[level_idx].unbind(),
mode="bilinear",
align_corners=True,
)
upsample_score = upsample_score.view(batch_size, -1, spatial_shapes[level_idx].prod())
upsample_score = upsample_score.transpose(1, 2)
level_memory = level_memory + level_memory * upsample_score * self.alpha[level_idx]
# predict the foreground score of the current layer
score = self.enc_mask_predictor(level_memory)
valid_score = score.squeeze(-1).masked_fill(mask, score.min())
score = score.transpose(1, 2).view(batch_size, -1, *spatial_shapes[level_idx])
# get the topk salience index of the current feature map level
level_score, level_inds = valid_score.topk(level_token_nums[level_idx], dim=1)
level_inds = level_inds + level_start_index[level_idx]
salience_score.append(score)
selected_inds.append(level_inds)
selected_score.append(level_score)
selected_score = torch.cat(selected_score[::-1], 1)
index = torch.sort(selected_score, dim=1, descending=True)[1]
selected_inds = torch.cat(selected_inds[::-1], 1).gather(1, index)
# create layer-wise filtering
num_inds = selected_inds.shape[1]
# change dtype to avoid shape inference error during exporting ONNX
cast_dtype = num_inds.dtype if torchvision._is_tracing() else torch.int64
layer_filter_ratio = (num_inds * self.layer_filter_ratio).to(cast_dtype)
selected_inds = [selected_inds[:, :r] for r in layer_filter_ratio]
salience_score = salience_score[::-1]
foreground_score = self.flatten_multi_level(salience_score).squeeze(-1)
foreground_score = foreground_score.masked_fill(mask_flatten, foreground_score.min())
# transformer encoder
memory = self.encoder(
query=feat_flatten,
query_pos=lvl_pos_embed_flatten,
query_key_padding_mask=mask_flatten,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
# salience input
foreground_score=foreground_score,
focus_token_nums=focus_token_nums,
foreground_inds=selected_inds,
multi_level_masks=multi_level_masks,
)
if self.neck is not None:
feat_unflatten = memory.split(spatial_shapes.prod(-1).unbind(), dim=1)
feat_unflatten = dict((
i,
feat.transpose(1, 2).contiguous().reshape(-1, self.embed_dim, *spatial_shape),
) for i, (feat, spatial_shape) in enumerate(zip(feat_unflatten, spatial_shapes)))
feat_unflatten = list(self.neck(feat_unflatten).values())
memory = torch.cat([feat.flatten(2).transpose(1, 2) for feat in feat_unflatten], dim=1)
# get encoder output, classes and coordinates
output_memory, output_proposals = self.gen_encoder_output_proposals(
memory, mask_flatten, spatial_shapes
)
enc_outputs_class = self.encoder_class_head(output_memory)
enc_outputs_coord = self.encoder_bbox_head(output_memory) + output_proposals
enc_outputs_coord = enc_outputs_coord.sigmoid()
# get topk output classes and coordinates
if torchvision._is_tracing():
topk = torch.min(torch.tensor(self.two_stage_num_proposals * 4), enc_outputs_class.shape[1])
else:
topk = min(self.two_stage_num_proposals * 4, enc_outputs_class.shape[1])
topk_scores, topk_index = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)
topk_index = self.nms_on_topk_index(
topk_scores, topk_index, spatial_shapes, level_start_index, iou_threshold=0.3
).unsqueeze(-1)
enc_outputs_class = enc_outputs_class.gather(1, topk_index.expand(-1, -1, self.num_classes))
enc_outputs_coord = enc_outputs_coord.gather(1, topk_index.expand(-1, -1, 4))
# get target and reference points
reference_points = enc_outputs_coord.detach()
target = self.tgt_embed.weight.expand(multi_level_feats[0].shape[0], -1, -1)
# combine with noised_label_query and noised_box_query for denoising training
if noised_label_query is not None and noised_box_query is not None:
target = torch.cat([noised_label_query, target], 1)
reference_points = torch.cat([noised_box_query.sigmoid(), reference_points], 1)
# decoder
outputs_classes, outputs_coords = self.decoder(
query=target,
value=memory,
key_padding_mask=mask_flatten,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
attn_mask=attn_mask,
)
return outputs_classes, outputs_coords, enc_outputs_class, enc_outputs_coord, salience_score
@staticmethod
def fast_repeat_interleave(input, repeats):
"""torch.Tensor.repeat_interleave is slow for one-dimension input for unknown reasons.
This is a simple faster implementation. Notice the return shares memory with the input.
:param input: input Tensor
:param repeats: repeat numbers of each element in the specified dim
:param dim: the dimension to repeat, defaults to None
"""
# the following inplementation runs a little faster under one-dimension settings
return torch.cat([aa.expand(bb) for aa, bb in zip(input, repeats)])
@torch.no_grad()
def nms_on_topk_index(
self, topk_scores, topk_index, spatial_shapes, level_start_index, iou_threshold=0.3
):
batch_size, num_topk = topk_scores.shape
if torchvision._is_tracing():
num_pixels = spatial_shapes.prod(-1).unbind()
else:
num_pixels = spatial_shapes.prod(-1).tolist()
# flatten topk_scores and topk_index for batched_nms
topk_scores, topk_index = map(lambda x: x.view(-1), (topk_scores, topk_index))
# get level coordinates for queries and construct boxes for them
level_index = torch.arange(level_start_index.shape[0], device=level_start_index.device)
feat_width, start_index, level_idx = map(
lambda x: self.fast_repeat_interleave(x, num_pixels)[topk_index],
(spatial_shapes[:, 1], level_start_index, level_index),
)
topk_spatial_index = topk_index - start_index
x = topk_spatial_index % feat_width
y = torch.div(topk_spatial_index, feat_width, rounding_mode="trunc")
coordinates = torch.stack([x - 1.0, y - 1.0, x + 1.0, y + 1.0], -1)
# get unique idx for queries in different images and levels
image_idx = torch.arange(batch_size).repeat_interleave(num_topk, 0)
image_idx = image_idx.to(level_idx.device)
idxs = level_idx + level_start_index.shape[0] * image_idx
# perform batched_nms
indices = torchvision.ops.batched_nms(coordinates, topk_scores, idxs, iou_threshold)
# stack valid index
results_index = []
if torchvision._is_tracing():
min_num = torch.tensor(self.two_stage_num_proposals)
else:
min_num = self.two_stage_num_proposals
# get indices in each image
for i in range(batch_size):
topk_index_per_image = topk_index[indices[image_idx[indices] == i]]
if torchvision._is_tracing():
min_num = torch.min(topk_index_per_image.shape[0], min_num)
else:
min_num = min(topk_index_per_image.shape[0], min_num)
results_index.append(topk_index_per_image)
return torch.stack([index[:min_num] for index in results_index])
class SalienceTransformerEncoderLayer(nn.Module):
def __init__(
self,
embed_dim=256,
d_ffn=1024,
dropout=0.1,
n_heads=8,
activation=nn.ReLU(inplace=True),
n_levels=4,
n_points=4,
# focus parameter
topk_sa=300,
):
super().__init__()
self.embed_dim = embed_dim
self.topk_sa = topk_sa
# pre attention
self.pre_attention = nn.MultiheadAttention(embed_dim, n_heads, dropout, batch_first=True)
self.pre_dropout = nn.Dropout(dropout)
self.pre_norm = nn.LayerNorm(embed_dim)
# self attention
self.self_attn = MultiScaleDeformableAttention(embed_dim, n_levels, n_heads, n_points)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(embed_dim)
# ffn
self.linear1 = nn.Linear(embed_dim, d_ffn)
self.activation = activation
self.dropout2 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, embed_dim)
self.dropout3 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(embed_dim)
self.init_weights()
def init_weights(self):
# initialize self_attention
nn.init.xavier_uniform_(self.pre_attention.in_proj_weight)
nn.init.xavier_uniform_(self.pre_attention.out_proj.weight)
# initilize Linear layer
nn.init.xavier_uniform_(self.linear1.weight)
nn.init.xavier_uniform_(self.linear2.weight)
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, query):
src2 = self.linear2(self.dropout2(self.activation(self.linear1(query))))
query = query + self.dropout3(src2)
query = self.norm2(query)
return query
def forward(
self,
query,
query_pos,
value, # focus parameter
reference_points,
spatial_shapes,
level_start_index,
query_key_padding_mask=None,
# focus parameter
score_tgt=None,
foreground_pre_layer=None,
):
mc_score = score_tgt.max(-1)[0] * foreground_pre_layer
select_tgt_index = torch.topk(mc_score, self.topk_sa, dim=1)[1]
select_tgt_index = select_tgt_index.unsqueeze(-1).expand(-1, -1, self.embed_dim)
select_tgt = torch.gather(query, 1, select_tgt_index)
select_pos = torch.gather(query_pos, 1, select_tgt_index)
query_with_pos = key_with_pos = self.with_pos_embed(select_tgt, select_pos)
tgt2 = self.pre_attention(
query_with_pos,
key_with_pos,
select_tgt,
)[0]
select_tgt = select_tgt + self.pre_dropout(tgt2)
select_tgt = self.pre_norm(select_tgt)
query = query.scatter(1, select_tgt_index, select_tgt)
# self attention
src2 = self.self_attn(
query=self.with_pos_embed(query, query_pos),
reference_points=reference_points,
value=value,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
key_padding_mask=query_key_padding_mask,
)
query = query + self.dropout1(src2)
query = self.norm1(query)
# ffn
query = self.forward_ffn(query)
return query
class SalienceTransformerEncoder(nn.Module):
def __init__(self, encoder_layer: nn.Module, num_layers: int = 6):
super().__init__()
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])
self.num_layers = num_layers
self.embed_dim = encoder_layer.embed_dim
# learnt background embed for prediction
self.background_embedding = PositionEmbeddingLearned(200, num_pos_feats=self.embed_dim // 2)
self.init_weights()
def init_weights(self):
# initialize encoder layers
for layer in self.layers:
if hasattr(layer, "init_weights"):
layer.init_weights()
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
reference_points_list = []
for lvl, (h, w) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, h - 0.5, h, dtype=torch.float32, device=device),
torch.linspace(0.5, w - 0.5, w, dtype=torch.float32, device=device),
indexing="ij",
)
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * h)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * w)
ref = torch.stack((ref_x, ref_y), -1) # [n, h*w, 2]
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1) # [n, s, 2]
reference_points = reference_points[:, :, None] * valid_ratios[:, None] # [n, s, l, 2]
return reference_points
def forward(
self,
query,
spatial_shapes,
level_start_index,
valid_ratios,
query_pos=None,
query_key_padding_mask=None,
# salience input
foreground_score=None,
focus_token_nums=None,
foreground_inds=None,
multi_level_masks=None,
):
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=query.device)
b, n, s, p = reference_points.shape
ori_reference_points = reference_points
ori_pos = query_pos
value = output = query
for layer_id, layer in enumerate(self.layers):
inds_for_query = foreground_inds[layer_id].unsqueeze(-1).expand(-1, -1, self.embed_dim)
query = torch.gather(output, 1, inds_for_query)
query_pos = torch.gather(ori_pos, 1, inds_for_query)
foreground_pre_layer = torch.gather(foreground_score, 1, foreground_inds[layer_id])
reference_points = torch.gather(
ori_reference_points.view(b, n, -1), 1,
foreground_inds[layer_id].unsqueeze(-1).repeat(1, 1, s * p)
).view(b, -1, s, p)
score_tgt = self.enhance_mcsp(query)
query = layer(
query,
query_pos,
value,
reference_points,
spatial_shapes,
level_start_index,
query_key_padding_mask,
score_tgt,
foreground_pre_layer,
)
outputs = []
for i in range(foreground_inds[layer_id].shape[0]):
foreground_inds_no_pad = foreground_inds[layer_id][i][:focus_token_nums[i]]
query_no_pad = query[i][:focus_token_nums[i]]
outputs.append(
output[i].scatter(
0,
foreground_inds_no_pad.unsqueeze(-1).repeat(1, query.size(-1)),
query_no_pad,
)
)
output = torch.stack(outputs)
# add learnt embedding for background
if multi_level_masks is not None:
background_embedding = [
self.background_embedding(mask).flatten(2).transpose(1, 2) for mask in multi_level_masks
]
background_embedding = torch.cat(background_embedding, dim=1)
background_embedding.scatter_(1, inds_for_query, 0)
background_embedding *= (~query_key_padding_mask).unsqueeze(-1)
output = output + background_embedding
return output
class SalienceTransformerDecoderLayer(nn.Module):
def __init__(
self,
embed_dim=256,
d_ffn=1024,
n_heads=8,
dropout=0.1,
activation=nn.ReLU(inplace=True),
n_levels=4,
n_points=4,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = n_heads
# cross attention
self.cross_attn = MultiScaleDeformableAttention(embed_dim, n_levels, n_heads, n_points)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(embed_dim)
# self attention
self.self_attn = nn.MultiheadAttention(embed_dim, n_heads, dropout=dropout, batch_first=True)
self.dropout2 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(embed_dim)
# ffn
self.linear1 = nn.Linear(embed_dim, d_ffn)
self.activation = activation
self.dropout3 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, embed_dim)
self.dropout4 = nn.Dropout(dropout)
self.norm3 = nn.LayerNorm(embed_dim)
self.init_weights()
def init_weights(self):
# initialize self_attention
nn.init.xavier_uniform_(self.self_attn.in_proj_weight)
nn.init.xavier_uniform_(self.self_attn.out_proj.weight)
# initialize Linear layer
nn.init.xavier_uniform_(self.linear1.weight)
nn.init.xavier_uniform_(self.linear2.weight)
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, tgt):
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout4(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward(
self,
query,
query_pos,
reference_points,
value,
spatial_shapes,
level_start_index,
self_attn_mask=None,
key_padding_mask=None,
):
# self attention
query_with_pos = key_with_pos = self.with_pos_embed(query, query_pos)
query2 = self.self_attn(
query=query_with_pos,
key=key_with_pos,
value=query,
attn_mask=self_attn_mask,
)[0]
query = query + self.dropout2(query2)
query = self.norm2(query)
# cross attention
query2 = self.cross_attn(
query=self.with_pos_embed(query, query_pos),
reference_points=reference_points,
value=value,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
key_padding_mask=key_padding_mask,
)
query = query + self.dropout1(query2)
query = self.norm1(query)
# ffn
query = self.forward_ffn(query)
return query
class SalienceTransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, num_classes):
super().__init__()
# parameters
self.embed_dim = decoder_layer.embed_dim
self.num_layers = num_layers
self.num_classes = num_classes
# decoder layers and embedding
self.layers = nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(num_layers)])
self.ref_point_head = MLP(2 * self.embed_dim, self.embed_dim, self.embed_dim, 2)
# iterative bounding box refinement
self.class_head = nn.ModuleList([nn.Linear(self.embed_dim, num_classes) for _ in range(num_layers)])
self.bbox_head = nn.ModuleList([MLP(self.embed_dim, self.embed_dim, 4, 3) for _ in range(num_layers)])
self.norm = nn.LayerNorm(self.embed_dim)
self.init_weights()
def init_weights(self):
# initialize decoder layers
for layer in self.layers:
if hasattr(layer, "init_weights"):
layer.init_weights()
# initialize decoder classification layers
prior_prob = 0.01
bias_value = -math.log((1 - prior_prob) / prior_prob)
for class_head in self.class_head:
nn.init.constant_(class_head.bias, bias_value)
# initiailize decoder regression layers
for bbox_head in self.bbox_head:
nn.init.constant_(bbox_head.layers[-1].weight, 0.0)
nn.init.constant_(bbox_head.layers[-1].bias, 0.0)
def forward(
self,
query,
reference_points,
value,
spatial_shapes,
level_start_index,
valid_ratios,
key_padding_mask=None,
attn_mask=None,
):
outputs_classes = []
outputs_coords = []
valid_ratio_scale = torch.cat([valid_ratios, valid_ratios], -1)[:, None]
for layer_idx, layer in enumerate(self.layers):
reference_points_input = reference_points.detach()[:, :, None] * valid_ratio_scale
query_sine_embed = get_sine_pos_embed(reference_points_input[:, :, 0, :])
query_pos = self.ref_point_head(query_sine_embed)
# relation embedding
query = layer(
query=query,
query_pos=query_pos,
reference_points=reference_points_input,
value=value,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
key_padding_mask=key_padding_mask,
self_attn_mask=attn_mask,
)
# get output, reference_points are not detached for look_forward_twice
output_class = self.class_head[layer_idx](self.norm(query))
output_coord = self.bbox_head[layer_idx](self.norm(query)) + inverse_sigmoid(reference_points)
output_coord = output_coord.sigmoid()
outputs_classes.append(output_class)
outputs_coords.append(output_coord)
if layer_idx == self.num_layers - 1:
break
# iterative bounding box refinement
reference_points = self.bbox_head[layer_idx](query) + inverse_sigmoid(reference_points.detach())
reference_points = reference_points.sigmoid()
outputs_classes = torch.stack(outputs_classes)
outputs_coords = torch.stack(outputs_coords)
return outputs_classes, outputs_coords
import copy
from typing import Dict
import torch
import torch.distributed
from torch import nn
from torch.nn import functional as F
from torchvision.ops import boxes as box_ops
from models.bricks.losses import sigmoid_focal_loss, vari_sigmoid_focal_loss
from util.utils import get_world_size, is_dist_avail_and_initialized
class SetCriterion(nn.Module):
"""This class computes the loss for DETR.
The process happens in two steps:
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
"""
def __init__(
self,
num_classes: int,
matcher: nn.Module,
weight_dict: Dict,
alpha: float = 0.25,
gamma: float = 2.0,
two_stage_binary_cls=False,
):
"""Create the criterion.
:param num_classes: number of object categories, omitting the special no-object category
:param matcher: module able to compute a matching between targets and proposals
:param weight_dict: dict containing as key the names of the losses and as values their relative weight
:param alpha: alpha in Focal Loss, defaults to 0.25
:param gamma: gamma in Focal loss, defaults to 2.0
:param two_stage_binary_cls: Whether to use two-stage binary classification loss, defaults to False
"""
super().__init__()
self.num_classes = num_classes
self.matcher = matcher
self.weight_dict = weight_dict
self.alpha = alpha
self.gamma = gamma
self.two_stage_binary_cls = two_stage_binary_cls
def loss_labels(self, outputs, targets, num_boxes, indices, **kwargs):
"""Classification loss (NLL)
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
"""
assert "pred_logits" in outputs
src_logits = outputs["pred_logits"]
idx = self._get_src_permutation_idx(indices)
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full(
src_logits.shape[:2],
self.num_classes,
dtype=torch.int64,
device=src_logits.device,
)
target_classes[idx] = target_classes_o
target_classes_onehot = torch.zeros(
[src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
dtype=src_logits.dtype,
layout=src_logits.layout,
device=src_logits.device,
)
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
target_classes_onehot = target_classes_onehot[:, :, :-1]
loss_class = (
sigmoid_focal_loss(
src_logits,
target_classes_onehot,
num_boxes,
alpha=self.alpha,
gamma=self.gamma,
) * src_logits.shape[1]
)
losses = {"loss_class": loss_class}
return losses
def loss_boxes(self, outputs, targets, num_boxes, indices, **kwargs):
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size.
"""
assert "pred_boxes" in outputs
idx = self._get_src_permutation_idx(indices)
src_boxes = outputs["pred_boxes"][idx]
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
losses = {}
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
loss_giou = 1 - torch.diag(
box_ops.generalized_box_iou(
box_ops._box_cxcywh_to_xyxy(src_boxes),
box_ops._box_cxcywh_to_xyxy(target_boxes),
)
)
losses["loss_giou"] = loss_giou.sum() / num_boxes
return losses
def _get_src_permutation_idx(self, indices):
# permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices])
return batch_idx, src_idx
def _get_tgt_permutation_idx(self, indices):
# permute targets following indices
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
return batch_idx, tgt_idx
def calculate_loss(self, outputs, targets, num_boxes, indices=None, **kwargs):
losses = {}
# get matching results for each image
if not indices:
gt_boxes, gt_labels = list(zip(*map(lambda x: (x["boxes"], x["labels"]), targets)))
pred_logits, pred_boxes = outputs["pred_logits"], outputs["pred_boxes"]
indices = list(map(self.matcher, pred_boxes, pred_logits, gt_boxes, gt_labels))
loss_class = self.loss_labels(outputs, targets, num_boxes, indices=indices)
loss_boxes = self.loss_boxes(outputs, targets, num_boxes, indices=indices)
losses.update(loss_class)
losses.update(loss_boxes)
return losses
def forward(self, outputs, targets):
"""This performs the loss computation
:param outputs: dict of tensors, see the output specification of the model for the format
:param targets: list of dicts, such that len(targets) == batch_size
:return: a dict containing losses
"""
# Compute the average number of target boxes accross all nodes, for normalization purposes
num_boxes = sum(len(t["labels"]) for t in targets)
num_boxes = torch.as_tensor(
data=[num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device
)
if is_dist_avail_and_initialized():
torch.distributed.all_reduce(num_boxes)
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
# Compute all the requested losses
losses = {}
matching_outputs = {k: v for k, v in outputs.items() if k != "aux_outputs" and k != "enc_outputs"}
losses.update(self.calculate_loss(matching_outputs, targets, num_boxes))
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if "aux_outputs" in outputs:
for i, aux_outputs in enumerate(outputs["aux_outputs"]):
# get matching results for each image
losses_aux = self.calculate_loss(aux_outputs, targets, num_boxes)
losses.update({k + f"_{i}": v for k, v in losses_aux.items()})
if "enc_outputs" in outputs:
enc_outputs = outputs["enc_outputs"]
bin_targets = copy.deepcopy(targets)
if self.two_stage_binary_cls:
for bt in bin_targets:
bt["labels"] = torch.zeros_like(bt["labels"])
losses_enc = self.calculate_loss(enc_outputs, bin_targets, num_boxes)
losses.update({k + f"_enc": v for k, v in losses_enc.items()})
return losses
class HybridSetCriterion(SetCriterion):
def loss_labels(self, outputs, targets, num_boxes, indices, **kwargs):
assert "pred_boxes" in outputs
idx = self._get_src_permutation_idx(indices)
src_boxes = outputs["pred_boxes"][idx]
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
iou_score = torch.diag(
box_ops.box_iou(
box_ops._box_cxcywh_to_xyxy(src_boxes),
box_ops._box_cxcywh_to_xyxy(target_boxes),
)
).detach() # add detach according to RT-DETR
assert "pred_logits" in outputs
src_logits = outputs["pred_logits"]
# construct onehot targets, shape: (batch_size, num_queries, num_classes)
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full(
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
)
target_classes[idx] = target_classes_o
target_classes_onehot = F.one_hot(target_classes, self.num_classes + 1)[..., :-1]
# construct iou_score, shape: (batch_size, num_queries)
target_score = torch.zeros_like(target_classes, dtype=iou_score.dtype)
target_score[idx] = iou_score
loss_class = (
vari_sigmoid_focal_loss(
src_logits,
target_classes_onehot,
target_score,
num_boxes=num_boxes,
alpha=self.alpha,
gamma=self.gamma,
) * src_logits.shape[1]
)
losses = {"loss_class": loss_class}
return losses
import copy
import logging
import os
from typing import Dict, List, Optional, Union
import torch.distributed
import torch.jit
import torchvision
from torch import Tensor, nn
from torchvision.models.detection.image_list import ImageList
from torchvision.ops import boxes as box_ops
from transforms import functional as F
from transforms import v2 as T
from transforms.functional import InterpolationMode
from util.misc import decode_labels, encode_labels, image_list_from_tensors
from util.utils import get_world_size, is_dist_avail_and_initialized
class EvalResize(nn.Module):
"""Resize transform friendly to ONNX and torchscript"""
def __init__(
self,
min_size: int,
max_size: Optional[int] = None,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
):
super().__init__()
assert isinstance(min_size, int) and isinstance(max_size, int)
self.min_size = min_size
self.max_size = max_size
self.interpolation = interpolation
self.antialias = antialias
def forward(self, image: Tensor):
assert isinstance(image, Tensor), "Only one image Tensor is supported"
if torchvision._is_tracing():
from torch.onnx import operators
orig_height, orig_width = operators.shape_as_tensor(image)[-2:]
else:
orig_height, orig_width = torch.tensor(image.shape[-2:])
r = self.min_size / torch.min(orig_height, orig_width)
if self.max_size is not None:
r = torch.min(r, self.max_size / torch.max(orig_height, orig_width))
new_width = (orig_width * r).to(orig_width.dtype)
new_height = (orig_height * r).to(orig_width.dtype)
return F.resize(
image, size=(new_height, new_width), interpolation=self.interpolation, antialias=self.antialias
)
class BaseDetector(nn.Module):
def __init__(self, min_size=None, max_size=None):
"""Initialize BaseDetector. Before forward propagation, input images should be padded and batched,
and optionally resized. For training mode, the resize and other augmentations are done in
dataset transformation, whereas in evaluation mode, since model needs original shapes (before
any augmentation) for calculating COCO metrics, I perform resize inside the forward function.
Therefore, the input image MUST NOT have any augmentation for evaluation mode!
:param min_size: the minimum threshold to resize input images, defaults to None
:param max_size: the maximum threshold to resize input images, defaults to None
"""
super().__init__()
size = [s for s in (min_size, max_size) if isinstance(s, (int, float))]
if len(size) != 0:
eval_transform = [EvalResize(min(size), max(size), antialias=True)]
else:
eval_transform = []
eval_transform.append(T.ConvertImageDtype(torch.float))
eval_transform.append(T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), True))
self.eval_transform = nn.Sequential(*eval_transform)
self._device = None
@property
def device(self):
if self._device is None:
self._device = next(iter(self.parameters())).device
return self._device
@property
def CLASSES(self):
"""This returns the classes of the current model. By default, the class
information is encoded in a tensor named self._classes_. If not registered,
the function will use default [0, ..., num_classes - 1] as a replacement.
:return: A list contains class information of the current model.
"""
logger = logging.getLogger(os.path.basename(os.getcwd()) + "." + __name__)
if not hasattr(self, "_classes_") or self._classes_ is None:
logger.warn("register default classes for model")
dummy_classes = tuple(str(s) for s in range(self.num_classes))
self.register_buffer("_classes_", torch.tensor(encode_labels(dummy_classes)))
return decode_labels(tuple(self._classes_.tolist()))
@staticmethod
def check_boxes(targets):
for target_idx, target in enumerate(targets):
boxes = target["boxes"]
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
# print the first degenerate box
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
degen_bb: List[float] = boxes[bb_idx].tolist()
torch._assert(
False,
"All bounding boxes should have positive height and width."
f" Found invalid box {degen_bb} for target at index {target_idx}.",
)
def preprocess_image(self, images: List[Tensor]) -> ImageList:
"""Preprocess normalization and make up a batch for input images.
:param images: list of input images, each has shape of (c, h, w)
:return: ImageList of the normalized and batched images
"""
if isinstance(images, torch.Tensor):
images = images.unbind(0)
if not self.training and self.eval_transform:
images = [self.eval_transform(image) for image in images]
images = image_list_from_tensors(images)
return images
@torch.inference_mode()
def preprocess(self, images: List[Tensor], targets: List[Dict] = None):
if targets is not None:
self.check_boxes(targets)
return self.preprocess_image(images), targets
@staticmethod
def query_original_sizes(images):
if torchvision._is_tracing():
from torch.onnx import operators
if isinstance(images, torch.Tensor):
images = images.unbind(0)
original_sizes = [operators.shape_as_tensor(m)[-2:] for m in images]
original_sizes = torch.stack(original_sizes).to(images[0].device)
else:
original_sizes = [m.shape[-2:] for m in images]
original_sizes = torch.as_tensor(original_sizes, device=images[0].device)
return original_sizes
class DETRDetector(BaseDetector):
@torch.jit.unused
def _set_aux_loss(self, outputs_class, outputs_coord):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return [{"pred_logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
@staticmethod
def prepare_targets(images: ImageList, targets: List[Dict]):
# avoid overriding the input
targets = copy.deepcopy(targets)
# change boxes format from xyxy to normalized cxcywh
for image_size, target in zip(images.image_sizes, targets):
h, w = image_size
boxes = target["boxes"]
boxes = box_ops._box_xyxy_to_cxcywh(boxes)
boxes = boxes / boxes.new_tensor([w, h, w, h])
target["boxes"] = boxes
return targets
@staticmethod
def construct_mask(images: ImageList):
# construct mask
b, c, h, w = images.tensors.shape
mask = images.tensors.new_ones((b, h, w), device=images.tensors.device)
for img_id, image_size in enumerate(images.image_sizes):
mask[img_id, :image_size[0], :image_size[1]] = 0
return mask
@torch.no_grad()
def preprocess(self, images: List[Tensor], targets: List[Dict] = None):
images = self.preprocess_image(images)
mask = self.construct_mask(images)
if targets:
self.check_boxes(targets)
targets = self.prepare_targets(images, targets)
return images, targets, mask
class DNDETRDetector(DETRDetector):
def compute_dn_loss(self, outputs, targets, **kwargs):
losses = {}
device = self.device
# get num_boxes
num_boxes = sum(len(t["labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=device)
if is_dist_avail_and_initialized():
torch.distributed.all_reduce(num_boxes)
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
if "denoising_output" in outputs:
denoising_output, denoising_groups, max_gt_num_per_image = (
outputs["denoising_output"],
outputs["denoising_groups"],
outputs["max_gt_num_per_image"],
)
dn_idx = []
for i in range(len(targets)):
if len(targets[i]["labels"]) > 0:
group_index, target_index = torch.meshgrid(
torch.arange(denoising_groups, device=device),
torch.arange(len(targets[i]["labels"]), device=device),
indexing="ij",
)
output_idx = group_index * max_gt_num_per_image + target_index
output_idx = output_idx.flatten()
tgt_idx = target_index.flatten()
else:
output_idx = tgt_idx = torch.tensor([], dtype=torch.long, device=device)
dn_idx.append((output_idx, tgt_idx))
# calculate matching loss
denoising_losses = self.criterion.calculate_loss(
denoising_output,
targets,
num_boxes=num_boxes * denoising_groups,
indices=dn_idx,
**kwargs,
)
denoising_losses = {k + "_dn": v for k, v in denoising_losses.items()}
losses.update(denoising_losses)
# calculate auxiliary loss
for i in range(len(denoising_output.get("aux_outputs", []))):
denoising_output_aux = denoising_output["aux_outputs"][i]
denoising_losses = self.criterion.calculate_loss(
denoising_output_aux,
targets,
num_boxes=num_boxes * denoising_groups,
indices=dn_idx,
**kwargs,
)
denoising_losses = {k + f"_dn_{i}": v for k, v in denoising_losses.items()}
losses.update(denoising_losses)
return losses
def dn_post_process(self, outputs_class, outputs_coord, dn_metas):
if dn_metas and "max_gt_num_per_image" in dn_metas:
padding_size = dn_metas["max_gt_num_per_image"] * dn_metas["denoising_groups"]
output_known_class = outputs_class[:, :, :padding_size, :]
output_known_coord = outputs_coord[:, :, :padding_size, :]
outputs_class = outputs_class[:, :, padding_size:, :]
outputs_coord = outputs_coord[:, :, padding_size:, :]
out = {
"pred_logits": output_known_class[-1],
"pred_boxes": output_known_coord[-1],
}
if self.aux_loss:
out["aux_outputs"] = self._set_aux_loss(output_known_class, output_known_coord)
dn_metas["denoising_output"] = out
return outputs_class, outputs_coord
from typing import Dict, List, Tuple
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from torchvision.ops import boxes as box_ops
from models.bricks.denoising import GenerateCDNQueries
from models.bricks.losses import sigmoid_focal_loss
from models.detectors.base_detector import DNDETRDetector
class SalienceCriterion(nn.Module):
def __init__(
self,
limit_range: Tuple = ((-1, 64), (64, 128), (128, 256), (256, 99999)),
noise_scale: float = 0.0,
alpha: float = 0.25,
gamma: float = 2.0,
):
super().__init__()
self.limit_range = limit_range
self.noise_scale = noise_scale
self.alpha = alpha
self.gamma = gamma
def forward(self, foreground_mask, targets, feature_strides, image_sizes):
gt_boxes_list = []
for t, (img_h, img_w) in zip(targets, image_sizes):
boxes = t["boxes"]
boxes = box_ops._box_cxcywh_to_xyxy(boxes)
scale_factor = torch.tensor([img_w, img_h, img_w, img_h], device=boxes.device)
gt_boxes_list.append(boxes * scale_factor)
mask_targets = []
for level_idx, (mask, feature_stride) in enumerate(zip(foreground_mask, feature_strides)):
feature_shape = mask.shape[-2:]
coord_x, coord_y = self.get_pixel_coordinate(feature_shape, feature_stride, device=mask.device)
masks_per_level = []
for gt_boxes in gt_boxes_list:
mask = self.get_mask_single_level(coord_x, coord_y, gt_boxes, level_idx)
masks_per_level.append(mask)
masks_per_level = torch.stack(masks_per_level)
mask_targets.append(masks_per_level)
mask_targets = torch.cat(mask_targets, dim=1)
foreground_mask = torch.cat([e.flatten(-2) for e in foreground_mask], -1)
foreground_mask = foreground_mask.squeeze(1)
num_pos = torch.sum(mask_targets > 0.5 * self.noise_scale).clamp_(min=1)
salience_loss = (
sigmoid_focal_loss(
foreground_mask,
mask_targets,
num_pos,
alpha=self.alpha,
gamma=self.gamma,
) * foreground_mask.shape[1]
)
return {"loss_salience": salience_loss}
def get_pixel_coordinate(self, feature_shape, stride, device):
height, width = feature_shape
coord_y, coord_x = torch.meshgrid(
torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device) * stride[0],
torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device) * stride[1],
indexing="ij",
)
coord_y = coord_y.reshape(-1)
coord_x = coord_x.reshape(-1)
return coord_x, coord_y
def get_mask_single_level(self, coord_x, coord_y, gt_boxes, level_idx):
# gt_label: (m,) gt_boxes: (m, 4)
# coord_x: (h*w, )
left_border_distance = coord_x[:, None] - gt_boxes[None, :, 0] # (h*w, m)
top_border_distance = coord_y[:, None] - gt_boxes[None, :, 1]
right_border_distance = gt_boxes[None, :, 2] - coord_x[:, None]
bottom_border_distance = gt_boxes[None, :, 3] - coord_y[:, None]
border_distances = torch.stack(
[left_border_distance, top_border_distance, right_border_distance, bottom_border_distance],
dim=-1,
) # [h*w, m, 4]
# the foreground queries must satisfy two requirements:
# 1. the quereis located in bounding boxes
# 2. the distance from queries to the box center match the feature map stride
min_border_distances = torch.min(border_distances, dim=-1)[0] # [h*w, m]
max_border_distances = torch.max(border_distances, dim=-1)[0]
mask_in_gt_boxes = min_border_distances > 0
min_limit, max_limit = self.limit_range[level_idx]
mask_in_level = (max_border_distances > min_limit) & (max_border_distances <= max_limit)
mask_pos = mask_in_gt_boxes & mask_in_level
# scale-independent salience confidence
row_factor = left_border_distance + right_border_distance
col_factor = top_border_distance + bottom_border_distance
delta_x = (left_border_distance - right_border_distance) / row_factor
delta_y = (top_border_distance - bottom_border_distance) / col_factor
confidence = torch.sqrt(delta_x**2 + delta_y**2) / 2
confidence_per_box = 1 - confidence
confidence_per_box[~mask_in_gt_boxes] = 0
# process positive coordinates
if confidence_per_box.numel() != 0:
mask = confidence_per_box.max(-1)[0]
else:
mask = torch.zeros(coord_y.shape, device=confidence.device, dtype=confidence.dtype)
# process negative coordinates
mask_pos = mask_pos.long().sum(dim=-1) >= 1
mask[~mask_pos] = 0
# add noise to add randomness
mask = (1 - self.noise_scale) * mask + self.noise_scale * torch.rand_like(mask)
return mask
# SalienceDETR has the architecture similar to FocusDETR
class SalienceDETR(DNDETRDetector):
def __init__(
# model structure
self,
backbone: nn.Module,
neck: nn.Module,
position_embedding: nn.Module,
transformer: nn.Module,
criterion: nn.Module,
postprocessor: nn.Module,
focus_criterion: nn.Module,
# model parameters
num_classes: int = 91,
num_queries: int = 900,
denoising_nums: int = 100,
# model variants
aux_loss: bool = True,
min_size: int = None,
max_size: int = None,
):
super().__init__(min_size, max_size)
# define model parameters
self.num_classes = num_classes
self.aux_loss = aux_loss
embed_dim = transformer.embed_dim
# define model structures
self.backbone = backbone
self.neck = neck
self.position_embedding = position_embedding
self.transformer = transformer
self.criterion = criterion
self.postprocessor = postprocessor
self.denoising_generator = GenerateCDNQueries(
num_queries=num_queries,
num_classes=num_classes,
label_embed_dim=embed_dim,
denoising_nums=denoising_nums,
label_noise_prob=0.5,
box_noise_scale=1.0,
)
self.focus_criterion = focus_criterion
def forward(self, images: List[Tensor], targets: List[Dict] = None):
# get original image sizes, used for postprocess
original_image_sizes = self.query_original_sizes(images)
images, targets, mask = self.preprocess(images, targets)
# extract features
multi_level_feats = self.backbone(images.tensors)
multi_level_feats = self.neck(multi_level_feats)
multi_level_masks = []
multi_level_position_embeddings = []
for feature in multi_level_feats:
multi_level_masks.append(F.interpolate(mask[None], size=feature.shape[-2:]).to(torch.bool)[0])
multi_level_position_embeddings.append(self.position_embedding(multi_level_masks[-1]))
if self.training:
# collect ground truth for denoising generation
gt_labels_list = [t["labels"] for t in targets]
gt_boxes_list = [t["boxes"] for t in targets]
noised_results = self.denoising_generator(gt_labels_list, gt_boxes_list)
noised_label_query = noised_results[0]
noised_box_query = noised_results[1]
attn_mask = noised_results[2]
denoising_groups = noised_results[3]
max_gt_num_per_image = noised_results[4]
else:
noised_label_query = None
noised_box_query = None
attn_mask = None
denoising_groups = None
max_gt_num_per_image = None
# feed into transformer
outputs_class, outputs_coord, enc_class, enc_coord, foreground_mask = self.transformer(
multi_level_feats,
multi_level_masks,
multi_level_position_embeddings,
noised_label_query,
noised_box_query,
attn_mask=attn_mask,
)
# hack implementation for distributed training
outputs_class[0] += self.denoising_generator.label_encoder.weight[0, 0] * 0.0
# denoising postprocessing
if denoising_groups is not None and max_gt_num_per_image is not None:
dn_metas = {
"denoising_groups": denoising_groups,
"max_gt_num_per_image": max_gt_num_per_image,
}
outputs_class, outputs_coord = self.dn_post_process(outputs_class, outputs_coord, dn_metas)
# prepare for loss computation
output = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
if self.aux_loss:
output["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord)
# prepare two stage output
output["enc_outputs"] = {"pred_logits": enc_class, "pred_boxes": enc_coord}
if self.training:
# compute loss
loss_dict = self.criterion(output, targets)
dn_losses = self.compute_dn_loss(dn_metas, targets)
loss_dict.update(dn_losses)
# compute focus loss
feature_stride = [(
images.tensors.shape[-2] / feature.shape[-2],
images.tensors.shape[-1] / feature.shape[-1],
) for feature in multi_level_feats]
focus_loss = self.focus_criterion(foreground_mask, targets, feature_stride, images.image_sizes)
loss_dict.update(focus_loss)
# loss reweighting
weight_dict = self.criterion.weight_dict
loss_dict = dict((k, loss_dict[k] * weight_dict[k]) for k in loss_dict.keys() if k in weight_dict)
return loss_dict
detections = self.postprocessor(output, original_image_sizes)
return detections
import torch
from scipy.optimize import linear_sum_assignment
from torch import Tensor, nn
from torchvision.ops.boxes import _box_cxcywh_to_xyxy, generalized_box_iou
class HungarianMatcher(nn.Module):
"""This class implements the Hungarian matching algorithm for bipartite graphs. It matches predicted bounding
boxes to ground truth boxes based on the minimum cost assignment. The cost is computed as a weighted sum of
classification, bounding box, and generalized intersection over union (IoU) costs. The focal loss is used to
weigh the classification cost. The HungarianMatcher class can be used in single or mixed assignment modes.
The mixed assignment modes is introduced in `Align-DETR <https://arxiv.org/abs/2304.07527>`_.
:param cost_class: The weight of the classification cost, defaults to 1
:param cost_bbox: The weight of the bounding box cost, defaults to 1
:param cost_giou: The weight of the generalized IoU cost, defaults to 1
:param focal_alpha: The alpha parameter of the focal loss, defaults to 0.25
:param focal_gamma: The gamma parameter of the focal loss, defaults to 2.0
:param mixed_match: If True, mixed assignment is used, defaults to False
"""
def __init__(
self,
cost_class: float = 1,
cost_bbox: float = 1,
cost_giou: float = 1,
focal_alpha: float = 0.25,
focal_gamma: float = 2.0,
mixed_match: bool = False,
):
super().__init__()
self.cost_class = cost_class
self.cost_bbox = cost_bbox
self.cost_giou = cost_giou
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma
self.mixed_match = mixed_match
def calculate_class_cost(self, pred_logits, gt_labels, **kwargs):
out_prob = pred_logits.sigmoid()
# Compute the classification cost.
neg_cost_class = -(1 - self.focal_alpha) * out_prob**self.focal_gamma * (1 - out_prob + 1e-6).log()
pos_cost_class = -self.focal_alpha * (1 - out_prob)**self.focal_gamma * (out_prob + 1e-6).log()
cost_class = pos_cost_class[:, gt_labels] - neg_cost_class[:, gt_labels]
return cost_class
def calculate_bbox_cost(self, pred_boxes, gt_boxes, **kwargs):
# Compute the L1 cost between boxes
cost_bbox = torch.cdist(pred_boxes, gt_boxes, p=1)
return cost_bbox
def calculate_giou_cost(self, pred_boxes, gt_boxes, **kwargs):
# Compute the giou cost betwen boxes
cost_giou = -generalized_box_iou(_box_cxcywh_to_xyxy(pred_boxes), _box_cxcywh_to_xyxy(gt_boxes))
return cost_giou
@torch.no_grad()
def calculate_cost(self, pred_boxes: Tensor, pred_logits: Tensor, gt_boxes: Tensor, gt_labels: Tensor):
# Calculate class, bbox and giou cost
cost_class = self.calculate_class_cost(pred_logits, gt_labels)
cost_bbox = self.calculate_bbox_cost(pred_boxes, gt_boxes)
cost_giou = self.calculate_giou_cost(pred_boxes, gt_boxes)
# Final cost matrix
c = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
return c
@torch.no_grad()
def forward(
self, pred_boxes: Tensor, pred_logits: Tensor, gt_boxes: Tensor, gt_labels: Tensor, gt_copy: int = 1
):
c = self.calculate_cost(pred_boxes, pred_logits, gt_boxes, gt_labels)
# single assignment
if not self.mixed_match:
indices = linear_sum_assignment(c.cpu())
return torch.as_tensor(indices[0]), torch.as_tensor(indices[1])
# mixed assignment, used in AlignDETR
gt_size = c.size(-1)
num_queries = len(c)
gt_copy = min(int(num_queries * 0.5 / gt_size), gt_copy) if gt_size > 0 else gt_copy
src_ind, tgt_ind = linear_sum_assignment(c.cpu().repeat(1, gt_copy))
tgt_ind = tgt_ind % gt_size
tgt_ind, ind = torch.as_tensor(tgt_ind, dtype=torch.int64).sort()
src_ind = torch.as_tensor(src_ind, dtype=torch.int64)[ind].view(-1)
return src_ind, tgt_ind
from functools import partial
from typing import List
from torch import nn
from models.bricks.misc import Conv2dNormActivation
class ChannelMapper(nn.Module):
def __init__(
self,
in_channels: List[int],
out_channels: int,
num_outs: int,
kernel_size: int = 1,
stride: int = 1,
groups: int = 1,
norm_layer=partial(nn.GroupNorm, 32),
activation_layer: nn.Module = None,
dilation: int = 1,
inplace: bool = True,
bias: bool = None,
):
self.in_channels = in_channels
super().__init__()
self.convs = nn.ModuleList()
self.num_channels = [out_channels] * num_outs
for in_channel in in_channels:
self.convs.append(
Conv2dNormActivation(
in_channels=in_channel,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
bias=bias,
groups=groups,
dilation=dilation,
norm_layer=norm_layer,
activation_layer=activation_layer,
inplace=inplace,
)
)
for _ in range(num_outs - len(in_channels)):
self.convs.append(
Conv2dNormActivation(
in_channels=in_channel,
out_channels=out_channels,
kernel_size=3,
stride=2,
padding=1,
bias=bias,
groups=groups,
dilation=dilation,
norm_layer=norm_layer,
activation_layer=activation_layer,
inplace=inplace,
)
)
in_channel = out_channels
self.init_weights()
def init_weights(self):
# initialize modules
for layer in self.modules():
if isinstance(layer, nn.Conv2d):
nn.init.xavier_uniform_(layer.weight, gain=1)
if layer.bias:
nn.init.constant_(layer.bias, 0)
def forward(self, inputs):
inputs = list(inputs.values())
assert len(inputs) == len(self.in_channels)
outs = [self.convs[i](inputs[i]) for i in range(len(inputs))]
for i in range(len(inputs), len(self.convs)):
if i == len(inputs):
outs.append(self.convs[i](inputs[-1]))
else:
outs.append(self.convs[i](outs[-1]))
return outs
from collections import OrderedDict
from typing import List
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from models.bricks.basic import SqueezeAndExcitation
from models.bricks.misc import Conv2dNormActivation
class RepVggPluXBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
activation_layer: nn.Module = nn.ReLU,
inplace: bool = True,
groups: int = 4,
alpha: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.activation = activation_layer(inplace=True)
self.conv1 = Conv2dNormActivation(
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
groups=groups,
activation_layer=None,
inplace=inplace,
)
self.conv2 = Conv2dNormActivation(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
groups=groups,
activation_layer=None,
inplace=inplace,
)
self.alpha = nn.Parameter(torch.tensor(1.0)) if alpha else 1.0
self.se_module = SqueezeAndExcitation(channels=out_channels,)
if self.in_channels != self.out_channels:
self.identity = nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
)
else:
self.identity = nn.Identity()
def forward(self, x: Tensor) -> Tensor:
y = self.conv1(x) + self.alpha * self.conv2(x)
y = self.se_module(self.activation(y))
return y + self.identity(x)
class CSPRepPluXLayer(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_blocks: int = 3,
expansion: float = 1.0,
groups: int = 4,
norm_layer: nn.Module = nn.BatchNorm2d,
activation_layer: nn.Module = nn.SiLU,
):
super().__init__()
hidden_channels = int(out_channels * expansion)
self.conv1 = Conv2dNormActivation(
in_channels,
hidden_channels,
kernel_size=1,
stride=1,
norm_layer=norm_layer,
activation_layer=activation_layer,
inplace=True,
)
self.conv2 = Conv2dNormActivation(
in_channels,
hidden_channels,
kernel_size=1,
stride=1,
norm_layer=norm_layer,
activation_layer=activation_layer,
inplace=True,
)
self.bottlenecks = nn.Sequential(
*[
RepVggPluXBlock(
hidden_channels,
hidden_channels,
groups=groups,
activation_layer=activation_layer,
) for _ in range(num_blocks)
]
)
if hidden_channels != out_channels:
self.conv3 = Conv2dNormActivation(
hidden_channels,
out_channels,
kernel_size=1,
stride=1,
norm_layer=norm_layer,
activation_layer=activation_layer,
)
else:
self.conv3 = nn.Identity()
def forward(self, x: Tensor) -> Tensor:
x = self.bottlenecks(self.conv1(x)) + self.conv2(x)
x = self.conv3(x)
return x
class RepVGGPluXNetwork(nn.Module):
def __init__(
self,
in_channels_list: List[int],
out_channels_list: List[int],
groups: int = 4,
norm_layer: nn.Module = nn.BatchNorm2d,
activation: nn.Module = nn.SiLU,
extra_block: bool = False,
):
"""The implementation RepVGGPluXNetwork, the network is basically built with RepVGGPluxBlock
upon PathAggregationNetwork.
:param in_channels_list: input channels list, example: [256, 512, 1024, 2048]
:param out_channels_list: output channel list, example: [256, 512, 1024, 2048]
:param groups: number of groups used on GroupConvolution in RepVGGPluXBlock, defaults to 4
:param norm_layer: norm layer type, defaults to nn.BatchNorm2d
:param activation: activation layer type, defaults to nn.SiLU
:param extra_block: whether to add an extra block, defaults to False
"""
super(RepVGGPluXNetwork, self).__init__()
for idx in range(len(in_channels_list)):
if in_channels_list[idx] == 0:
raise ValueError("in_channels=0 is currently not supported")
self.lateral_convs = nn.ModuleList()
self.layer_blocks = nn.ModuleList()
for idx in range(1, len(out_channels_list)):
lateral_conv_module = Conv2dNormActivation(
out_channels_list[idx],
out_channels_list[idx - 1],
kernel_size=1,
stride=1,
norm_layer=norm_layer,
activation_layer=activation,
inplace=True,
)
layer_block_module = CSPRepPluXLayer(
out_channels_list[idx - 1] * 2,
out_channels_list[idx - 1],
groups=groups,
norm_layer=norm_layer,
activation_layer=activation,
)
self.lateral_convs.append(lateral_conv_module)
self.layer_blocks.append(layer_block_module)
self.downsample_blocks = nn.ModuleList()
self.pan_blocks = nn.ModuleList()
for idx in range(len(in_channels_list) - 1):
downsample_block_module = Conv2dNormActivation(
out_channels_list[idx],
out_channels_list[idx + 1],
kernel_size=3,
stride=2,
padding=1,
norm_layer=norm_layer,
activation_layer=activation,
inplace=True,
)
pan_block_module = CSPRepPluXLayer(
out_channels_list[idx + 1] * 2,
out_channels_list[idx + 1],
groups=groups,
norm_layer=norm_layer,
activation_layer=activation,
)
self.downsample_blocks.append(downsample_block_module)
self.pan_blocks.append(pan_block_module)
self.extra_block = extra_block
self.init_weights()
def init_weights(self):
# initialize parameters now to avoid modifying the initialization of top_blocks
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform_(m.weight, a=1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x: OrderedDict):
keys = list(x.keys())
x = list(x.values())
assert len(x) == len(self.layer_blocks) + 1
# top down path
results = x
inner_outs = [results[-1]]
for idx in range(len(results) - 1, 0, -1):
feat_high = inner_outs[0]
feat_low = results[idx - 1]
feat_high = self.lateral_convs[idx - 1](feat_high)
inner_outs[0] = feat_high
upsample_feat = F.interpolate(
feat_high,
size=feat_low.shape[-2:],
mode="nearest",
)
inner_out = self.layer_blocks[idx - 1](torch.cat([upsample_feat, feat_low], dim=1))
inner_outs.insert(0, inner_out)
# bottom up path
results = [inner_outs[0]]
for idx in range(len(inner_outs) - 1):
feat_low = results[-1]
feat_high = inner_outs[idx + 1]
downsample_feat = self.downsample_blocks[idx](feat_low)
out = self.pan_blocks[idx](torch.cat([downsample_feat, feat_high], dim=1))
results.append(out)
# output layer
output = OrderedDict()
for idx in range(len(x)):
output[keys[idx]] = results[idx]
# extra block
if self.extra_block:
output["pool"] = F.max_pool2d(list(output.values())[-1], 1, 2, 0)
return output
from typing import List, Tuple, Union
from torch import nn
def match_name_keywords(name: str, name_keywords: Union[Tuple, List, str]):
if isinstance(name_keywords, str):
name_keywords = [name_keywords]
for b in name_keywords:
if b in name:
return True
return False
def finetune_backbone_param(model, lr):
return [
{
"params": [
p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad
]
},
{
"params": [
p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad
],
"lr": lr * 0.1,
},
]
def finetune_backbone_with_no_norm_weight_decay(model, lr):
norm_classes = (
nn.modules.batchnorm._BatchNorm,
nn.LayerNorm,
nn.GroupNorm,
nn.modules.instancenorm._InstanceNorm,
nn.LocalResponseNorm,
)
backbone_norm = []
other_norm = []
backbone = []
other = []
for name, module in model.named_modules():
if next(module.children(), None):
if "backbone" in name:
backbone.extend(p for p in module.parameters(recurse=False) if p.requires_grad)
else:
other.extend(p for p in module.parameters(recurse=False) if p.requires_grad)
elif isinstance(module, norm_classes):
if "backbone" in name:
backbone_norm.extend(p for p in module.parameters() if p.requires_grad)
else:
other_norm.extend(p for p in module.parameters() if p.requires_grad)
else:
if "backbone" in name:
backbone.extend(p for p in module.parameters() if p.requires_grad)
else:
other.extend(p for p in module.parameters() if p.requires_grad)
return [
{
"params": other,
},
{
"params": backbone_norm,
"lr": lr * 0.1,
"weight_decay": 0,
},
{
"params": other_norm,
"weight_decay": 0,
},
{
"params": backbone,
"lr": lr * 0.1,
},
]
def finetune_backbone_and_linear_projection(model, lr):
linear_keywords = ("reference_points", "sampling_offsets")
norm_bias_keywords = ("norm", "bias")
backbone = []
backbone_norm = []
linear_projection = []
linear_projection_norm = []
other = []
other_norm = []
for name, parameters in model.named_parameters():
if not parameters.requires_grad:
continue
if (
match_name_keywords(name, "backbone")
and not match_name_keywords(name, linear_keywords)
and match_name_keywords(name, norm_bias_keywords)
):
backbone_norm.append(parameters)
elif (
match_name_keywords(name, "backbone")
and not match_name_keywords(name, linear_keywords)
and not match_name_keywords(name, norm_bias_keywords)
):
backbone.append(parameters)
elif (
not match_name_keywords(name, "backbone")
and match_name_keywords(name, linear_keywords)
and match_name_keywords(name, norm_bias_keywords)
):
linear_projection_norm.append(parameters)
elif (
not match_name_keywords(name, "backbone")
and match_name_keywords(name, linear_keywords)
and not match_name_keywords(name, norm_bias_keywords)
):
linear_projection.append(parameters)
elif match_name_keywords(name, norm_bias_keywords):
other_norm.append(parameters)
else:
other.append(parameters)
return [
{
"params": other,
},
{
"params": backbone,
"lr": lr * 0.1,
},
{
"params": backbone_norm,
"lr": lr * 0.1,
"weight_decay": 0,
},
{
"params": linear_projection,
"lr": lr * 0.1,
},
{
"params": linear_projection_norm,
"lr": lr * 0.1,
"weight_decay": 0,
},
{
"params": other_norm,
"weight_decay": 0,
},
]
absl-py==2.1.0
accelerate==0.30.1
albucore==0.0.8
albumentations==1.4.8
annotated-types==0.7.0
antlr4-python3-runtime==4.9.3
astunparse==1.6.3
cachetools==5.3.3
certifi==2024.6.2
charset-normalizer==3.3.2
contourpy==1.1.1
cycler==0.12.1
filelock==3.14.0
fonttools==4.53.0
fsspec==2024.6.0
fvcore==0.1.5.post20221221
google-auth==2.29.0
google-auth-oauthlib==1.0.0
grpcio==1.64.1
huggingface-hub==0.23.2
idna==3.7
imageio==2.34.1
importlib_metadata==7.1.0
importlib_resources==6.4.0
iopath==0.1.10
Jinja2==3.1.4
joblib==1.4.2
kiwisolver==1.4.5
lazy_loader==0.4
Markdown==3.6
MarkupSafe==2.1.5
matplotlib==3.7.5
mpmath==1.3.0
networkx==3.1
ninja==1.11.1.1
numpy==1.24.4
oauthlib==3.2.2
omegaconf==2.3.0
opencv-python-headless==4.10.0.82
packaging==24.0
pillow==10.3.0
portalocker==2.8.2
protobuf==5.27.0
psutil==5.9.8
pyasn1==0.6.0
pyasn1_modules==0.4.0
pycocotools==2.0.7
pydantic==2.7.3
pydantic_core==2.18.4
pyparsing==3.1.2
python-dateutil==2.9.0.post0
PyWavelets==1.4.1
PyYAML==6.0.1
requests==2.32.3
requests-oauthlib==2.0.0
rsa==4.9
safetensors==0.4.3
scikit-image==0.21.0
scikit-learn==1.3.2
scipy==1.10.1
six==1.16.0
sympy==1.12.1
tabulate==0.9.0
tensorboard==2.14.0
tensorboard-data-server==0.7.2
termcolor==2.4.0
terminaltables==3.1.10
threadpoolctl==3.5.0
tifffile==2023.7.10
tomli==2.0.1
tqdm==4.66.4
typing_extensions==4.12.1
urllib3==1.25.1
Werkzeug==3.0.3
yacs==0.1.8
zipp==3.19.2
import argparse
import contextlib
import io
import json
import logging
import os
import tempfile
from typing import Dict
import accelerate
import torch
from accelerate import Accelerator
from pycocotools.coco import COCO
from torch.utils import data
from datasets.coco import CocoDetection
from util.coco_eval import CocoEvaluator, loadRes
from util.coco_utils import get_coco_api_from_dataset
from util.collate_fn import collate_fn
from util.engine import evaluate_acc
from util.lazy_load import Config
from util.logger import setup_logger
from util.misc import fixed_generator, seed_worker
from util.utils import load_checkpoint, load_state_dict
from util.visualize import visualize_coco_bounding_boxes
def parse_args():
parser = argparse.ArgumentParser(description="Test on a datasets.")
# dataset parameters
parser.add_argument("--coco-path", type=str, required=True)
parser.add_argument("--subset", type=str, default="val")
parser.add_argument("--workers", type=int, default=0)
# choose model to inference on dataset or result_file
parser.add_argument("--model-config", type=str, default=None)
parser.add_argument("--checkpoint", type=str, default=None)
parser.add_argument("--result", type=str, default=None)
# visualize parameters
parser.add_argument("--show-dir", type=str, default=None)
parser.add_argument("--show-conf", type=float, default=0.5)
# plot parameters
parser.add_argument("--font-scale", type=float, default=1.0)
parser.add_argument("--box-thick", type=int, default=1)
parser.add_argument("--fill-alpha", type=float, default=0.2)
parser.add_argument("--text-box-color", type=int, nargs="+", default=(255, 255, 255))
parser.add_argument("--text-font-color", type=int, nargs="+", default=None)
parser.add_argument("--text-alpha", type=float, default=1.0)
# engine parameters
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
return args
def create_test_data_loader(dataset, accelerator=None, **kwargs):
data_loader = data.DataLoader(
dataset,
shuffle=False,
worker_init_fn=seed_worker,
generator=fixed_generator(),
**kwargs,
)
if accelerator:
data_loader = accelerator.prepare_data_loader(data_loader)
return data_loader
def test_on_dataset():
args = parse_args()
# set fixed seed and deterministic_algorithms
accelerator = Accelerator(cpu=args.model_config is None)
accelerate.utils.set_seed(args.seed, device_specific=False)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# deterministic in low version pytorch leads to RuntimeError
# torch.use_deterministic_algorithms(True, warn_only=True)
# setup logger
for logger_name in ["py.warnings", "accelerate", os.path.basename(os.getcwd())]:
setup_logger(distributed_rank=accelerator.local_process_index, name=logger_name)
logger = logging.getLogger(os.path.basename(os.getcwd()))
# get dataset
dataset = CocoDetection(
img_folder=f"{args.coco_path}/{args.subset}2017",
ann_file=f"{args.coco_path}/annotations/instances_{args.subset}2017.json",
transforms=None, # the eval_transform is integrated in the model
train=args.subset == "train",
)
data_loader = create_test_data_loader(
dataset,
accelerator=accelerator,
batch_size=1,
num_workers=args.workers,
collate_fn=collate_fn,
)
# get evaluation results from model output
if args.model_config:
model = Config(args.model_config).model.eval()
checkpoint = load_checkpoint(args.checkpoint)
if isinstance(checkpoint, Dict) and "model" in checkpoint:
checkpoint = checkpoint["model"]
load_state_dict(model, checkpoint)
model = accelerator.prepare_model(model)
coco_evaluator = evaluate_acc(model, data_loader, 0, accelerator)
# if not given path to save results, use temp file
if args.result is None:
temp_file = tempfile.NamedTemporaryFile()
args.result = temp_file.name
# save prediction results
with open(args.result, "w") as f:
det_results = coco_evaluator.predictions["bbox"]
f.write(json.dumps(det_results))
logger.info(f"Detection results are saved into {args.result}")
# get evaluation results from json file
if args.model_config is None or args.show_dir and accelerator.is_main_process:
coco_dt = loadRes(COCO(f"{args.coco_path}/annotations/instances_{args.subset}2017.json"), args.result)
# if not given model, evaluate COCO metric on predicted json results
if args.model_config is None and accelerator.is_main_process:
coco = get_coco_api_from_dataset(data_loader.dataset)
coco_evaluator = CocoEvaluator(coco, ["bbox"])
coco_evaluator.coco_eval["bbox"].cocoDt = coco_dt
coco_evaluator.coco_eval["bbox"].evaluate()
redirect_string = io.StringIO()
with contextlib.redirect_stdout(redirect_string):
coco_evaluator.accumulate()
coco_evaluator.summarize()
logger.info(redirect_string.getvalue())
# plot results for each image
if args.show_dir and accelerator.is_main_process:
accelerator.state.device = "cpu" # change device to CPU for plot
dataset.coco = coco_dt # load predicted results into data_loader
data_loader = create_test_data_loader(
dataset, accelerator=accelerator, batch_size=1, num_workers=args.workers
)
visualize_coco_bounding_boxes(
data_loader=data_loader,
show_conf=args.show_conf,
show_dir=args.show_dir,
font_scale=args.font_scale,
box_thick=args.box_thick,
fill_alpha=args.fill_alpha,
text_box_color=args.text_box_color,
text_font_color=args.text_font_color,
text_alpha=args.text_alpha,
)
if __name__ == "__main__":
test_on_dataset()
import argparse
import os
import sys
import numpy as np
import torch
from fvcore.nn import FlopCountAnalysis, flop_count_table
from tqdm import tqdm
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from util.lazy_load import Config
def parse_args():
parser = argparse.ArgumentParser(description="Benchmarking a model")
parser.add_argument("--model-config", type=str, required=True)
parser.add_argument("--shape", type=int, nargs="+", default=(1333, 800), help="input image size")
parser.add_argument("--repeat", type=int, default=50)
parser.add_argument("--device", default="cuda")
args = parser.parse_args()
return args
def get_flops():
args = parse_args()
# initialize model
model = Config(args.model_config).model
model.eval_transform = None
model.eval().to(args.device)
# test FLOPs
image = torch.randn(3, args.shape[0], args.shape[1]).to(args.device)
flops = FlopCountAnalysis(model, ((image,),))
print(flop_count_table(flops))
# test memory allocation
print(f"Memory allocation {torch.cuda.memory_allocated() / 1024**3} GB")
print(f"Max memory allocation {torch.cuda.max_memory_allocated() / 1024**3} GB")
# test model parameters
print(f"Model parameters {sum(p.numel() for p in model.parameters()) / 1024**3} GB")
# test inference time
print("warm up...")
with torch.inference_mode():
for _ in range(10):
_ = model((image,))
torch.cuda.synchronize()
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
timings = np.zeros((args.repeat, 1))
print("testing inference time...")
with torch.inference_mode():
for rep in tqdm(range(args.repeat)):
starter.record()
_ = model((image,))
ender.record()
torch.cuda.synchronize()
curr_time = starter.elapsed_time(ender)
timings[rep] = curr_time
avg = timings.sum() / rep
print(f"avg inference time per image = {avg / 1000}")
if __name__ == "__main__":
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
get_flops()
import argparse
import os
import sys
import warnings
from typing import Dict, List, Tuple
import numpy as np
import onnx
import torch
from torch import Tensor
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from util import utils
from util.lazy_load import Config
class ONNXDetector:
def __init__(self, onnx_file):
import onnxruntime
self.session = onnxruntime.InferenceSession(
onnx_file, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)
self.io_binding = self.session.io_binding()
self.is_cuda_available = onnxruntime.get_device() == "GPU"
def __call__(self, images: List[Tensor], targets: List[Dict] = None):
if targets is not None:
warnings.warn("Currently ONNXDetector only support inference, targets will be ignored")
assert len(images) == 1, "Currently ONNXDetector only support batch_size=1 for inference"
assert images[0].ndim == 3, "Each image must be with three dimensions of C, H, W"
if isinstance(images, (List, Tuple)):
images = torch.stack(images)
# set io binding for inputs/outputs
device_type = images.device.type if self.is_cuda_available else "cpu"
if not self.is_cuda_available:
images = images.cpu()
self.io_binding.bind_input(
name="images",
device_type=device_type,
device_id=0,
element_type=np.float32,
shape=images.shape,
buffer_ptr=images.data_ptr(),
)
for output in self.session.get_outputs():
self.io_binding.bind_output(output.name)
# run session to get outputs
self.session.run_with_iobinding(self.io_binding)
detections = self.io_binding.copy_outputs_to_cpu()
return detections
def parse_args():
parser = argparse.ArgumentParser(description="Convert a pytorch model to ONNX model")
# model parameters
parser.add_argument("--model-config", type=str, default=None)
parser.add_argument("--checkpoint", type=str, default=None)
parser.add_argument("--shape", type=int, nargs="+", default=(1333, 800))
# save parameters
parser.add_argument("--save-file", type=str, required=True)
# onnx parameters
parser.add_argument("--opset-version", type=int, default=17)
parser.add_argument("--dynamic-export", type=bool, default=True)
parser.add_argument("--simplify", action="store_true")
parser.add_argument("--verify", action="store_true")
args = parser.parse_args()
return args
def pytorch2onnx():
# get args from parser
args = parse_args()
model = Config(args.model_config).model
model.eval()
if args.checkpoint:
checkpoint = torch.load(args.checkpoint, map_location="cpu")
utils.load_state_dict(model, checkpoint["model"] if "model" in checkpoint else checkpoint)
image = torch.randn(1, 3, args.shape[0], args.shape[1])
if args.dynamic_export:
dynamic_axes = {
"images": {
0: "batch",
2: "height",
3: "width",
},
}
else:
dynamic_axes = None
torch.onnx.export(
model=model,
args=image,
f=args.save_file,
input_names=["images"],
output_names=["scores", "labels", "boxes"],
dynamic_axes=dynamic_axes,
opset_version=args.opset_version,
)
if args.simplify:
import onnxsim
model_ops, check_ok = onnxsim.simplify(args.save_file)
if check_ok:
onnx.save(model_ops, args.save_file)
print(f"Successfully simplified ONNX model: {args.save_file}")
else:
warnings.warn("Failed to simplify ONNX model.")
print(f"Successfully exported ONNX model: {args.save_file}")
if args.verify:
# check by onnx
onnx_model = onnx.load(args.save_file)
onnx.checker.check_model(onnx_model)
# check onnx results and pytorch results
onnx_model = ONNXDetector(args.save_file)
onnx_results = onnx_model(image)
pytorch_results = list(model(image)[0].values())
err_msg = "The numerical values are different between Pytorch and ONNX"
err_msg += "But it does not necessarily mean the exported ONNX is problematic."
for onnx_res, pytorch_res in zip(onnx_results, pytorch_results):
np.testing.assert_allclose(onnx_res, pytorch_res, rtol=1e-3, atol=1e-5, err_msg=err_msg)
print("The numerical values are the same between Pytorch and ONNX")
if __name__ == "__main__":
pytorch2onnx()
import argparse
import os
import sys
from torch.utils import data
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from datasets.coco import CocoDetection
from transforms import presets
from transforms import v2 as T
from util.collate_fn import collate_fn
from util.logger import setup_logger
from util.misc import fixed_generator, seed_worker
from util.visualize import visualize_coco_bounding_boxes
def parse_args():
parser = argparse.ArgumentParser(description="Visualize a datasets")
# dataset parameters
parser.add_argument("--coco-img", type=str, required=True)
parser.add_argument("--coco-ann", type=str, required=True)
parser.add_argument("--transform", type=str, default=None)
parser.add_argument("--workers", type=int, default=2)
# visualize parameters
parser.add_argument("--show-dir", type=str, default=None, required=True)
parser.add_argument("--show-conf", type=float, default=0.5)
# plot parameters
parser.add_argument("--font-scale", type=float, default=1.0)
parser.add_argument("--box-thick", type=int, default=1)
parser.add_argument("--fill-alpha", type=float, default=0.2)
parser.add_argument("--text-box-color", type=int, nargs="+", default=(255, 255, 255))
parser.add_argument("--text-font-color", type=int, nargs="+", default=None)
parser.add_argument("--text-alpha", type=float, default=1.0)
# engine parameters
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
return args
def visualize_datasets():
args = parse_args()
# setup logger
for logger_name in ["py.warnings", "accelerate", os.path.basename(os.getcwd())]:
setup_logger(name=logger_name)
# remove the ConvertDtype and Normalize for visualization
if args.transform:
transform = getattr(presets, args.transform)
transform = remove_cvtdtype_normalize(transform)
else:
transform = None
# plot annotations for each image
if args.show_dir:
dataset = CocoDetection(img_folder=args.coco_img, ann_file=args.coco_ann, transforms=transform)
data_loader = data.DataLoader(
dataset,
1,
shuffle=False,
num_workers=args.workers,
worker_init_fn=seed_worker,
generator=fixed_generator(),
collate_fn=collate_fn,
)
visualize_coco_bounding_boxes(
data_loader=data_loader,
show_conf=args.show_conf,
show_dir=args.show_dir,
font_scale=args.font_scale,
box_thick=args.box_thick,
fill_alpha=args.fill_alpha,
text_box_color=args.text_box_color,
text_font_color=args.text_font_color,
text_alpha=args.text_alpha,
)
def remove_cvtdtype_normalize(transform):
if isinstance(transform, T.Compose):
transform = [remove_cvtdtype_normalize(trans) for trans in transform.transforms]
transform = [trans for trans in transform if trans is not None]
return T.Compose(transform)
if isinstance(transform, (T.ConvertDtype, T.Normalize)):
return None
return transform
if __name__ == "__main__":
visualize_datasets()
from .transforms import *
from .autoaugment import *
from .convert_coco_polys_to_mask import ConvertCocoPolysToMask
from .simple_copy_paste import SimpleCopyPaste
import numbers
from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from PIL import Image, ImageEnhance, ImageOps
try:
import accimage
except ImportError:
accimage = None
@torch.jit.unused
def _is_pil_image(img: Any) -> bool:
if accimage is not None:
return isinstance(img, (Image.Image, accimage.Image))
else:
return isinstance(img, Image.Image)
@torch.jit.unused
def get_dimensions(img: Any) -> List[int]:
if _is_pil_image(img):
if hasattr(img, "getbands"):
channels = len(img.getbands())
else:
channels = img.channels
width, height = img.size
return [channels, height, width]
raise TypeError(f"Unexpected type {type(img)}")
@torch.jit.unused
def get_image_size(img: Any) -> List[int]:
if _is_pil_image(img):
return list(img.size)
raise TypeError(f"Unexpected type {type(img)}")
@torch.jit.unused
def get_image_num_channels(img: Any) -> int:
if _is_pil_image(img):
if hasattr(img, "getbands"):
return len(img.getbands())
else:
return img.channels
raise TypeError(f"Unexpected type {type(img)}")
@torch.jit.unused
def hflip(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return img.transpose(Image.FLIP_LEFT_RIGHT)
@torch.jit.unused
def vflip(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return img.transpose(Image.FLIP_TOP_BOTTOM)
@torch.jit.unused
def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(brightness_factor)
return img
@torch.jit.unused
def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(contrast_factor)
return img
@torch.jit.unused
def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
enhancer = ImageEnhance.Color(img)
img = enhancer.enhance(saturation_factor)
return img
@torch.jit.unused
def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image:
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
input_mode = img.mode
if input_mode in {"L", "1", "I", "F"}:
return img
h, s, v = img.convert("HSV").split()
np_h = np.array(h, dtype=np.uint8)
# uint8 addition take cares of rotation across boundaries
with np.errstate(over="ignore"):
np_h += np.uint8(hue_factor * 255)
h = Image.fromarray(np_h, "L")
img = Image.merge("HSV", (h, s, v)).convert(input_mode)
return img
@torch.jit.unused
def adjust_gamma(
img: Image.Image,
gamma: float,
gain: float = 1.0,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
if gamma < 0:
raise ValueError("Gamma should be a non-negative real number")
input_mode = img.mode
img = img.convert("RGB")
gamma_map = [int((255 + 1 - 1e-3) * gain * pow(ele / 255.0, gamma)) for ele in range(256)] * 3
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
img = img.convert(input_mode)
return img
@torch.jit.unused
def pad(
img: Image.Image,
padding: Union[int, List[int], Tuple[int, ...]],
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")
if not isinstance(padding_mode, str):
raise TypeError("Got inappropriate padding_mode arg")
if isinstance(padding, list):
padding = tuple(padding)
if isinstance(padding, tuple) and len(padding) not in [1, 2, 4]:
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
if isinstance(padding, tuple) and len(padding) == 1:
# Compatibility with `functional_tensor.pad`
padding = padding[0]
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
if padding_mode == "constant":
opts = _parse_fill(fill, img, name="fill")
if img.mode == "P":
palette = img.getpalette()
image = ImageOps.expand(img, border=padding, **opts)
image.putpalette(palette)
return image
return ImageOps.expand(img, border=padding, **opts)
else:
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
if isinstance(padding, tuple) and len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
if isinstance(padding, tuple) and len(padding) == 4:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
p = [pad_left, pad_top, pad_right, pad_bottom]
cropping = -np.minimum(p, 0)
if cropping.any():
crop_left, crop_top, crop_right, crop_bottom = cropping
img = img.crop((crop_left, crop_top, img.width - crop_right, img.height - crop_bottom))
pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0)
if img.mode == "P":
palette = img.getpalette()
img = np.asarray(img)
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode)
img = Image.fromarray(img)
img.putpalette(palette)
return img
img = np.asarray(img)
# RGB image
if len(img.shape) == 3:
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode)
# Grayscale image
if len(img.shape) == 2:
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
return Image.fromarray(img)
@torch.jit.unused
def crop(
img: Image.Image,
top: int,
left: int,
height: int,
width: int,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return img.crop((left, top, left + width, top + height))
@torch.jit.unused
def resize(
img: Image.Image,
size: Union[List[int], int],
interpolation: int = Image.BILINEAR,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
if not (isinstance(size, list) and len(size) == 2):
raise TypeError(f"Got inappropriate size arg: {size}")
return img.resize(tuple(size[::-1]), interpolation)
@torch.jit.unused
def _parse_fill(
fill: Optional[Union[float, List[float], Tuple[float, ...]]],
img: Image.Image,
name: str = "fillcolor",
) -> Dict[str, Optional[Union[float, List[float], Tuple[float, ...]]]]:
# Process fill color for affine transforms
num_channels = get_image_num_channels(img)
if fill is None:
fill = 0
if isinstance(fill, (int, float)) and num_channels > 1:
fill = tuple([fill] * num_channels)
if isinstance(fill, (list, tuple)):
if len(fill) != num_channels:
msg = "The number of elements in 'fill' does not match the number of channels of the image ({} != {})"
raise ValueError(msg.format(len(fill), num_channels))
fill = tuple(fill)
if img.mode != "F":
if isinstance(fill, (list, tuple)):
fill = tuple(int(x) for x in fill)
else:
fill = int(fill)
return {name: fill}
@torch.jit.unused
def affine(
img: Image.Image,
matrix: List[float],
interpolation: int = Image.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
output_size = img.size
opts = _parse_fill(fill, img)
return img.transform(output_size, Image.AFFINE, matrix, interpolation, **opts)
@torch.jit.unused
def rotate(
img: Image.Image,
angle: float,
interpolation: int = Image.NEAREST,
expand: bool = False,
center: Optional[Tuple[int, int]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
opts = _parse_fill(fill, img)
return img.rotate(angle, interpolation, expand, center, **opts)
@torch.jit.unused
def perspective(
img: Image.Image,
perspective_coeffs: List[float],
interpolation: int = Image.BICUBIC,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
opts = _parse_fill(fill, img)
return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs, interpolation, **opts)
@torch.jit.unused
def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
if num_output_channels == 1:
img = img.convert("L")
elif num_output_channels == 3:
img = img.convert("L")
np_img = np.array(img, dtype=np.uint8)
np_img = np.dstack([np_img, np_img, np_img])
img = Image.fromarray(np_img, "RGB")
else:
raise ValueError("num_output_channels should be either 1 or 3")
return img
@torch.jit.unused
def invert(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.invert(img)
@torch.jit.unused
def posterize(img: Image.Image, bits: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.posterize(img, bits)
@torch.jit.unused
def solarize(img: Image.Image, threshold: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.solarize(img, threshold)
@torch.jit.unused
def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
enhancer = ImageEnhance.Sharpness(img)
img = enhancer.enhance(sharpness_factor)
return img
@torch.jit.unused
def autocontrast(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.autocontrast(img)
@torch.jit.unused
def equalize(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
return ImageOps.equalize(img)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment