Unverified Commit cce49ba9 authored by Chengyu Wang's avatar Chengyu Wang Committed by GitHub
Browse files

Add openlane v2 (#121)

parent dbf29e61
/*!
**************************************************************************************************
* InternImage
* Copyright (c) 2022 OpenGVLab
* Licensed under The MIT License [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include "cuda/dcnv3_im2col_cuda.cuh"
#include <vector>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/torch.h>
at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset,
const at::Tensor &mask, const int kernel_h,
const int kernel_w, const int stride_h,
const int stride_w, const int pad_h,
const int pad_w, const int dilation_h,
const int dilation_w, const int group,
const int group_channels,
const float offset_scale, const int im2col_step) {
AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous");
AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous");
AT_ASSERTM(mask.is_contiguous(), "mask tensor has to be contiguous");
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor");
AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");
const int batch = input.size(0);
const int height_in = input.size(1);
const int width_in = input.size(2);
const int channels = input.size(3);
const int height_out =
(height_in + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h +
1;
const int width_out =
(width_in + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w +
1;
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0,
"batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
AT_ASSERTM(
channels == (group * group_channels),
"Input channels and group times group channels wont match: (%d vs %d).",
channels, group * group_channels);
auto output =
at::zeros({batch, height_out, width_out, group * group_channels},
input.options());
const int batch_n = im2col_step_;
auto output_n = output.view({batch / batch_n, batch_n, height_out,
width_out, group * group_channels});
auto per_input_size = height_in * width_in * group * group_channels;
auto per_offset_size =
height_out * width_out * group * kernel_h * kernel_w * 2;
auto per_mask_size = height_out * width_out * group * kernel_h * kernel_w;
for (int n = 0; n < batch / im2col_step_; ++n) {
auto columns = output_n.select(0, n);
// AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.type(), "ms_deform_attn_forward_cuda", ([&] {
dcnv3_im2col_cuda(
at::cuda::getCurrentCUDAStream(),
input.data<scalar_t>() + n * im2col_step_ * per_input_size,
offset.data<scalar_t>() +
n * im2col_step_ * per_offset_size,
mask.data<scalar_t>() + n * im2col_step_ * per_mask_size,
columns.data<scalar_t>(), kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, batch_n, height_in, width_in, height_out,
width_out, offset_scale);
}));
}
return output;
}
std::vector<at::Tensor>
dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
const at::Tensor &mask, const int kernel_h,
const int kernel_w, const int stride_h, const int stride_w,
const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group,
const int group_channels, const float offset_scale,
const at::Tensor &grad_output, const int im2col_step) {
AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous");
AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous");
AT_ASSERTM(mask.is_contiguous(), "mask tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(),
"grad_output tensor has to be contiguous");
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor");
AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");
AT_ASSERTM(grad_output.type().is_cuda(),
"grad_output must be a CUDA tensor");
const int batch = input.size(0);
const int height_in = input.size(1);
const int width_in = input.size(2);
const int channels = input.size(3);
const int height_out =
(height_in + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h +
1;
const int width_out =
(width_in + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w +
1;
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0,
"batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
AT_ASSERTM(
channels == (group * group_channels),
"Input channels and group times group channels wont match: (%d vs %d).",
channels, group * group_channels);
auto dtype = input.dtype();
if (dtype == at::kHalf) {
dtype = at::kFloat;
}
auto grad_input = at::zeros_like(input, dtype);
auto grad_offset = at::zeros_like(offset, dtype);
auto grad_mask = at::zeros_like(mask, dtype);
const int batch_n = im2col_step_;
auto per_input_size = height_in * width_in * group * group_channels;
auto per_offset_size =
height_out * width_out * group * kernel_h * kernel_w * 2;
auto per_mask_size = height_out * width_out * group * kernel_h * kernel_w;
auto grad_output_n =
grad_output.view({batch / im2col_step_, batch_n, height_out * width_out,
group, group_channels});
for (int n = 0; n < batch / im2col_step_; ++n) {
auto grad_output_g = grad_output_n.select(0, n);
// AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.type(), "ms_deform_attn_backward_cuda", ([&] {
dcnv3_col2im_cuda(
at::cuda::getCurrentCUDAStream(),
grad_output_g.data<scalar_t>(),
input.data<scalar_t>() + n * im2col_step_ * per_input_size,
offset.data<scalar_t>() +
n * im2col_step_ * per_offset_size,
mask.data<scalar_t>() + n * im2col_step_ * per_mask_size,
kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
dilation_h, dilation_w, group, group_channels, batch_n,
height_in, width_in, height_out, width_out, offset_scale,
grad_input.data<opmath_t>() +
n * im2col_step_ * per_input_size,
grad_offset.data<opmath_t>() +
n * im2col_step_ * per_offset_size,
grad_mask.data<opmath_t>() +
n * im2col_step_ * per_mask_size);
}));
}
if (input.dtype() == torch::kHalf) {
return {grad_input.to(torch::kHalf), grad_offset.to(torch::kHalf),
grad_mask.to(torch::kHalf)};
} else {
return {grad_input, grad_offset, grad_mask};
}
}
\ No newline at end of file
/*!
**************************************************************************************************
* InternImage
* Copyright (c) 2022 OpenGVLab
* Licensed under The MIT License [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include <torch/extension.h>
at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset,
const at::Tensor &mask, const int kernel_h,
const int kernel_w, const int stride_h,
const int stride_w, const int pad_h,
const int pad_w, const int dilation_h,
const int dilation_w, const int group,
const int group_channels,
const float offset_scale, const int im2col_step);
std::vector<at::Tensor>
dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
const at::Tensor &mask, const int kernel_h,
const int kernel_w, const int stride_h, const int stride_w,
const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group,
const int group_channels, const float offset_scale,
const at::Tensor &grad_output, const int im2col_step);
/*!
**************************************************************************************************
* InternImage
* Copyright (c) 2022 OpenGVLab
* Licensed under The MIT License [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <ATen/ATen.h>
#include <ATen/OpMathType.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 256;
inline int GET_BLOCKS(const int N, const int num_threads) {
return (N + num_threads - 1) / num_threads;
}
#define opmath_t at::opmath_type<scalar_t>
template <typename scalar_t>
__device__ opmath_t dcnv3_im2col_bilinear(const scalar_t *&bottom_data,
const int &height, const int &width,
const int &group,
const int &group_channels,
const opmath_t &h, const opmath_t &w,
const int &g, const int &c) {
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const opmath_t lh = h - h_low;
const opmath_t lw = w - w_low;
const opmath_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = group * group_channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = g * group_channels + c;
opmath_t v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
}
opmath_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
}
opmath_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
}
opmath_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
}
const opmath_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const opmath_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename scalar_t>
__device__ void dcnv3_col2im_bilinear(
const scalar_t *&bottom_data, const int &height, const int &width,
const int &nheads, const int &group_channels, const opmath_t &h,
const opmath_t &w, const int &m, const int &c, const opmath_t offset_scale,
const opmath_t &top_grad, const opmath_t &mask, opmath_t *&grad_im,
opmath_t *grad_offset, opmath_t *grad_mask) {
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const opmath_t lh = h - h_low;
const opmath_t lw = w - w_low;
const opmath_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * group_channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * group_channels + c;
const opmath_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const opmath_t top_grad_im = top_grad * mask;
opmath_t grad_h_weight = 0, grad_w_weight = 0;
opmath_t v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
atomicAdd(grad_im + ptr1, w1 * top_grad_im);
}
opmath_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
atomicAdd(grad_im + ptr2, w2 * top_grad_im);
}
opmath_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
atomicAdd(grad_im + ptr3, w3 * top_grad_im);
}
opmath_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
atomicAdd(grad_im + ptr4, w4 * top_grad_im);
}
const opmath_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
*grad_mask = top_grad * val;
*grad_offset = offset_scale * grad_w_weight * top_grad_im;
*(grad_offset + 1) = offset_scale * grad_h_weight * top_grad_im;
}
template <typename scalar_t>
__device__ void dcnv3_col2im_bilinear_gm(
const scalar_t *&bottom_data, const int &height, const int &width,
const int &nheads, const int &group_channels, const opmath_t &h,
const opmath_t &w, const int &m, const int &c, const opmath_t offset_scale,
const opmath_t &top_grad, const opmath_t &mask, opmath_t *&grad_im,
opmath_t *grad_offset, opmath_t *grad_mask) {
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const opmath_t lh = h - h_low;
const opmath_t lw = w - w_low;
const opmath_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = nheads * group_channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = m * group_channels + c;
const opmath_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const opmath_t top_grad_im = top_grad * mask;
opmath_t grad_h_weight = 0, grad_w_weight = 0;
opmath_t v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
atomicAdd(grad_im + ptr1, w1 * top_grad_im);
}
opmath_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
atomicAdd(grad_im + ptr2, w2 * top_grad_im);
}
opmath_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
atomicAdd(grad_im + ptr3, w3 * top_grad_im);
}
opmath_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
atomicAdd(grad_im + ptr4, w4 * top_grad_im);
}
const opmath_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
atomicAdd(grad_mask, top_grad * val);
atomicAdd(grad_offset, offset_scale * grad_w_weight * top_grad_im);
atomicAdd(grad_offset + 1, offset_scale * grad_h_weight * top_grad_im);
}
template <typename scalar_t>
__global__ void dcnv3_im2col_gpu_kernel(
const int num_kernels, const scalar_t *data_im, const scalar_t *data_offset,
const scalar_t *data_mask, scalar_t *data_col, const int kernel_h,
const int kernel_w, const int stride_h, const int stride_w, const int pad_h,
const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int height_in,
const int width_in, const int height_out, const int width_out,
const opmath_t offset_scale) {
CUDA_KERNEL_LOOP(index, num_kernels) {
int _temp = index;
const int c_col = _temp % group_channels;
_temp /= group_channels;
const int sampling_index = _temp;
const int g_col = _temp % group;
_temp /= group;
const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w +
(_temp % width_out) * stride_w;
_temp /= width_out;
const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h +
(_temp % height_out) * stride_h;
_temp /= height_out;
const int b_col = _temp;
const int input_size = height_in * width_in;
scalar_t *data_col_ptr = data_col + index;
const int kernel_size = kernel_h * kernel_w;
int data_weight_ptr = sampling_index * kernel_size;
int data_loc_w_ptr = data_weight_ptr << 1;
const int qid_stride = group * group_channels;
opmath_t col = 0;
const scalar_t *data_im_ptr = data_im + b_col * input_size * qid_stride;
// top-left
const opmath_t p0_w_ =
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) {
const opmath_t offset_w = data_offset[data_loc_w_ptr];
const opmath_t offset_h = data_offset[data_loc_w_ptr + 1];
const opmath_t loc_w =
p0_w_ + (i * dilation_w + offset_w) * offset_scale;
const opmath_t loc_h =
p0_h_ + (j * dilation_h + offset_h) * offset_scale;
const opmath_t weight = data_mask[data_weight_ptr];
if (loc_h > -1 && loc_w > -1 && loc_h < height_in &&
loc_w < width_in) {
col += dcnv3_im2col_bilinear(
data_im_ptr, height_in, width_in, group,
group_channels, loc_h, loc_w, g_col, c_col) *
weight;
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
}
}
*data_col_ptr = col;
}
}
// debug
template <typename scalar_t, unsigned int blockSize>
__global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
const int num_kernels, const scalar_t *grad_col, const scalar_t *data_im,
const scalar_t *data_offset, const scalar_t *data_mask, const int kernel_h,
const int kernel_w, const int stride_h, const int stride_w, const int pad_h,
const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int height_in,
const int width_in, const int height_out, const int width_out,
const opmath_t offset_scale, opmath_t *grad_im, opmath_t *grad_offset,
opmath_t *grad_mask) {
CUDA_KERNEL_LOOP(index, num_kernels) {
__shared__ opmath_t cache_grad_offset[blockSize * 2];
__shared__ opmath_t cache_grad_mask[blockSize];
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % group_channels;
_temp /= group_channels;
const int sampling_index = _temp;
const int g_col = _temp % group;
_temp /= group;
const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w +
(_temp % width_out) * stride_w;
_temp /= width_out;
const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h +
(_temp % height_out) * stride_h;
_temp /= height_out;
const int b_col = _temp;
const opmath_t top_grad = grad_col[index];
const int input_size = height_in * width_in;
const int kernel_size = kernel_h * kernel_w;
int data_weight_ptr = sampling_index * kernel_size;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_offset += grad_sampling_ptr << 1;
grad_mask += grad_sampling_ptr;
const int qid_stride = group * group_channels;
const int im_ptr_offset = b_col * input_size * qid_stride;
const scalar_t *data_im_ptr = data_im + im_ptr_offset;
opmath_t *grad_im_ptr = grad_im + im_ptr_offset;
const opmath_t p0_w_ =
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) {
const opmath_t offset_w = data_offset[data_loc_w_ptr];
const opmath_t offset_h = data_offset[data_loc_w_ptr + 1];
const opmath_t loc_w =
p0_w_ + (i * dilation_w + offset_w) * offset_scale;
const opmath_t loc_h =
p0_h_ + (j * dilation_h + offset_h) * offset_scale;
const opmath_t weight = data_mask[data_weight_ptr];
*(cache_grad_offset + (threadIdx.x << 1)) = 0;
*(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_mask + threadIdx.x) = 0;
if (loc_h > -1 && loc_w > -1 && loc_h < height_in &&
loc_w < width_in) {
dcnv3_col2im_bilinear(
data_im_ptr, height_in, width_in, group, group_channels,
loc_h, loc_w, g_col, c_col, offset_scale, top_grad,
weight, grad_im_ptr,
cache_grad_offset + (threadIdx.x << 1),
cache_grad_mask + threadIdx.x);
}
__syncthreads();
if (tid == 0) {
opmath_t _grad_w = cache_grad_offset[0],
_grad_h = cache_grad_offset[1],
_grad_a = cache_grad_mask[0];
int sid = 2;
for (unsigned int tid = 1; tid < blockSize; ++tid) {
_grad_w += cache_grad_offset[sid];
_grad_h += cache_grad_offset[sid + 1];
_grad_a += cache_grad_mask[tid];
sid += 2;
}
*grad_offset = _grad_w;
*(grad_offset + 1) = _grad_h;
*grad_mask = _grad_a;
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_mask += 1;
grad_offset += 2;
}
}
}
}
template <typename scalar_t, unsigned int blockSize>
__global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
const int num_kernels, const scalar_t *grad_col, const scalar_t *data_im,
const scalar_t *data_offset, const scalar_t *data_mask, const int kernel_h,
const int kernel_w, const int stride_h, const int stride_w, const int pad_h,
const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int height_in,
const int width_in, const int height_out, const int width_out,
const opmath_t offset_scale, opmath_t *grad_im, opmath_t *grad_offset,
opmath_t *grad_mask) {
CUDA_KERNEL_LOOP(index, num_kernels) {
__shared__ opmath_t cache_grad_offset[blockSize * 2];
__shared__ opmath_t cache_grad_mask[blockSize];
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % group_channels;
_temp /= group_channels;
const int sampling_index = _temp;
const int g_col = _temp % group;
_temp /= group;
const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w +
(_temp % width_out) * stride_w;
_temp /= width_out;
const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h +
(_temp % height_out) * stride_h;
_temp /= height_out;
const int b_col = _temp;
const opmath_t top_grad = grad_col[index];
const int input_size = height_in * width_in;
const int kernel_size = kernel_h * kernel_w;
int data_weight_ptr = sampling_index * kernel_size;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_offset += grad_sampling_ptr << 1;
grad_mask += grad_sampling_ptr;
const int qid_stride = group * group_channels;
const int im_ptr_offset = b_col * input_size * qid_stride;
const scalar_t *data_im_ptr = data_im + im_ptr_offset;
opmath_t *grad_im_ptr = grad_im + im_ptr_offset;
const opmath_t p0_w_ =
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) {
const opmath_t offset_w = data_offset[data_loc_w_ptr];
const opmath_t offset_h = data_offset[data_loc_w_ptr + 1];
const opmath_t loc_w =
p0_w_ + (i * dilation_w + offset_w) * offset_scale;
const opmath_t loc_h =
p0_h_ + (j * dilation_h + offset_h) * offset_scale;
const opmath_t weight = data_mask[data_weight_ptr];
*(cache_grad_offset + (threadIdx.x << 1)) = 0;
*(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_mask + threadIdx.x) = 0;
if (loc_h > -1 && loc_w > -1 && loc_h < height_in &&
loc_w < width_in) {
dcnv3_col2im_bilinear(
data_im_ptr, height_in, width_in, group, group_channels,
loc_h, loc_w, g_col, c_col, offset_scale, top_grad,
weight, grad_im_ptr,
cache_grad_offset + (threadIdx.x << 1),
cache_grad_mask + threadIdx.x);
}
__syncthreads();
for (unsigned int s = blockSize / 2; s > 0; s >>= 1) {
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_mask[tid] += cache_grad_mask[tid + s];
cache_grad_offset[xid1] += cache_grad_offset[xid2];
cache_grad_offset[xid1 + 1] +=
cache_grad_offset[xid2 + 1];
}
__syncthreads();
}
if (tid == 0) {
*grad_offset = cache_grad_offset[0];
*(grad_offset + 1) = cache_grad_offset[1];
*grad_mask = cache_grad_mask[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_mask += 1;
grad_offset += 2;
}
}
}
}
template <typename scalar_t>
__global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v1(
const int num_kernels, const scalar_t *grad_col, const scalar_t *data_im,
const scalar_t *data_offset, const scalar_t *data_mask, const int kernel_h,
const int kernel_w, const int stride_h, const int stride_w, const int pad_h,
const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int height_in,
const int width_in, const int height_out, const int width_out,
const opmath_t offset_scale, opmath_t *grad_im, opmath_t *grad_offset,
opmath_t *grad_mask) {
CUDA_KERNEL_LOOP(index, num_kernels) {
extern __shared__ int _s[];
opmath_t *cache_grad_offset = (opmath_t *)_s;
opmath_t *cache_grad_mask = cache_grad_offset + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % group_channels;
_temp /= group_channels;
const int sampling_index = _temp;
const int g_col = _temp % group;
_temp /= group;
const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w +
(_temp % width_out) * stride_w;
_temp /= width_out;
const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h +
(_temp % height_out) * stride_h;
_temp /= height_out;
const int b_col = _temp;
const opmath_t top_grad = grad_col[index];
const int input_size = height_in * width_in;
const int kernel_size = kernel_h * kernel_w;
int data_weight_ptr = sampling_index * kernel_size;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_offset += grad_sampling_ptr << 1;
grad_mask += grad_sampling_ptr;
const int qid_stride = group * group_channels;
const int im_ptr_offset = b_col * input_size * qid_stride;
const scalar_t *data_im_ptr = data_im + im_ptr_offset;
opmath_t *grad_im_ptr = grad_im + im_ptr_offset;
const opmath_t p0_w_ =
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) {
const opmath_t offset_w = data_offset[data_loc_w_ptr];
const opmath_t offset_h = data_offset[data_loc_w_ptr + 1];
const opmath_t loc_w =
p0_w_ + (i * dilation_w + offset_w) * offset_scale;
const opmath_t loc_h =
p0_h_ + (j * dilation_h + offset_h) * offset_scale;
const opmath_t weight = data_mask[data_weight_ptr];
*(cache_grad_offset + (threadIdx.x << 1)) = 0;
*(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_mask + threadIdx.x) = 0;
if (loc_h > -1 && loc_w > -1 && loc_h < height_in &&
loc_w < width_in) {
dcnv3_col2im_bilinear(
data_im_ptr, height_in, width_in, group, group_channels,
loc_h, loc_w, g_col, c_col, offset_scale, top_grad,
weight, grad_im_ptr,
cache_grad_offset + (threadIdx.x << 1),
cache_grad_mask + threadIdx.x);
}
__syncthreads();
if (tid == 0) {
opmath_t _grad_w = cache_grad_offset[0],
_grad_h = cache_grad_offset[1],
_grad_a = cache_grad_mask[0];
int sid = 2;
for (unsigned int tid = 1; tid < blockDim.x; ++tid) {
_grad_w += cache_grad_offset[sid];
_grad_h += cache_grad_offset[sid + 1];
_grad_a += cache_grad_mask[tid];
sid += 2;
}
*grad_offset = _grad_w;
*(grad_offset + 1) = _grad_h;
*grad_mask = _grad_a;
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_mask += 1;
grad_offset += 2;
}
}
}
}
template <typename scalar_t>
__global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2(
const int num_kernels, const scalar_t *grad_col, const scalar_t *data_im,
const scalar_t *data_offset, const scalar_t *data_mask, const int kernel_h,
const int kernel_w, const int stride_h, const int stride_w, const int pad_h,
const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int height_in,
const int width_in, const int height_out, const int width_out,
const opmath_t offset_scale, opmath_t *grad_im, opmath_t *grad_offset,
opmath_t *grad_mask) {
CUDA_KERNEL_LOOP(index, num_kernels) {
extern __shared__ int _s[];
opmath_t *cache_grad_offset = (opmath_t *)_s;
opmath_t *cache_grad_mask = cache_grad_offset + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % group_channels;
_temp /= group_channels;
const int sampling_index = _temp;
const int g_col = _temp % group;
_temp /= group;
const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w +
(_temp % width_out) * stride_w;
_temp /= width_out;
const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h +
(_temp % height_out) * stride_h;
_temp /= height_out;
const int b_col = _temp;
const opmath_t top_grad = grad_col[index];
const int input_size = height_in * width_in;
const int kernel_size = kernel_h * kernel_w;
int data_weight_ptr = sampling_index * kernel_size;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_offset += grad_sampling_ptr << 1;
grad_mask += grad_sampling_ptr;
const int qid_stride = group * group_channels;
const int im_ptr_offset = b_col * input_size * qid_stride;
const scalar_t *data_im_ptr = data_im + im_ptr_offset;
opmath_t *grad_im_ptr = grad_im + im_ptr_offset;
const opmath_t p0_w_ =
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) {
const opmath_t offset_w = data_offset[data_loc_w_ptr];
const opmath_t offset_h = data_offset[data_loc_w_ptr + 1];
const opmath_t loc_w =
p0_w_ + (i * dilation_w + offset_w) * offset_scale;
const opmath_t loc_h =
p0_h_ + (j * dilation_h + offset_h) * offset_scale;
const opmath_t weight = data_mask[data_weight_ptr];
*(cache_grad_offset + (threadIdx.x << 1)) = 0;
*(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_mask + threadIdx.x) = 0;
if (loc_h > -1 && loc_w > -1 && loc_h < height_in &&
loc_w < width_in) {
dcnv3_col2im_bilinear(
data_im_ptr, height_in, width_in, group, group_channels,
loc_h, loc_w, g_col, c_col, offset_scale, top_grad,
weight, grad_im_ptr,
cache_grad_offset + (threadIdx.x << 1),
cache_grad_mask + threadIdx.x);
}
__syncthreads();
for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0;
s >>= 1, spre >>= 1) {
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_mask[tid] += cache_grad_mask[tid + s];
cache_grad_offset[xid1] += cache_grad_offset[xid2];
cache_grad_offset[xid1 + 1] +=
cache_grad_offset[xid2 + 1];
if (tid + (s << 1) < spre) {
cache_grad_mask[tid] +=
cache_grad_mask[tid + (s << 1)];
cache_grad_offset[xid1] +=
cache_grad_offset[xid2 + (s << 1)];
cache_grad_offset[xid1 + 1] +=
cache_grad_offset[xid2 + 1 + (s << 1)];
}
}
__syncthreads();
}
if (tid == 0) {
*grad_offset = cache_grad_offset[0];
*(grad_offset + 1) = cache_grad_offset[1];
*grad_mask = cache_grad_mask[0];
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_mask += 1;
grad_offset += 2;
}
}
}
}
template <typename scalar_t>
__global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
const int num_kernels, const scalar_t *grad_col, const scalar_t *data_im,
const scalar_t *data_offset, const scalar_t *data_mask, const int kernel_h,
const int kernel_w, const int stride_h, const int stride_w, const int pad_h,
const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int height_in,
const int width_in, const int height_out, const int width_out,
const opmath_t offset_scale, opmath_t *grad_im, opmath_t *grad_offset,
opmath_t *grad_mask) {
CUDA_KERNEL_LOOP(index, num_kernels) {
extern __shared__ int _s[];
opmath_t *cache_grad_offset = (opmath_t *)_s;
opmath_t *cache_grad_mask = cache_grad_offset + 2 * blockDim.x;
unsigned int tid = threadIdx.x;
int _temp = index;
const int c_col = _temp % group_channels;
_temp /= group_channels;
const int sampling_index = _temp;
const int g_col = _temp % group;
_temp /= group;
const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w +
(_temp % width_out) * stride_w;
_temp /= width_out;
const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h +
(_temp % height_out) * stride_h;
_temp /= height_out;
const int b_col = _temp;
const opmath_t top_grad = grad_col[index];
const int input_size = height_in * width_in;
const int kernel_size = kernel_h * kernel_w;
int data_weight_ptr = sampling_index * kernel_size;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_offset += grad_sampling_ptr << 1;
grad_mask += grad_sampling_ptr;
const int qid_stride = group * group_channels;
const int im_ptr_offset = b_col * input_size * qid_stride;
const scalar_t *data_im_ptr = data_im + im_ptr_offset;
opmath_t *grad_im_ptr = grad_im + im_ptr_offset;
const opmath_t p0_w_ =
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) {
const opmath_t offset_w = data_offset[data_loc_w_ptr];
const opmath_t offset_h = data_offset[data_loc_w_ptr + 1];
const opmath_t loc_w =
p0_w_ + (i * dilation_w + offset_w) * offset_scale;
const opmath_t loc_h =
p0_h_ + (j * dilation_h + offset_h) * offset_scale;
const opmath_t weight = data_mask[data_weight_ptr];
*(cache_grad_offset + (threadIdx.x << 1)) = 0;
*(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0;
*(cache_grad_mask + threadIdx.x) = 0;
if (loc_h > -1 && loc_w > -1 && loc_h < height_in &&
loc_w < width_in) {
dcnv3_col2im_bilinear(
data_im_ptr, height_in, width_in, group, group_channels,
loc_h, loc_w, g_col, c_col, offset_scale, top_grad,
weight, grad_im_ptr,
cache_grad_offset + (threadIdx.x << 1),
cache_grad_mask + threadIdx.x);
}
__syncthreads();
for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0;
s >>= 1, spre >>= 1) {
if (tid < s) {
const unsigned int xid1 = tid << 1;
const unsigned int xid2 = (tid + s) << 1;
cache_grad_mask[tid] += cache_grad_mask[tid + s];
cache_grad_offset[xid1] += cache_grad_offset[xid2];
cache_grad_offset[xid1 + 1] +=
cache_grad_offset[xid2 + 1];
if (tid + (s << 1) < spre) {
cache_grad_mask[tid] +=
cache_grad_mask[tid + (s << 1)];
cache_grad_offset[xid1] +=
cache_grad_offset[xid2 + (s << 1)];
cache_grad_offset[xid1 + 1] +=
cache_grad_offset[xid2 + 1 + (s << 1)];
}
}
__syncthreads();
}
if (tid == 0) {
atomicAdd(grad_offset, cache_grad_offset[0]);
atomicAdd(grad_offset + 1, cache_grad_offset[1]);
atomicAdd(grad_mask, cache_grad_mask[0]);
}
__syncthreads();
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_mask += 1;
grad_offset += 2;
}
}
}
}
template <typename scalar_t>
__global__ void dcnv3_col2im_gpu_kernel_gm(
const int num_kernels, const scalar_t *grad_col, const scalar_t *data_im,
const scalar_t *data_offset, const scalar_t *data_mask, const int kernel_h,
const int kernel_w, const int stride_h, const int stride_w, const int pad_h,
const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int height_in,
const int width_in, const int height_out, const int width_out,
const opmath_t offset_scale, opmath_t *grad_im, opmath_t *grad_offset,
opmath_t *grad_mask) {
CUDA_KERNEL_LOOP(index, num_kernels) {
int _temp = index;
const int c_col = _temp % group_channels;
_temp /= group_channels;
const int sampling_index = _temp;
const int g_col = _temp % group;
_temp /= group;
const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w +
(_temp % width_out) * stride_w;
_temp /= width_out;
const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h +
(_temp % height_out) * stride_h;
_temp /= height_out;
const int b_col = _temp;
const opmath_t top_grad = grad_col[index];
const int input_size = height_in * width_in;
const int kernel_size = kernel_h * kernel_w;
int data_weight_ptr = sampling_index * kernel_size;
int data_loc_w_ptr = data_weight_ptr << 1;
const int grad_sampling_ptr = data_weight_ptr;
grad_offset += grad_sampling_ptr << 1;
grad_mask += grad_sampling_ptr;
const int qid_stride = group * group_channels;
const int im_ptr_offset = b_col * input_size * qid_stride;
const scalar_t *data_im_ptr = data_im + im_ptr_offset;
opmath_t *grad_im_ptr = grad_im + im_ptr_offset;
const opmath_t p0_w_ =
p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const opmath_t p0_h_ =
p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) {
const opmath_t offset_w = data_offset[data_loc_w_ptr];
const opmath_t offset_h = data_offset[data_loc_w_ptr + 1];
const opmath_t loc_w =
p0_w_ + (i * dilation_w + offset_w) * offset_scale;
const opmath_t loc_h =
p0_h_ + (j * dilation_h + offset_h) * offset_scale;
const opmath_t weight = data_mask[data_weight_ptr];
if (loc_h > -1 && loc_w > -1 && loc_h < height_in &&
loc_w < width_in) {
dcnv3_col2im_bilinear_gm(
data_im_ptr, height_in, width_in, group, group_channels,
loc_h, loc_w, g_col, c_col, offset_scale, top_grad,
weight, grad_im_ptr, grad_offset, grad_mask);
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
grad_mask += 1;
grad_offset += 2;
}
}
}
}
template <typename scalar_t>
void dcnv3_im2col_cuda(cudaStream_t stream, const scalar_t *data_im,
const scalar_t *data_offset, const scalar_t *data_mask,
scalar_t *data_col, const int kernel_h,
const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
const int group, const int group_channels,
const int batch_n, const int height_in,
const int width_in, const int height_out,
const int width_out, const opmath_t offset_scale) {
const int num_kernels =
batch_n * height_out * width_out * group * group_channels;
const int num_actual_kernels =
batch_n * height_out * width_out * group * group_channels;
const int num_threads = CUDA_NUM_THREADS;
dcnv3_im2col_gpu_kernel<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, data_im, data_offset, data_mask, data_col,
kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
dilation_h, dilation_w, group, group_channels, height_in,
width_in, height_out, width_out, offset_scale);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in dcnv3_im2col_cuda: %s\n", cudaGetErrorString(err));
}
}
template <typename scalar_t>
void dcnv3_col2im_cuda(
cudaStream_t stream, const scalar_t *grad_col, const scalar_t *data_im,
const scalar_t *data_offset, const scalar_t *data_mask, const int kernel_h,
const int kernel_w, const int stride_h, const int stride_w, const int pad_h,
const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels, const int batch_n,
const int height_in, const int width_in, const int height_out,
const int width_out, const opmath_t offset_scale, opmath_t *grad_im,
opmath_t *grad_offset, opmath_t *grad_mask) {
const int num_threads =
(group_channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : group_channels;
const int num_kernels =
batch_n * height_out * width_out * group * group_channels;
const int num_actual_kernels =
batch_n * height_out * width_out * group * group_channels;
if (group_channels > 1024) {
if ((group_channels & 1023) == 0) {
dcnv3_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads * 3 * sizeof(opmath_t), stream>>>(
num_kernels, grad_col, data_im, data_offset, data_mask,
kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
dilation_h, dilation_w, group, group_channels, height_in,
width_in, height_out, width_out, offset_scale, grad_im,
grad_offset, grad_mask);
} else {
dcnv3_col2im_gpu_kernel_gm<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_im, data_offset,
data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset,
grad_mask);
}
} else {
switch (group_channels) {
case 1:
dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_im, data_offset,
data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset,
grad_mask);
break;
case 2:
dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_im, data_offset,
data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset,
grad_mask);
break;
case 4:
dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_im, data_offset,
data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset,
grad_mask);
break;
case 8:
dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_im, data_offset,
data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset,
grad_mask);
break;
case 16:
dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_im, data_offset,
data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset,
grad_mask);
break;
case 32:
dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_im, data_offset,
data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset,
grad_mask);
break;
case 64:
dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_im, data_offset,
data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset,
grad_mask);
break;
case 128:
dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_im, data_offset,
data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset,
grad_mask);
break;
case 256:
dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_im, data_offset,
data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset,
grad_mask);
break;
case 512:
dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_im, data_offset,
data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset,
grad_mask);
break;
case 1024:
dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
1024>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_im, data_offset,
data_mask, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
group_channels, height_in, width_in, height_out,
width_out, offset_scale, grad_im, grad_offset,
grad_mask);
break;
default:
if (group_channels < 64) {
dcnv3_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads * 3 * sizeof(opmath_t), stream>>>(
num_kernels, grad_col, data_im, data_offset, data_mask,
kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
dilation_h, dilation_w, group, group_channels,
height_in, width_in, height_out, width_out,
offset_scale, grad_im, grad_offset, grad_mask);
} else {
dcnv3_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads * 3 * sizeof(opmath_t), stream>>>(
num_kernels, grad_col, data_im, data_offset, data_mask,
kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
dilation_h, dilation_w, group, group_channels,
height_in, width_in, height_out, width_out,
offset_scale, grad_im, grad_offset, grad_mask);
}
}
}
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in dcnv3_col2im_cuda: %s\n", cudaGetErrorString(err));
}
}
\ No newline at end of file
/*!
**************************************************************************************************
* InternImage
* Copyright (c) 2022 OpenGVLab
* Licensed under The MIT License [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#pragma once
#include "cpu/dcnv3_cpu.h"
#ifdef WITH_CUDA
#include "cuda/dcnv3_cuda.h"
#endif
at::Tensor dcnv3_forward(const at::Tensor &input, const at::Tensor &offset,
const at::Tensor &mask, const int kernel_h,
const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
const int group, const int group_channels,
const float offset_scale, const int im2col_step) {
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return dcnv3_cuda_forward(input, offset, mask, kernel_h, kernel_w,
stride_h, stride_w, pad_h, pad_w, dilation_h,
dilation_w, group, group_channels,
offset_scale, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
std::vector<at::Tensor>
dcnv3_backward(const at::Tensor &input, const at::Tensor &offset,
const at::Tensor &mask, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_h,
const int pad_w, const int dilation_h, const int dilation_w,
const int group, const int group_channels,
const float offset_scale, const at::Tensor &grad_output,
const int im2col_step) {
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return dcnv3_cuda_backward(input, offset, mask, kernel_h, kernel_w,
stride_h, stride_w, pad_h, pad_w, dilation_h,
dilation_w, group, group_channels,
offset_scale, grad_output, im2col_step);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
/*!
**************************************************************************************************
* InternImage
* Copyright (c) 2022 OpenGVLab
* Licensed under The MIT License [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include "dcnv3.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("dcnv3_forward", &dcnv3_forward, "dcnv3_forward");
m.def("dcnv3_backward", &dcnv3_backward, "dcnv3_backward");
}
# --------------------------------------------------------
# InternImage
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import time
import torch
import torch.nn as nn
import math
from torch.autograd import gradcheck
from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch
H_in, W_in = 8, 8
N, M, D = 2, 4, 16
Kh, Kw = 3, 3
P = Kh * Kw
offset_scale = 2.0
pad = 1
dilation = 1
stride = 1
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
torch.manual_seed(3)
@torch.no_grad()
def check_forward_equal_with_pytorch_double():
input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01
offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10
mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask /= mask.sum(-1, keepdim=True)
mask = mask.reshape(N, H_out, W_out, M*P)
output_pytorch = dcnv3_core_pytorch(
input.double(),
offset.double(),
mask.double(),
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale).detach().cpu()
im2col_step = 2
output_cuda = DCNv3Function.apply(
input.double(),
offset.double(),
mask.double(),
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step).detach().cpu()
fwdok = torch.allclose(output_cuda, output_pytorch)
max_abs_err = (output_cuda - output_pytorch).abs().max()
max_rel_err = ((output_cuda - output_pytorch).abs() /
output_pytorch.abs()).max()
print('>>> forward double')
print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
@torch.no_grad()
def check_forward_equal_with_pytorch_float():
input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01
offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10
mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask /= mask.sum(-1, keepdim=True)
mask = mask.reshape(N, H_out, W_out, M*P)
output_pytorch = dcnv3_core_pytorch(
input,
offset,
mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale).detach().cpu()
im2col_step = 2
output_cuda = DCNv3Function.apply(
input,
offset,
mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step).detach().cpu()
fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
max_abs_err = (output_cuda - output_pytorch).abs().max()
max_rel_err = ((output_cuda - output_pytorch).abs() /
output_pytorch.abs()).max()
print('>>> forward float')
print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_offset=True, grad_mask=True):
# H_in, W_in = 4, 4
N = 2
M = 2
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
D = channels
input0 = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01
offset0 = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10
mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask0 /= mask0.sum(-1, keepdim=True)
mask0 = mask0.reshape(N, H_out, W_out, M*P)
input0.requires_grad = grad_input
offset0.requires_grad = grad_offset
mask0.requires_grad = grad_mask
output_pytorch = dcnv3_core_pytorch(
input0.double(),
offset0.double(),
mask0.double(),
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale)
output_pytorch.sum().backward()
input1 = input0.detach()
offset1 = offset0.detach()
mask1 = mask0.detach()
input1.requires_grad = grad_input
offset1.requires_grad = grad_offset
mask1.requires_grad = grad_mask
im2col_step = 2
output_cuda = DCNv3Function.apply(
input1.double(),
offset1.double(),
mask1.double(),
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step)
output_cuda.sum().backward()
print(f'>>> backward double: channels {D}')
bwdok = torch.allclose(input0.grad, input1.grad, rtol=1e-2, atol=1e-3)
max_abs_err = (input0.grad - input1.grad).abs().max()
max_rel_err = ((input0.grad - input1.grad).abs() /
input0.grad.abs()).max()
print(
f'* {bwdok} input_grad check_backward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
bwdok = torch.allclose(offset0.grad, offset1.grad, rtol=1e-2, atol=1e-3)
max_abs_err = (offset0.grad - offset1.grad).abs().max()
max_rel_err = ((offset0.grad - offset1.grad).abs() /
offset0.grad.abs()).max()
print(
f'* {bwdok} offset_grad check_backward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
bwdok = torch.allclose(mask0.grad, mask1.grad, rtol=1e-2, atol=1e-3)
max_abs_err = (mask0.grad - mask1.grad).abs().max()
max_rel_err = ((mask0.grad - mask1.grad).abs() /
mask0.grad.abs()).max()
print(
f'* {bwdok} mask_grad check_backward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
def check_backward_equal_with_pytorch_float(channels=4, grad_input=True, grad_offset=True, grad_mask=True):
# H_in, W_in = 4, 4
N = 2
M = 2
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
D = channels
input0 = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01
offset0 = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10
mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask0 /= mask0.sum(-1, keepdim=True)
mask0 = mask0.reshape(N, H_out, W_out, M*P)
input0.requires_grad = grad_input
offset0.requires_grad = grad_offset
mask0.requires_grad = grad_mask
output_pytorch = dcnv3_core_pytorch(
input0,
offset0,
mask0,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale)
output_pytorch.sum().backward()
input1 = input0.detach()
offset1 = offset0.detach()
mask1 = mask0.detach()
input1.requires_grad = grad_input
offset1.requires_grad = grad_offset
mask1.requires_grad = grad_mask
im2col_step = 2
output_cuda = DCNv3Function.apply(
input1,
offset1,
mask1,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale,
im2col_step)
output_cuda.sum().backward()
print(f'>>> backward float: channels {D}')
bwdok = torch.allclose(input0.grad, input1.grad, rtol=1e-2, atol=1e-3)
max_abs_err = (input0.grad - input1.grad).abs().max()
max_rel_err = ((input0.grad - input1.grad).abs() /
input0.grad.abs()).max()
print(
f'* {bwdok} input_grad check_backward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
bwdok = torch.allclose(offset0.grad, offset1.grad, rtol=1e-2, atol=1e-3)
max_abs_err = (offset0.grad - offset1.grad).abs().max()
max_rel_err = ((offset0.grad - offset1.grad).abs() /
offset0.grad.abs()).max()
print(
f'* {bwdok} offset_grad check_backward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
bwdok = torch.allclose(mask0.grad, mask1.grad, rtol=1e-2, atol=1e-3)
max_abs_err = (mask0.grad - mask1.grad).abs().max()
max_rel_err = ((mask0.grad - mask1.grad).abs() /
mask0.grad.abs()).max()
print(
f'* {bwdok} mask_grad check_backward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
@torch.no_grad()
def check_time_cost(im2col_step=128):
N = 512
H_in, W_in = 64, 64
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01
offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10
mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask /= mask.sum(-1, keepdim=True)
mask = mask.reshape(N, H_out, W_out, M*P)
print(
f'>>> time cost: im2col_step {im2col_step}; input {input.shape}; points {P} ')
repeat = 100
for i in range(repeat):
output_cuda = DCNv3Function.apply(
input,
offset,
mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, 1.0,
im2col_step)
torch.cuda.synchronize()
start = time.time()
for i in range(repeat):
output_cuda = DCNv3Function.apply(
input,
offset,
mask,
Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, 1.0,
im2col_step)
torch.cuda.synchronize()
print(f'foward time cost: {(time.time() - start) / repeat}')
if __name__ == '__main__':
check_forward_equal_with_pytorch_double()
check_forward_equal_with_pytorch_float()
for channels in [1, 16, 30, 32, 64, 71, 1025]:
check_backward_equal_with_pytorch_double(channels, True, True, True)
for channels in [1, 16, 30, 32, 64, 71, 1025]:
check_backward_equal_with_pytorch_float(channels, True, True, True)
for i in range(3):
im2col_step = 128 * (2 ** i)
check_time_cost(im2col_step)
from .baseline import Baseline
from .road_bev import ROAD_BEVFormer
# ==============================================================================
# Binaries and/or source for the following packages or projects
# are presented under one or more of the following open source licenses:
# baseline.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0
#
# Contact wanghuijie@pjlab.org.cn if you have any issue.
#
# Copyright (c) 2023 The OpenLane-v2 Dataset Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch
from mmdet3d.models import DETECTORS, build_neck, build_head
from mmdet3d.models.detectors import MVXTwoStageDetector
@DETECTORS.register_module()
class Baseline(MVXTwoStageDetector):
def __init__(self,
img_backbone=None,
img_neck=None,
img_view_transformer=None,
lc_head=None,
te_head=None,
lclc_head=None,
lcte_head=None,
**kwargs):
super().__init__(img_backbone=img_backbone, img_neck=img_neck, **kwargs)
self.img_view_transformer = build_neck(img_view_transformer)
self.lc_head = build_head(lc_head)
self.te_head = build_head(te_head)
self.lclc_head = build_head(lclc_head)
self.lcte_head = build_head(lcte_head)
def simple_forward(self, img, img_metas):
# extract image features
B, N, C, imH, imW = img.shape
img = img.view(B * N, C, imH, imW)
x = self.img_backbone(img)
if self.with_img_neck:
x = self.img_neck(x)
if type(x) in [list, tuple]:
x = x[0]
# view transformation
bev_feat = self.img_view_transformer(
x,
torch.cat([
torch.cat([torch.tensor(l, device=x.device, dtype=torch.float32).unsqueeze(0) for l in img_metas[b]['lidar2img']], dim=0).unsqueeze(0)
for b in range(B)], dim=0),
(img_metas[0]['img_shape'][0][0], img_metas[0]['img_shape'][0][1]),
)
_, output_dim, ouput_H, output_W = x.shape
pv_feat = x.view(B, N, output_dim, ouput_H, output_W)[:, 0, ...]
# lc
lc_img_metas = [{
'batch_input_shape': (bev_feat.shape[-2], bev_feat.shape[-1]),
'img_shape': (bev_feat.shape[-2], bev_feat.shape[-1], None),
'scale_factor': None, # dummy
} for _ in range(B)]
all_lc_cls_scores_list, all_lc_preds_list, lc_outs_dec_list = self.lc_head(
[bev_feat],
lc_img_metas,
)
# te
te_img_metas = [{
'batch_input_shape': (img_metas[b]['pad_shape'][0][0], img_metas[b]['pad_shape'][0][1]),
'img_shape': (img_metas[b]['img_shape'][0][0], img_metas[b]['img_shape'][0][1], None),
'scale_factor': img_metas[b]['scale_factor'],
} for b in range(B)]
all_te_cls_scores_list, all_te_preds_list, te_outs_dec_list = self.te_head(
[pv_feat],
te_img_metas,
)
# topology_lclc
all_lclc_preds_list = self.lclc_head(
lc_outs_dec_list,
lc_outs_dec_list,
)
# topology_lcte
all_lcte_preds_list = self.lcte_head(
lc_outs_dec_list,
te_outs_dec_list,
)
return {
'all_lc_cls_scores_list': all_lc_cls_scores_list,
'all_lc_preds_list': all_lc_preds_list,
'lc_img_metas': lc_img_metas,
'all_te_cls_scores_list': all_te_cls_scores_list,
'all_te_preds_list': all_te_preds_list,
'te_img_metas': te_img_metas,
'all_lclc_preds_list': all_lclc_preds_list,
'all_lcte_preds_list': all_lcte_preds_list,
}
def forward_train(self,
img,
img_metas,
gt_lc=None,
gt_lc_labels=None,
gt_te=None,
gt_te_labels=None,
gt_topology_lclc=None,
gt_topology_lcte=None,
**kwargs):
outs = self.simple_forward(img, img_metas)
losses = dict()
# lc
lc_loss_dict, lc_assign_results = self.lc_head.loss(
outs['all_lc_cls_scores_list'],
outs['all_lc_preds_list'],
gt_lc,
gt_lc_labels,
outs['lc_img_metas'],
)
losses.update({
f'lc_{key}': val for key, val in lc_loss_dict.items()
})
# te
te_loss_dict, te_assign_results = self.te_head.loss(
outs['all_te_cls_scores_list'],
outs['all_te_preds_list'],
gt_te,
gt_te_labels,
outs['te_img_metas'],
)
losses.update({
f'te_{key}': val for key, val in te_loss_dict.items()
})
# topology_lclc
topology_lclc_loss_dict = self.lclc_head.loss(
outs['all_lclc_preds_list'],
lc_assign_results,
lc_assign_results,
gt_topology_lclc,
)
losses.update({
f'topology_lclc_{key}': val for key, val in topology_lclc_loss_dict.items()
})
# topology_lcte
topology_lcte_loss_dict = self.lcte_head.loss(
outs['all_lcte_preds_list'],
lc_assign_results,
te_assign_results,
gt_topology_lcte
)
losses.update({
f'topology_lcte_{key}': val for key, val in topology_lcte_loss_dict.items()
})
return losses
def forward_test(self, img, img_metas, **kwargs):
outs = self.simple_forward(img, img_metas)
pred_lc = self.lc_head.get_bboxes(
outs['all_lc_cls_scores_list'],
outs['all_lc_preds_list'],
outs['lc_img_metas'],
)
pred_te = self.te_head.get_bboxes(
outs['all_te_cls_scores_list'],
outs['all_te_preds_list'],
outs['te_img_metas'],
rescale=True,
)
pred_topology_lclc = self.lclc_head.get_topology(outs['all_lclc_preds_list'])
pred_topology_lcte = self.lcte_head.get_topology(outs['all_lcte_preds_list'])
assert len(pred_lc) == len(pred_te) == 1, \
'evaluation implemented for bs=1'
return [{
'pred_lc': pred_lc[0],
'pred_te': pred_te[0],
'pred_topology_lclc': pred_topology_lclc[0],
'pred_topology_lcte': pred_topology_lcte[0],
}]
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Tianyu Li
# ---------------------------------------------
import time
import copy
import numpy as np
import torch
from mmcv.runner import force_fp32, auto_fp16
from mmdet.core import bbox2result
from mmdet.models import DETECTORS
from mmdet.models.builder import build_head
from mmdet3d.models.builder import build_neck
from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector
@DETECTORS.register_module()
class ROAD_BEVFormer(MVXTwoStageDetector):
def __init__(self,
pts_voxel_layer=None,
pts_voxel_encoder=None,
pts_middle_encoder=None,
pts_fusion_layer=None,
img_backbone=None,
pts_backbone=None,
img_neck=None,
pts_neck=None,
pts_bbox_head=None,
img_roi_head=None,
img_rpn_head=None,
bev_constructor=None,
bbox_head=None,
bbox_train_cfg=None,
lclc_head=None,
lcte_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
video_test_mode=False
):
super(ROAD_BEVFormer,
self).__init__(pts_voxel_layer, pts_voxel_encoder,
pts_middle_encoder, pts_fusion_layer,
img_backbone, pts_backbone, img_neck, pts_neck,
pts_bbox_head, img_roi_head, img_rpn_head,
train_cfg, test_cfg, pretrained)
if bev_constructor is not None:
self.bev_constructor = build_neck(bev_constructor)
if bbox_head is not None:
bbox_head.update(train_cfg=bbox_train_cfg)
self.bbox_head = build_head(bbox_head)
else:
self.bbox_head = None
if lclc_head is not None:
self.lclc_head = build_head(lclc_head)
else:
self.lclc_head = None
if lcte_head is not None:
self.lcte_head = build_head(lcte_head)
else:
self.lcte_head = None
self.fp16_enabled = False
# temporal
self.video_test_mode = video_test_mode
self.prev_frame_info = {
'prev_bev': None,
'scene_token': None,
'prev_pos': 0,
'prev_angle': 0,
}
def extract_img_feat(self, img, img_metas, len_queue=None):
"""Extract features of images."""
B = img.size(0)
if img is not None:
# input_shape = img.shape[-2:]
# # update real input shape of each single img
# for img_meta in img_metas:
# img_meta.update(input_shape=input_shape)
if img.dim() == 5 and img.size(0) == 1:
img.squeeze_()
elif img.dim() == 5 and img.size(0) > 1:
B, N, C, H, W = img.size()
img = img.reshape(B * N, C, H, W)
img_feats = self.img_backbone(img)
if isinstance(img_feats, dict):
img_feats = list(img_feats.values())
else:
return None
if self.with_img_neck:
img_feats = self.img_neck(img_feats)
img_feats_reshaped = []
for img_feat in img_feats:
BN, C, H, W = img_feat.size()
if len_queue is not None:
img_feats_reshaped.append(img_feat.view(int(B/len_queue), len_queue, int(BN / B), C, H, W))
else:
img_feats_reshaped.append(img_feat.view(B, int(BN / B), C, H, W))
return img_feats_reshaped
@auto_fp16(apply_to=('img'))
def extract_feat(self, img, img_metas=None, len_queue=None):
"""Extract features from images and points."""
img_feats = self.extract_img_feat(img, img_metas, len_queue=len_queue)
return img_feats
def forward_dummy(self, img):
dummy_metas = None
return self.forward_test(img=img, img_metas=[[dummy_metas]])
def forward(self, return_loss=True, **kwargs):
"""Calls either forward_train or forward_test depending on whether
return_loss=True.
Note this setting will change the expected inputs. When
`return_loss=True`, img and img_metas are single-nested (i.e.
torch.Tensor and list[dict]), and when `resturn_loss=False`, img and
img_metas should be double nested (i.e. list[torch.Tensor],
list[list[dict]]), with the outer list indicating test time
augmentations.
"""
if return_loss:
return self.forward_train(**kwargs)
else:
return self.forward_test(**kwargs)
def obtain_history_bev(self, imgs_queue, img_metas_list):
"""Obtain history BEV features iteratively. To save GPU memory, gradients are not calculated.
"""
self.eval()
with torch.no_grad():
prev_bev = None
bs, len_queue, num_cams, C, H, W = imgs_queue.shape
imgs_queue = imgs_queue.reshape(bs*len_queue, num_cams, C, H, W)
img_feats_list = self.extract_feat(img=imgs_queue, len_queue=len_queue)
for i in range(len_queue):
img_metas = [each[i] for each in img_metas_list]
# img_feats = self.extract_feat(img=img, img_metas=img_metas)
img_feats = [each_scale[:, i] for each_scale in img_feats_list]
prev_bev = self.bev_constructor(img_feats, img_metas, prev_bev)
self.train()
return prev_bev
@auto_fp16(apply_to=('img', 'points'))
def forward_train(self,
img=None,
img_metas=None,
gt_te=None,
gt_te_labels=None,
gt_lc=None,
gt_lc_labels=None,
gt_topology_lclc=None,
gt_topology_lcte=None,
gt_bboxes_ignore=None,
):
prev_bev = None
img_feats = self.extract_feat(img=img, img_metas=img_metas)
bev_feats = self.bev_constructor(img_feats, img_metas, prev_bev)
losses = dict()
outs = self.pts_bbox_head(img_feats, bev_feats, img_metas)
loss_inputs = [outs, gt_lc, gt_lc_labels]
lane_losses, lane_assign_result = self.pts_bbox_head.loss(*loss_inputs, img_metas=img_metas)
for loss in lane_losses:
losses['lane_head.' + loss] = lane_losses[loss]
lane_feats = outs['history_states']
if self.lclc_head is not None:
lclc_losses = self.lclc_head.forward_train(lane_feats, lane_assign_result, lane_feats, lane_assign_result, gt_topology_lclc)
for loss in lclc_losses:
losses['lclc_head.' + loss] = lclc_losses[loss]
if self.bbox_head is not None:
front_view_img_feats = [lvl[:, 0] for lvl in img_feats]
batch_input_shape = tuple(img[0, 0].size()[-2:])
bbox_img_metas = []
for img_meta in img_metas:
bbox_img_metas.append(
dict(
batch_input_shape=batch_input_shape,
img_shape=img_meta['img_shape'][0],
scale_factor=img_meta['scale_factor'][0]))
img_meta['batch_input_shape'] = batch_input_shape
te_losses = {}
bbox_outs = self.bbox_head(front_view_img_feats, bbox_img_metas)
bbox_losses, te_assign_result = self.bbox_head.loss(bbox_outs, gt_te, gt_te_labels, bbox_img_metas, gt_bboxes_ignore)
for loss in bbox_losses:
te_losses['bbox_head.' + loss] = bbox_losses[loss]
if self.lcte_head is not None:
te_feats = bbox_outs['history_states']
lcte_losses = self.lcte_head.forward_train(lane_feats, lane_assign_result, te_feats, te_assign_result, gt_topology_lcte)
for loss in lcte_losses:
te_losses['lcte_head.' + loss] = lcte_losses[loss]
num_gt_bboxes = sum([len(gt) for gt in gt_te_labels])
if num_gt_bboxes == 0:
for loss in te_losses:
te_losses[loss] *= 0
losses.update(te_losses)
return losses
def forward_test(self, img_metas, img=None, **kwargs):
for var, name in [(img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError('{} must be a list, but got {}'.format(
name, type(var)))
img = [img] if img is None else img
new_prev_bev, results_list = self.simple_test(
img_metas, img, prev_bev=None, **kwargs)
return results_list
def simple_test_pts(self, x, img_metas, img=None, prev_bev=None, rescale=False):
"""Test function"""
batchsize = len(img_metas)
bev_feats = self.bev_constructor(x, img_metas, prev_bev)
outs = self.pts_bbox_head(x, bev_feats, img_metas)
lane_results = self.pts_bbox_head.get_lanes(
outs, img_metas, rescale=rescale)
if self.lclc_head is not None:
lane_feats = outs['history_states']
lclc_results = self.lclc_head.get_relationship(lane_feats, lane_feats)
lclc_results = [result.detach().cpu().numpy() for result in lclc_results]
else:
lclc_results = [None for _ in range(batchsize)]
if self.bbox_head is not None:
front_view_img_feats = [lvl[:, 0] for lvl in x]
batch_input_shape = tuple(img[0, 0].size()[-2:])
bbox_img_metas = []
for img_meta in img_metas:
bbox_img_metas.append(
dict(
batch_input_shape=batch_input_shape,
img_shape=img_meta['img_shape'][0],
scale_factor=img_meta['scale_factor'][0]))
img_meta['batch_input_shape'] = batch_input_shape
bbox_outs = self.bbox_head(front_view_img_feats, bbox_img_metas)
bbox_results = self.bbox_head.get_bboxes(bbox_outs, bbox_img_metas, rescale=rescale)
else:
bbox_results = [None for _ in range(batchsize)]
if self.bbox_head is not None and self.lcte_head is not None:
te_feats = bbox_outs['history_states']
lcte_results = self.lcte_head.get_relationship(lane_feats, te_feats)
lcte_results = [result.detach().cpu().numpy() for result in lcte_results]
else:
lcte_results = [None for _ in range(batchsize)]
return bev_feats, bbox_results, lane_results, lclc_results, lcte_results
def simple_test(self, img_metas, img=None, prev_bev=None, rescale=False):
"""Test function without augmentaiton."""
img_feats = self.extract_feat(img=img, img_metas=img_metas)
results_list = [dict() for i in range(len(img_metas))]
new_prev_bev, bbox_results, lane_results, lclc_results, lcte_results = self.simple_test_pts(
img_feats, img_metas, img, prev_bev, rescale=rescale)
for result_dict, bbox, lane, lclc, lcte in zip(results_list, bbox_results, lane_results, lclc_results, lcte_results):
result_dict['pred_te'] = bbox
result_dict['pred_lc'] = lane
result_dict['pred_topology_lclc'] = lclc
result_dict['pred_topology_lcte'] = lcte
return new_prev_bev, results_list
from .custom_detr_head import *
from .topology_head import *
from .lc_deformable_detr_head import LCDeformableDETRHead
from .te_deformable_detr_head import TEDeformableDETRHead
from .relationship_head import RelationshipHead
\ No newline at end of file
# ==============================================================================
# Binaries and/or source for the following packages or projects
# are presented under one or more of the following open source licenses:
# custom_detr_head.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0
#
# Contact wanghuijie@pjlab.org.cn if you have any issue.
#
# Copyright (c) 2023 The OpenLane-v2 Dataset Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch
import torch.nn.functional as F
from mmcv.cnn import Linear
from mmdet.core import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh, multi_apply, reduce_mean
from mmdet.models import HEADS, DETRHead
@HEADS.register_module()
class CustomDETRHead(DETRHead):
def __init__(self,
num_classes,
in_channels,
num_query,
object_type,
num_reg_dim=4,
num_layers=1,
feedforward_channels=512,
embed_dims=64,
num_heads=4,
dropout=0.1,
ffn_dropout=0.1,
**kwargs):
self.object_type = object_type
if self.object_type == 'lane':
self.num_reg_dim = num_reg_dim
assert self.num_reg_dim % 3 == 0
self.bev_range = kwargs['bev_range']
elif self.object_type == 'bbox':
self.num_reg_dim = 4
assert self.num_reg_dim == num_reg_dim == 4
else:
raise NotImplementedError
transformer=dict(
type='Transformer',
encoder=dict(
type='DetrTransformerEncoder',
num_layers=num_layers,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=embed_dims,
num_heads=num_heads,
dropout=dropout)
],
ffn_cfgs=dict(
type='FFN',
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=2,
ffn_drop=0.,
act_cfg=dict(type='ReLU', inplace=True),
),
feedforward_channels=feedforward_channels,
ffn_dropout=ffn_dropout,
operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
decoder=dict(
type='DetrTransformerDecoder',
return_intermediate=True,
num_layers=num_layers,
transformerlayers=dict(
type='DetrTransformerDecoderLayer',
attn_cfgs=dict(
type='MultiheadAttention',
embed_dims=embed_dims,
num_heads=num_heads,
dropout=dropout),
ffn_cfgs=dict(
type='FFN',
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=2,
ffn_drop=0.,
act_cfg=dict(type='ReLU', inplace=True),
),
feedforward_channels=feedforward_channels,
ffn_dropout=ffn_dropout,
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
'ffn', 'norm')),
))
positional_encoding=dict(
type='SinePositionalEncoding', num_feats=embed_dims//2, normalize=True)
super().__init__(
num_classes=num_classes,
in_channels=in_channels,
num_query=num_query,
transformer=transformer,
positional_encoding=positional_encoding,
**kwargs,
)
def _init_layers(self):
super()._init_layers()
self.fc_reg = Linear(self.embed_dims, self.num_reg_dim)
def forward_single(self, x, img_metas):
# construct binary masks which used for the transformer.
# NOTE following the official DETR repo, non-zero values representing
# ignored positions, while zero values means valid positions.
if self.object_type == 'lane':
masks = x.new_zeros((x.shape[0], x.shape[2], x.shape[3]))
else:
batch_size = x.size(0)
input_img_h, input_img_w = img_metas[0]['batch_input_shape']
masks = x.new_ones((batch_size, input_img_h, input_img_w))
for img_id in range(batch_size):
img_h, img_w, _ = img_metas[img_id]['img_shape']
masks[img_id, :img_h, :img_w] = 0
x = self.input_proj(x)
# interpolate masks to have the same spatial shape with x
masks = F.interpolate(
masks.unsqueeze(1), size=x.shape[-2:]).to(torch.bool).squeeze(1)
# position encoding
pos_embed = self.positional_encoding(masks) # [bs, embed_dim, h, w]
# outs_dec: [nb_dec, bs, num_query, embed_dim]
outs_dec, _ = self.transformer(x, masks, self.query_embedding.weight,
pos_embed)
all_cls_scores = self.fc_cls(outs_dec)
all_bbox_preds = self.fc_reg(self.activate(
self.reg_ffn(outs_dec))).sigmoid()
return all_cls_scores, all_bbox_preds, outs_dec
def loss(self,
all_cls_scores_list,
all_bbox_preds_list,
gt_bboxes_list,
gt_labels_list,
img_metas,
gt_bboxes_ignore=None):
# NOTE defaultly only the outputs from the last feature scale is used.
all_cls_scores = all_cls_scores_list[-1]
all_bbox_preds = all_bbox_preds_list[-1]
assert gt_bboxes_ignore is None, \
'Only supports for gt_bboxes_ignore setting to None.'
num_dec_layers = len(all_cls_scores)
all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
all_gt_bboxes_ignore_list = [
gt_bboxes_ignore for _ in range(num_dec_layers)
]
img_metas_list = [img_metas for _ in range(num_dec_layers)]
losses_cls, losses_bbox, losses_iou, assign_result = multi_apply(
self.loss_single, all_cls_scores, all_bbox_preds,
all_gt_bboxes_list, all_gt_labels_list, img_metas_list,
all_gt_bboxes_ignore_list)
loss_dict = dict()
# loss from the last decoder layer
loss_dict['loss_cls'] = losses_cls[-1]
loss_dict['loss_bbox'] = losses_bbox[-1]
if self.object_type != 'lane':
loss_dict['loss_iou'] = losses_iou[-1]
# loss from other decoder layers
num_dec_layer = 0
for loss_cls_i, loss_bbox_i, loss_iou_i in zip(losses_cls[:-1],
losses_bbox[:-1],
losses_iou[:-1]):
loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i
if self.object_type != 'lane':
loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i
num_dec_layer += 1
return loss_dict, assign_result
def loss_single(self,
cls_scores,
bbox_preds,
gt_bboxes_list,
gt_labels_list,
img_metas,
gt_bboxes_ignore_list=None):
num_imgs = cls_scores.size(0)
cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)]
cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list,
gt_bboxes_list, gt_labels_list,
img_metas, gt_bboxes_ignore_list)
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg, assign_result) = cls_reg_targets
labels = torch.cat(labels_list, 0)
label_weights = torch.cat(label_weights_list, 0)
bbox_targets = torch.cat(bbox_targets_list, 0)
bbox_weights = torch.cat(bbox_weights_list, 0)
# classification loss
cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
# construct weighted avg_factor to match with the official DETR repo
cls_avg_factor = num_total_pos * 1.0 + \
num_total_neg * self.bg_cls_weight
if self.sync_cls_avg_factor:
cls_avg_factor = reduce_mean(
cls_scores.new_tensor([cls_avg_factor]))
cls_avg_factor = max(cls_avg_factor, 1)
loss_cls = self.loss_cls(
cls_scores, labels, label_weights, avg_factor=cls_avg_factor)
# Compute the average number of gt boxes across all gpus, for
# normalization purposes
num_total_pos = loss_cls.new_tensor([num_total_pos])
num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
if self.object_type == 'lane':
bbox_preds = bbox_preds.reshape(-1, self.num_reg_dim)
loss_iou = None
else:
# construct factors used for rescale bboxes
factors = []
for img_meta, bbox_pred in zip(img_metas, bbox_preds):
img_h, img_w, _ = img_meta['img_shape']
factor = bbox_pred.new_tensor([img_w, img_h, img_w,
img_h]).unsqueeze(0).repeat(
bbox_pred.size(0), 1)
factors.append(factor)
factors = torch.cat(factors, 0)
# DETR regress the relative position of boxes (cxcywh) in the image,
# thus the learning target is normalized by the image size. So here
# we need to re-scale them for calculating IoU loss
bbox_preds = bbox_preds.reshape(-1, self.num_reg_dim)
bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors
bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors
# regression IoU loss, defaultly GIoU loss
loss_iou = self.loss_iou(
bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos)
# regression L1 loss
loss_bbox = self.loss_bbox(
bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos)
return loss_cls, loss_bbox, loss_iou, assign_result
def get_targets(self,
cls_scores_list,
bbox_preds_list,
gt_bboxes_list,
gt_labels_list,
img_metas,
gt_bboxes_ignore_list=None):
assert gt_bboxes_ignore_list is None, \
'Only supports for gt_bboxes_ignore setting to None.'
num_imgs = len(cls_scores_list)
gt_bboxes_ignore_list = [
gt_bboxes_ignore_list for _ in range(num_imgs)
]
(labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, pos_inds_list, neg_inds_list, pos_assigned_gt_inds_list) = multi_apply(
self._get_target_single, cls_scores_list, bbox_preds_list,
gt_bboxes_list, gt_labels_list, img_metas, gt_bboxes_ignore_list)
num_total_pos = sum((inds.numel() for inds in pos_inds_list))
num_total_neg = sum((inds.numel() for inds in neg_inds_list))
assign_result = dict(
pos_inds=pos_inds_list, neg_inds=neg_inds_list, pos_assigned_gt_inds=pos_assigned_gt_inds_list
)
return (labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, num_total_pos, num_total_neg, assign_result)
def _get_target_single(self,
cls_score,
bbox_pred,
gt_bboxes,
gt_labels,
img_meta,
gt_bboxes_ignore=None):
num_bboxes = bbox_pred.size(0)
# assigner and sampler
assign_result = self.assigner.assign(bbox_pred, cls_score, gt_bboxes,
gt_labels, img_meta,
gt_bboxes_ignore)
sampling_result = self.sampler.sample(assign_result, bbox_pred,
gt_bboxes)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
pos_assigned_gt_inds = sampling_result.pos_assigned_gt_inds
# label targets
labels = gt_bboxes.new_full((num_bboxes, ),
self.num_classes,
dtype=torch.long)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
label_weights = gt_bboxes.new_ones(num_bboxes)
# bbox targets
bbox_targets = torch.zeros_like(bbox_pred)
bbox_weights = torch.zeros_like(bbox_pred)
bbox_weights[pos_inds] = 1.0
if self.object_type == 'lane':
pos_gt_bboxes = sampling_result.pos_gt_bboxes
pos_gt_bboxes_normalized = torch.zeros_like(pos_gt_bboxes)
for p in range(self.num_reg_dim // 3):
pos_gt_bboxes_normalized[..., 3*p] = (pos_gt_bboxes[..., 3*p] - self.bev_range[0]) / (self.bev_range[3] - self.bev_range[0])
pos_gt_bboxes_normalized[..., 3*p+1] = (pos_gt_bboxes[..., 3*p+1] - self.bev_range[1]) / (self.bev_range[4] - self.bev_range[1])
pos_gt_bboxes_normalized[..., 3*p+2] = (pos_gt_bboxes[..., 3*p+2] - self.bev_range[2]) / (self.bev_range[5] - self.bev_range[2])
pos_gt_bboxes_targets = pos_gt_bboxes_normalized
else:
img_h, img_w, _ = img_meta['img_shape']
# DETR regress the relative position of boxes (cxcywh) in the image.
# Thus the learning target should be normalized by the image size, also
# the box format should be converted from defaultly x1y1x2y2 to cxcywh.
factor = bbox_pred.new_tensor([img_w, img_h, img_w,
img_h]).unsqueeze(0)
pos_gt_bboxes_normalized = sampling_result.pos_gt_bboxes / factor
pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized)
bbox_targets[pos_inds] = pos_gt_bboxes_targets
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
neg_inds, pos_assigned_gt_inds)
def _get_bboxes_single(self,
cls_score,
bbox_pred,
img_shape,
scale_factor,
rescale=False):
assert len(cls_score) == len(bbox_pred)
if self.object_type == 'lane':
cls_score = cls_score.sigmoid()
det_bboxes = bbox_pred
for p in range(self.num_reg_dim // 3):
det_bboxes[..., 3*p] = det_bboxes[..., 3*p] * (self.bev_range[3] - self.bev_range[0]) + self.bev_range[0]
det_bboxes[..., 3*p+1] = det_bboxes[..., 3*p+1] * (self.bev_range[4] - self.bev_range[1]) + self.bev_range[1]
det_bboxes[..., 3*p+2] = det_bboxes[..., 3*p+2] * (self.bev_range[5] - self.bev_range[2]) + self.bev_range[2]
det_bboxes[..., 3*p].clamp_(min=self.bev_range[0], max=self.bev_range[3])
det_bboxes[..., 3*p+1].clamp_(min=self.bev_range[1], max=self.bev_range[4])
det_bboxes[..., 3*p+2].clamp_(min=self.bev_range[2], max=self.bev_range[5])
else:
# exclude background
if self.loss_cls.use_sigmoid:
cls_score = cls_score.sigmoid()
else:
cls_score = F.softmax(cls_score, dim=-1)[..., :-1]
cls_score, det_labels = cls_score.max(-1)
det_bboxes = bbox_cxcywh_to_xyxy(bbox_pred)
det_bboxes[:, 0::2] = det_bboxes[:, 0::2] * img_shape[1]
det_bboxes[:, 1::2] = det_bboxes[:, 1::2] * img_shape[0]
det_bboxes[:, 0::2].clamp_(min=0, max=img_shape[1])
det_bboxes[:, 1::2].clamp_(min=0, max=img_shape[0])
if rescale:
det_bboxes /= det_bboxes.new_tensor(scale_factor)
det_bboxes = torch.cat((det_bboxes, det_labels.unsqueeze(1)), -1)
return det_bboxes.cpu().numpy(), cls_score.cpu().numpy()
def onnx_export(self, **kwargs):
raise NotImplementedError(f'TODO: replace 4 with self.num_reg_dim : {self.num_reg_dim}')
# ==============================================================================
# Binaries and/or source for the following packages or projects
# are presented under one or more of the following open source licenses:
# custom_detr_head.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0
#
# Contact litianyu@pjlab.org.cn if you have any issue.
#
# Copyright (c) 2023 The OpenLane-v2 Dataset Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import copy
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import mmcv
from mmcv.cnn import Linear, bias_init_with_prob, build_activation_layer
from mmcv.runner import auto_fp16, force_fp32
from mmcv.utils import TORCH_VERSION, digit_version
from mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean
from mmdet.models.builder import HEADS, build_loss
from mmdet.models.dense_heads import AnchorFreeHead
from mmdet.models.utils import build_transformer
from mmdet.models.utils.transformer import inverse_sigmoid
@HEADS.register_module()
class LCDeformableDETRHead(AnchorFreeHead):
def __init__(self,
num_classes,
in_channels,
num_query=100,
with_box_refine=False,
with_shared_param=None,
transformer=None,
num_reg_fcs=2,
code_weights=None,
pc_range=None,
bev_h=30,
bev_w=30,
sync_cls_avg_factor=False,
loss_cls=dict(
type='CrossEntropyLoss',
bg_cls_weight=0.1,
use_sigmoid=False,
loss_weight=1.0,
class_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
loss_iou=dict(type='GIoULoss', loss_weight=2.0),
train_cfg=dict(
assigner=dict(
type='HungarianAssigner',
cls_cost=dict(type='ClassificationCost', weight=1.),
reg_cost=dict(type='BBoxL1Cost', weight=5.0),
iou_cost=dict(
type='IoUCost', iou_mode='giou', weight=2.0))),
test_cfg=dict(max_per_img=100),
init_cfg=None,
**kwargs):
# NOTE here use `AnchorFreeHead` instead of `TransformerHead`,
# since it brings inconvenience when the initialization of
# `AnchorFreeHead` is called.
super(AnchorFreeHead, self).__init__(init_cfg)
self.bg_cls_weight = 0
self.sync_cls_avg_factor = sync_cls_avg_factor
if train_cfg:
assert 'assigner' in train_cfg, 'assigner should be provided '\
'when train_cfg is set.'
assigner = train_cfg['assigner']
assert loss_cls['loss_weight'] == assigner['cls_cost']['weight'], \
'The classification weight for loss and matcher should be' \
'exactly the same.'
assert loss_bbox['loss_weight'] == assigner['reg_cost'][
'weight'], 'The regression L1 weight for loss and matcher ' \
'should be exactly the same.'
assert loss_iou['loss_weight'] == assigner['iou_cost']['weight'], \
'The regression iou weight for loss and matcher should be' \
'exactly the same.'
self.assigner = build_assigner(assigner)
# DETR sampling=False, so use PseudoSampler
sampler_cfg = dict(type='PseudoSampler')
self.sampler = build_sampler(sampler_cfg, context=self)
self.num_query = num_query
self.num_classes = num_classes
self.in_channels = in_channels
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.fp16_enabled = False
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.loss_iou = build_loss(loss_iou)
if self.loss_cls.use_sigmoid:
self.cls_out_channels = num_classes
else:
self.cls_out_channels = num_classes + 1
self.act_cfg = transformer.get('act_cfg',
dict(type='ReLU', inplace=True))
self.activate = build_activation_layer(self.act_cfg)
self.transformer = build_transformer(transformer)
self.embed_dims = self.transformer.embed_dims
self.bev_h = bev_h
self.bev_w = bev_w
self.fp16_enabled = False
self.with_box_refine = with_box_refine
if with_shared_param is not None:
self.with_shared_param = with_shared_param
else:
self.with_shared_param = not self.with_box_refine
self.as_two_stage = False
if 'code_size' in kwargs:
self.code_size = kwargs['code_size']
else:
self.code_size = 6
if code_weights is not None:
self.code_weights = code_weights
else:
self.code_weights = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
self.code_weights = nn.Parameter(torch.tensor(
self.code_weights, requires_grad=False), requires_grad=False)
self.gt_c_save = self.code_size
self.pc_range = pc_range
self.real_w = self.pc_range[3] - self.pc_range[0]
self.real_h = self.pc_range[4] - self.pc_range[1]
self.num_reg_fcs = num_reg_fcs
self._init_layers()
def _init_layers(self):
"""Initialize classification branch and regression branch of head."""
cls_branch = []
for _ in range(self.num_reg_fcs):
cls_branch.append(Linear(self.embed_dims, self.embed_dims))
cls_branch.append(nn.LayerNorm(self.embed_dims))
cls_branch.append(nn.ReLU(inplace=True))
cls_branch.append(Linear(self.embed_dims, self.cls_out_channels))
fc_cls = nn.Sequential(*cls_branch)
reg_branch = []
for _ in range(self.num_reg_fcs):
reg_branch.append(Linear(self.embed_dims, self.embed_dims))
reg_branch.append(nn.ReLU())
reg_branch.append(Linear(self.embed_dims, self.code_size))
reg_branch = nn.Sequential(*reg_branch)
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
# last reg_branch is used to generate proposal from
# encode feature map when as_two_stage is True.
num_pred = (self.transformer.decoder.num_layers + 1) if \
self.as_two_stage else self.transformer.decoder.num_layers
if not self.with_shared_param:
self.cls_branches = _get_clones(fc_cls, num_pred)
self.reg_branches = _get_clones(reg_branch, num_pred)
else:
self.cls_branches = nn.ModuleList(
[fc_cls for _ in range(num_pred)])
self.reg_branches = nn.ModuleList(
[reg_branch for _ in range(num_pred)])
self.query_embedding = nn.Embedding(self.num_query, self.embed_dims * 2)
def init_weights(self):
"""Initialize weights of the DeformDETR head."""
self.transformer.init_weights()
if self.loss_cls.use_sigmoid:
bias_init = bias_init_with_prob(0.01)
for m in self.cls_branches:
nn.init.constant_(m[-1].bias, bias_init)
@auto_fp16(apply_to=('mlvl_feats'))
def forward(self, mlvl_feats, bev_feats, img_metas):
"""Forward function.
Args:
mlvl_feats (tuple[Tensor]): Features from the upstream
network, each is a 5D-tensor with shape
(B, N, C, H, W).
prev_bev: previous bev featues
only_bev: only compute BEV features with encoder.
Returns:
all_cls_scores (Tensor): Outputs from the classification head, \
shape [nb_dec, bs, num_query, cls_out_channels]. Note \
cls_out_channels should includes background.
all_lanes_preds (Tensor): Sigmoid outputs from the regression \
head with normalized coordinate format (cx, cy, w, l, cz, h, theta, vx, vy). \
Shape [nb_dec, bs, num_query, 9].
"""
dtype = mlvl_feats[0].dtype
object_query_embeds = self.query_embedding.weight.to(dtype)
outputs = self.transformer(
mlvl_feats,
bev_feats,
object_query_embeds,
bev_h=self.bev_h,
bev_w=self.bev_w,
reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501
cls_branches=None,
img_metas=img_metas
)
hs, init_reference, inter_references = outputs
hs = hs.permute(0, 2, 1, 3)
outputs_classes = []
outputs_coords = []
for lvl in range(hs.shape[0]):
if lvl == 0:
reference = init_reference
else:
reference = inter_references[lvl - 1]
reference = inverse_sigmoid(reference)
outputs_class = self.cls_branches[lvl](hs[lvl])
tmp = self.reg_branches[lvl](hs[lvl])
assert reference.shape[-1] == 3
for p in range(self.code_size // 3):
tmp[..., 3*p:3*p+3] = tmp[..., 3*p:3*p+3] + reference
tmp[..., 3*p:3*p+3] = tmp[..., 3*p:3*p+3].sigmoid()
tmp[..., 3*p] = tmp[..., 3*p] * (self.pc_range[3] - self.pc_range[0]) + self.pc_range[0]
tmp[..., 3*p+1] = tmp[..., 3*p+1] * (self.pc_range[4] - self.pc_range[1]) + self.pc_range[1]
tmp[..., 3*p+2] = tmp[..., 3*p+2] * (self.pc_range[5] - self.pc_range[2]) + self.pc_range[2]
outputs_coord = tmp
outputs_classes.append(outputs_class)
outputs_coords.append(outputs_coord)
outputs_classes = torch.stack(outputs_classes)
outputs_coords = torch.stack(outputs_coords)
outs = {
'all_cls_scores': outputs_classes,
'all_lanes_preds': outputs_coords,
'enc_cls_scores': None,
'enc_bbox_preds': None,
'history_states': hs
}
return outs
def _get_target_single(self,
cls_score,
lanes_pred,
gt_labels,
gt_lanes,
gt_bboxes_ignore=None):
""""Compute regression and classification targets for one image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_score (Tensor): Box score logits from a single decoder layer
for one image. Shape [num_query, cls_out_channels].
bbox_pred (Tensor): Sigmoid outputs from a single decoder layer
for one image, with normalized coordinate (cx, cy, w, h) and
shape [num_query, 4].
gt_bboxes (Tensor): Ground truth bboxes for one image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (Tensor): Ground truth class indices for one image
with shape (num_gts, ).
gt_bboxes_ignore (Tensor, optional): Bounding boxes
which can be ignored. Default None.
Returns:
tuple[Tensor]: a tuple containing the following for one image.
- labels (Tensor): Labels of each image.
- label_weights (Tensor]): Label weights of each image.
- bbox_targets (Tensor): BBox targets of each image.
- bbox_weights (Tensor): BBox weights of each image.
- pos_inds (Tensor): Sampled positive indices for each image.
- neg_inds (Tensor): Sampled negative indices for each image.
"""
num_bboxes = lanes_pred.size(0)
# assigner and sampler
assign_result = self.assigner.assign(lanes_pred, cls_score, gt_lanes,
gt_labels, gt_bboxes_ignore)
sampling_result = self.sampler.sample(assign_result, lanes_pred,
gt_lanes)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
pos_assigned_gt_inds = sampling_result.pos_assigned_gt_inds
labels = gt_lanes.new_full((num_bboxes,), self.num_classes, dtype=torch.long)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds].long()
label_weights = gt_lanes.new_ones(num_bboxes)
# bbox targets
gt_c = gt_lanes.shape[-1]
if gt_c == 0:
gt_c = self.gt_c_save
sampling_result.pos_gt_bboxes = torch.zeros((0, gt_c)).to(sampling_result.pos_gt_bboxes.device)
else:
self.gt_c_save = gt_c
bbox_targets = torch.zeros_like(lanes_pred)[..., :gt_c]
bbox_weights = torch.zeros_like(lanes_pred)
bbox_weights[pos_inds] = 1.0
# DETR
bbox_targets[pos_inds] = sampling_result.pos_gt_bboxes
return (labels, label_weights, bbox_targets, bbox_weights,
pos_inds, neg_inds, pos_assigned_gt_inds)
def get_targets(self,
cls_scores_list,
lanes_preds_list,
gt_lanes_list,
gt_labels_list,
gt_bboxes_ignore_list=None):
""""Compute regression and classification targets for a batch image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_scores_list (list[Tensor]): Box score logits from a single
decoder layer for each image with shape [num_query,
cls_out_channels].
bbox_preds_list (list[Tensor]): Sigmoid outputs from a single
decoder layer for each image, with normalized coordinate
(cx, cy, w, h) and shape [num_query, 4].
gt_lanes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
tuple: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels for all images.
- label_weights_list (list[Tensor]): Label weights for all \
images.
- bbox_targets_list (list[Tensor]): BBox targets for all \
images.
- bbox_weights_list (list[Tensor]): BBox weights for all \
images.
- num_total_pos (int): Number of positive samples in all \
images.
- num_total_neg (int): Number of negative samples in all \
images.
"""
assert gt_bboxes_ignore_list is None, \
'Only supports for gt_bboxes_ignore setting to None.'
num_imgs = len(cls_scores_list)
gt_bboxes_ignore_list = [
gt_bboxes_ignore_list for _ in range(num_imgs)
]
(labels_list, label_weights_list, lanes_targets_list,
lanes_weights_list, pos_inds_list, neg_inds_list, pos_assigned_gt_inds_list) = multi_apply(
self._get_target_single, cls_scores_list, lanes_preds_list,
gt_labels_list, gt_lanes_list, gt_bboxes_ignore_list)
num_total_pos = sum((inds.numel() for inds in pos_inds_list))
num_total_neg = sum((inds.numel() for inds in neg_inds_list))
assign_result = dict(
pos_inds=pos_inds_list, neg_inds=neg_inds_list, pos_assigned_gt_inds=pos_assigned_gt_inds_list
)
return (labels_list, label_weights_list, lanes_targets_list,
lanes_weights_list, num_total_pos, num_total_neg, assign_result)
def loss_single(self,
cls_scores,
lanes_preds,
gt_lanes_list,
gt_labels_list,
gt_bboxes_ignore_list=None):
""""Loss function for outputs from a single decoder layer of a single
feature level.
Args:
cls_scores (Tensor): Box score logits from a single decoder layer
for all images. Shape [bs, num_query, cls_out_channels].
lanes_preds (Tensor): Sigmoid outputs from a single decoder layer
for all images, with normalized coordinate (cx, cy, w, h) and
shape [bs, num_query, 4].
gt_lanes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components for outputs from
a single decoder layer.
"""
num_imgs = cls_scores.size(0)
cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
lanes_preds_list = [lanes_preds[i] for i in range(num_imgs)]
cls_reg_targets = self.get_targets(cls_scores_list, lanes_preds_list,
gt_lanes_list, gt_labels_list,
gt_bboxes_ignore_list)
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg, assign_result) = cls_reg_targets
labels = torch.cat(labels_list, 0)
label_weights = torch.cat(label_weights_list, 0)
bbox_targets = torch.cat(bbox_targets_list, 0)
bbox_weights = torch.cat(bbox_weights_list, 0)
# classification loss
cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
# construct weighted avg_factor to match with the official DETR repo
cls_avg_factor = num_total_pos * 1.0 + \
num_total_neg * self.bg_cls_weight
if self.sync_cls_avg_factor:
cls_avg_factor = reduce_mean(
cls_scores.new_tensor([cls_avg_factor]))
cls_avg_factor = max(cls_avg_factor, 1)
loss_cls = self.loss_cls(
cls_scores, labels, label_weights, avg_factor=cls_avg_factor)
# Compute the average number of gt boxes accross all gpus, for
# normalization purposes
num_total_pos = loss_cls.new_tensor([num_total_pos])
num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
# regression L1 loss
lanes_preds = lanes_preds.reshape(-1, lanes_preds.size(-1))
isnotnan = torch.isfinite(bbox_targets).all(dim=-1)
bbox_weights = bbox_weights * self.code_weights
loss_bbox = self.loss_bbox(
lanes_preds[isnotnan, :self.code_size],
bbox_targets[isnotnan, :self.code_size],
bbox_weights[isnotnan, :self.code_size],
avg_factor=num_total_pos)
if digit_version(TORCH_VERSION) >= digit_version('1.8'):
loss_cls = torch.nan_to_num(loss_cls)
loss_bbox = torch.nan_to_num(loss_bbox)
return loss_cls, loss_bbox, assign_result
@force_fp32(apply_to=('preds_dicts'))
def loss(self,
preds_dicts,
gt_lanes_3d,
gt_labels_list,
gt_bboxes_ignore=None,
img_metas=None):
""""Loss function.
Args:
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
preds_dicts:
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
all_lanes_preds (Tensor): Sigmoid regression
outputs of all decode layers. Each is a 4D-tensor with
normalized coordinate format (cx, cy, w, h) and shape
[nb_dec, bs, num_query, 4].
enc_cls_scores (Tensor): Classification scores of
points on encode feature map , has shape
(N, h*w, num_classes). Only be passed when as_two_stage is
True, otherwise is None.
enc_bbox_preds (Tensor): Regression results of each points
on the encode feature map, has shape (N, h*w, 4). Only be
passed when as_two_stage is True, otherwise is None.
gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert gt_bboxes_ignore is None, \
f'{self.__class__.__name__} only supports ' \
f'for gt_bboxes_ignore setting to None.'
all_cls_scores = preds_dicts['all_cls_scores']
all_lanes_preds = preds_dicts['all_lanes_preds']
enc_cls_scores = preds_dicts['enc_cls_scores']
enc_bbox_preds = preds_dicts['enc_bbox_preds']
num_dec_layers = len(all_cls_scores)
device = gt_labels_list[0].device
gt_lanes_list = [lane for lane in gt_lanes_3d]
all_gt_lanes_list = [gt_lanes_list for _ in range(num_dec_layers)]
all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
losses_cls, losses_bbox, assign_result = multi_apply(
self.loss_single, all_cls_scores, all_lanes_preds,
all_gt_lanes_list, all_gt_labels_list)
loss_dict = dict()
# loss of proposal generated from encode feature map.
if enc_cls_scores is not None:
binary_labels_list = [
torch.zeros_like(gt_labels_list[i])
for i in range(len(all_gt_labels_list))
]
enc_loss_cls, enc_losses_bbox = \
self.loss_single(enc_cls_scores, enc_bbox_preds,
gt_lanes_3d, binary_labels_list, gt_bboxes_ignore)
loss_dict['enc_loss_cls'] = enc_loss_cls
loss_dict['enc_loss_bbox'] = enc_losses_bbox
# loss from the last decoder layer
loss_dict['loss_lane_cls'] = losses_cls[-1]
loss_dict['loss_lane_reg'] = losses_bbox[-1]
# loss from other decoder layers
num_dec_layer = 0
for loss_cls_i, loss_bbox_i in zip(losses_cls[:-1],
losses_bbox[:-1]):
loss_dict[f'd{num_dec_layer}.loss_lane_cls'] = loss_cls_i
loss_dict[f'd{num_dec_layer}.loss_lane_reg'] = loss_bbox_i
num_dec_layer += 1
return loss_dict, assign_result
@force_fp32(apply_to=('preds_dicts'))
def get_lanes(self, preds_dicts, img_metas, rescale=False):
"""Generate bboxes from bbox head predictions.
Args:
preds_dicts (tuple[list[dict]]): Prediction results.
img_metas (list[dict]): Point cloud and image's meta info.
Returns:
list[dict]: Decoded bbox, scores and labels after nms.
"""
all_cls_scores = preds_dicts['all_cls_scores'][-1]
all_lanes_preds = preds_dicts['all_lanes_preds'][-1]
batch_size = all_cls_scores.size()[0]
predictions_list = []
for i in range(batch_size):
cls_scores = all_cls_scores[i].sigmoid()
predictions_list.append([
all_lanes_preds[i].detach().cpu().numpy(),
cls_scores.detach().cpu().numpy()])
return predictions_list
import copy
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import mmcv
from mmcv.cnn import Linear, bias_init_with_prob, build_activation_layer
from mmcv.cnn.bricks.transformer import build_feedforward_network
from mmcv.runner import auto_fp16, force_fp32
from mmcv.utils import TORCH_VERSION, digit_version
from mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean
from mmdet.models.builder import HEADS, build_loss
from mmdet.models.dense_heads import AnchorFreeHead
from mmdet.models.utils import build_transformer
from mmdet.models.utils.transformer import inverse_sigmoid
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
@HEADS.register_module()
class RelationshipHead(nn.Module):
def __init__(self,
in_channels_o1,
in_channels_o2=None,
shared_param=True,
loss_rel=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25)):
super().__init__()
self.MLP_o1 = MLP(in_channels_o1, in_channels_o1, 128, 3)
self.shared_param = shared_param
if shared_param:
self.MLP_o2 = self.MLP_o1
else:
self.MLP_o2 = MLP(in_channels_o2, in_channels_o2, 128, 3)
self.classifier = MLP(256, 256, 1, 3)
self.loss_rel = build_loss(loss_rel)
def forward_train(self, o1_feats, o1_assign_results, o2_feats, o2_assign_results, gt_adj):
rel_pred = self.forward(o1_feats, o2_feats)
losses = self.loss(rel_pred, gt_adj, o1_assign_results, o2_assign_results)
return losses
def get_relationship(self, o1_feats, o2_feats):
rel_pred = self.forward(o1_feats, o2_feats)
rel_results = rel_pred.squeeze(-1).sigmoid()
rel_results = [_ for _ in rel_results]
return rel_results
def forward(self, o1_feats, o2_feats):
# feats: D, B, num_query, num_embedding
o1_embeds = self.MLP_o1(o1_feats[-1])
o2_embeds = self.MLP_o2(o2_feats[-1])
num_query_o1 = o1_embeds.size(1)
num_query_o2 = o2_embeds.size(1)
o1_tensor = o1_embeds.unsqueeze(2).repeat(1, 1, num_query_o2, 1)
o2_tensor = o2_embeds.unsqueeze(1).repeat(1, num_query_o1, 1, 1)
relationship_tensor = torch.cat([o1_tensor, o2_tensor], dim=-1)
relationship_pred = self.classifier(relationship_tensor)
return relationship_pred
def loss(self, rel_preds, gt_adjs, o1_assign_results, o2_assign_results):
B, num_query_o1, num_query_o2, _ = rel_preds.size()
o1_assign = o1_assign_results[-1]
o1_pos_inds = o1_assign['pos_inds']
o1_pos_assigned_gt_inds = o1_assign['pos_assigned_gt_inds']
if self.shared_param:
o2_assign = o1_assign
o2_pos_inds = o1_pos_inds
o2_pos_assigned_gt_inds = o1_pos_assigned_gt_inds
else:
o2_assign = o2_assign_results[-1]
o2_pos_inds = o2_assign['pos_inds']
o2_pos_assigned_gt_inds = o2_assign['pos_assigned_gt_inds']
targets = []
for i in range(B):
gt_adj = gt_adjs[i]
target = torch.zeros_like(rel_preds[i].squeeze(-1), dtype=gt_adj.dtype, device=rel_preds.device)
xs = o1_pos_inds[i].unsqueeze(-1).repeat(1, o2_pos_inds[i].size(0))
ys = o2_pos_inds[i].unsqueeze(0).repeat(o1_pos_inds[i].size(0), 1)
target[xs, ys] = gt_adj[o1_pos_assigned_gt_inds[i]][:, o2_pos_assigned_gt_inds[i]]
targets.append(target)
targets = torch.stack(targets, dim=0)
targets = 1 - targets.view(-1).long()
rel_preds = rel_preds.view(-1, 1)
# weight = (1 - targets) * 3 + targets
loss_rel = self.loss_rel(rel_preds, targets)
if digit_version(TORCH_VERSION) >= digit_version('1.8'):
loss_rel = torch.nan_to_num(loss_rel)
return dict(loss_rel=loss_rel)
# ==============================================================================
# Binaries and/or source for the following packages or projects
# are presented under one or more of the following open source licenses:
# custom_detr_head.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0
#
# Contact litianyu@pjlab.org.cn if you have any issue.
#
# Copyright (c) 2023 The OpenLane-v2 Dataset Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import Linear, bias_init_with_prob, constant_init
from mmcv.runner import force_fp32
from mmdet.core import (bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh,
multi_apply, reduce_mean)
from mmdet.models.utils.transformer import inverse_sigmoid
from mmdet.models import HEADS, build_loss
from mmdet.models.dense_heads import DETRHead
@HEADS.register_module()
class TEDeformableDETRHead(DETRHead):
"""Head of DeformDETR: Deformable DETR: Deformable Transformers for End-to-
End Object Detection.
Code is modified from the `official github repo
<https://github.com/fundamentalvision/Deformable-DETR>`_.
More details can be found in the `paper
<https://arxiv.org/abs/2010.04159>`_ .
Args:
with_box_refine (bool): Whether to refine the reference points
in the decoder. Defaults to False.
as_two_stage (bool) : Whether to generate the proposal from
the outputs of encoder.
transformer (obj:`ConfigDict`): ConfigDict is used for building
the Encoder and Decoder.
"""
def __init__(self,
*args,
with_box_refine=False,
as_two_stage=False,
transformer=None,
**kwargs):
self.with_box_refine = with_box_refine
self.as_two_stage = as_two_stage
if self.as_two_stage:
transformer['as_two_stage'] = self.as_two_stage
super(TEDeformableDETRHead, self).__init__(
*args, transformer=transformer, **kwargs)
def _init_layers(self):
"""Initialize classification branch and regression branch of head."""
fc_cls = Linear(self.embed_dims, self.cls_out_channels)
reg_branch = []
for _ in range(self.num_reg_fcs):
reg_branch.append(Linear(self.embed_dims, self.embed_dims))
reg_branch.append(nn.ReLU())
reg_branch.append(Linear(self.embed_dims, 4))
reg_branch = nn.Sequential(*reg_branch)
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
# last reg_branch is used to generate proposal from
# encode feature map when as_two_stage is True.
num_pred = (self.transformer.decoder.num_layers + 1) if \
self.as_two_stage else self.transformer.decoder.num_layers
if self.with_box_refine:
self.cls_branches = _get_clones(fc_cls, num_pred)
self.reg_branches = _get_clones(reg_branch, num_pred)
else:
self.cls_branches = nn.ModuleList(
[fc_cls for _ in range(num_pred)])
self.reg_branches = nn.ModuleList(
[reg_branch for _ in range(num_pred)])
if not self.as_two_stage:
self.query_embedding = nn.Embedding(self.num_query,
self.embed_dims * 2)
def init_weights(self):
"""Initialize weights of the DeformDETR head."""
self.transformer.init_weights()
if self.loss_cls.use_sigmoid:
bias_init = bias_init_with_prob(0.01)
for m in self.cls_branches:
nn.init.constant_(m.bias, bias_init)
for m in self.reg_branches:
constant_init(m[-1], 0, bias=0)
nn.init.constant_(self.reg_branches[0][-1].bias.data[2:], -2.0)
if self.as_two_stage:
for m in self.reg_branches:
nn.init.constant_(m[-1].bias.data[2:], 0.0)
def forward(self, mlvl_feats, img_metas):
"""Forward function.
Args:
mlvl_feats (tuple[Tensor]): Features from the upstream
network, each is a 4D-tensor with shape
(N, C, H, W).
img_metas (list[dict]): List of image information.
Returns:
all_cls_scores (Tensor): Outputs from the classification head, \
shape [nb_dec, bs, num_query, cls_out_channels]. Note \
cls_out_channels should includes background.
all_bbox_preds (Tensor): Sigmoid outputs from the regression \
head with normalized coordinate format (cx, cy, w, h). \
Shape [nb_dec, bs, num_query, 4].
enc_outputs_class (Tensor): The score of each point on encode \
feature map, has shape (N, h*w, num_class). Only when \
as_two_stage is True it would be returned, otherwise \
`None` would be returned.
enc_outputs_coord (Tensor): The proposal generate from the \
encode feature map, has shape (N, h*w, 4). Only when \
as_two_stage is True it would be returned, otherwise \
`None` would be returned.
"""
batch_size = mlvl_feats[0].size(0)
input_img_h, input_img_w = img_metas[0]['batch_input_shape']
img_masks = mlvl_feats[0].new_ones(
(batch_size, input_img_h, input_img_w))
for img_id in range(batch_size):
img_h, img_w, _ = img_metas[img_id]['img_shape']
img_masks[img_id, :img_h, :img_w] = 0
mlvl_masks = []
mlvl_positional_encodings = []
for feat in mlvl_feats:
mlvl_masks.append(
F.interpolate(img_masks[None],
size=feat.shape[-2:]).to(torch.bool).squeeze(0))
mlvl_positional_encodings.append(
self.positional_encoding(mlvl_masks[-1]))
query_embeds = None
if not self.as_two_stage:
query_embeds = self.query_embedding.weight
hs, init_reference, inter_references, \
enc_outputs_class, enc_outputs_coord = self.transformer(
mlvl_feats,
mlvl_masks,
query_embeds,
mlvl_positional_encodings,
reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501
cls_branches=self.cls_branches if self.as_two_stage else None # noqa:E501
)
hs = hs.permute(0, 2, 1, 3)
outputs_classes = []
outputs_coords = []
for lvl in range(hs.shape[0]):
if lvl == 0:
reference = init_reference
else:
reference = inter_references[lvl - 1]
reference = inverse_sigmoid(reference)
outputs_class = self.cls_branches[lvl](hs[lvl])
tmp = self.reg_branches[lvl](hs[lvl])
if reference.shape[-1] == 4:
tmp += reference
else:
assert reference.shape[-1] == 2
tmp[..., :2] += reference
outputs_coord = tmp.sigmoid()
outputs_classes.append(outputs_class)
outputs_coords.append(outputs_coord)
outputs_classes = torch.stack(outputs_classes)
outputs_coords = torch.stack(outputs_coords)
outs = {
'all_cls_scores': outputs_classes,
'all_bbox_preds': outputs_coords,
'enc_cls_scores': enc_outputs_class if self.as_two_stage else None,
'enc_bbox_preds': enc_outputs_coord.sigmoid() if self.as_two_stage else None,
'history_states': hs
}
return outs
@force_fp32(apply_to=('preds_dicts'))
def loss(self,
preds_dicts,
gt_bboxes_list,
gt_labels_list,
img_metas,
gt_bboxes_ignore=None):
""""Loss function.
Args:
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
all_bbox_preds (Tensor): Sigmoid regression
outputs of all decode layers. Each is a 4D-tensor with
normalized coordinate format (cx, cy, w, h) and shape
[nb_dec, bs, num_query, 4].
enc_cls_scores (Tensor): Classification scores of
points on encode feature map , has shape
(N, h*w, num_classes). Only be passed when as_two_stage is
True, otherwise is None.
enc_bbox_preds (Tensor): Regression results of each points
on the encode feature map, has shape (N, h*w, 4). Only be
passed when as_two_stage is True, otherwise is None.
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
img_metas (list[dict]): List of image meta information.
gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert gt_bboxes_ignore is None, \
f'{self.__class__.__name__} only supports ' \
f'for gt_bboxes_ignore setting to None.'
all_cls_scores = preds_dicts['all_cls_scores']
all_bbox_preds = preds_dicts['all_bbox_preds']
enc_cls_scores = preds_dicts['enc_cls_scores']
enc_bbox_preds = preds_dicts['enc_bbox_preds']
num_dec_layers = len(all_cls_scores)
all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
all_gt_bboxes_ignore_list = [
gt_bboxes_ignore for _ in range(num_dec_layers)
]
img_metas_list = [img_metas for _ in range(num_dec_layers)]
losses_cls, losses_bbox, losses_iou, assign_result = multi_apply(
self.loss_single, all_cls_scores, all_bbox_preds,
all_gt_bboxes_list, all_gt_labels_list, img_metas_list,
all_gt_bboxes_ignore_list)
loss_dict = dict()
# loss of proposal generated from encode feature map.
if enc_cls_scores is not None:
binary_labels_list = [
torch.zeros_like(gt_labels_list[i])
for i in range(len(img_metas))
]
enc_loss_cls, enc_losses_bbox, enc_losses_iou = \
self.loss_single(enc_cls_scores, enc_bbox_preds,
gt_bboxes_list, binary_labels_list,
img_metas, gt_bboxes_ignore)
loss_dict['enc_loss_cls'] = enc_loss_cls
loss_dict['enc_loss_bbox'] = enc_losses_bbox
loss_dict['enc_loss_iou'] = enc_losses_iou
# loss from the last decoder layer
loss_dict['loss_cls'] = losses_cls[-1]
loss_dict['loss_bbox'] = losses_bbox[-1]
loss_dict['loss_iou'] = losses_iou[-1]
# loss from other decoder layers
num_dec_layer = 0
for loss_cls_i, loss_bbox_i, loss_iou_i in zip(losses_cls[:-1],
losses_bbox[:-1],
losses_iou[:-1]):
loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i
loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i
num_dec_layer += 1
return loss_dict, assign_result
def loss_single(self,
cls_scores,
bbox_preds,
gt_bboxes_list,
gt_labels_list,
img_metas,
gt_bboxes_ignore_list=None):
""""Loss function for outputs from a single decoder layer of a single
feature level.
Args:
cls_scores (Tensor): Box score logits from a single decoder layer
for all images. Shape [bs, num_query, cls_out_channels].
bbox_preds (Tensor): Sigmoid outputs from a single decoder layer
for all images, with normalized coordinate (cx, cy, w, h) and
shape [bs, num_query, 4].
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
img_metas (list[dict]): List of image meta information.
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components for outputs from
a single decoder layer.
"""
num_imgs = cls_scores.size(0)
cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)]
cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list,
gt_bboxes_list, gt_labels_list,
img_metas, gt_bboxes_ignore_list)
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg, assign_result) = cls_reg_targets
labels = torch.cat(labels_list, 0)
label_weights = torch.cat(label_weights_list, 0)
bbox_targets = torch.cat(bbox_targets_list, 0)
bbox_weights = torch.cat(bbox_weights_list, 0)
# classification loss
cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
# construct weighted avg_factor to match with the official DETR repo
cls_avg_factor = num_total_pos * 1.0 + \
num_total_neg * self.bg_cls_weight
if self.sync_cls_avg_factor:
cls_avg_factor = reduce_mean(
cls_scores.new_tensor([cls_avg_factor]))
cls_avg_factor = max(cls_avg_factor, 1)
loss_cls = self.loss_cls(
cls_scores, labels, label_weights, avg_factor=cls_avg_factor)
# Compute the average number of gt boxes across all gpus, for
# normalization purposes
num_total_pos = loss_cls.new_tensor([num_total_pos])
num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
# construct factors used for rescale bboxes
factors = []
for img_meta, bbox_pred in zip(img_metas, bbox_preds):
img_h, img_w, _ = img_meta['img_shape']
factor = bbox_pred.new_tensor([img_w, img_h, img_w,
img_h]).unsqueeze(0).repeat(
bbox_pred.size(0), 1)
factors.append(factor)
factors = torch.cat(factors, 0)
# DETR regress the relative position of boxes (cxcywh) in the image,
# thus the learning target is normalized by the image size. So here
# we need to re-scale them for calculating IoU loss
bbox_preds = bbox_preds.reshape(-1, 4)
bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors
bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors
# regression IoU loss, defaultly GIoU loss
loss_iou = self.loss_iou(
bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos)
# regression L1 loss
loss_bbox = self.loss_bbox(
bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos)
return loss_cls, loss_bbox, loss_iou, assign_result
def get_targets(self,
cls_scores_list,
bbox_preds_list,
gt_bboxes_list,
gt_labels_list,
img_metas,
gt_bboxes_ignore_list=None):
""""Compute regression and classification targets for a batch image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_scores_list (list[Tensor]): Box score logits from a single
decoder layer for each image with shape [num_query,
cls_out_channels].
bbox_preds_list (list[Tensor]): Sigmoid outputs from a single
decoder layer for each image, with normalized coordinate
(cx, cy, w, h) and shape [num_query, 4].
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
img_metas (list[dict]): List of image meta information.
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
tuple: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels for all images.
- label_weights_list (list[Tensor]): Label weights for all \
images.
- bbox_targets_list (list[Tensor]): BBox targets for all \
images.
- bbox_weights_list (list[Tensor]): BBox weights for all \
images.
- num_total_pos (int): Number of positive samples in all \
images.
- num_total_neg (int): Number of negative samples in all \
images.
"""
assert gt_bboxes_ignore_list is None, \
'Only supports for gt_bboxes_ignore setting to None.'
num_imgs = len(cls_scores_list)
gt_bboxes_ignore_list = [
gt_bboxes_ignore_list for _ in range(num_imgs)
]
(labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, pos_inds_list, neg_inds_list, pos_assigned_gt_inds_list) = multi_apply(
self._get_target_single, cls_scores_list, bbox_preds_list,
gt_bboxes_list, gt_labels_list, img_metas, gt_bboxes_ignore_list)
num_total_pos = sum((inds.numel() for inds in pos_inds_list))
num_total_neg = sum((inds.numel() for inds in neg_inds_list))
assign_result = dict(
pos_inds=pos_inds_list, neg_inds=neg_inds_list, pos_assigned_gt_inds=pos_assigned_gt_inds_list
)
return (labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, num_total_pos, num_total_neg, assign_result)
def _get_target_single(self,
cls_score,
bbox_pred,
gt_bboxes,
gt_labels,
img_meta,
gt_bboxes_ignore=None):
""""Compute regression and classification targets for one image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_score (Tensor): Box score logits from a single decoder layer
for one image. Shape [num_query, cls_out_channels].
bbox_pred (Tensor): Sigmoid outputs from a single decoder layer
for one image, with normalized coordinate (cx, cy, w, h) and
shape [num_query, 4].
gt_bboxes (Tensor): Ground truth bboxes for one image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (Tensor): Ground truth class indices for one image
with shape (num_gts, ).
img_meta (dict): Meta information for one image.
gt_bboxes_ignore (Tensor, optional): Bounding boxes
which can be ignored. Default None.
Returns:
tuple[Tensor]: a tuple containing the following for one image.
- labels (Tensor): Labels of each image.
- label_weights (Tensor]): Label weights of each image.
- bbox_targets (Tensor): BBox targets of each image.
- bbox_weights (Tensor): BBox weights of each image.
- pos_inds (Tensor): Sampled positive indices for each image.
- neg_inds (Tensor): Sampled negative indices for each image.
"""
num_bboxes = bbox_pred.size(0)
# assigner and sampler
assign_result = self.assigner.assign(bbox_pred, cls_score, gt_bboxes,
gt_labels, img_meta,
gt_bboxes_ignore)
sampling_result = self.sampler.sample(assign_result, bbox_pred,
gt_bboxes)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
pos_assigned_gt_inds = sampling_result.pos_assigned_gt_inds
# label targets
labels = gt_bboxes.new_full((num_bboxes, ),
self.num_classes,
dtype=torch.long)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
label_weights = gt_bboxes.new_ones(num_bboxes)
# bbox targets
bbox_targets = torch.zeros_like(bbox_pred)
bbox_weights = torch.zeros_like(bbox_pred)
bbox_weights[pos_inds] = 1.0
img_h, img_w, _ = img_meta['img_shape']
# DETR regress the relative position of boxes (cxcywh) in the image.
# Thus the learning target should be normalized by the image size, also
# the box format should be converted from defaultly x1y1x2y2 to cxcywh.
factor = bbox_pred.new_tensor([img_w, img_h, img_w,
img_h]).unsqueeze(0)
pos_gt_bboxes_normalized = sampling_result.pos_gt_bboxes / factor
pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized)
bbox_targets[pos_inds] = pos_gt_bboxes_targets
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
neg_inds, pos_assigned_gt_inds)
def simple_test_bboxes(self, feats, img_metas, rescale=False):
"""Test det bboxes without test-time augmentation.
Args:
feats (tuple[torch.Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
The first item is ``bboxes`` with shape (n, 5),
where 5 represent (tl_x, tl_y, br_x, br_y, score).
The shape of the second tensor in the tuple is ``labels``
with shape (n,)
"""
# forward of this head requires img_metas
outs = self.forward(feats, img_metas)
results_list = self.get_bboxes(outs, img_metas, rescale=rescale)
return results_list
@force_fp32(apply_to=('preds_dicts'))
def get_bboxes(self,
preds_dicts,
img_metas,
rescale=False):
"""Transform network outputs for a batch into bbox predictions.
Args:
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
all_bbox_preds (Tensor): Sigmoid regression
outputs of all decode layers. Each is a 4D-tensor with
normalized coordinate format (cx, cy, w, h) and shape
[nb_dec, bs, num_query, 4].
enc_cls_scores (Tensor): Classification scores of
points on encode feature map , has shape
(N, h*w, num_classes). Only be passed when as_two_stage is
True, otherwise is None.
enc_bbox_preds (Tensor): Regression results of each points
on the encode feature map, has shape (N, h*w, 4). Only be
passed when as_two_stage is True, otherwise is None.
img_metas (list[dict]): Meta information of each image.
rescale (bool, optional): If True, return boxes in original
image space. Default False.
Returns:
list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple. \
The first item is an (n, 5) tensor, where the first 4 columns \
are bounding box positions (tl_x, tl_y, br_x, br_y) and the \
5-th column is a score between 0 and 1. The second item is a \
(n,) tensor where each item is the predicted class label of \
the corresponding box.
"""
all_cls_scores = preds_dicts['all_cls_scores']
all_bbox_preds = preds_dicts['all_bbox_preds']
enc_cls_scores = preds_dicts['enc_cls_scores']
enc_bbox_preds = preds_dicts['enc_bbox_preds']
cls_scores = all_cls_scores[-1]
bbox_preds = all_bbox_preds[-1]
result_list = []
for img_id in range(len(img_metas)):
cls_score = cls_scores[img_id]
bbox_pred = bbox_preds[img_id]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
proposals = self._get_bboxes_single(cls_score, bbox_pred,
img_shape, scale_factor,
rescale)
result_list.append(proposals)
return result_list
def _get_bboxes_single(self,
cls_score,
bbox_pred,
img_shape,
scale_factor,
rescale=False):
"""Transform outputs from the last decoder layer into bbox predictions
for each image.
Args:
cls_score (Tensor): Box score logits from the last decoder layer
for each image. Shape [num_query, cls_out_channels].
bbox_pred (Tensor): Sigmoid outputs from the last decoder layer
for each image, with coordinate format (cx, cy, w, h) and
shape [num_query, 4].
img_shape (tuple[int]): Shape of input image, (height, width, 3).
scale_factor (ndarray, optional): Scale factor of the image arange
as (w_scale, h_scale, w_scale, h_scale).
rescale (bool, optional): If True, return boxes in original image
space. Default False.
Returns:
tuple[Tensor]: Results of detected bboxes and labels.
- det_bboxes: Predicted bboxes with shape [num_query, 5], \
where the first 4 columns are bounding box positions \
(tl_x, tl_y, br_x, br_y) and the 5-th column are scores \
between 0 and 1.
- det_labels: Predicted labels of the corresponding box with \
shape [num_query].
"""
assert len(cls_score) == len(bbox_pred)
# exclude background
if self.loss_cls.use_sigmoid:
cls_score = cls_score.sigmoid()
else:
cls_score = F.softmax(cls_score, dim=-1)[..., :-1]
scores, det_labels = cls_score.max(-1)
det_bboxes = bbox_cxcywh_to_xyxy(bbox_pred)
det_bboxes[:, 0::2] = det_bboxes[:, 0::2] * img_shape[1]
det_bboxes[:, 1::2] = det_bboxes[:, 1::2] * img_shape[0]
det_bboxes[:, 0::2].clamp_(min=0, max=img_shape[1])
det_bboxes[:, 1::2].clamp_(min=0, max=img_shape[0])
if rescale:
det_bboxes /= det_bboxes.new_tensor(scale_factor)
det_bboxes = torch.cat((det_bboxes, det_labels.unsqueeze(1)), -1)
pred = [
det_bboxes.detach().cpu().numpy(),
scores.detach().cpu().numpy()
]
return pred
# ==============================================================================
# Binaries and/or source for the following packages or projects
# are presented under one or more of the following open source licenses:
# topology_head.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0
#
# Contact wanghuijie@pjlab.org.cn if you have any issue.
#
# Copyright (c) 2023 The OpenLane-v2 Dataset Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import BaseModule
from mmdet.models import HEADS, build_loss
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
@HEADS.register_module()
class TopologyHead(BaseModule):
def __init__(self,
in_channels,
hidden_channels,
out_channels,
num_layers,
loss_cls):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.mlp = MLP(self.in_channels, hidden_channels, out_channels, num_layers)
self.loss_cls = build_loss(loss_cls)
def forward(self, all_a_preds_list, all_b_preds_list):
# NOTE defaultly only the outputs from the last feature scale is used.
all_a_preds_list = all_a_preds_list[-1]
all_b_preds_list = all_b_preds_list[-1]
num_out = all_a_preds_list.shape[0]
assert num_out == all_a_preds_list.shape[0] == all_b_preds_list.shape[0]
num_row, num_column = all_a_preds_list.shape[-2], all_b_preds_list.shape[-2]
assert self.in_channels == all_a_preds_list.shape[-1] + all_b_preds_list.shape[-1], \
f'self.in_channels = {self.in_channels} != all_a_preds_list.shape[-1] {all_a_preds_list.shape[-1]} + all_b_preds_list.shape[-1] {all_b_preds_list.shape[-1]}'
outs = []
for o in range(num_out):
adj = torch.cat([
all_a_preds_list[o].unsqueeze(2).repeat(1, 1, num_column, 1),
all_b_preds_list[o].unsqueeze(1).repeat(1, num_row, 1, 1),
], dim=-1)
outs.append(self.mlp(adj).sigmoid())
return outs
def loss(self, pred_adj_list, row_assign_results, column_assign_results, gt_adj):
# NOTE defaultly only the outputs from the last decoder layer is used.
pred_adj = pred_adj_list[-1]
row_assign_result = row_assign_results[-1]
column_assign_result = column_assign_results[-1]
targets = []
for b in range(pred_adj.shape[0]):
target = pred_adj.new_zeros(pred_adj[b].shape[:-1])
rs = row_assign_result['pos_inds'][b].unsqueeze(-1).repeat(1, column_assign_result['pos_inds'][b].shape[0])
cs = column_assign_result['pos_inds'][b].unsqueeze(0).repeat(row_assign_result['pos_inds'][b].shape[0], 1)
target[rs, cs] = gt_adj[b][row_assign_result['pos_assigned_gt_inds'][b]][:, column_assign_result['pos_assigned_gt_inds'][b]].float()
targets.append(target)
targets = 1 - torch.stack(targets, dim=0) # 0 as positive
loss_dict = dict()
pred_adj = pred_adj.reshape(-1, self.out_channels)
targets = targets.long().reshape(-1)
loss_dict['loss_cls'] = self.loss_cls(pred_adj, targets)
return loss_dict
def get_topology(self, pred_adj_list):
# NOTE defaultly only the outputs from the last decoder layer is used.
pred_adj = pred_adj_list[-1].squeeze(-1)
return pred_adj.cpu().numpy()
from .spatial_cross_attention import SpatialCrossAttention, MSDeformableAttention3D
from .temporal_self_attention import TemporalSelfAttention
from .encoder import BEVFormerEncoder, BEVFormerLayer
from .decoder import LaneDetectionTransformerDecoder
from .bevformer_constructer import BEVFormerConstructer
from .transformer import PerceptionTransformer
import numpy as np
import torch
import torch.nn as nn
from torch.nn.init import normal_
from torchvision.transforms.functional import rotate
from mmcv.cnn import xavier_init
from mmcv.cnn.bricks.transformer import build_transformer_layer_sequence, build_positional_encoding
from mmcv.runner.base_module import BaseModule
from mmcv.runner import force_fp32, auto_fp16
from mmdet.models.utils.builder import TRANSFORMER
from mmdet3d.models import NECKS
from .temporal_self_attention import TemporalSelfAttention
from .spatial_cross_attention import MSDeformableAttention3D
from .decoder import CustomMSDeformableAttention
@NECKS.register_module()
class BEVFormerConstructer(BaseModule):
"""Implements the BEVFormer BEV Constructer.
Args:
as_two_stage (bool): Generate query from encoder features.
Default: False.
num_feature_levels (int): Number of feature maps from FPN:
Default: 4.
two_stage_num_proposals (int): Number of proposals when set
`as_two_stage` as True. Default: 300.
"""
def __init__(self,
num_feature_levels=4,
num_cams=6,
embed_dims=256,
rotate_prev_bev=True,
use_shift=True,
use_can_bus=True,
can_bus_norm=True,
use_cams_embeds=True,
pc_range=[-51.2, -51.2, -5.0, 51.2, 51.2, 3.0],
bev_h=200,
bev_w=200,
rotate_center=[100, 100],
encoder=None,
positional_encoding=None,
**kwargs):
super(BEVFormerConstructer, self).__init__(**kwargs)
self.embed_dims = embed_dims
self.num_feature_levels = num_feature_levels
self.num_cams = num_cams
self.fp16_enabled = False
self.rotate_prev_bev = rotate_prev_bev
self.use_shift = use_shift
self.use_can_bus = use_can_bus
self.can_bus_norm = can_bus_norm
self.use_cams_embeds = use_cams_embeds
self.encoder = build_transformer_layer_sequence(encoder)
self.positional_encoding = build_positional_encoding(positional_encoding)
self.pc_range = pc_range
self.real_w = self.pc_range[3] - self.pc_range[0]
self.real_h = self.pc_range[4] - self.pc_range[1]
self.bev_h = bev_h
self.bev_w = bev_w
self.rotate_center = rotate_center
self.init_layers()
def init_layers(self):
self.bev_embedding = nn.Embedding(
self.bev_h * self.bev_w, self.embed_dims)
self.level_embeds = nn.Parameter(torch.Tensor(
self.num_feature_levels, self.embed_dims))
self.cams_embeds = nn.Parameter(
torch.Tensor(self.num_cams, self.embed_dims))
self.can_bus_mlp = nn.Sequential(
nn.Linear(18, self.embed_dims // 2),
nn.ReLU(inplace=True),
nn.Linear(self.embed_dims // 2, self.embed_dims),
nn.ReLU(inplace=True),
)
if self.can_bus_norm:
self.can_bus_mlp.add_module('norm', nn.LayerNorm(self.embed_dims))
def init_weights(self):
"""Initialize the transformer weights."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MSDeformableAttention3D) or isinstance(m, TemporalSelfAttention) \
or isinstance(m, CustomMSDeformableAttention):
try:
m.init_weight()
except AttributeError:
m.init_weights()
normal_(self.level_embeds)
normal_(self.cams_embeds)
xavier_init(self.can_bus_mlp, distribution='uniform', bias=0.)
# @auto_fp16(apply_to=('mlvl_feats', 'prev_bev'))
def forward(self, mlvl_feats, img_metas, prev_bev=None, **kwargs):
"""
obtain bev features.
"""
bs, num_cam, _, _, _ = mlvl_feats[0].shape
dtype = mlvl_feats[0].dtype
bev_queries = self.bev_embedding.weight.to(dtype)
bev_queries = bev_queries.unsqueeze(1).repeat(1, bs, 1)
bev_mask = torch.zeros((bs, self.bev_h, self.bev_w),
device=bev_queries.device).to(dtype)
bev_pos = self.positional_encoding(bev_mask).to(dtype)
bev_pos = bev_pos.flatten(2).permute(2, 0, 1)
# obtain rotation angle and shift with ego motion
delta_x = np.array([each['can_bus'][0]
for each in img_metas])
delta_y = np.array([each['can_bus'][1]
for each in img_metas])
ego_angle = np.array(
[each['can_bus'][-2] / np.pi * 180 for each in img_metas])
grid_length_y = self.real_h / self.bev_h
grid_length_x = self.real_w / self.bev_w
translation_length = np.sqrt(delta_x ** 2 + delta_y ** 2)
translation_angle = np.arctan2(delta_y, delta_x) / np.pi * 180
bev_angle = ego_angle - translation_angle
shift_y = translation_length * \
np.cos(bev_angle / 180 * np.pi) / grid_length_y / self.bev_h
shift_x = translation_length * \
np.sin(bev_angle / 180 * np.pi) / grid_length_x / self.bev_w
shift_y = shift_y * self.use_shift
shift_x = shift_x * self.use_shift
shift = bev_queries.new_tensor(
[shift_x, shift_y]).permute(1, 0) # xy, bs -> bs, xy
if prev_bev is not None:
if prev_bev.shape[1] == self.bev_h * self.bev_w:
prev_bev = prev_bev.permute(1, 0, 2)
if self.rotate_prev_bev:
for i in range(bs):
# num_prev_bev = prev_bev.size(1)
rotation_angle = img_metas[i]['can_bus'][-1]
tmp_prev_bev = prev_bev[:, i].reshape(
self.bev_h, self.bev_w, -1).permute(2, 0, 1)
tmp_prev_bev = rotate(tmp_prev_bev, rotation_angle,
center=self.rotate_center)
tmp_prev_bev = tmp_prev_bev.permute(1, 2, 0).reshape(
self.bev_h * self.bev_w, 1, -1)
prev_bev[:, i] = tmp_prev_bev[:, 0]
# add can bus signals
can_bus = bev_queries.new_tensor(
[each['can_bus'] for each in img_metas]) # [:, :]
can_bus = self.can_bus_mlp(can_bus)[None, :, :]
bev_queries = bev_queries + can_bus * self.use_can_bus
feat_flatten = []
spatial_shapes = []
for lvl, feat in enumerate(mlvl_feats):
bs, num_cam, c, h, w = feat.shape
spatial_shape = (h, w)
feat = feat.flatten(3).permute(1, 0, 3, 2)
if self.use_cams_embeds:
feat = feat + self.cams_embeds[:, None, None, :].to(feat.dtype)
feat = feat + self.level_embeds[None,
None, lvl:lvl + 1, :].to(feat.dtype)
spatial_shapes.append(spatial_shape)
feat_flatten.append(feat)
feat_flatten = torch.cat(feat_flatten, 2)
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=bev_pos.device)
level_start_index = torch.cat((spatial_shapes.new_zeros(
(1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
feat_flatten = feat_flatten.permute(
0, 2, 1, 3) # (num_cam, H*W, bs, embed_dims)
bev_embed = self.encoder(
bev_queries,
feat_flatten,
feat_flatten,
bev_h=self.bev_h,
bev_w=self.bev_w,
bev_pos=bev_pos,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
prev_bev=prev_bev,
shift=shift,
img_metas=img_metas,
**kwargs
)
return bev_embed
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
import copy
import warnings
import torch
import torch.nn as nn
from mmcv import ConfigDict, deprecated_api_warning
from mmcv.cnn import Linear, build_activation_layer, build_norm_layer
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.cnn.bricks.registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)
# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
try:
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention # noqa F401
warnings.warn(
ImportWarning(
'``MultiScaleDeformableAttention`` has been moved to '
'``mmcv.ops.multi_scale_deform_attn``, please change original path ' # noqa E501
'``from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` ' # noqa E501
'to ``from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` ' # noqa E501
))
except ImportError:
warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '
'``mmcv.ops.multi_scale_deform_attn``, '
'You should install ``mmcv-full`` if you need this module. ')
from mmcv.cnn.bricks.transformer import build_feedforward_network, build_attention
@TRANSFORMER_LAYER.register_module()
class MyCustomBaseTransformerLayer(BaseModule):
"""Base `TransformerLayer` for vision transformer.
It can be built from `mmcv.ConfigDict` and support more flexible
customization, for example, using any number of `FFN or LN ` and
use different kinds of `attention` by specifying a list of `ConfigDict`
named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
when you specifying `norm` as the first element of `operation_order`.
More details about the `prenorm`: `On Layer Normalization in the
Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ .
Args:
attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
Configs for `self_attention` or `cross_attention` modules,
The order of the configs in the list should be consistent with
corresponding attentions in operation_order.
If it is a dict, all of the attention modules in operation_order
will be built with this config. Default: None.
ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
Configs for FFN, The order of the configs in the list should be
consistent with corresponding ffn in operation_order.
If it is a dict, all of the attention modules in operation_order
will be built with this config.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
Support `prenorm` when you specifying first element as `norm`.
Default:None.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
batch_first (bool): Key, Query and Value are shape
of (batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
"""
def __init__(self,
attn_cfgs=None,
ffn_cfgs=dict(
type='FFN',
embed_dims=256,
feedforward_channels=1024,
num_fcs=2,
ffn_drop=0.,
act_cfg=dict(type='ReLU', inplace=True),
),
operation_order=None,
norm_cfg=dict(type='LN'),
init_cfg=None,
batch_first=True,
**kwargs):
deprecated_args = dict(
feedforward_channels='feedforward_channels',
ffn_dropout='ffn_drop',
ffn_num_fcs='num_fcs')
for ori_name, new_name in deprecated_args.items():
if ori_name in kwargs:
warnings.warn(
f'The arguments `{ori_name}` in BaseTransformerLayer '
f'has been deprecated, now you should set `{new_name}` '
f'and other FFN related arguments '
f'to a dict named `ffn_cfgs`. ')
ffn_cfgs[new_name] = kwargs[ori_name]
super(MyCustomBaseTransformerLayer, self).__init__(init_cfg)
self.batch_first = batch_first
assert set(operation_order) & set(
['self_attn', 'norm', 'ffn', 'cross_attn']) == \
set(operation_order), f'The operation_order of' \
f' {self.__class__.__name__} should ' \
f'contains all four operation type ' \
f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
num_attn = operation_order.count('self_attn') + operation_order.count(
'cross_attn')
if isinstance(attn_cfgs, dict):
attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
else:
assert num_attn == len(attn_cfgs), f'The length ' \
f'of attn_cfg {num_attn} is ' \
f'not consistent with the number of attention' \
f'in operation_order {operation_order}.'
self.num_attn = num_attn
self.operation_order = operation_order
self.norm_cfg = norm_cfg
self.pre_norm = operation_order[0] == 'norm'
self.attentions = ModuleList()
index = 0
for operation_name in operation_order:
if operation_name in ['self_attn', 'cross_attn']:
if 'batch_first' in attn_cfgs[index]:
assert self.batch_first == attn_cfgs[index]['batch_first']
else:
attn_cfgs[index]['batch_first'] = self.batch_first
attention = build_attention(attn_cfgs[index])
# Some custom attentions used as `self_attn`
# or `cross_attn` can have different behavior.
attention.operation_name = operation_name
self.attentions.append(attention)
index += 1
self.embed_dims = self.attentions[0].embed_dims
self.ffns = ModuleList()
num_ffns = operation_order.count('ffn')
if isinstance(ffn_cfgs, dict):
ffn_cfgs = ConfigDict(ffn_cfgs)
if isinstance(ffn_cfgs, dict):
ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
assert len(ffn_cfgs) == num_ffns
for ffn_index in range(num_ffns):
if 'embed_dims' not in ffn_cfgs[ffn_index]:
ffn_cfgs['embed_dims'] = self.embed_dims
else:
assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
self.ffns.append(
build_feedforward_network(ffn_cfgs[ffn_index]))
self.norms = ModuleList()
num_norms = operation_order.count('norm')
for _ in range(num_norms):
self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
def forward(self,
query,
key=None,
value=None,
query_pos=None,
key_pos=None,
attn_masks=None,
query_key_padding_mask=None,
key_padding_mask=None,
**kwargs):
"""Forward function for `TransformerDecoderLayer`.
**kwargs contains some specific arguments of attentions.
Args:
query (Tensor): The input query with shape
[num_queries, bs, embed_dims] if
self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
value (Tensor): The value tensor with same shape as `key`.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`.
Default: None.
attn_masks (List[Tensor] | None): 2D Tensor used in
calculation of corresponding attention. The length of
it should equal to the number of `attention` in
`operation_order`. Default: None.
query_key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_queries]. Only used in `self_attn` layer.
Defaults to None.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_keys]. Default: None.
Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
"""
norm_index = 0
attn_index = 0
ffn_index = 0
identity = query
if attn_masks is None:
attn_masks = [None for _ in range(self.num_attn)]
elif isinstance(attn_masks, torch.Tensor):
attn_masks = [
copy.deepcopy(attn_masks) for _ in range(self.num_attn)
]
warnings.warn(f'Use same attn_mask in all attentions in '
f'{self.__class__.__name__} ')
else:
assert len(attn_masks) == self.num_attn, f'The length of ' \
f'attn_masks {len(attn_masks)} must be equal ' \
f'to the number of attention in ' \
f'operation_order {self.num_attn}'
for layer in self.operation_order:
if layer == 'self_attn':
temp_key = temp_value = query
query = self.attentions[attn_index](
query,
temp_key,
temp_value,
identity if self.pre_norm else None,
query_pos=query_pos,
key_pos=query_pos,
attn_mask=attn_masks[attn_index],
key_padding_mask=query_key_padding_mask,
**kwargs)
attn_index += 1
identity = query
elif layer == 'norm':
query = self.norms[norm_index](query)
norm_index += 1
elif layer == 'cross_attn':
query = self.attentions[attn_index](
query,
key,
value,
identity if self.pre_norm else None,
query_pos=query_pos,
key_pos=key_pos,
attn_mask=attn_masks[attn_index],
key_padding_mask=key_padding_mask,
**kwargs)
attn_index += 1
identity = query
elif layer == 'ffn':
query = self.ffns[ffn_index](
query, identity if self.pre_norm else None)
ffn_index += 1
return query
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
from cmath import pi
from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch
import mmcv
import cv2 as cv
import copy
import warnings
from matplotlib import pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import xavier_init, constant_init
from mmcv.cnn.bricks.registry import (ATTENTION, TRANSFORMER_LAYER,
TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn.bricks.transformer import BaseTransformerLayer, TransformerLayerSequence
import math
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning,
to_2tuple)
from mmcv.utils import ext_loader
from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32, \
MultiScaleDeformableAttnFunction_fp16
ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
def inverse_sigmoid(x, eps=1e-5):
"""Inverse function of sigmoid.
Args:
x (Tensor): The tensor to do the
inverse.
eps (float): EPS avoid numerical
overflow. Defaults 1e-5.
Returns:
Tensor: The x has passed the inverse
function of sigmoid, has same
shape with input.
"""
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class LaneDetectionTransformerDecoder(TransformerLayerSequence):
def __init__(self, *args, return_intermediate=False, **kwargs):
super(LaneDetectionTransformerDecoder, self).__init__(*args, **kwargs)
self.return_intermediate = return_intermediate
self.fp16_enabled = False
def forward(self,
query,
*args,
reference_points=None,
reg_branches=None,
key_padding_mask=None,
**kwargs):
"""Forward function for `Detr3DTransformerDecoder`.
Args:
query (Tensor): Input query with shape
`(num_query, bs, embed_dims)`.
reference_points (Tensor): The reference
points of offset. has shape
(bs, num_query, 4) when as_two_stage,
otherwise has shape ((bs, num_query, 2).
reg_branch: (obj:`nn.ModuleList`): Used for
refining the regression results. Only would
be passed when with_box_refine is True,
otherwise would be passed a `None`.
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
"""
output = query
intermediate = []
intermediate_reference_points = []
for lid, layer in enumerate(self.layers):
reference_points_input = reference_points[..., :2].unsqueeze(
2) # BS NUM_QUERY NUM_LEVEL 2
output = layer(
output,
*args,
reference_points=reference_points_input,
key_padding_mask=key_padding_mask,
**kwargs)
output = output.permute(1, 0, 2)
if reg_branches is not None:
tmp = reg_branches[lid](output)
assert reference_points.shape[-1] == 3
new_reference_points = torch.zeros_like(reference_points)
ref_center = (tmp[..., :3] + tmp[..., -3:]) / 2
new_reference_points = ref_center + inverse_sigmoid(reference_points)
new_reference_points = new_reference_points.sigmoid()
reference_points = new_reference_points.detach()
output = output.permute(1, 0, 2)
if self.return_intermediate:
intermediate.append(output)
intermediate_reference_points.append(reference_points)
if self.return_intermediate:
return torch.stack(intermediate), torch.stack(
intermediate_reference_points)
return output, reference_points
@TRANSFORMER_LAYER.register_module()
class CustomDetrTransformerDecoderLayer(BaseTransformerLayer):
"""Implements decoder layer in DETR transformer.
Args:
attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
Configs for self_attention or cross_attention, the order
should be consistent with it in `operation_order`. If it is
a dict, it would be expand to the number of attention in
`operation_order`.
feedforward_channels (int): The hidden dimension for FFNs.
ffn_dropout (float): Probability of an element to be zeroed
in ffn. Default 0.0.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
Default:None
act_cfg (dict): The activation config for FFNs. Default: `LN`
norm_cfg (dict): Config dict for normalization layer.
Default: `LN`.
ffn_num_fcs (int): The number of fully-connected layers in FFNs.
Default:2.
"""
def __init__(self,
attn_cfgs,
ffn_cfgs,
operation_order=None,
norm_cfg=dict(type='LN'),
**kwargs):
super(CustomDetrTransformerDecoderLayer, self).__init__(
attn_cfgs=attn_cfgs,
ffn_cfgs=ffn_cfgs,
operation_order=operation_order,
norm_cfg=norm_cfg,
**kwargs)
assert len(operation_order) == 6
assert set(operation_order) == set(
['self_attn', 'norm', 'cross_attn', 'ffn'])
@ATTENTION.register_module()
class CustomMSDeformableAttention(BaseModule):
"""An attention module used in Deformable-Detr.
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_identity`.
Default: 0.1.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims=256,
num_heads=8,
num_levels=4,
num_points=4,
im2col_step=64,
dropout=0.1,
batch_first=False,
norm_cfg=None,
init_cfg=None):
super().__init__(init_cfg)
if embed_dims % num_heads != 0:
raise ValueError(f'embed_dims must be divisible by num_heads, '
f'but got {embed_dims} and {num_heads}')
dim_per_head = embed_dims // num_heads
self.norm_cfg = norm_cfg
self.dropout = nn.Dropout(dropout)
self.batch_first = batch_first
self.fp16_enabled = False
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError(
'invalid input for _is_power_of_2: {} (type: {})'.format(
n, type(n)))
return (n & (n - 1) == 0) and n != 0
if not _is_power_of_2(dim_per_head):
warnings.warn(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.')
self.im2col_step = im2col_step
self.embed_dims = embed_dims
self.num_levels = num_levels
self.num_heads = num_heads
self.num_points = num_points
self.sampling_offsets = nn.Linear(
embed_dims, num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dims,
num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dims, embed_dims)
self.output_proj = nn.Linear(embed_dims, embed_dims)
self.init_weights()
def init_weights(self):
"""Default initialization for Parameters of Module."""
constant_init(self.sampling_offsets, 0.)
thetas = torch.arange(
self.num_heads,
dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init /
grid_init.abs().max(-1, keepdim=True)[0]).view(
self.num_heads, 1, 1,
2).repeat(1, self.num_levels, self.num_points, 1)
for i in range(self.num_points):
grid_init[:, :, i, :] *= i + 1
self.sampling_offsets.bias.data = grid_init.view(-1)
constant_init(self.attention_weights, val=0., bias=0.)
xavier_init(self.value_proj, distribution='uniform', bias=0.)
xavier_init(self.output_proj, distribution='uniform', bias=0.)
self._is_init = True
@deprecated_api_warning({'residual': 'identity'},
cls_name='MultiScaleDeformableAttention')
def forward(self,
query,
key=None,
value=None,
identity=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
flag='decoder',
**kwargs):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
identity (Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
`query` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if value is None:
value = query
if identity is None:
identity = query
if query_pos is not None:
query = query + query_pos
if not self.batch_first:
# change to (bs, num_query ,embed_dims)
query = query.permute(1, 0, 2)
value = value.permute(1, 0, 2)
bs, num_query, _ = query.shape
bs, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
value = self.value_proj(value)
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], 0.0)
value = value.view(bs, num_value, self.num_heads, -1)
sampling_offsets = self.sampling_offsets(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_levels * self.num_points)
attention_weights = attention_weights.softmax(-1)
attention_weights = attention_weights.view(bs, num_query,
self.num_heads,
self.num_levels,
self.num_points)
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack(
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets \
/ offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.num_points \
* reference_points[:, :, None, :, None, 2:] \
* 0.5
else:
raise ValueError(
f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if torch.cuda.is_available() and value.is_cuda:
# using fp16 deformable attention is unstable because it performs many sum operations
if value.dtype == torch.float16:
MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
else:
MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
output = MultiScaleDeformableAttnFunction.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights)
output = self.output_proj(output)
if not self.batch_first:
# (num_query, bs ,embed_dims)
output = output.permute(1, 0, 2)
return self.dropout(output) + identity
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
from .custom_base_transformer_layer import MyCustomBaseTransformerLayer
import copy
import warnings
from mmcv.cnn.bricks.registry import (ATTENTION,
TRANSFORMER_LAYER,
TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn.bricks.transformer import TransformerLayerSequence
from mmcv.runner import force_fp32, auto_fp16
import numpy as np
import torch
import cv2 as cv
import mmcv
from mmcv.utils import TORCH_VERSION, digit_version
from mmcv.utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class BEVFormerEncoder(TransformerLayerSequence):
"""
Attention with both self and cross
Implements the decoder in DETR transformer.
Args:
return_intermediate (bool): Whether to return intermediate outputs.
coder_norm_cfg (dict): Config of last normalization layer. Default:
`LN`.
"""
def __init__(self, *args, pc_range=None, num_points_in_pillar=4, return_intermediate=False, dataset_type='nuscenes',
**kwargs):
super(BEVFormerEncoder, self).__init__(*args, **kwargs)
self.return_intermediate = return_intermediate
self.num_points_in_pillar = num_points_in_pillar
self.pc_range = pc_range
self.fp16_enabled = False
@staticmethod
def get_reference_points(H, W, Z=8, num_points_in_pillar=4, dim='3d', bs=1, device='cuda', dtype=torch.float):
"""Get the reference points used in SCA and TSA.
Args:
H, W: spatial shape of bev.
Z: hight of pillar.
D: sample D points uniformly from each pillar.
device (obj:`device`): The device where
reference_points should be.
Returns:
Tensor: reference points used in decoder, has \
shape (bs, num_keys, num_levels, 2).
"""
# reference points in 3D space, used in spatial cross-attention (SCA)
if dim == '3d':
zs = torch.linspace(0.5, Z - 0.5, num_points_in_pillar, dtype=dtype,
device=device).view(-1, 1, 1).expand(num_points_in_pillar, H, W) / Z
xs = torch.linspace(0.5, W - 0.5, W, dtype=dtype,
device=device).view(1, 1, W).expand(num_points_in_pillar, H, W) / W
ys = torch.linspace(0.5, H - 0.5, H, dtype=dtype,
device=device).view(1, H, 1).expand(num_points_in_pillar, H, W) / H
ref_3d = torch.stack((xs, ys, zs), -1)
ref_3d = ref_3d.permute(0, 3, 1, 2).flatten(2).permute(0, 2, 1)
ref_3d = ref_3d[None].repeat(bs, 1, 1, 1)
return ref_3d
# reference points on 2D bev plane, used in temporal self-attention (TSA).
elif dim == '2d':
ref_y, ref_x = torch.meshgrid(
torch.linspace(
0.5, H - 0.5, H, dtype=dtype, device=device),
torch.linspace(
0.5, W - 0.5, W, dtype=dtype, device=device)
)
ref_y = ref_y.reshape(-1)[None] / H
ref_x = ref_x.reshape(-1)[None] / W
ref_2d = torch.stack((ref_x, ref_y), -1)
ref_2d = ref_2d.repeat(bs, 1, 1).unsqueeze(2)
return ref_2d
# This function must use fp32!!!
@force_fp32(apply_to=('reference_points', 'img_metas'))
def point_sampling(self, reference_points, pc_range, img_metas):
lidar2img = []
for img_meta in img_metas:
lidar2img.append(img_meta['lidar2img'])
lidar2img = np.asarray(lidar2img)
lidar2img = reference_points.new_tensor(lidar2img) # (B, N, 4, 4)
reference_points = reference_points.clone()
reference_points[..., 0:1] = reference_points[..., 0:1] * \
(pc_range[3] - pc_range[0]) + pc_range[0]
reference_points[..., 1:2] = reference_points[..., 1:2] * \
(pc_range[4] - pc_range[1]) + pc_range[1]
reference_points[..., 2:3] = reference_points[..., 2:3] * \
(pc_range[5] - pc_range[2]) + pc_range[2]
reference_points = torch.cat(
(reference_points, torch.ones_like(reference_points[..., :1])), -1)
reference_points = reference_points.permute(1, 0, 2, 3)
D, B, num_query = reference_points.size()[:3]
num_cam = lidar2img.size(1)
reference_points = reference_points.view(
D, B, 1, num_query, 4).repeat(1, 1, num_cam, 1, 1).unsqueeze(-1)
lidar2img = lidar2img.view(
1, B, num_cam, 1, 4, 4).repeat(D, 1, 1, num_query, 1, 1)
reference_points_cam = torch.matmul(lidar2img.to(torch.float32),
reference_points.to(torch.float32)).squeeze(-1)
eps = 1e-5
bev_mask = (reference_points_cam[..., 2:3] > eps)
reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(
reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3]) * eps)
reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1]
reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0]
bev_mask = (bev_mask & (reference_points_cam[..., 1:2] > 0.0)
& (reference_points_cam[..., 1:2] < 1.0)
& (reference_points_cam[..., 0:1] < 1.0)
& (reference_points_cam[..., 0:1] > 0.0))
if digit_version(TORCH_VERSION) >= digit_version('1.8'):
bev_mask = torch.nan_to_num(bev_mask)
else:
bev_mask = bev_mask.new_tensor(
np.nan_to_num(bev_mask.cpu().numpy()))
reference_points_cam = reference_points_cam.permute(2, 1, 3, 0, 4)
bev_mask = bev_mask.permute(2, 1, 3, 0, 4).squeeze(-1)
return reference_points_cam, bev_mask
@auto_fp16()
def forward(self,
bev_query,
key,
value,
*args,
bev_h=None,
bev_w=None,
bev_pos=None,
spatial_shapes=None,
level_start_index=None,
valid_ratios=None,
prev_bev=None,
shift=0.,
**kwargs):
"""Forward function for `TransformerDecoder`.
Args:
bev_query (Tensor): Input BEV query with shape
`(num_query, bs, embed_dims)`.
key & value (Tensor): Input multi-cameta features with shape
(num_cam, num_value, bs, embed_dims)
reference_points (Tensor): The reference
points of offset. has shape
(bs, num_query, 4) when as_two_stage,
otherwise has shape ((bs, num_query, 2).
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
(bs, num_levels, 2)
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
"""
output = bev_query
intermediate = []
ref_3d = self.get_reference_points(
bev_h, bev_w, self.pc_range[5]-self.pc_range[2], self.num_points_in_pillar, dim='3d', bs=bev_query.size(1), device=bev_query.device, dtype=bev_query.dtype)
ref_2d = self.get_reference_points(
bev_h, bev_w, dim='2d', bs=bev_query.size(1), device=bev_query.device, dtype=bev_query.dtype)
reference_points_cam, bev_mask = self.point_sampling(
ref_3d, self.pc_range, kwargs['img_metas'])
# bug: this code should be 'shift_ref_2d = ref_2d.clone()', we keep this bug for reproducing our results in paper.
shift_ref_2d = ref_2d.clone() # .clone()
shift_ref_2d += shift[:, None, None, :]
# (num_query, bs, embed_dims) -> (bs, num_query, embed_dims)
bev_query = bev_query.permute(1, 0, 2)
bev_pos = bev_pos.permute(1, 0, 2)
bs, len_bev, num_bev_level, _ = ref_2d.shape
if prev_bev is not None:
prev_bev = prev_bev.permute(1, 0, 2)
prev_bev = torch.stack(
[prev_bev, bev_query], 1).reshape(bs*2, len_bev, -1)
hybird_ref_2d = torch.stack([shift_ref_2d, ref_2d], 1).reshape(
bs*2, len_bev, num_bev_level, 2)
else:
hybird_ref_2d = torch.stack([ref_2d, ref_2d], 1).reshape(
bs*2, len_bev, num_bev_level, 2)
for lid, layer in enumerate(self.layers):
output = layer(
bev_query,
key,
value,
*args,
bev_pos=bev_pos,
ref_2d=hybird_ref_2d,
ref_3d=ref_3d,
bev_h=bev_h,
bev_w=bev_w,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
reference_points_cam=reference_points_cam,
bev_mask=bev_mask,
prev_bev=prev_bev,
**kwargs)
bev_query = output
if self.return_intermediate:
intermediate.append(output)
if self.return_intermediate:
return torch.stack(intermediate)
return output
@TRANSFORMER_LAYER.register_module()
class BEVFormerLayer(MyCustomBaseTransformerLayer):
"""Implements decoder layer in DETR transformer.
Args:
attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
Configs for self_attention or cross_attention, the order
should be consistent with it in `operation_order`. If it is
a dict, it would be expand to the number of attention in
`operation_order`.
feedforward_channels (int): The hidden dimension for FFNs.
ffn_dropout (float): Probability of an element to be zeroed
in ffn. Default 0.0.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
Default:None
act_cfg (dict): The activation config for FFNs. Default: `LN`
norm_cfg (dict): Config dict for normalization layer.
Default: `LN`.
ffn_num_fcs (int): The number of fully-connected layers in FFNs.
Default:2.
"""
def __init__(self,
attn_cfgs,
ffn_cfgs,
operation_order=None,
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN'),
**kwargs):
super(BEVFormerLayer, self).__init__(
attn_cfgs=attn_cfgs,
ffn_cfgs=ffn_cfgs,
operation_order=operation_order,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
**kwargs)
self.fp16_enabled = False
assert len(operation_order) == 6
assert set(operation_order) == set(
['self_attn', 'norm', 'cross_attn', 'ffn'])
def forward(self,
query,
key=None,
value=None,
bev_pos=None,
query_pos=None,
key_pos=None,
attn_masks=None,
query_key_padding_mask=None,
key_padding_mask=None,
ref_2d=None,
ref_3d=None,
bev_h=None,
bev_w=None,
reference_points_cam=None,
mask=None,
spatial_shapes=None,
level_start_index=None,
prev_bev=None,
**kwargs):
"""Forward function for `TransformerDecoderLayer`.
**kwargs contains some specific arguments of attentions.
Args:
query (Tensor): The input query with shape
[num_queries, bs, embed_dims] if
self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
value (Tensor): The value tensor with same shape as `key`.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`.
Default: None.
attn_masks (List[Tensor] | None): 2D Tensor used in
calculation of corresponding attention. The length of
it should equal to the number of `attention` in
`operation_order`. Default: None.
query_key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_queries]. Only used in `self_attn` layer.
Defaults to None.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_keys]. Default: None.
Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
"""
norm_index = 0
attn_index = 0
ffn_index = 0
identity = query
if attn_masks is None:
attn_masks = [None for _ in range(self.num_attn)]
elif isinstance(attn_masks, torch.Tensor):
attn_masks = [
copy.deepcopy(attn_masks) for _ in range(self.num_attn)
]
warnings.warn(f'Use same attn_mask in all attentions in '
f'{self.__class__.__name__} ')
else:
assert len(attn_masks) == self.num_attn, f'The length of ' \
f'attn_masks {len(attn_masks)} must be equal ' \
f'to the number of attention in ' \
f'operation_order {self.num_attn}'
for layer in self.operation_order:
# temporal self attention
if layer == 'self_attn':
query = self.attentions[attn_index](
query,
prev_bev,
prev_bev,
identity if self.pre_norm else None,
query_pos=bev_pos,
key_pos=bev_pos,
attn_mask=attn_masks[attn_index],
key_padding_mask=query_key_padding_mask,
reference_points=ref_2d,
spatial_shapes=torch.tensor(
[[bev_h, bev_w]], device=query.device),
level_start_index=torch.tensor([0], device=query.device),
**kwargs)
attn_index += 1
identity = query
elif layer == 'norm':
query = self.norms[norm_index](query)
norm_index += 1
# spaital cross attention
elif layer == 'cross_attn':
query = self.attentions[attn_index](
query,
key,
value,
identity if self.pre_norm else None,
query_pos=query_pos,
key_pos=key_pos,
reference_points=ref_3d,
reference_points_cam=reference_points_cam,
mask=mask,
attn_mask=attn_masks[attn_index],
key_padding_mask=key_padding_mask,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
**kwargs)
attn_index += 1
identity = query
elif layer == 'ffn':
query = self.ffns[ffn_index](
query, identity if self.pre_norm else None)
ffn_index += 1
return query
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment