Commit 19472568 authored by 雍大凯's avatar 雍大凯
Browse files

将子模块转换为普通目录

parent 51e55208
#pragma once
#include <torch/extension.h>
at::Tensor geometric_kernel_attn_cuda_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step);
std::vector<at::Tensor> geometric_kernel_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step);
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include <ATen/ATen.h>
#include <ATen/hip/HIPContext.h>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <THH/THHAtomics.cuh>
#include <vector>
#include "geometric_kernel_attn_hip_kernel.cuh"
at::Tensor geometric_kernel_attn_cuda_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step) {
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = ::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
const int batch_n = im2col_step_;
auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "multiscale_kernel_attn_forward_cuda", ([&] {
multiscale_kernel_attn_forward_cuda(at::hip::getCurrentHIPStreamMasqueradingAsCUDA(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<int64_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
columns.data<scalar_t>());
}));
}
output = output.view({batch, num_query, num_heads*channels});
return output;
}
std::vector<at::Tensor> geometric_kernel_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step) {
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = ::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto grad_value = at::zeros_like(value);
auto grad_attn_weight = at::zeros_like(attn_weight);
const int batch_n = im2col_step_;
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "multiscale_kernel_attn_backward_cuda", ([&] {
multiscale_kernel_attn_backward_cuda(at::hip::getCurrentHIPStreamMasqueradingAsCUDA(),
grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<int64_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
}));
}
return {
grad_value, grad_attn_weight
};
}
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
#include <algorithm>
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N, const int num_threads) {
return (N + num_threads - 1) / num_threads;
}
__device__ int clip(int n, int lower, int upper) {
n = n >= lower ? n : lower;
return n < upper ? n : upper;
}
template <typename scalar_t>
__device__ scalar_t multi_scale_kernel_attn_sampling(
const scalar_t *&bottom_data, const int &height, const int &width,
const int &nheads, const int &channels, const int &h,
const int &w, const int &m, const int &c) {
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int base_ptr = m * channels + c;
const int h_ptr_offset = h_stride * h;
const int w_ptr_offset = w_stride * w;
scalar_t val = bottom_data[base_ptr + h_ptr_offset + w_ptr_offset];
return val;
}
template <typename scalar_t>
__device__ void multiscale_kernel_attn_sampling_backward(
const scalar_t *&bottom_data, const int &height, const int &width,
const int &nheads, const int &channels, const int &h,
const int &w, const int &m, const int &c, const scalar_t &top_grad,
const scalar_t &attn_weight, scalar_t *&grad_value, scalar_t *grad_attn_weight) {
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_ptr_offset = h_stride * h;
const int w_ptr_offset = w_stride * w;
const int base_ptr = m * channels + c;
const scalar_t top_grad_value = top_grad * attn_weight;
// scalar_t grad_h_weight = 0, grad_w_weight = 0;
const int ptr = base_ptr + h_ptr_offset + w_ptr_offset;
scalar_t val = bottom_data[ptr];
atomicAdd(grad_value + ptr, top_grad_value);
*grad_attn_weight = top_grad * val;
}
template <typename scalar_t>
__global__ void multiscale_kernel_attn_forward_gpu_kernel(
const int n, const scalar_t *data_value, const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index, const int64_t *data_sampling_loc,
const scalar_t *data_attn_weight, const int batch_size,
const int spatial_size, const int num_heads, const int channels,
const int num_levels, const int num_query, const int num_point,
scalar_t *data_col) {
CUDA_KERNEL_LOOP(index, n) {
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
scalar_t *data_col_ptr = data_col + index;
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
scalar_t col = 0;
for (int l_col = 0; l_col < num_levels; ++l_col) {
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const scalar_t *data_value_ptr =
data_value +
(data_value_ptr_init_offset + level_start_id * qid_stride);
for (int p_col = 0; p_col < num_point; ++p_col) {
const int loc_w = data_sampling_loc[data_loc_w_ptr];
const int loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const int loc_h_ = clip(loc_h, 0, spatial_h-1);
const int loc_w_ = clip(loc_w, 0, spatial_w-1);
col += multi_scale_kernel_attn_sampling(data_value_ptr, spatial_h, spatial_w, num_heads,
channels, loc_h_, loc_w_, m_col, c_col) * weight;
data_weight_ptr += 1;
data_loc_w_ptr += 2;
}
}
*data_col_ptr = col;
}
}
template <typename scalar_t, unsigned int blockSize>
__global__ void multiscale_kernel_attn_backward_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const int64_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
__shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
// grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
// const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const int loc_w = data_sampling_loc[data_loc_w_ptr];
const int loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
*(cache_grad_attn_weight+threadIdx.x)=0;
const int loc_h_ = clip(loc_h, 0, spatial_h-1);
const int loc_w_ = clip(loc_w, 0, spatial_w-1);
multiscale_kernel_attn_sampling_backward(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, loc_h_, loc_w_, m_col, c_col,
top_grad, weight, grad_value_ptr, cache_grad_attn_weight+threadIdx.x);
__syncthreads();
for (unsigned int s=blockSize/2; s>0; s>>=1)
{
if (tid < s) {
// const unsigned int xid1 = tid << 1;
//const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
}
__syncthreads();
}
if (tid == 0)
{
*grad_attn_weight = cache_grad_attn_weight[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
}
}
}
}
template <typename scalar_t>
__global__ void multiscale_kernel_attn_backward_gpu_kernel_shm_reduce_v2(
const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const int64_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
extern __shared__ int _s[];
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
// grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
// const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const int loc_w = data_sampling_loc[data_loc_w_ptr];
const int loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
*(cache_grad_attn_weight+threadIdx.x)=0;
const int loc_h_ = clip(loc_h, 0, spatial_h-1);
const int loc_w_ = clip(loc_w, 0, spatial_w-1);
multiscale_kernel_attn_sampling_backward(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, loc_h_, loc_w_, m_col, c_col,
top_grad, weight, grad_value_ptr, cache_grad_attn_weight+threadIdx.x);
__syncthreads();
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
{
if (tid < s) {
// const unsigned int xid1 = tid << 1;
// const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
if (tid + (s << 1) < spre)
{
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
}
}
__syncthreads();
}
if (tid == 0)
{
*grad_attn_weight = cache_grad_attn_weight[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
}
}
}
}
template <typename scalar_t>
void multiscale_kernel_attn_forward_cuda(cudaStream_t stream,
const scalar_t* data_value,
const int64_t* data_spatial_shapes,
const int64_t* data_level_start_index,
const int64_t* data_sampling_loc,
const scalar_t* data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t* data_col)
{
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
const int num_threads = CUDA_NUM_THREADS;
multiscale_kernel_attn_forward_gpu_kernel<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in multiscale_kernel_attn_forward_cuda: %s\n", cudaGetErrorString(err));
}
}
template <typename scalar_t>
void multiscale_kernel_attn_backward_cuda(cudaStream_t stream,
const scalar_t* grad_col,
const scalar_t* data_value,
const int64_t * data_spatial_shapes,
const int64_t * data_level_start_index,
const int64_t * data_sampling_loc,
const scalar_t * data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t* grad_value,
scalar_t* grad_attn_weight)
{
const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
switch(channels) {
case 128:
multiscale_kernel_attn_backward_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_attn_weight);
break;
case 256:
multiscale_kernel_attn_backward_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_attn_weight);
break;
case 512:
multiscale_kernel_attn_backward_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_attn_weight);
break;
case 1024:
multiscale_kernel_attn_backward_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0, stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_attn_weight);
break;
default:
multiscale_kernel_attn_backward_gpu_kernel_shm_reduce_v2<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, num_threads*3*sizeof(scalar_t), stream>>>(
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_attn_weight);
}
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in multiscale_kernel_attn_backward_cuda: %s\n", cudaGetErrorString(err));
}
}
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "hip/hip_runtime.h"
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <ATen/ATen.h>
#include <ATen/hip/HIPContext.h>
#include <THH/THHAtomics.cuh>
#include <algorithm>
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N, const int num_threads) {
return (N + num_threads - 1) / num_threads;
}
__device__ int clip(int n, int lower, int upper) {
n = n >= lower ? n : lower;
return n < upper ? n : upper;
}
template <typename scalar_t>
__device__ scalar_t multi_scale_kernel_attn_sampling(
const scalar_t *&bottom_data, const int &height, const int &width,
const int &nheads, const int &channels, const int &h,
const int &w, const int &m, const int &c) {
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int base_ptr = m * channels + c;
const int h_ptr_offset = h_stride * h;
const int w_ptr_offset = w_stride * w;
scalar_t val = bottom_data[base_ptr + h_ptr_offset + w_ptr_offset];
return val;
}
template <typename scalar_t>
__device__ void multiscale_kernel_attn_sampling_backward(
const scalar_t *&bottom_data, const int &height, const int &width,
const int &nheads, const int &channels, const int &h,
const int &w, const int &m, const int &c, const scalar_t &top_grad,
const scalar_t &attn_weight, scalar_t *&grad_value, scalar_t *grad_attn_weight) {
const int w_stride = nheads * channels;
const int h_stride = width * w_stride;
const int h_ptr_offset = h_stride * h;
const int w_ptr_offset = w_stride * w;
const int base_ptr = m * channels + c;
const scalar_t top_grad_value = top_grad * attn_weight;
// scalar_t grad_h_weight = 0, grad_w_weight = 0;
const int ptr = base_ptr + h_ptr_offset + w_ptr_offset;
scalar_t val = bottom_data[ptr];
atomicAdd(grad_value + ptr, top_grad_value);
*grad_attn_weight = top_grad * val;
}
template <typename scalar_t>
__global__ void multiscale_kernel_attn_forward_gpu_kernel(
const int n, const scalar_t *data_value, const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index, const int64_t *data_sampling_loc,
const scalar_t *data_attn_weight, const int batch_size,
const int spatial_size, const int num_heads, const int channels,
const int num_levels, const int num_query, const int num_point,
scalar_t *data_col) {
CUDA_KERNEL_LOOP(index, n) {
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
scalar_t *data_col_ptr = data_col + index;
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
scalar_t col = 0;
for (int l_col = 0; l_col < num_levels; ++l_col) {
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const scalar_t *data_value_ptr =
data_value +
(data_value_ptr_init_offset + level_start_id * qid_stride);
for (int p_col = 0; p_col < num_point; ++p_col) {
const int loc_w = data_sampling_loc[data_loc_w_ptr];
const int loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
const int loc_h_ = clip(loc_h, 0, spatial_h-1);
const int loc_w_ = clip(loc_w, 0, spatial_w-1);
col += multi_scale_kernel_attn_sampling(data_value_ptr, spatial_h, spatial_w, num_heads,
channels, loc_h_, loc_w_, m_col, c_col) * weight;
data_weight_ptr += 1;
data_loc_w_ptr += 2;
}
}
*data_col_ptr = col;
}
}
template <typename scalar_t, unsigned int blockSize>
__global__ void multiscale_kernel_attn_backward_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const int64_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
__shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
// grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
// const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const int loc_w = data_sampling_loc[data_loc_w_ptr];
const int loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
*(cache_grad_attn_weight+threadIdx.x)=0;
const int loc_h_ = clip(loc_h, 0, spatial_h-1);
const int loc_w_ = clip(loc_w, 0, spatial_w-1);
multiscale_kernel_attn_sampling_backward(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, loc_h_, loc_w_, m_col, c_col,
top_grad, weight, grad_value_ptr, cache_grad_attn_weight+threadIdx.x);
__syncthreads();
for (unsigned int s=blockSize/2; s>0; s>>=1)
{
if (tid < s) {
// const unsigned int xid1 = tid << 1;
//const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
}
__syncthreads();
}
if (tid == 0)
{
*grad_attn_weight = cache_grad_attn_weight[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
}
}
}
}
template <typename scalar_t>
__global__ void multiscale_kernel_attn_backward_gpu_kernel_shm_reduce_v2(
const int n,
const scalar_t *grad_col,
const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const int64_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t *grad_value,
scalar_t *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, n)
{
extern __shared__ int _s[];
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
const int m_col = _temp % num_heads;
_temp /= num_heads;
const int q_col = _temp % num_query;
_temp /= num_query;
const int b_col = _temp;
const scalar_t top_grad = grad_col[index];
int data_weight_ptr = sampling_index * num_levels * num_point;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
// grad_sampling_loc += grad_sampling_ptr << 1;
grad_attn_weight += grad_sampling_ptr;
const int grad_weight_stride = 1;
// const int grad_loc_stride = 2;
const int qid_stride = num_heads * channels;
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
for (int l_col=0; l_col < num_levels; ++l_col)
{
const int level_start_id = data_level_start_index[l_col];
const int spatial_h_ptr = l_col << 1;
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
for (int p_col=0; p_col < num_point; ++p_col)
{
const int loc_w = data_sampling_loc[data_loc_w_ptr];
const int loc_h = data_sampling_loc[data_loc_w_ptr + 1];
const scalar_t weight = data_attn_weight[data_weight_ptr];
*(cache_grad_attn_weight+threadIdx.x)=0;
const int loc_h_ = clip(loc_h, 0, spatial_h-1);
const int loc_w_ = clip(loc_w, 0, spatial_w-1);
multiscale_kernel_attn_sampling_backward(
data_value_ptr, spatial_h, spatial_w, num_heads, channels, loc_h_, loc_w_, m_col, c_col,
top_grad, weight, grad_value_ptr, cache_grad_attn_weight+threadIdx.x);
__syncthreads();
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
{
if (tid < s) {
// const unsigned int xid1 = tid << 1;
// const unsigned int xid2 = (tid + s) << 1;
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
if (tid + (s << 1) < spre)
{
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
}
}
__syncthreads();
}
if (tid == 0)
{
*grad_attn_weight = cache_grad_attn_weight[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_attn_weight += grad_weight_stride;
}
}
}
}
template <typename scalar_t>
void multiscale_kernel_attn_forward_cuda(hipStream_t stream,
const scalar_t* data_value,
const int64_t* data_spatial_shapes,
const int64_t* data_level_start_index,
const int64_t* data_sampling_loc,
const scalar_t* data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t* data_col)
{
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
const int num_threads = CUDA_NUM_THREADS;
hipLaunchKernelGGL(( multiscale_kernel_attn_forward_gpu_kernel<scalar_t>)
, dim3(GET_BLOCKS(num_actual_kernels, num_threads)), dim3(num_threads),
0, stream,
num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
hipError_t err = hipGetLastError();
if (err != hipSuccess)
{
printf("error in multiscale_kernel_attn_forward_cuda: %s\n", hipGetErrorString(err));
}
}
template <typename scalar_t>
void multiscale_kernel_attn_backward_cuda(hipStream_t stream,
const scalar_t* grad_col,
const scalar_t* data_value,
const int64_t * data_spatial_shapes,
const int64_t * data_level_start_index,
const int64_t * data_sampling_loc,
const scalar_t * data_attn_weight,
const int batch_size,
const int spatial_size,
const int num_heads,
const int channels,
const int num_levels,
const int num_query,
const int num_point,
scalar_t* grad_value,
scalar_t* grad_attn_weight)
{
const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
switch(channels) {
case 128:
hipLaunchKernelGGL(( multiscale_kernel_attn_backward_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>)
, dim3(GET_BLOCKS(num_actual_kernels, num_threads)), dim3(num_threads),
0, stream,
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_attn_weight);
break;
case 256:
hipLaunchKernelGGL(( multiscale_kernel_attn_backward_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>)
, dim3(GET_BLOCKS(num_actual_kernels, num_threads)), dim3(num_threads),0, stream,
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_attn_weight);
break;
case 512:
hipLaunchKernelGGL(( multiscale_kernel_attn_backward_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>)
, dim3(GET_BLOCKS(num_actual_kernels, num_threads)), dim3(num_threads), 0, stream,
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_attn_weight);
break;
case 1024:
hipLaunchKernelGGL(( multiscale_kernel_attn_backward_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>)
, dim3(GET_BLOCKS(num_actual_kernels, num_threads)), dim3(num_threads), 0, stream,
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_attn_weight);
break;
default:
hipLaunchKernelGGL(( multiscale_kernel_attn_backward_gpu_kernel_shm_reduce_v2<scalar_t>)
, dim3(GET_BLOCKS(num_actual_kernels, num_threads)), dim3(num_threads), num_threads*3*sizeof(scalar_t), stream,
num_kernels,
grad_col,
data_value,
data_spatial_shapes,
data_level_start_index,
data_sampling_loc,
data_attn_weight,
batch_size,
spatial_size,
num_heads,
channels,
num_levels,
num_query,
num_point,
grad_value,
grad_attn_weight);
}
hipError_t err = hipGetLastError();
if (err != hipSuccess)
{
printf("error in multiscale_kernel_attn_backward_cuda: %s\n", hipGetErrorString(err));
}
}
#include "geometric_kernel_attn.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("geometric_kernel_attn_cuda_forward", &geometric_kernel_attn_cuda_forward, "geometric_kernel_attn_cuda_forward");
m.def("geometric_kernel_attn_cuda_backward", &geometric_kernel_attn_cuda_backward, "geometric_kernel_attn_cuda_backward");
}
import copy
import torch
import torch.nn as nn
import numpy as np
from torch.nn.init import normal_
import torch.nn.functional as F
from mmdet.models.utils.builder import TRANSFORMER
from mmcv.cnn import Linear, bias_init_with_prob, xavier_init, constant_init
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.cnn.bricks.transformer import build_transformer_layer_sequence
from torchvision.transforms.functional import rotate
from projects.mmdet3d_plugin.bevformer.modules.temporal_self_attention import TemporalSelfAttention
from projects.mmdet3d_plugin.bevformer.modules.spatial_cross_attention import MSDeformableAttention3D
from projects.mmdet3d_plugin.bevformer.modules.decoder import CustomMSDeformableAttention
from .builder import build_fuser, FUSERS
from typing import List
@FUSERS.register_module()
class ConvFuser(nn.Sequential):
def __init__(self, in_channels: int, out_channels: int) -> None:
self.in_channels = in_channels
self.out_channels = out_channels
super().__init__(
nn.Conv2d(sum(in_channels), out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
)
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
return super().forward(torch.cat(inputs, dim=1))
@TRANSFORMER.register_module()
class MapTRPerceptionTransformer(BaseModule):
"""Implements the Detr3D transformer.
Args:
as_two_stage (bool): Generate query from encoder features.
Default: False.
num_feature_levels (int): Number of feature maps from FPN:
Default: 4.
two_stage_num_proposals (int): Number of proposals when set
`as_two_stage` as True. Default: 300.
"""
def __init__(self,
num_feature_levels=4,
num_cams=6,
z_cfg=dict(
pred_z_flag=False,
gt_z_flag=False,
),
two_stage_num_proposals=300,
fuser=None,
encoder=None,
decoder=None,
embed_dims=256,
rotate_prev_bev=True,
use_shift=True,
use_can_bus=True,
can_bus_norm=True,
use_cams_embeds=True,
rotate_center=[100, 100],
modality='vision',
feat_down_sample_indice=-1,
**kwargs):
super(MapTRPerceptionTransformer, self).__init__(**kwargs)
if modality == 'fusion':
self.fuser = build_fuser(fuser) #TODO
# self.use_attn_bev = encoder['type'] == 'BEVFormerEncoder'
self.use_attn_bev = 'BEVFormerEncoder' in encoder['type']
self.encoder = build_transformer_layer_sequence(encoder)
self.decoder = build_transformer_layer_sequence(decoder)
self.embed_dims = embed_dims
self.num_feature_levels = num_feature_levels
self.num_cams = num_cams
self.fp16_enabled = False
self.rotate_prev_bev = rotate_prev_bev
self.use_shift = use_shift
self.use_can_bus = use_can_bus
self.can_bus_norm = can_bus_norm
self.use_cams_embeds = use_cams_embeds
self.two_stage_num_proposals = two_stage_num_proposals
self.z_cfg=z_cfg
self.init_layers()
self.rotate_center = rotate_center
self.feat_down_sample_indice = feat_down_sample_indice
def init_layers(self):
"""Initialize layers of the Detr3DTransformer."""
self.level_embeds = nn.Parameter(torch.Tensor(
self.num_feature_levels, self.embed_dims))
self.cams_embeds = nn.Parameter(
torch.Tensor(self.num_cams, self.embed_dims))
self.reference_points = nn.Linear(self.embed_dims, 2) if not self.z_cfg['gt_z_flag'] \
else nn.Linear(self.embed_dims, 3)
self.can_bus_mlp = nn.Sequential(
nn.Linear(18, self.embed_dims // 2),
nn.ReLU(inplace=True),
nn.Linear(self.embed_dims // 2, self.embed_dims),
nn.ReLU(inplace=True),
)
if self.can_bus_norm:
self.can_bus_mlp.add_module('norm', nn.LayerNorm(self.embed_dims))
def init_weights(self):
"""Initialize the transformer weights."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MSDeformableAttention3D) or isinstance(m, TemporalSelfAttention) \
or isinstance(m, CustomMSDeformableAttention):
try:
m.init_weight()
except AttributeError:
m.init_weights()
normal_(self.level_embeds)
normal_(self.cams_embeds)
xavier_init(self.reference_points, distribution='uniform', bias=0.)
xavier_init(self.can_bus_mlp, distribution='uniform', bias=0.)
# TODO apply fp16 to this module cause grad_norm NAN
# @auto_fp16(apply_to=('mlvl_feats', 'bev_queries', 'prev_bev', 'bev_pos'), out_fp32=True)
def attn_bev_encode(
self,
mlvl_feats,
bev_queries,
bev_h,
bev_w,
grid_length=[0.512, 0.512],
bev_pos=None,
prev_bev=None,
**kwargs):
bs = mlvl_feats[0].size(0)
bev_queries = bev_queries.unsqueeze(1).repeat(1, bs, 1)
bev_pos = bev_pos.flatten(2).permute(2, 0, 1)
# obtain rotation angle and shift with ego motion
delta_x = np.array([each['can_bus'][0]
for each in kwargs['img_metas']])
delta_y = np.array([each['can_bus'][1]
for each in kwargs['img_metas']])
ego_angle = np.array(
[each['can_bus'][-2] / np.pi * 180 for each in kwargs['img_metas']])
grid_length_y = grid_length[0]
grid_length_x = grid_length[1]
translation_length = np.sqrt(delta_x ** 2 + delta_y ** 2)
translation_angle = np.arctan2(delta_y, delta_x) / np.pi * 180
bev_angle = ego_angle - translation_angle
shift_y = translation_length * \
np.cos(bev_angle / 180 * np.pi) / grid_length_y / bev_h
shift_x = translation_length * \
np.sin(bev_angle / 180 * np.pi) / grid_length_x / bev_w
shift_y = shift_y * self.use_shift
shift_x = shift_x * self.use_shift
shift = bev_queries.new_tensor(
[shift_x, shift_y]).permute(1, 0) # xy, bs -> bs, xy
if prev_bev is not None:
if prev_bev.shape[1] == bev_h * bev_w:
prev_bev = prev_bev.permute(1, 0, 2)
if self.rotate_prev_bev:
for i in range(bs):
# num_prev_bev = prev_bev.size(1)
rotation_angle = kwargs['img_metas'][i]['can_bus'][-1]
tmp_prev_bev = prev_bev[:, i].reshape(
bev_h, bev_w, -1).permute(2, 0, 1)
tmp_prev_bev = rotate(tmp_prev_bev, rotation_angle,
center=self.rotate_center)
tmp_prev_bev = tmp_prev_bev.permute(1, 2, 0).reshape(
bev_h * bev_w, 1, -1)
prev_bev[:, i] = tmp_prev_bev[:, 0]
# add can bus signals
can_bus = bev_queries.new_tensor(
[each['can_bus'] for each in kwargs['img_metas']]) # [:, :]
can_bus = self.can_bus_mlp(can_bus)[None, :, :]
bev_queries = bev_queries + can_bus * self.use_can_bus
feat_flatten = []
spatial_shapes = []
for lvl, feat in enumerate(mlvl_feats):
bs, num_cam, c, h, w = feat.shape
spatial_shape = (h, w)
feat = feat.flatten(3).permute(1, 0, 3, 2)
if self.use_cams_embeds:
feat = feat + self.cams_embeds[:, None, None, :].to(feat.dtype)
feat = feat + self.level_embeds[None,
None, lvl:lvl + 1, :].to(feat.dtype)
spatial_shapes.append(spatial_shape)
feat_flatten.append(feat)
feat_flatten = torch.cat(feat_flatten, 2)
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=bev_pos.device)
level_start_index = torch.cat((spatial_shapes.new_zeros(
(1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
feat_flatten = feat_flatten.permute(
0, 2, 1, 3) # (num_cam, H*W, bs, embed_dims)
ret_dict = self.encoder(
bev_queries,
feat_flatten,
feat_flatten,
mlvl_feats=mlvl_feats,
bev_h=bev_h,
bev_w=bev_w,
bev_pos=bev_pos,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
prev_bev=prev_bev,
shift=shift,
**kwargs
)
return ret_dict
def lss_bev_encode(
self,
mlvl_feats,
prev_bev=None,
**kwargs):
# import ipdb;ipdb.set_trace()
# assert len(mlvl_feats) == 1, 'Currently we only use last single level feat in LSS'
# import ipdb;ipdb.set_trace()
images = mlvl_feats[self.feat_down_sample_indice]
img_metas = kwargs['img_metas']
encoder_outputdict = self.encoder(images,img_metas)
bev_embed = encoder_outputdict['bev']
depth = encoder_outputdict['depth']
bs, c, _,_ = bev_embed.shape
bev_embed = bev_embed.view(bs,c,-1).permute(0,2,1).contiguous()
ret_dict = dict(
bev=bev_embed,
depth=depth
)
return ret_dict
def get_bev_features(
self,
mlvl_feats,
lidar_feat,
bev_queries,
bev_h,
bev_w,
grid_length=[0.512, 0.512],
bev_pos=None,
prev_bev=None,
**kwargs):
"""
obtain bev features.
"""
if self.use_attn_bev:
ret_dict = self.attn_bev_encode(
mlvl_feats,
bev_queries,
bev_h,
bev_w,
grid_length=grid_length,
bev_pos=bev_pos,
prev_bev=prev_bev,
**kwargs)
bev_embed = ret_dict['bev']
depth = ret_dict['depth']
else:
ret_dict = self.lss_bev_encode(
mlvl_feats,
prev_bev=prev_bev,
**kwargs)
bev_embed = ret_dict['bev']
depth = ret_dict['depth']
if lidar_feat is not None:
bs = mlvl_feats[0].size(0)
bev_embed = bev_embed.view(bs, bev_h, bev_w, -1).permute(0,3,1,2).contiguous()
lidar_feat = lidar_feat.permute(0,1,3,2).contiguous() # B C H W
lidar_feat = nn.functional.interpolate(lidar_feat, size=(bev_h,bev_w), mode='bicubic', align_corners=False)
fused_bev = self.fuser([bev_embed, lidar_feat])
fused_bev = fused_bev.flatten(2).permute(0,2,1).contiguous()
bev_embed = fused_bev
ret_dict = dict(
bev=bev_embed,
depth=depth
)
return ret_dict
#@torch.compile(mode="max-autotune-no-cudagraphs")
def format_feats(self, mlvl_feats):
bs = mlvl_feats[0].size(0)
feat_flatten = []
spatial_shapes = []
for lvl, feat in enumerate(mlvl_feats):
bs, num_cam, c, h, w = feat.shape
spatial_shape = (h, w)
feat = feat.flatten(3).permute(1, 0, 3, 2)
if self.use_cams_embeds:
feat = feat + self.cams_embeds[:, None, None, :].to(feat.dtype)
feat = feat + self.level_embeds[None,
None, lvl:lvl + 1, :].to(feat.dtype)
spatial_shapes.append(spatial_shape)
feat_flatten.append(feat)
feat_flatten = torch.cat(feat_flatten, 2)
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=feat.device)
level_start_index = torch.cat((spatial_shapes.new_zeros(
(1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
feat_flatten = feat_flatten.permute(
0, 2, 1, 3) # (num_cam, H*W, bs, embed_dims)
return feat_flatten, spatial_shapes, level_start_index
# TODO apply fp16 to this module cause grad_norm NAN
# @auto_fp16(apply_to=('mlvl_feats', 'bev_queries', 'object_query_embed', 'prev_bev', 'bev_pos'))
#@torch.compile(mode="max-autotune-no-cudagraphs")
def initialize_queries_and_bev(self, object_query_embed, bev_embed, bs, bev_h, bev_w):
query_pos, query = torch.split(
object_query_embed, self.embed_dims, dim=1)
query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
query = query.unsqueeze(0).expand(bs, -1, -1)
reference_points = self.reference_points(query_pos)
reference_points = reference_points.sigmoid()
init_reference_out = reference_points
query = query.permute(1, 0, 2)
query_pos = query_pos.permute(1, 0, 2)
bev_embed = bev_embed.permute(1, 0, 2)
spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device)
level_start_index=torch.tensor([0], device=query.device)
return query, bev_embed, query_pos, reference_points, spatial_shapes, level_start_index, init_reference_out
def forward(self,
mlvl_feats,
lidar_feat,
bev_queries,
object_query_embed,
bev_h,
bev_w,
grid_length=[0.512, 0.512],
bev_pos=None,
reg_branches=None,
cls_branches=None,
prev_bev=None,
**kwargs):
"""Forward function for `Detr3DTransformer`.
Args:
mlvl_feats (list(Tensor)): Input queries from
different level. Each element has shape
[bs, num_cams, embed_dims, h, w].
bev_queries (Tensor): (bev_h*bev_w, c)
bev_pos (Tensor): (bs, embed_dims, bev_h, bev_w)
object_query_embed (Tensor): The query embedding for decoder,
with shape [num_query, c].
reg_branches (obj:`nn.ModuleList`): Regression heads for
feature maps from each decoder layer. Only would
be passed when `with_box_refine` is True. Default to None.
Returns:
tuple[Tensor]: results of decoder containing the following tensor.
- bev_embed: BEV features
- inter_states: Outputs from decoder. If
return_intermediate_dec is True output has shape \
(num_dec_layers, bs, num_query, embed_dims), else has \
shape (1, bs, num_query, embed_dims).
- init_reference_out: The initial value of reference \
points, has shape (bs, num_queries, 4).
- inter_references_out: The internal value of reference \
points in decoder, has shape \
(num_dec_layers, bs,num_query, embed_dims)
- enc_outputs_class: The classification score of \
proposals generated from \
encoder's feature maps, has shape \
(batch, h*w, num_classes). \
Only would be returned when `as_two_stage` is True, \
otherwise None.
- enc_outputs_coord_unact: The regression results \
generated from encoder's feature maps., has shape \
(batch, h*w, 4). Only would \
be returned when `as_two_stage` is True, \
otherwise None.
"""
ouput_dic = self.get_bev_features(
mlvl_feats,
lidar_feat,
bev_queries,
bev_h,
bev_w,
grid_length=grid_length,
bev_pos=bev_pos,
prev_bev=prev_bev,
**kwargs) # bev_embed shape: bs, bev_h*bev_w, embed_dims
bev_embed = ouput_dic['bev']
depth = ouput_dic['depth']
bs = mlvl_feats[0].size(0)
# query_pos, query = torch.split(
# object_query_embed, self.embed_dims, dim=1)
# query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
# query = query.unsqueeze(0).expand(bs, -1, -1)
# reference_points = self.reference_points(query_pos)
# reference_points = reference_points.sigmoid()
# init_reference_out = reference_points
# query = query.permute(1, 0, 2)
# query_pos = query_pos.permute(1, 0, 2)
# bev_embed = bev_embed.permute(1, 0, 2)
query, bev_embed, query_pos, reference_points, spatial_shapes, level_start_index, init_reference_out = self.initialize_queries_and_bev(object_query_embed, bev_embed, bs, bev_h, bev_w)
feat_flatten, feat_spatial_shapes, feat_level_start_index \
= self.format_feats(mlvl_feats)
inter_states, inter_references = self.decoder(
query=query,
key=None,
value=bev_embed,
query_pos=query_pos,
reference_points=reference_points,
reg_branches=reg_branches,
cls_branches=cls_branches,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
mlvl_feats=mlvl_feats,
feat_flatten=feat_flatten,
feat_spatial_shapes=feat_spatial_shapes,
feat_level_start_index=feat_level_start_index,
**kwargs)
inter_references_out = inter_references
return bev_embed, depth, inter_states, init_reference_out, inter_references_out
from .vovnet import VoVNet
from .efficientnet import EfficientNet
from .swin import SwinTransformer
__all__ = ['VoVNet']
\ No newline at end of file
import copy
import math
from functools import partial
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn.bricks import ConvModule, DropPath
from mmcv.runner import BaseModule, Sequential
from mmdet.models.builder import BACKBONES
from ..utils import InvertedResidual, SELayer, make_divisible
class EdgeResidual(BaseModule):
"""Edge Residual Block.
Args:
in_channels (int): The input channels of this module.
out_channels (int): The output channels of this module.
mid_channels (int): The input channels of the second convolution.
kernel_size (int): The kernel size of the first convolution.
Defaults to 3.
stride (int): The stride of the first convolution. Defaults to 1.
se_cfg (dict, optional): Config dict for se layer. Defaults to None,
which means no se layer.
with_residual (bool): Use residual connection. Defaults to True.
conv_cfg (dict, optional): Config dict for convolution layer.
Defaults to None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='BN')``.
act_cfg (dict): Config dict for activation layer.
Defaults to ``dict(type='ReLU')``.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
init_cfg (dict | list[dict], optional): Initialization config dict.
"""
def __init__(self,
in_channels,
out_channels,
mid_channels,
kernel_size=3,
stride=1,
se_cfg=None,
with_residual=True,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
drop_path_rate=0.,
with_cp=False,
init_cfg=None,
**kwargs):
super(EdgeResidual, self).__init__(init_cfg=init_cfg)
assert stride in [1, 2]
self.with_cp = with_cp
self.drop_path = DropPath(
drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.with_se = se_cfg is not None
self.with_residual = (
stride == 1 and in_channels == out_channels and with_residual)
if self.with_se:
assert isinstance(se_cfg, dict)
self.conv1 = ConvModule(
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=kernel_size,
stride=1,
padding=kernel_size // 2,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
if self.with_se:
self.se = SELayer(**se_cfg)
self.conv2 = ConvModule(
in_channels=mid_channels,
out_channels=out_channels,
kernel_size=1,
stride=stride,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
def forward(self, x):
def _inner_forward(x):
out = x
out = self.conv1(out)
if self.with_se:
out = self.se(out)
out = self.conv2(out)
if self.with_residual:
return x + self.drop_path(out)
else:
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out
def model_scaling(layer_setting, arch_setting):
"""Scaling operation to the layer's parameters according to the
arch_setting."""
# scale width
new_layer_setting = copy.deepcopy(layer_setting)
for layer_cfg in new_layer_setting:
for block_cfg in layer_cfg:
block_cfg[1] = make_divisible(block_cfg[1] * arch_setting[0], 8)
# scale depth
split_layer_setting = [new_layer_setting[0]]
for layer_cfg in new_layer_setting[1:-1]:
tmp_index = [0]
for i in range(len(layer_cfg) - 1):
if layer_cfg[i + 1][1] != layer_cfg[i][1]:
tmp_index.append(i + 1)
tmp_index.append(len(layer_cfg))
for i in range(len(tmp_index) - 1):
split_layer_setting.append(layer_cfg[tmp_index[i]:tmp_index[i +
1]])
split_layer_setting.append(new_layer_setting[-1])
num_of_layers = [len(layer_cfg) for layer_cfg in split_layer_setting[1:-1]]
new_layers = [
int(math.ceil(arch_setting[1] * num)) for num in num_of_layers
]
merge_layer_setting = [split_layer_setting[0]]
for i, layer_cfg in enumerate(split_layer_setting[1:-1]):
if new_layers[i] <= num_of_layers[i]:
tmp_layer_cfg = layer_cfg[:new_layers[i]]
else:
tmp_layer_cfg = copy.deepcopy(layer_cfg) + [layer_cfg[-1]] * (
new_layers[i] - num_of_layers[i])
if tmp_layer_cfg[0][3] == 1 and i != 0:
merge_layer_setting[-1] += tmp_layer_cfg.copy()
else:
merge_layer_setting.append(tmp_layer_cfg.copy())
merge_layer_setting.append(split_layer_setting[-1])
return merge_layer_setting
@BACKBONES.register_module(force=True)
class EfficientNet(BaseModule):
"""EfficientNet backbone.
Args:
arch (str): Architecture of efficientnet. Defaults to b0.
out_indices (Sequence[int]): Output from which stages.
Defaults to (6, ).
frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to 0, which means not freezing any parameters.
conv_cfg (dict): Config dict for convolution layer.
Defaults to None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Defaults to dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Defaults to dict(type='Swish').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
"""
# Parameters to build layers.
# 'b' represents the architecture of normal EfficientNet family includes
# 'b0', 'b1', 'b2', 'b3', 'b4', 'b5', 'b6', 'b7', 'b8'.
# 'e' represents the architecture of EfficientNet-EdgeTPU including 'es',
# 'em', 'el'.
# 6 parameters are needed to construct a layer, From left to right:
# - kernel_size: The kernel size of the block
# - out_channel: The number of out_channels of the block
# - se_ratio: The sequeeze ratio of SELayer.
# - stride: The stride of the block
# - expand_ratio: The expand_ratio of the mid_channels
# - block_type: -1: Not a block, 0: InvertedResidual, 1: EdgeResidual
layer_settings = {
'b': [[[3, 32, 0, 2, 0, -1]],
[[3, 16, 4, 1, 1, 0]],
[[3, 24, 4, 2, 6, 0],
[3, 24, 4, 1, 6, 0]],
[[5, 40, 4, 2, 6, 0],
[5, 40, 4, 1, 6, 0]],
[[3, 80, 4, 2, 6, 0],
[3, 80, 4, 1, 6, 0],
[3, 80, 4, 1, 6, 0],
[5, 112, 4, 1, 6, 0],
[5, 112, 4, 1, 6, 0],
[5, 112, 4, 1, 6, 0]],
[[5, 192, 4, 2, 6, 0],
[5, 192, 4, 1, 6, 0],
[5, 192, 4, 1, 6, 0],
[5, 192, 4, 1, 6, 0],
[3, 320, 4, 1, 6, 0]],
[[1, 1280, 0, 1, 0, -1]]
],
'e': [[[3, 32, 0, 2, 0, -1]],
[[3, 24, 0, 1, 3, 1]],
[[3, 32, 0, 2, 8, 1],
[3, 32, 0, 1, 8, 1]],
[[3, 48, 0, 2, 8, 1],
[3, 48, 0, 1, 8, 1],
[3, 48, 0, 1, 8, 1],
[3, 48, 0, 1, 8, 1]],
[[5, 96, 0, 2, 8, 0],
[5, 96, 0, 1, 8, 0],
[5, 96, 0, 1, 8, 0],
[5, 96, 0, 1, 8, 0],
[5, 96, 0, 1, 8, 0],
[5, 144, 0, 1, 8, 0],
[5, 144, 0, 1, 8, 0],
[5, 144, 0, 1, 8, 0],
[5, 144, 0, 1, 8, 0]],
[[5, 192, 0, 2, 8, 0],
[5, 192, 0, 1, 8, 0]],
[[1, 1280, 0, 1, 0, -1]]
]
} # yapf: disable
# Parameters to build different kinds of architecture.
# From left to right: scaling factor for width, scaling factor for depth,
# resolution.
arch_settings = {
'b0': (1.0, 1.0, 224),
'b1': (1.0, 1.1, 240),
'b2': (1.1, 1.2, 260),
'b3': (1.2, 1.4, 300),
'b4': (1.4, 1.8, 380),
'b5': (1.6, 2.2, 456),
'b6': (1.8, 2.6, 528),
'b7': (2.0, 3.1, 600),
'b8': (2.2, 3.6, 672),
'es': (1.0, 1.0, 224),
'em': (1.0, 1.1, 240),
'el': (1.2, 1.4, 300)
}
def __init__(self,
arch='b0',
drop_path_rate=0.,
out_indices=(6, ),
frozen_stages=0,
conv_cfg=dict(type='Conv2dAdaptivePadding'),
norm_cfg=dict(type='BN', eps=1e-3),
act_cfg=dict(type='Swish'),
norm_eval=False,
with_cp=False,
init_cfg=[
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
layer=['_BatchNorm', 'GroupNorm'],
val=1)
]):
super(EfficientNet, self).__init__(init_cfg)
assert arch in self.arch_settings, \
f'"{arch}" is not one of the arch_settings ' \
f'({", ".join(self.arch_settings.keys())})'
self.arch_setting = self.arch_settings[arch]
self.layer_setting = self.layer_settings[arch[:1]]
for index in out_indices:
if index not in range(0, len(self.layer_setting)):
raise ValueError('the item in out_indices must in '
f'range(0, {len(self.layer_setting)}). '
f'But received {index}')
if frozen_stages not in range(len(self.layer_setting) + 1):
raise ValueError('frozen_stages must be in range(0, '
f'{len(self.layer_setting) + 1}). '
f'But received {frozen_stages}')
self.drop_path_rate = drop_path_rate
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.layer_setting = model_scaling(self.layer_setting,
self.arch_setting)
block_cfg_0 = self.layer_setting[0][0]
block_cfg_last = self.layer_setting[-1][0]
self.in_channels = make_divisible(block_cfg_0[1], 8)
self.out_channels = block_cfg_last[1]
self.layers = nn.ModuleList()
self.layers.append(
ConvModule(
in_channels=3,
out_channels=self.in_channels,
kernel_size=block_cfg_0[0],
stride=block_cfg_0[3],
padding=block_cfg_0[0] // 2,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
self.make_layer()
# Avoid building unused layers in mmdetection.
if len(self.layers) < max(self.out_indices) + 1:
self.layers.append(
ConvModule(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=block_cfg_last[0],
stride=block_cfg_last[3],
padding=block_cfg_last[0] // 2,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
def make_layer(self):
# Without the first and the final conv block.
layer_setting = self.layer_setting[1:-1]
total_num_blocks = sum([len(x) for x in layer_setting])
block_idx = 0
dpr = [
x.item()
for x in torch.linspace(0, self.drop_path_rate, total_num_blocks)
] # stochastic depth decay rule
for i, layer_cfg in enumerate(layer_setting):
# Avoid building unused layers in mmdetection.
if i > max(self.out_indices) - 1:
break
layer = []
for i, block_cfg in enumerate(layer_cfg):
(kernel_size, out_channels, se_ratio, stride, expand_ratio,
block_type) = block_cfg
mid_channels = int(self.in_channels * expand_ratio)
out_channels = make_divisible(out_channels, 8)
if se_ratio <= 0:
se_cfg = None
else:
# In mmdetection, the `divisor` is deleted to align
# the logic of SELayer with mmcls.
se_cfg = dict(
channels=mid_channels,
ratio=expand_ratio * se_ratio,
act_cfg=(self.act_cfg, dict(type='Sigmoid')))
if block_type == 1: # edge tpu
if i > 0 and expand_ratio == 3:
with_residual = False
expand_ratio = 4
else:
with_residual = True
mid_channels = int(self.in_channels * expand_ratio)
if se_cfg is not None:
# In mmdetection, the `divisor` is deleted to align
# the logic of SELayer with mmcls.
se_cfg = dict(
channels=mid_channels,
ratio=se_ratio * expand_ratio,
act_cfg=(self.act_cfg, dict(type='Sigmoid')))
block = partial(EdgeResidual, with_residual=with_residual)
else:
block = InvertedResidual
layer.append(
block(
in_channels=self.in_channels,
out_channels=out_channels,
mid_channels=mid_channels,
kernel_size=kernel_size,
stride=stride,
se_cfg=se_cfg,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
drop_path_rate=dpr[block_idx],
with_cp=self.with_cp,
# In mmdetection, `with_expand_conv` is set to align
# the logic of InvertedResidual with mmcls.
with_expand_conv=(mid_channels != self.in_channels)))
self.in_channels = out_channels
block_idx += 1
self.layers.append(Sequential(*layer))
def forward(self, x):
outs = []
# import pdb;pdb.set_trace()
for i, layer in enumerate(self.layers):
x = layer(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
def _freeze_stages(self):
for i in range(self.frozen_stages):
m = self.layers[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
def train(self, mode=True):
super(EfficientNet, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer, trunc_normal_init
from mmcv.cnn.bricks.transformer import FFN, build_dropout
from mmcv.cnn.utils.weight_init import constant_init
from mmcv.runner import _load_checkpoint
from mmcv.runner.base_module import BaseModule, ModuleList
from torch.nn.modules.linear import Linear
from torch.nn.modules.normalization import LayerNorm
from torch.nn.modules.utils import _pair as to_2tuple
import torch.utils.checkpoint as checkpoint
from mmseg.ops import resize
from mmdet3d.utils import get_root_logger
from mmdet.models.builder import BACKBONES
from mmcv.cnn.bricks.registry import ATTENTION
from ..utils import PatchEmbed, swin_convert
class PatchMerging(BaseModule):
"""Merge patch feature map.
This layer use nn.Unfold to group feature map by kernel_size, and use norm
and linear layer to embed grouped feature map.
Args:
in_channels (int): The num of input channels.
out_channels (int): The num of output channels.
stride (int | tuple): the stride of the sliding length in the
unfold layer. Defaults: 2. (Default to be equal with kernel_size).
bias (bool, optional): Whether to add bias in linear layer or not.
Defaults: False.
norm_cfg (dict, optional): Config dict for normalization layer.
Defaults: dict(type='LN').
init_cfg (dict, optional): The extra config for initialization.
Defaults: None.
"""
def __init__(self,
in_channels,
out_channels,
stride=2,
bias=False,
norm_cfg=dict(type='LN'),
init_cfg=None):
super().__init__(init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
self.sampler = nn.Unfold(
kernel_size=stride, dilation=1, padding=0, stride=stride)
sample_dim = stride**2 * in_channels
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
else:
self.norm = None
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
def forward(self, x, hw_shape):
"""
x: x.shape -> [B, H*W, C]
hw_shape: (H, W)
"""
B, L, C = x.shape
H, W = hw_shape
assert L == H * W, 'input feature has wrong size'
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
# stride is fixed to be equal to kernel_size.
if (H % self.stride != 0) or (W % self.stride != 0):
x = F.pad(x, (0, W % self.stride, 0, H % self.stride))
# Use nn.Unfold to merge patch. About 25% faster than original method,
# but need to modify pretrained model for compatibility
x = self.sampler(x) # B, 4*C, H/2*W/2
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
x = self.norm(x) if self.norm else x
x = self.reduction(x)
down_hw_shape = (H + 1) // 2, (W + 1) // 2
return x, down_hw_shape
@ATTENTION.register_module()
class WindowMSA(BaseModule):
"""Window based multi-head self-attention (W-MSA) module with relative
position bias.
Args:
embed_dims (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Default: 0.0
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.0
init_cfg (dict | None, optional): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
window_size,
qkv_bias=True,
qk_scale=None,
attn_drop_rate=0.,
proj_drop_rate=0.,
init_cfg=None):
super().__init__()
self.embed_dims = embed_dims
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_embed_dims = embed_dims // num_heads
self.scale = qk_scale or head_embed_dims**-0.5
self.init_cfg = init_cfg
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# About 2x faster than original impl
Wh, Ww = self.window_size
rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
rel_position_index = rel_index_coords + rel_index_coords.T
rel_position_index = rel_position_index.flip(1).contiguous()
self.register_buffer('relative_position_index', rel_position_index)
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop_rate)
self.proj = nn.Linear(embed_dims, embed_dims)
self.proj_drop = nn.Dropout(proj_drop_rate)
self.softmax = nn.Softmax(dim=-1)
def init_weights(self):
trunc_normal_init(self.relative_position_bias_table, std=0.02)
def forward(self, x, mask=None):
"""
Args:
x (tensor): input features with shape of (num_windows*B, N, C)
mask (tensor | None, Optional): mask with shape of (num_windows,
Wh*Ww, Wh*Ww), value should be between (-inf, 0].
"""
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[
2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B // nW, nW, self.num_heads, N,
N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@staticmethod
def double_step_seq(step1, len1, step2, len2):
seq1 = torch.arange(0, step1 * len1, step1)
seq2 = torch.arange(0, step2 * len2, step2)
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
@ATTENTION.register_module()
class ShiftWindowMSA(BaseModule):
"""Shift Window Multihead Self-Attention Module.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window.
shift_size (int, optional): The shift step of each window towards
right-bottom. If zero, act as regular window-msa. Defaults to 0.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: True
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Defaults: None.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Defaults: 0.
proj_drop_rate (float, optional): Dropout ratio of output.
Defaults: 0.
dropout_layer (dict, optional): The dropout_layer used before output.
Defaults: dict(type='DropPath', drop_prob=0.).
init_cfg (dict, optional): The extra config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
window_size,
shift_size=0,
qkv_bias=True,
qk_scale=None,
attn_drop_rate=0,
proj_drop_rate=0,
dropout_layer=dict(type='DropPath', drop_prob=0.),
init_cfg=None):
super().__init__(init_cfg)
self.window_size = window_size
self.shift_size = shift_size
assert 0 <= self.shift_size < self.window_size
self.w_msa = WindowMSA(
embed_dims=embed_dims,
num_heads=num_heads,
window_size=to_2tuple(window_size),
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop_rate=attn_drop_rate,
proj_drop_rate=proj_drop_rate,
init_cfg=None)
self.drop = build_dropout(dropout_layer)
def forward(self, query, hw_shape):
B, L, C = query.shape
H, W = hw_shape
assert L == H * W, 'input feature has wrong size'
query = query.view(B, H, W, C)
# pad feature maps to multiples of window size
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
H_pad, W_pad = query.shape[1], query.shape[2]
# cyclic shift
if self.shift_size > 0:
shifted_query = torch.roll(
query,
shifts=(-self.shift_size, -self.shift_size),
dims=(1, 2))
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, H_pad, W_pad, 1),
device=query.device) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
# w_slices = (slice(0, -self.window_size),
# slice(-self.window_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# nW, window_size, window_size, 1
mask_windows = self.window_partition(img_mask)
mask_windows = mask_windows.view(
-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0,
float(-100.0)).masked_fill(
attn_mask == 0, float(0.0))
else:
shifted_query = query
attn_mask = None
# nW*B, window_size, window_size, C
query_windows = self.window_partition(shifted_query)
# nW*B, window_size*window_size, C
query_windows = query_windows.view(-1, self.window_size**2, C)
# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
attn_windows = self.w_msa(query_windows, mask=attn_mask)
# merge windows
attn_windows = attn_windows.view(-1, self.window_size,
self.window_size, C)
# B H' W' C
shifted_x = self.window_reverse(attn_windows, H_pad, W_pad)
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(
shifted_x,
shifts=(self.shift_size, self.shift_size),
dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
x = self.drop(x)
return x
def window_reverse(self, windows, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
window_size = self.window_size
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size,
window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
def window_partition(self, x):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
window_size = self.window_size
x = x.view(B, H // window_size, window_size, W // window_size,
window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
windows = windows.view(-1, window_size, window_size, C)
return windows
class SwinBlock(BaseModule):
""""
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
window size (int, optional): The local window scale. Default: 7.
shift (bool): whether to shift window or not. Default False.
qkv_bias (int, optional): enable bias for qkv if True. Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop_rate (float, optional): Dropout rate. Default: 0.
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.2.
act_cfg (dict, optional): The config dict of activation function.
Default: dict(type='GELU').
norm_cfg (dict, optional): The config dict of nomalization.
Default: dict(type='LN').
init_cfg (dict | list | None, optional): The init config.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
window_size=7,
shift=False,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_cfg=None):
super(SwinBlock, self).__init__()
self.init_cfg = init_cfg
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
self.attn = ShiftWindowMSA(
embed_dims=embed_dims,
num_heads=num_heads,
window_size=window_size,
shift_size=window_size // 2 if shift else 0,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop_rate=attn_drop_rate,
proj_drop_rate=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
init_cfg=None)
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=2,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg,
add_identity=True,
init_cfg=None)
self.hw_shape = None
def forward(self, x):
hw_shape = self.hw_shape
identity = x
x = self.norm1(x)
x = self.attn(x, hw_shape)
x = x + identity
identity = x
x = self.norm2(x)
x = self.ffn(x, identity=identity)
return x
class SwinBlockSequence(BaseModule):
"""Implements one stage in Swin Transformer.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
depth (int): The number of blocks in this stage.
window size (int): The local window scale. Default: 7.
qkv_bias (int): enable bias for qkv if True. Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop_rate (float, optional): Dropout rate. Default: 0.
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.2.
downsample (BaseModule | None, optional): The downsample operation
module. Default: None.
act_cfg (dict, optional): The config dict of activation function.
Default: dict(type='GELU').
norm_cfg (dict, optional): The config dict of nomalization.
Default: dict(type='LN').
init_cfg (dict | list | None, optional): The init config.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
depth,
window_size=7,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
downsample=None,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_cfg=None,
with_cp=True):
super().__init__()
self.init_cfg = init_cfg
drop_path_rate = drop_path_rate if isinstance(
drop_path_rate,
list) else [deepcopy(drop_path_rate) for _ in range(depth)]
self.blocks = ModuleList()
for i in range(depth):
block = SwinBlock(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=feedforward_channels,
window_size=window_size,
shift=False if i % 2 == 0 else True,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate[i],
act_cfg=act_cfg,
norm_cfg=norm_cfg,
init_cfg=None)
self.blocks.append(block)
self.downsample = downsample
self.with_cp = with_cp
def forward(self, x, hw_shape):
for block in self.blocks:
block.hw_shape=hw_shape
if self.with_cp:
x = checkpoint.checkpoint(block, x)
else:
x = block(x)
if self.downsample:
x_down, down_hw_shape = self.downsample(x, hw_shape)
return x_down, down_hw_shape, x, hw_shape
else:
return x, hw_shape, x, hw_shape
@BACKBONES.register_module(force=True)
class SwinTransformer(BaseModule):
""" Swin Transformer
A PyTorch implement of : `Swin Transformer:
Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/abs/2103.14030
Inspiration from
https://github.com/microsoft/Swin-Transformer
Args:
pretrain_img_size (int | tuple[int]): The size of input image when
pretrain. Defaults: 224.
in_channels (int): The num of input channels.
Defaults: 3.
embed_dims (int): The feature dimension. Default: 96.
patch_size (int | tuple[int]): Patch size. Default: 4.
window_size (int): Window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
Default: 4.
depths (tuple[int]): Depths of each Swin Transformer stage.
Default: (2, 2, 6, 2).
num_heads (tuple[int]): Parallel attention heads of each Swin
Transformer stage. Default: (3, 6, 12, 24).
strides (tuple[int]): The patch merging or patch embedding stride of
each Swin Transformer stage. (In swin, we set kernel size equal to
stride.) Default: (4, 2, 2, 2).
out_indices (tuple[int]): Output from which stages.
Default: (0, 1, 2, 3).
qkv_bias (bool, optional): If True, add a learnable bias to query, key,
value. Default: True
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
patch_norm (bool): If add a norm layer for patch embed and patch
merging. Default: True.
drop_rate (float): Dropout rate. Defaults: 0.
attn_drop_rate (float): Attention dropout rate. Default: 0.
drop_path_rate (float): Stochastic depth rate. Defaults: 0.1.
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults: False.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='LN').
norm_cfg (dict): Config dict for normalization layer at
output of backone. Defaults: dict(type='LN').
pretrain_style (str): Choose to use official or mmcls pretrain weights.
Default: official.
pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
pretrain_img_size=224,
in_channels=3,
embed_dims=96,
patch_size=4,
window_size=7,
mlp_ratio=4,
depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
strides=(4, 2, 2, 2),
out_indices=(0, 1, 2, 3),
qkv_bias=True,
qk_scale=None,
patch_norm=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.1,
use_abs_pos_embed=False,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
pretrain_style='official',
pretrained=None,
init_cfg=None,
with_cp=True,
output_missing_index_as_none=False,
frozen_stages=-1):
super(SwinTransformer, self).__init__()
if isinstance(pretrain_img_size, int):
pretrain_img_size = to_2tuple(pretrain_img_size)
elif isinstance(pretrain_img_size, tuple):
if len(pretrain_img_size) == 1:
pretrain_img_size = to_2tuple(pretrain_img_size[0])
assert len(pretrain_img_size) == 2, \
f'The size of image should have length 1 or 2, ' \
f'but got {len(pretrain_img_size)}'
assert pretrain_style in ['official', 'mmcls'], 'We only support load '
'official ckpt and mmcls ckpt.'
if isinstance(pretrained, str) or pretrained is None:
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
else:
raise TypeError('pretrained must be a str or None')
num_layers = len(depths)
self.out_indices = out_indices
self.use_abs_pos_embed = use_abs_pos_embed
self.pretrain_style = pretrain_style
self.pretrained = pretrained
self.init_cfg = init_cfg
self.frozen_stages = frozen_stages
assert strides[0] == patch_size, 'Use non-overlapping patch embed.'
self.patch_embed = PatchEmbed(
in_channels=in_channels,
embed_dims=embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=strides[0],
pad_to_patch_size=True,
norm_cfg=norm_cfg if patch_norm else None,
init_cfg=None)
if self.use_abs_pos_embed:
patch_row = pretrain_img_size[0] // patch_size
patch_col = pretrain_img_size[1] // patch_size
num_patches = patch_row * patch_col
self.absolute_pos_embed = nn.Parameter(
torch.zeros((1, num_patches, embed_dims)))
self.drop_after_pos = nn.Dropout(p=drop_rate)
# stochastic depth
total_depth = sum(depths)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
] # stochastic depth decay rule
self.stages = ModuleList()
in_channels = embed_dims
for i in range(num_layers):
if i < num_layers - 1:
downsample = PatchMerging(
in_channels=in_channels,
out_channels=2 * in_channels,
stride=strides[i + 1],
norm_cfg=norm_cfg if patch_norm else None,
init_cfg=None)
else:
downsample = None
stage = SwinBlockSequence(
embed_dims=in_channels,
num_heads=num_heads[i],
feedforward_channels=mlp_ratio * in_channels,
depth=depths[i],
window_size=window_size,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=dpr[:depths[i]],
downsample=downsample,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
init_cfg=None,
with_cp=with_cp)
self.stages.append(stage)
dpr = dpr[depths[i]:]
if downsample:
in_channels = downsample.out_channels
self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)]
# Add a norm layer for each output
for i in out_indices:
layer = build_norm_layer(norm_cfg, self.num_features[i])[1]
layer_name = f'norm{i}'
self.add_module(layer_name, layer)
self.output_missing_index_as_none = output_missing_index_as_none
self._freeze_stages()
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
if self.frozen_stages >= 1 and self.use_abs_pos_embed:
self.absolute_pos_embed.requires_grad = False
if self.frozen_stages >= 2:
self.drop_after_pos.eval()
for i in range(0, self.frozen_stages - 1):
m = self.stages[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
def init_weights(self):
if self.pretrained is None:
super().init_weights()
if self.use_abs_pos_embed:
trunc_normal_init(self.absolute_pos_embed, std=0.02)
for m in self.modules():
if isinstance(m, Linear):
trunc_normal_init(m.weight, std=.02)
if m.bias is not None:
constant_init(m.bias, 0)
elif isinstance(m, LayerNorm):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)
elif isinstance(self.pretrained, str):
logger = get_root_logger()
ckpt = _load_checkpoint(
self.pretrained, logger=logger, map_location='cpu')
if 'state_dict' in ckpt:
state_dict = ckpt['state_dict']
elif 'model' in ckpt:
state_dict = ckpt['model']
else:
state_dict = ckpt
if self.pretrain_style == 'official':
state_dict = swin_convert(state_dict)
# strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
# reshape absolute position embedding
if state_dict.get('absolute_pos_embed') is not None:
absolute_pos_embed = state_dict['absolute_pos_embed']
N1, L, C1 = absolute_pos_embed.size()
N2, C2, H, W = self.absolute_pos_embed.size()
if N1 != N2 or C1 != C2 or L != H * W:
logger.warning('Error in loading absolute_pos_embed, pass')
else:
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
N2, H, W, C2).permute(0, 3, 1, 2).contiguous()
# interpolate position bias table if needed
relative_position_bias_table_keys = [
k for k in state_dict.keys()
if 'relative_position_bias_table' in k
]
for table_key in relative_position_bias_table_keys:
table_pretrained = state_dict[table_key]
table_current = self.state_dict()[table_key]
L1, nH1 = table_pretrained.size()
L2, nH2 = table_current.size()
if nH1 != nH2:
logger.warning(f'Error in loading {table_key}, pass')
else:
if L1 != L2:
S1 = int(L1**0.5)
S2 = int(L2**0.5)
table_pretrained_resized = resize(
table_pretrained.permute(1, 0).reshape(
1, nH1, S1, S1),
size=(S2, S2),
mode='bicubic')
state_dict[table_key] = table_pretrained_resized.view(
nH2, L2).permute(1, 0).contiguous()
# load state_dict
self.load_state_dict(state_dict, False)
def forward(self, x):
x = self.patch_embed(x)
hw_shape = (self.patch_embed.DH, self.patch_embed.DW)
if self.use_abs_pos_embed:
x = x + self.absolute_pos_embed
x = self.drop_after_pos(x)
outs = []
for i, stage in enumerate(self.stages):
x, hw_shape, out, out_hw_shape = stage(x, hw_shape)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
out = norm_layer(out)
out = out.view(-1, *out_hw_shape,
self.num_features[i]).permute(0, 3, 1,
2).contiguous()
outs.append(out)
elif self.output_missing_index_as_none:
outs.append(None)
return outs
def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer
freezed."""
super(SwinTransformer, self).train(mode)
self._freeze_stages()
\ No newline at end of file
from collections import OrderedDict
from mmcv.runner import BaseModule
from mmdet.models.builder import BACKBONES
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm
VoVNet19_slim_dw_eSE = {
'stem': [64, 64, 64],
'stage_conv_ch': [64, 80, 96, 112],
'stage_out_ch': [112, 256, 384, 512],
"layer_per_block": 3,
"block_per_stage": [1, 1, 1, 1],
"eSE": True,
"dw": True
}
VoVNet19_dw_eSE = {
'stem': [64, 64, 64],
"stage_conv_ch": [128, 160, 192, 224],
"stage_out_ch": [256, 512, 768, 1024],
"layer_per_block": 3,
"block_per_stage": [1, 1, 1, 1],
"eSE": True,
"dw": True
}
VoVNet19_slim_eSE = {
'stem': [64, 64, 128],
'stage_conv_ch': [64, 80, 96, 112],
'stage_out_ch': [112, 256, 384, 512],
'layer_per_block': 3,
'block_per_stage': [1, 1, 1, 1],
'eSE': True,
"dw": False
}
VoVNet19_eSE = {
'stem': [64, 64, 128],
"stage_conv_ch": [128, 160, 192, 224],
"stage_out_ch": [256, 512, 768, 1024],
"layer_per_block": 3,
"block_per_stage": [1, 1, 1, 1],
"eSE": True,
"dw": False
}
VoVNet39_eSE = {
'stem': [64, 64, 128],
"stage_conv_ch": [128, 160, 192, 224],
"stage_out_ch": [256, 512, 768, 1024],
"layer_per_block": 5,
"block_per_stage": [1, 1, 2, 2],
"eSE": True,
"dw": False
}
VoVNet57_eSE = {
'stem': [64, 64, 128],
"stage_conv_ch": [128, 160, 192, 224],
"stage_out_ch": [256, 512, 768, 1024],
"layer_per_block": 5,
"block_per_stage": [1, 1, 4, 3],
"eSE": True,
"dw": False
}
VoVNet99_eSE = {
'stem': [64, 64, 128],
"stage_conv_ch": [128, 160, 192, 224],
"stage_out_ch": [256, 512, 768, 1024],
"layer_per_block": 5,
"block_per_stage": [1, 3, 9, 3],
"eSE": True,
"dw": False
}
_STAGE_SPECS = {
"V-19-slim-dw-eSE": VoVNet19_slim_dw_eSE,
"V-19-dw-eSE": VoVNet19_dw_eSE,
"V-19-slim-eSE": VoVNet19_slim_eSE,
"V-19-eSE": VoVNet19_eSE,
"V-39-eSE": VoVNet39_eSE,
"V-57-eSE": VoVNet57_eSE,
"V-99-eSE": VoVNet99_eSE,
}
def dw_conv3x3(in_channels, out_channels, module_name, postfix, stride=1, kernel_size=3, padding=1):
"""3x3 convolution with padding"""
return [
(
'{}_{}/dw_conv3x3'.format(module_name, postfix),
nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=out_channels,
bias=False
)
),
(
'{}_{}/pw_conv1x1'.format(module_name, postfix),
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, groups=1, bias=False)
),
('{}_{}/pw_norm'.format(module_name, postfix), nn.BatchNorm2d(out_channels)),
('{}_{}/pw_relu'.format(module_name, postfix), nn.ReLU(inplace=True)),
]
def conv3x3(in_channels, out_channels, module_name, postfix, stride=1, groups=1, kernel_size=3, padding=1):
"""3x3 convolution with padding"""
return [
(
f"{module_name}_{postfix}/conv",
nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False,
),
),
(f"{module_name}_{postfix}/norm", nn.BatchNorm2d(out_channels)),
(f"{module_name}_{postfix}/relu", nn.ReLU(inplace=True)),
]
def conv1x1(in_channels, out_channels, module_name, postfix, stride=1, groups=1, kernel_size=1, padding=0):
"""1x1 convolution with padding"""
return [
(
f"{module_name}_{postfix}/conv",
nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False,
),
),
(f"{module_name}_{postfix}/norm", nn.BatchNorm2d(out_channels)),
(f"{module_name}_{postfix}/relu", nn.ReLU(inplace=True)),
]
class Hsigmoid(nn.Module):
def __init__(self, inplace=True):
super(Hsigmoid, self).__init__()
self.inplace = inplace
def forward(self, x):
return F.relu6(x + 3.0, inplace=self.inplace) / 6.0
class eSEModule(nn.Module):
def __init__(self, channel, reduction=4):
super(eSEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Conv2d(channel, channel, kernel_size=1, padding=0)
self.hsigmoid = Hsigmoid()
def forward(self, x):
input = x
x = self.avg_pool(x)
x = self.fc(x)
x = self.hsigmoid(x)
return input * x
class _OSA_module(nn.Module):
def __init__(
self, in_ch, stage_ch, concat_ch, layer_per_block, module_name, SE=False, identity=False, depthwise=False
):
super(_OSA_module, self).__init__()
self.identity = identity
self.depthwise = depthwise
self.isReduced = False
self.layers = nn.ModuleList()
in_channel = in_ch
if self.depthwise and in_channel != stage_ch:
self.isReduced = True
self.conv_reduction = nn.Sequential(
OrderedDict(conv1x1(in_channel, stage_ch, "{}_reduction".format(module_name), "0"))
)
for i in range(layer_per_block):
if self.depthwise:
self.layers.append(nn.Sequential(OrderedDict(dw_conv3x3(stage_ch, stage_ch, module_name, i))))
else:
self.layers.append(nn.Sequential(OrderedDict(conv3x3(in_channel, stage_ch, module_name, i))))
in_channel = stage_ch
# feature aggregation
in_channel = in_ch + layer_per_block * stage_ch
self.concat = nn.Sequential(OrderedDict(conv1x1(in_channel, concat_ch, module_name, "concat")))
self.ese = eSEModule(concat_ch)
def forward(self, x):
identity_feat = x
output = []
output.append(x)
if self.depthwise and self.isReduced:
x = self.conv_reduction(x)
for layer in self.layers:
x = layer(x)
output.append(x)
x = torch.cat(output, dim=1)
xt = self.concat(x)
xt = self.ese(xt)
if self.identity:
xt = xt + identity_feat
return xt
class _OSA_stage(nn.Sequential):
def __init__(
self, in_ch, stage_ch, concat_ch, block_per_stage, layer_per_block, stage_num, SE=False, depthwise=False
):
super(_OSA_stage, self).__init__()
if not stage_num == 2:
self.add_module("Pooling", nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True))
if block_per_stage != 1:
SE = False
module_name = f"OSA{stage_num}_1"
self.add_module(
module_name, _OSA_module(in_ch, stage_ch, concat_ch, layer_per_block, module_name, SE, depthwise=depthwise)
)
for i in range(block_per_stage - 1):
if i != block_per_stage - 2: # last block
SE = False
module_name = f"OSA{stage_num}_{i + 2}"
self.add_module(
module_name,
_OSA_module(
concat_ch,
stage_ch,
concat_ch,
layer_per_block,
module_name,
SE,
identity=True,
depthwise=depthwise
),
)
@BACKBONES.register_module()
class VoVNet(BaseModule):
def __init__(self, spec_name, input_ch=3, out_features=None,
frozen_stages=-1, norm_eval=True, pretrained=None, init_cfg=None):
"""
Args:
input_ch(int) : the number of input channel
out_features (list[str]): name of the layers whose outputs should
be returned in forward. Can be anything in "stem", "stage2" ...
"""
super(VoVNet, self).__init__(init_cfg)
self.frozen_stages = frozen_stages
self.norm_eval = norm_eval
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
stage_specs = _STAGE_SPECS[spec_name]
stem_ch = stage_specs["stem"]
config_stage_ch = stage_specs["stage_conv_ch"]
config_concat_ch = stage_specs["stage_out_ch"]
block_per_stage = stage_specs["block_per_stage"]
layer_per_block = stage_specs["layer_per_block"]
SE = stage_specs["eSE"]
depthwise = stage_specs["dw"]
self._out_features = out_features
# Stem module
conv_type = dw_conv3x3 if depthwise else conv3x3
stem = conv3x3(input_ch, stem_ch[0], "stem", "1", 2)
stem += conv_type(stem_ch[0], stem_ch[1], "stem", "2", 1)
stem += conv_type(stem_ch[1], stem_ch[2], "stem", "3", 2)
self.add_module("stem", nn.Sequential((OrderedDict(stem))))
current_stirde = 4
self._out_feature_strides = {"stem": current_stirde, "stage2": current_stirde}
self._out_feature_channels = {"stem": stem_ch[2]}
stem_out_ch = [stem_ch[2]]
in_ch_list = stem_out_ch + config_concat_ch[:-1]
# OSA stages
self.stage_names = []
for i in range(4): # num_stages
name = "stage%d" % (i + 2) # stage 2 ... stage 5
self.stage_names.append(name)
self.add_module(
name,
_OSA_stage(
in_ch_list[i],
config_stage_ch[i],
config_concat_ch[i],
block_per_stage[i],
layer_per_block,
i + 2,
SE,
depthwise,
),
)
self._out_feature_channels[name] = config_concat_ch[i]
if not i == 0:
self._out_feature_strides[name] = current_stirde = int(current_stirde * 2)
# initialize weights
# self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
def forward(self, x):
outputs = {}
x = self.stem(x)
if "stem" in self._out_features:
outputs["stem"] = x
for name in self.stage_names:
x = getattr(self, name)(x)
if name in self._out_features:
outputs[name] = x
return outputs
def _freeze_stages(self):
if self.frozen_stages >= 0:
m = getattr(self, 'stem')
m.eval()
for param in m.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
m = getattr(self, f'stage{i+1}')
m.eval()
for param in m.parameters():
param.requires_grad = False
def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer
freezed."""
super(VoVNet, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
\ No newline at end of file
from .hooks import GradChecker
\ No newline at end of file
from mmcv.runner.hooks.hook import HOOKS, Hook
from projects.mmdet3d_plugin.models.utils import run_time
@HOOKS.register_module()
class GradChecker(Hook):
def after_train_iter(self, runner):
for key, val in runner.model.named_parameters():
if val.grad == None and val.requires_grad:
print('WARNNING: {key}\'s parameters are not be used!!!!'.format(key=key))
from .adamw import AdamW2
\ No newline at end of file
try:
from torch.optim import _functional as F
except:
print('WARNING!!!, I recommend using torch>=1.8')
import torch
from torch.optim.optimizer import Optimizer
from mmcv.runner.optimizer.builder import OPTIMIZERS
@OPTIMIZERS.register_module()
class AdamW2(Optimizer):
r"""Implements AdamW algorithm. Solve the bug of torch 1.8
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=1e-2, amsgrad=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad)
super(AdamW2, self).__init__(params, defaults)
def __setstate__(self, state):
super(AdamW2, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
state_sums = []
max_exp_avg_sqs = []
state_steps = []
amsgrad = group['amsgrad']
# put this line here for solving bug
beta1, beta2 = group['betas']
for p in group['params']:
if p.grad is None:
continue
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients')
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
if amsgrad:
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
# update the steps for each param group update
state['step'] += 1
# record the step after step update
state_steps.append(state['step'])
F.adamw(params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad,
beta1,
beta2,
group['lr'],
group['weight_decay'],
group['eps'])
return loss
\ No newline at end of file
import lightop
import torch
from torch.optim.optimizer import Optimizer
from mmcv.runner.optimizer.builder import OPTIMIZERS
@OPTIMIZERS.register_module()
class Miopen_AdamW(Optimizer):
r"""Implements AdamW algorithm. Solve the bug of torch 1.8
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=1e-2, amsgrad=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad)
super(Miopen_AdamW, self).__init__(params, defaults)
def __setstate__(self, state):
super(Miopen_AdamW, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
state_sums = []
max_exp_avg_sqs = []
state_steps = []
amsgrad = group['amsgrad']
# put this line here for solving bug
beta1, beta2 = group['betas']
for p in group['params']:
if p.grad is None:
continue
params_with_grad.append(p.contiguous())
if p.grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients')
grads.append(p.grad.contiguous())
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
exp_avgs.append(state['exp_avg'].contiguous())
exp_avg_sqs.append(state['exp_avg_sq'].contiguous())
if amsgrad:
max_exp_avg_sqs.append(state['max_exp_avg_sq'].contiguous())
# record the step after step update
state_steps.append(torch.tensor(state['step']))
lightop.miopen_adamw(
params=params_with_grad,
grads=grads,
exp_avgs=exp_avgs,
exp_avg_sqs=exp_avg_sqs,
max_exp_avg_sqs=max_exp_avg_sqs,
state_steps=state_steps,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'])
return loss
from .bricks import run_time
from .grid_mask import GridMask
from .position_embedding import RelPositionEmbedding
from .visual import save_tensor
from .inverted_residual import InvertedResidual
from .se_layer import DyReLU, SELayer
from .make_divisible import make_divisible
from .ckpt_convert import swin_convert, vit_convert
from .embed import PatchEmbed
\ No newline at end of file
import functools
import time
from collections import defaultdict
import torch
time_maps = defaultdict(lambda :0.)
count_maps = defaultdict(lambda :0.)
def run_time(name):
def middle(fn):
def wrapper(*args, **kwargs):
torch.cuda.synchronize()
start = time.time()
res = fn(*args, **kwargs)
torch.cuda.synchronize()
time_maps['%s : %s'%(name, fn.__name__) ] += time.time()-start
count_maps['%s : %s'%(name, fn.__name__) ] +=1
print("%s : %s takes up %f "% (name, fn.__name__,time_maps['%s : %s'%(name, fn.__name__) ] /count_maps['%s : %s'%(name, fn.__name__) ] ))
return res
return wrapper
return middle
\ No newline at end of file
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
def swin_convert(ckpt):
new_ckpt = OrderedDict()
def correct_unfold_reduction_order(x):
out_channel, in_channel = x.shape
x = x.reshape(out_channel, 4, in_channel // 4)
x = x[:, [0, 2, 1, 3], :].transpose(1,
2).reshape(out_channel, in_channel)
return x
def correct_unfold_norm_order(x):
in_channel = x.shape[0]
x = x.reshape(4, in_channel // 4)
x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
return x
for k, v in ckpt.items():
if k.startswith('head'):
continue
elif k.startswith('layers'):
new_v = v
if 'attn.' in k:
new_k = k.replace('attn.', 'attn.w_msa.')
elif 'mlp.' in k:
if 'mlp.fc1.' in k:
new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
elif 'mlp.fc2.' in k:
new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
else:
new_k = k.replace('mlp.', 'ffn.')
elif 'downsample' in k:
new_k = k
if 'reduction.' in k:
new_v = correct_unfold_reduction_order(v)
elif 'norm.' in k:
new_v = correct_unfold_norm_order(v)
else:
new_k = k
new_k = new_k.replace('layers', 'stages', 1)
elif k.startswith('patch_embed'):
new_v = v
if 'proj' in k:
new_k = k.replace('proj', 'projection')
else:
new_k = k
else:
new_v = v
new_k = k
new_ckpt[new_k] = new_v
return new_ckpt
def vit_convert(ckpt):
new_ckpt = OrderedDict()
for k, v in ckpt.items():
if k.startswith('head'):
continue
if k.startswith('norm'):
new_k = k.replace('norm.', 'ln1.')
elif k.startswith('patch_embed'):
if 'proj' in k:
new_k = k.replace('proj', 'projection')
else:
new_k = k
elif k.startswith('blocks'):
if 'norm' in k:
new_k = k.replace('norm', 'ln')
elif 'mlp.fc1' in k:
new_k = k.replace('mlp.fc1', 'ffn.layers.0.0')
elif 'mlp.fc2' in k:
new_k = k.replace('mlp.fc2', 'ffn.layers.1')
elif 'attn.qkv' in k:
new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_')
elif 'attn.proj' in k:
new_k = k.replace('attn.proj', 'attn.attn.out_proj')
else:
new_k = k
new_k = new_k.replace('blocks.', 'layers.')
else:
new_k = k
new_ckpt[new_k] = v
return new_ckpt
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn.functional as F
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner.base_module import BaseModule
from torch.nn.modules.utils import _pair as to_2tuple
# Modified from Pytorch-Image-Models
class PatchEmbed(BaseModule):
"""Image to Patch Embedding V2.
We use a conv layer to implement PatchEmbed.
Args:
in_channels (int): The num of input channels. Default: 3
embed_dims (int): The dimensions of embedding. Default: 768
conv_type (dict, optional): The config dict for conv layers type
selection. Default: None.
kernel_size (int): The kernel_size of embedding conv. Default: 16.
stride (int): The slide stride of embedding conv.
Default: None (Default to be equal with kernel_size).
padding (int): The padding length of embedding conv. Default: 0.
dilation (int): The dilation rate of embedding conv. Default: 1.
pad_to_patch_size (bool, optional): Whether to pad feature map shape
to multiple patch size. Default: True.
norm_cfg (dict, optional): Config dict for normalization layer.
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
Default: None.
"""
def __init__(self,
in_channels=3,
embed_dims=768,
conv_type=None,
kernel_size=16,
stride=16,
padding=0,
dilation=1,
pad_to_patch_size=True,
norm_cfg=None,
init_cfg=None):
super(PatchEmbed, self).__init__()
self.embed_dims = embed_dims
self.init_cfg = init_cfg
if stride is None:
stride = kernel_size
self.pad_to_patch_size = pad_to_patch_size
# The default setting of patch size is equal to kernel size.
patch_size = kernel_size
if isinstance(patch_size, int):
patch_size = to_2tuple(patch_size)
elif isinstance(patch_size, tuple):
if len(patch_size) == 1:
patch_size = to_2tuple(patch_size[0])
assert len(patch_size) == 2, \
f'The size of patch should have length 1 or 2, ' \
f'but got {len(patch_size)}'
self.patch_size = patch_size
# Use conv layer to embed
conv_type = conv_type or 'Conv2d'
self.projection = build_conv_layer(
dict(type=conv_type),
in_channels=in_channels,
out_channels=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
else:
self.norm = None
def forward(self, x):
H, W = x.shape[2], x.shape[3]
# TODO: Process overlapping op
if self.pad_to_patch_size:
# Modify H, W to multiple of patch size.
if H % self.patch_size[0] != 0:
x = F.pad(
x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
if W % self.patch_size[1] != 0:
x = F.pad(
x, (0, self.patch_size[1] - W % self.patch_size[1], 0, 0))
x = self.projection(x)
self.DH, self.DW = x.shape[2], x.shape[3]
x = x.flatten(2).transpose(1, 2)
if self.norm is not None:
x = self.norm(x)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment