Unverified Commit 4d77b4c8 authored by Jingwei Zhang's avatar Jingwei Zhang Committed by GitHub
Browse files

[Feature] Support BEVFusion in `projects/` (#2236)

* add bevfusion models

* refactor

* build successfully

* update ImageAug3D

* support inference

* update the format of final bboxes

* add new loading func

* align test precision

* polish docstring

* refactor transformer decoder

* polish code

* fix table in readme

* fix table in readme

* fix table in readme

* update pre-commit-config

* minor changes

* revert the changes of file_client_args in LoadAnnotation3D

* remove unnucessary functions in BEVFusion

* fix loading bug

* fix docstring
parent c6a8eb1f
#include <ATen/TensorUtils.h>
#include <torch/extension.h>
// #include "voxelization.h"
namespace {
template <typename T, typename T_int>
void dynamic_voxelize_kernel(const torch::TensorAccessor<T, 2> points,
torch::TensorAccessor<T_int, 2> coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const std::vector<int> grid_size,
const int num_points, const int num_features,
const int NDim) {
const int ndim_minus_1 = NDim - 1;
bool failed = false;
// int coor[NDim];
int* coor = new int[NDim]();
int c;
for (int i = 0; i < num_points; ++i) {
failed = false;
for (int j = 0; j < NDim; ++j) {
c = floor((points[i][j] - coors_range[j]) / voxel_size[j]);
// necessary to rm points out of range
if ((c < 0 || c >= grid_size[j])) {
failed = true;
break;
}
coor[j] = c;
}
for (int k = 0; k < NDim; ++k) {
if (failed)
coors[i][k] = -1;
else
coors[i][k] = coor[k];
}
}
delete[] coor;
return;
}
template <typename T, typename T_int>
void hard_voxelize_kernel(const torch::TensorAccessor<T, 2> points,
torch::TensorAccessor<T, 3> voxels,
torch::TensorAccessor<T_int, 2> coors,
torch::TensorAccessor<T_int, 1> num_points_per_voxel,
torch::TensorAccessor<T_int, 3> coor_to_voxelidx,
int& voxel_num, const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const std::vector<int> grid_size,
const int max_points, const int max_voxels,
const int num_points, const int num_features,
const int NDim) {
// declare a temp coors
at::Tensor temp_coors = at::zeros(
{num_points, NDim}, at::TensorOptions().dtype(at::kInt).device(at::kCPU));
// First use dynamic voxelization to get coors,
// then check max points/voxels constraints
dynamic_voxelize_kernel<T, int>(points, temp_coors.accessor<int, 2>(),
voxel_size, coors_range, grid_size,
num_points, num_features, NDim);
int voxelidx, num;
auto coor = temp_coors.accessor<int, 2>();
for (int i = 0; i < num_points; ++i) {
// T_int* coor = temp_coors.data_ptr<int>() + i * NDim;
if (coor[i][0] == -1) continue;
voxelidx = coor_to_voxelidx[coor[i][0]][coor[i][1]][coor[i][2]];
// record voxel
if (voxelidx == -1) {
voxelidx = voxel_num;
if (max_voxels != -1 && voxel_num >= max_voxels) continue;
voxel_num += 1;
coor_to_voxelidx[coor[i][0]][coor[i][1]][coor[i][2]] = voxelidx;
for (int k = 0; k < NDim; ++k) {
coors[voxelidx][k] = coor[i][k];
}
}
// put points into voxel
num = num_points_per_voxel[voxelidx];
if (max_points == -1 || num < max_points) {
for (int k = 0; k < num_features; ++k) {
voxels[voxelidx][num][k] = points[i][k];
}
num_points_per_voxel[voxelidx] += 1;
}
}
return;
}
} // namespace
namespace voxelization {
int hard_voxelize_cpu(const at::Tensor& points, at::Tensor& voxels,
at::Tensor& coors, at::Tensor& num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim = 3) {
// current version tooks about 0.02s_0.03s for one frame on cpu
// check device
AT_ASSERTM(points.device().is_cpu(), "points must be a CPU tensor");
std::vector<int> grid_size(NDim);
const int num_points = points.size(0);
const int num_features = points.size(1);
for (int i = 0; i < NDim; ++i) {
grid_size[i] =
round((coors_range[NDim + i] - coors_range[i]) / voxel_size[i]);
}
// coors, num_points_per_voxel, coor_to_voxelidx are int Tensor
// printf("cpu coor_to_voxelidx size: [%d, %d, %d]\n", grid_size[2],
// grid_size[1], grid_size[0]);
at::Tensor coor_to_voxelidx =
-at::ones({grid_size[2], grid_size[1], grid_size[0]}, coors.options());
int voxel_num = 0;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
points.scalar_type(), "hard_voxelize_forward", [&] {
hard_voxelize_kernel<scalar_t, int>(
points.accessor<scalar_t, 2>(), voxels.accessor<scalar_t, 3>(),
coors.accessor<int, 2>(), num_points_per_voxel.accessor<int, 1>(),
coor_to_voxelidx.accessor<int, 3>(), voxel_num, voxel_size,
coors_range, grid_size, max_points, max_voxels, num_points,
num_features, NDim);
});
return voxel_num;
}
void dynamic_voxelize_cpu(const at::Tensor& points, at::Tensor& coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int NDim = 3) {
// check device
AT_ASSERTM(points.device().is_cpu(), "points must be a CPU tensor");
std::vector<int> grid_size(NDim);
const int num_points = points.size(0);
const int num_features = points.size(1);
for (int i = 0; i < NDim; ++i) {
grid_size[i] =
round((coors_range[NDim + i] - coors_range[i]) / voxel_size[i]);
}
// coors, num_points_per_voxel, coor_to_voxelidx are int Tensor
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
points.scalar_type(), "hard_voxelize_forward", [&] {
dynamic_voxelize_kernel<scalar_t, int>(
points.accessor<scalar_t, 2>(), coors.accessor<int, 2>(),
voxel_size, coors_range, grid_size, num_points, num_features, NDim);
});
return;
}
} // namespace voxelization
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/types.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
namespace {
int const threadsPerBlock = sizeof(unsigned long long) * 8;
}
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
template <typename T, typename T_int>
__global__ void dynamic_voxelize_kernel(
const T* points, T_int* coors, const float voxel_x, const float voxel_y,
const float voxel_z, const float coors_x_min, const float coors_y_min,
const float coors_z_min, const float coors_x_max, const float coors_y_max,
const float coors_z_max, const int grid_x, const int grid_y,
const int grid_z, const int num_points, const int num_features,
const int NDim) {
// const int index = blockIdx.x * threadsPerBlock + threadIdx.x;
CUDA_1D_KERNEL_LOOP(index, num_points) {
// To save some computation
auto points_offset = points + index * num_features;
auto coors_offset = coors + index * NDim;
int c_x = floor((points_offset[0] - coors_x_min) / voxel_x);
if (c_x < 0 || c_x >= grid_x) {
coors_offset[0] = -1;
return;
}
int c_y = floor((points_offset[1] - coors_y_min) / voxel_y);
if (c_y < 0 || c_y >= grid_y) {
coors_offset[0] = -1;
coors_offset[1] = -1;
return;
}
int c_z = floor((points_offset[2] - coors_z_min) / voxel_z);
if (c_z < 0 || c_z >= grid_z) {
coors_offset[0] = -1;
coors_offset[1] = -1;
coors_offset[2] = -1;
} else {
coors_offset[0] = c_x;
coors_offset[1] = c_y;
coors_offset[2] = c_z;
}
}
}
template <typename T, typename T_int>
__global__ void assign_point_to_voxel(const int nthreads, const T* points,
T_int* point_to_voxelidx,
T_int* coor_to_voxelidx, T* voxels,
const int max_points,
const int num_features,
const int num_points, const int NDim) {
CUDA_1D_KERNEL_LOOP(thread_idx, nthreads) {
// const int index = blockIdx.x * threadsPerBlock + threadIdx.x;
int index = thread_idx / num_features;
int num = point_to_voxelidx[index];
int voxelidx = coor_to_voxelidx[index];
if (num > -1 && voxelidx > -1) {
auto voxels_offset =
voxels + voxelidx * max_points * num_features + num * num_features;
int k = thread_idx % num_features;
voxels_offset[k] = points[thread_idx];
}
}
}
template <typename T, typename T_int>
__global__ void assign_voxel_coors(const int nthreads, T_int* coor,
T_int* point_to_voxelidx,
T_int* coor_to_voxelidx, T_int* voxel_coors,
const int num_points, const int NDim) {
CUDA_1D_KERNEL_LOOP(thread_idx, nthreads) {
// const int index = blockIdx.x * threadsPerBlock + threadIdx.x;
// if (index >= num_points) return;
int index = thread_idx / NDim;
int num = point_to_voxelidx[index];
int voxelidx = coor_to_voxelidx[index];
if (num == 0 && voxelidx > -1) {
auto coors_offset = voxel_coors + voxelidx * NDim;
int k = thread_idx % NDim;
coors_offset[k] = coor[thread_idx];
}
}
}
template <typename T_int>
__global__ void point_to_voxelidx_kernel(const T_int* coor,
T_int* point_to_voxelidx,
T_int* point_to_pointidx,
const int max_points,
const int max_voxels,
const int num_points, const int NDim) {
CUDA_1D_KERNEL_LOOP(index, num_points) {
auto coor_offset = coor + index * NDim;
// skip invalid points
if ((index >= num_points) || (coor_offset[0] == -1)) return;
int num = 0;
int coor_x = coor_offset[0];
int coor_y = coor_offset[1];
int coor_z = coor_offset[2];
// only calculate the coors before this coor[index]
for (int i = 0; i < index; ++i) {
auto prev_coor = coor + i * NDim;
if (prev_coor[0] == -1) continue;
// Find all previous points that have the same coors
// if find the same coor, record it
if ((prev_coor[0] == coor_x) && (prev_coor[1] == coor_y) &&
(prev_coor[2] == coor_z)) {
num++;
if (num == 1) {
// point to the same coor that first show up
point_to_pointidx[index] = i;
} else if (num >= max_points) {
// out of boundary
return;
}
}
}
if (num == 0) {
point_to_pointidx[index] = index;
}
if (num < max_points) {
point_to_voxelidx[index] = num;
}
}
}
template <typename T_int>
__global__ void determin_voxel_num(
// const T_int* coor,
T_int* num_points_per_voxel, T_int* point_to_voxelidx,
T_int* point_to_pointidx, T_int* coor_to_voxelidx, T_int* voxel_num,
const int max_points, const int max_voxels, const int num_points) {
// only calculate the coors before this coor[index]
for (int i = 0; i < num_points; ++i) {
// if (coor[i][0] == -1)
// continue;
int point_pos_in_voxel = point_to_voxelidx[i];
// record voxel
if (point_pos_in_voxel == -1) {
// out of max_points or invalid point
continue;
} else if (point_pos_in_voxel == 0) {
// record new voxel
int voxelidx = voxel_num[0];
if (voxel_num[0] >= max_voxels) continue;
voxel_num[0] += 1;
coor_to_voxelidx[i] = voxelidx;
num_points_per_voxel[voxelidx] = 1;
} else {
int point_idx = point_to_pointidx[i];
int voxelidx = coor_to_voxelidx[point_idx];
if (voxelidx != -1) {
coor_to_voxelidx[i] = voxelidx;
num_points_per_voxel[voxelidx] += 1;
}
}
}
}
__global__ void nondisterministic_get_assign_pos(
const int nthreads, const int32_t *coors_map, int32_t *pts_id,
int32_t *coors_count, int32_t *reduce_count, int32_t *coors_order) {
CUDA_1D_KERNEL_LOOP(thread_idx, nthreads) {
int coors_idx = coors_map[thread_idx];
if (coors_idx > -1) {
int32_t coors_pts_pos = atomicAdd(&reduce_count[coors_idx], 1);
pts_id[thread_idx] = coors_pts_pos;
if (coors_pts_pos == 0) {
coors_order[coors_idx] = atomicAdd(coors_count, 1);
}
}
}
}
template<typename T>
__global__ void nondisterministic_assign_point_voxel(
const int nthreads, const T *points, const int32_t *coors_map,
const int32_t *pts_id, const int32_t *coors_in,
const int32_t *reduce_count, const int32_t *coors_order,
T *voxels, int32_t *coors, int32_t *pts_count, const int max_voxels,
const int max_points, const int num_features, const int NDim) {
CUDA_1D_KERNEL_LOOP(thread_idx, nthreads) {
int coors_idx = coors_map[thread_idx];
int coors_pts_pos = pts_id[thread_idx];
if (coors_idx > -1) {
int coors_pos = coors_order[coors_idx];
if (coors_pos < max_voxels && coors_pts_pos < max_points) {
auto voxels_offset =
voxels + (coors_pos * max_points + coors_pts_pos) * num_features;
auto points_offset = points + thread_idx * num_features;
for (int k = 0; k < num_features; k++) {
voxels_offset[k] = points_offset[k];
}
if (coors_pts_pos == 0) {
pts_count[coors_pos] = min(reduce_count[coors_idx], max_points);
auto coors_offset = coors + coors_pos * NDim;
auto coors_in_offset = coors_in + coors_idx * NDim;
for (int k = 0; k < NDim; k++) {
coors_offset[k] = coors_in_offset[k];
}
}
}
}
}
}
namespace voxelization {
int hard_voxelize_gpu(const at::Tensor& points, at::Tensor& voxels,
at::Tensor& coors, at::Tensor& num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim = 3) {
// current version tooks about 0.04s for one frame on cpu
// check device
CHECK_INPUT(points);
at::cuda::CUDAGuard device_guard(points.device());
const int num_points = points.size(0);
const int num_features = points.size(1);
const float voxel_x = voxel_size[0];
const float voxel_y = voxel_size[1];
const float voxel_z = voxel_size[2];
const float coors_x_min = coors_range[0];
const float coors_y_min = coors_range[1];
const float coors_z_min = coors_range[2];
const float coors_x_max = coors_range[3];
const float coors_y_max = coors_range[4];
const float coors_z_max = coors_range[5];
const int grid_x = round((coors_x_max - coors_x_min) / voxel_x);
const int grid_y = round((coors_y_max - coors_y_min) / voxel_y);
const int grid_z = round((coors_z_max - coors_z_min) / voxel_z);
// map points to voxel coors
at::Tensor temp_coors =
at::zeros({num_points, NDim}, points.options().dtype(at::kInt));
dim3 grid(std::min(at::cuda::ATenCeilDiv(num_points, 512), 4096));
dim3 block(512);
// 1. link point to corresponding voxel coors
AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "hard_voxelize_kernel", ([&] {
dynamic_voxelize_kernel<scalar_t, int>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
points.contiguous().data_ptr<scalar_t>(),
temp_coors.contiguous().data_ptr<int>(), voxel_x, voxel_y,
voxel_z, coors_x_min, coors_y_min, coors_z_min, coors_x_max,
coors_y_max, coors_z_max, grid_x, grid_y, grid_z, num_points,
num_features, NDim);
}));
cudaDeviceSynchronize();
AT_CUDA_CHECK(cudaGetLastError());
// 2. map point to the idx of the corresponding voxel, find duplicate coor
// create some temporary variables
auto point_to_pointidx = -at::ones(
{
num_points,
},
points.options().dtype(at::kInt));
auto point_to_voxelidx = -at::ones(
{
num_points,
},
points.options().dtype(at::kInt));
dim3 map_grid(std::min(at::cuda::ATenCeilDiv(num_points, 512), 4096));
dim3 map_block(512);
AT_DISPATCH_ALL_TYPES(
temp_coors.scalar_type(), "determin_duplicate", ([&] {
point_to_voxelidx_kernel<int>
<<<map_grid, map_block, 0, at::cuda::getCurrentCUDAStream()>>>(
temp_coors.contiguous().data_ptr<int>(),
point_to_voxelidx.contiguous().data_ptr<int>(),
point_to_pointidx.contiguous().data_ptr<int>(), max_points,
max_voxels, num_points, NDim);
}));
cudaDeviceSynchronize();
AT_CUDA_CHECK(cudaGetLastError());
// 3. determined voxel num and voxel's coor index
// make the logic in the CUDA device could accelerate about 10 times
auto coor_to_voxelidx = -at::ones(
{
num_points,
},
points.options().dtype(at::kInt));
auto voxel_num = at::zeros(
{
1,
},
points.options().dtype(at::kInt)); // must be zero from the beginning
AT_DISPATCH_ALL_TYPES(
temp_coors.scalar_type(), "determin_duplicate", ([&] {
determin_voxel_num<int><<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
num_points_per_voxel.contiguous().data_ptr<int>(),
point_to_voxelidx.contiguous().data_ptr<int>(),
point_to_pointidx.contiguous().data_ptr<int>(),
coor_to_voxelidx.contiguous().data_ptr<int>(),
voxel_num.contiguous().data_ptr<int>(), max_points, max_voxels,
num_points);
}));
cudaDeviceSynchronize();
AT_CUDA_CHECK(cudaGetLastError());
// 4. copy point features to voxels
// Step 4 & 5 could be parallel
auto pts_output_size = num_points * num_features;
dim3 cp_grid(std::min(at::cuda::ATenCeilDiv(pts_output_size, 512), 4096));
dim3 cp_block(512);
AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "assign_point_to_voxel", ([&] {
assign_point_to_voxel<float, int>
<<<cp_grid, cp_block, 0, at::cuda::getCurrentCUDAStream()>>>(
pts_output_size, points.contiguous().data_ptr<float>(),
point_to_voxelidx.contiguous().data_ptr<int>(),
coor_to_voxelidx.contiguous().data_ptr<int>(),
voxels.contiguous().data_ptr<float>(), max_points, num_features,
num_points, NDim);
}));
// cudaDeviceSynchronize();
// AT_CUDA_CHECK(cudaGetLastError());
// 5. copy coors of each voxels
auto coors_output_size = num_points * NDim;
dim3 coors_cp_grid(
std::min(at::cuda::ATenCeilDiv(coors_output_size, 512), 4096));
dim3 coors_cp_block(512);
AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "assign_point_to_voxel", ([&] {
assign_voxel_coors<float, int><<<coors_cp_grid, coors_cp_block, 0,
at::cuda::getCurrentCUDAStream()>>>(
coors_output_size, temp_coors.contiguous().data_ptr<int>(),
point_to_voxelidx.contiguous().data_ptr<int>(),
coor_to_voxelidx.contiguous().data_ptr<int>(),
coors.contiguous().data_ptr<int>(), num_points, NDim);
}));
cudaDeviceSynchronize();
AT_CUDA_CHECK(cudaGetLastError());
auto voxel_num_cpu = voxel_num.to(at::kCPU);
int voxel_num_int = voxel_num_cpu.data_ptr<int>()[0];
return voxel_num_int;
}
int nondisterministic_hard_voxelize_gpu(
const at::Tensor &points, at::Tensor &voxels,
at::Tensor &coors, at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim = 3) {
CHECK_INPUT(points);
at::cuda::CUDAGuard device_guard(points.device());
const int num_points = points.size(0);
const int num_features = points.size(1);
if (num_points == 0)
return 0;
const float voxel_x = voxel_size[0];
const float voxel_y = voxel_size[1];
const float voxel_z = voxel_size[2];
const float coors_x_min = coors_range[0];
const float coors_y_min = coors_range[1];
const float coors_z_min = coors_range[2];
const float coors_x_max = coors_range[3];
const float coors_y_max = coors_range[4];
const float coors_z_max = coors_range[5];
const int grid_x = round((coors_x_max - coors_x_min) / voxel_x);
const int grid_y = round((coors_y_max - coors_y_min) / voxel_y);
const int grid_z = round((coors_z_max - coors_z_min) / voxel_z);
// map points to voxel coors
at::Tensor temp_coors =
at::zeros({num_points, NDim}, points.options().dtype(torch::kInt32));
dim3 grid(std::min(at::cuda::ATenCeilDiv(num_points, 512), 4096));
dim3 block(512);
// 1. link point to corresponding voxel coors
AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "hard_voxelize_kernel", ([&] {
dynamic_voxelize_kernel<scalar_t, int>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
points.contiguous().data_ptr<scalar_t>(),
temp_coors.contiguous().data_ptr<int>(), voxel_x, voxel_y,
voxel_z, coors_x_min, coors_y_min, coors_z_min, coors_x_max,
coors_y_max, coors_z_max, grid_x, grid_y, grid_z, num_points,
num_features, NDim);
}));
at::Tensor coors_map;
at::Tensor coors_count;
at::Tensor coors_order;
at::Tensor reduce_count;
at::Tensor pts_id;
auto coors_clean = temp_coors.masked_fill(temp_coors.lt(0).any(-1, true), -1);
std::tie(temp_coors, coors_map, reduce_count) =
at::unique_dim(coors_clean, 0, true, true, false);
if (temp_coors.index({0, 0}).lt(0).item<bool>()) {
// the first element of temp_coors is (-1,-1,-1) and should be removed
temp_coors = temp_coors.slice(0, 1);
coors_map = coors_map - 1;
}
int num_coors = temp_coors.size(0);
temp_coors = temp_coors.to(torch::kInt32);
coors_map = coors_map.to(torch::kInt32);
coors_count = coors_map.new_zeros(1);
coors_order = coors_map.new_empty(num_coors);
reduce_count = coors_map.new_zeros(num_coors);
pts_id = coors_map.new_zeros(num_points);
dim3 cp_grid(std::min(at::cuda::ATenCeilDiv(num_points, 512), 4096));
dim3 cp_block(512);
AT_DISPATCH_ALL_TYPES(points.scalar_type(), "get_assign_pos", ([&] {
nondisterministic_get_assign_pos<<<cp_grid, cp_block, 0,
at::cuda::getCurrentCUDAStream()>>>(
num_points,
coors_map.contiguous().data_ptr<int32_t>(),
pts_id.contiguous().data_ptr<int32_t>(),
coors_count.contiguous().data_ptr<int32_t>(),
reduce_count.contiguous().data_ptr<int32_t>(),
coors_order.contiguous().data_ptr<int32_t>());
}));
AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "assign_point_to_voxel", ([&] {
nondisterministic_assign_point_voxel<scalar_t>
<<<cp_grid, cp_block, 0, at::cuda::getCurrentCUDAStream()>>>(
num_points, points.contiguous().data_ptr<scalar_t>(),
coors_map.contiguous().data_ptr<int32_t>(),
pts_id.contiguous().data_ptr<int32_t>(),
temp_coors.contiguous().data_ptr<int32_t>(),
reduce_count.contiguous().data_ptr<int32_t>(),
coors_order.contiguous().data_ptr<int32_t>(),
voxels.contiguous().data_ptr<scalar_t>(),
coors.contiguous().data_ptr<int32_t>(),
num_points_per_voxel.contiguous().data_ptr<int32_t>(),
max_voxels, max_points,
num_features, NDim);
}));
AT_CUDA_CHECK(cudaGetLastError());
return max_voxels < num_coors ? max_voxels : num_coors;
}
void dynamic_voxelize_gpu(const at::Tensor& points, at::Tensor& coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int NDim = 3) {
// current version tooks about 0.04s for one frame on cpu
// check device
CHECK_INPUT(points);
at::cuda::CUDAGuard device_guard(points.device());
const int num_points = points.size(0);
const int num_features = points.size(1);
const float voxel_x = voxel_size[0];
const float voxel_y = voxel_size[1];
const float voxel_z = voxel_size[2];
const float coors_x_min = coors_range[0];
const float coors_y_min = coors_range[1];
const float coors_z_min = coors_range[2];
const float coors_x_max = coors_range[3];
const float coors_y_max = coors_range[4];
const float coors_z_max = coors_range[5];
const int grid_x = round((coors_x_max - coors_x_min) / voxel_x);
const int grid_y = round((coors_y_max - coors_y_min) / voxel_y);
const int grid_z = round((coors_z_max - coors_z_min) / voxel_z);
const int col_blocks = at::cuda::ATenCeilDiv(num_points, threadsPerBlock);
dim3 blocks(col_blocks);
dim3 threads(threadsPerBlock);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(points.scalar_type(), "dynamic_voxelize_kernel", [&] {
dynamic_voxelize_kernel<scalar_t, int><<<blocks, threads, 0, stream>>>(
points.contiguous().data_ptr<scalar_t>(),
coors.contiguous().data_ptr<int>(), voxel_x, voxel_y, voxel_z,
coors_x_min, coors_y_min, coors_z_min, coors_x_max, coors_y_max,
coors_z_max, grid_x, grid_y, grid_z, num_points, num_features, NDim);
});
cudaDeviceSynchronize();
AT_CUDA_CHECK(cudaGetLastError());
return;
}
} // namespace voxelization
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
from torch import nn
from torch.autograd import Function
from torch.nn.modules.utils import _pair
from .voxel_layer import dynamic_voxelize, hard_voxelize
class _Voxelization(Function):
@staticmethod
def forward(ctx,
points,
voxel_size,
coors_range,
max_points=35,
max_voxels=20000,
deterministic=True):
"""convert kitti points(N, >=3) to voxels.
Args:
points: [N, ndim] float tensor. points[:, :3] contain xyz points
and points[:, 3:] contain other information like reflectivity
voxel_size: [3] list/tuple or array, float. xyz, indicate voxel
size
coors_range: [6] list/tuple or array, float. indicate voxel
range. format: xyzxyz, minmax
max_points: int. indicate maximum points contained in a voxel. if
max_points=-1, it means using dynamic_voxelize
max_voxels: int. indicate maximum voxels this function create.
for second, 20000 is a good choice. Users should shuffle points
before call this function because max_voxels may drop points.
deterministic: bool. whether to invoke the non-deterministic
version of hard-voxelization implementations. non-deterministic
version is considerablly fast but is not deterministic. only
affects hard voxelization. default True. for more information
of this argument and the implementation insights, please refer
to the following links:
https://github.com/open-mmlab/mmdetection3d/issues/894
https://github.com/open-mmlab/mmdetection3d/pull/904
it is an experimental feature and we will appreciate it if
you could share with us the failing cases.
Returns:
voxels: [M, max_points, ndim] float tensor. only contain points
and returned when max_points != -1.
coordinates: [M, 3] int32 tensor, always returned.
num_points_per_voxel: [M] int32 tensor. Only returned when
max_points != -1.
"""
if max_points == -1 or max_voxels == -1:
coors = points.new_zeros(size=(points.size(0), 3), dtype=torch.int)
dynamic_voxelize(points, coors, voxel_size, coors_range, 3)
return coors
else:
voxels = points.new_zeros(
size=(max_voxels, max_points, points.size(1)))
coors = points.new_zeros(size=(max_voxels, 3), dtype=torch.int)
num_points_per_voxel = points.new_zeros(
size=(max_voxels, ), dtype=torch.int)
voxel_num = hard_voxelize(
points,
voxels,
coors,
num_points_per_voxel,
voxel_size,
coors_range,
max_points,
max_voxels,
3,
deterministic,
)
# select the valid voxels
voxels_out = voxels[:voxel_num]
coors_out = coors[:voxel_num]
num_points_per_voxel_out = num_points_per_voxel[:voxel_num]
return voxels_out, coors_out, num_points_per_voxel_out
voxelization = _Voxelization.apply
class Voxelization(nn.Module):
def __init__(self,
voxel_size,
point_cloud_range,
max_num_points,
max_voxels=20000,
deterministic=True):
super(Voxelization, self).__init__()
"""
Args:
voxel_size (list): list [x, y, z] size of three dimension
point_cloud_range (list):
[x_min, y_min, z_min, x_max, y_max, z_max]
max_num_points (int): max number of points per voxel
max_voxels (tuple or int): max number of voxels in
(training, testing) time
deterministic: bool. whether to invoke the non-deterministic
version of hard-voxelization implementations. non-deterministic
version is considerablly fast but is not deterministic. only
affects hard voxelization. default True. for more information
of this argument and the implementation insights, please refer
to the following links:
https://github.com/open-mmlab/mmdetection3d/issues/894
https://github.com/open-mmlab/mmdetection3d/pull/904
it is an experimental feature and we will appreciate it if
you could share with us the failing cases.
"""
self.voxel_size = voxel_size
self.point_cloud_range = point_cloud_range
self.max_num_points = max_num_points
if isinstance(max_voxels, tuple):
self.max_voxels = max_voxels
else:
self.max_voxels = _pair(max_voxels)
self.deterministic = deterministic
point_cloud_range = torch.tensor(
point_cloud_range, dtype=torch.float32)
# [0, -40, -3, 70.4, 40, 1]
voxel_size = torch.tensor(voxel_size, dtype=torch.float32)
grid_size = (point_cloud_range[3:] -
point_cloud_range[:3]) / voxel_size
grid_size = torch.round(grid_size).long()
input_feat_shape = grid_size[:2]
self.grid_size = grid_size
# the origin shape is as [x-len, y-len, z-len]
# [w, h, d] -> [d, h, w] removed
self.pcd_shape = [*input_feat_shape, 1] # [::-1]
def forward(self, input):
"""
Args:
input: NC points
"""
if self.training:
max_voxels = self.max_voxels[0]
else:
max_voxels = self.max_voxels[1]
return voxelization(
input,
self.voxel_size,
self.point_cloud_range,
self.max_num_points,
max_voxels,
self.deterministic,
)
def __repr__(self):
tmpstr = self.__class__.__name__ + '('
tmpstr += 'voxel_size=' + str(self.voxel_size)
tmpstr += ', point_cloud_range=' + str(self.point_cloud_range)
tmpstr += ', max_num_points=' + str(self.max_num_points)
tmpstr += ', max_voxels=' + str(self.max_voxels)
tmpstr += ', deterministic=' + str(self.deterministic)
tmpstr += ')'
return tmpstr
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet3d.models.layers import make_sparse_convmodule
from mmdet3d.models.layers.spconv import IS_SPCONV2_AVAILABLE
from mmdet3d.models.middle_encoders import SparseEncoder
from mmdet3d.registry import MODELS
if IS_SPCONV2_AVAILABLE:
from spconv.pytorch import SparseConvTensor
else:
from mmcv.ops import SparseConvTensor
@MODELS.register_module()
class BEVFusionSparseEncoder(SparseEncoder):
r"""Sparse encoder for BEVFusion. The difference between this
implementation and that of ``SparseEncoder`` is that the shape order of 3D
conv is (H, W, D) in ``BEVFusionSparseEncoder`` rather than (D, H, W) in
``SparseEncoder``. This difference comes from the implementation of
``voxelization``.
Args:
in_channels (int): The number of input channels.
sparse_shape (list[int]): The sparse shape of input tensor.
order (list[str], optional): Order of conv module.
Defaults to ('conv', 'norm', 'act').
norm_cfg (dict, optional): Config of normalization layer. Defaults to
dict(type='BN1d', eps=1e-3, momentum=0.01).
base_channels (int, optional): Out channels for conv_input layer.
Defaults to 16.
output_channels (int, optional): Out channels for conv_out layer.
Defaults to 128.
encoder_channels (tuple[tuple[int]], optional):
Convolutional channels of each encode block.
Defaults to ((16, ), (32, 32, 32), (64, 64, 64), (64, 64, 64)).
encoder_paddings (tuple[tuple[int]], optional):
Paddings of each encode block.
Defaults to ((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1, 1)).
block_type (str, optional): Type of the block to use.
Defaults to 'conv_module'.
return_middle_feats (bool): Whether output middle features.
Default to False.
"""
def __init__(self,
in_channels,
sparse_shape,
order=('conv', 'norm', 'act'),
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
base_channels=16,
output_channels=128,
encoder_channels=((16, ), (32, 32, 32), (64, 64, 64), (64, 64,
64)),
encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1,
1)),
block_type='conv_module',
return_middle_feats=False):
super(SparseEncoder, self).__init__()
assert block_type in ['conv_module', 'basicblock']
self.sparse_shape = sparse_shape
self.in_channels = in_channels
self.order = order
self.base_channels = base_channels
self.output_channels = output_channels
self.encoder_channels = encoder_channels
self.encoder_paddings = encoder_paddings
self.stage_num = len(self.encoder_channels)
self.fp16_enabled = False
self.return_middle_feats = return_middle_feats
# Spconv init all weight on its own
assert isinstance(order, tuple) and len(order) == 3
assert set(order) == {'conv', 'norm', 'act'}
if self.order[0] != 'conv': # pre activate
self.conv_input = make_sparse_convmodule(
in_channels,
self.base_channels,
3,
norm_cfg=norm_cfg,
padding=1,
indice_key='subm1',
conv_type='SubMConv3d',
order=('conv', ))
else: # post activate
self.conv_input = make_sparse_convmodule(
in_channels,
self.base_channels,
3,
norm_cfg=norm_cfg,
padding=1,
indice_key='subm1',
conv_type='SubMConv3d')
encoder_out_channels = self.make_encoder_layers(
make_sparse_convmodule,
norm_cfg,
self.base_channels,
block_type=block_type)
self.conv_out = make_sparse_convmodule(
encoder_out_channels,
self.output_channels,
kernel_size=(1, 1, 3),
stride=(1, 1, 2),
norm_cfg=norm_cfg,
padding=0,
indice_key='spconv_down2',
conv_type='SparseConv3d')
def forward(self, voxel_features, coors, batch_size):
"""Forward of SparseEncoder.
Args:
voxel_features (torch.Tensor): Voxel features in shape (N, C).
coors (torch.Tensor): Coordinates in shape (N, 4),
the columns in the order of (batch_idx, z_idx, y_idx, x_idx).
batch_size (int): Batch size.
Returns:
torch.Tensor | tuple[torch.Tensor, list]: Return spatial features
include:
- spatial_features (torch.Tensor): Spatial features are out from
the last layer.
- encode_features (List[SparseConvTensor], optional): Middle layer
output features. When self.return_middle_feats is True, the
module returns middle features.
"""
coors = coors.int()
input_sp_tensor = SparseConvTensor(voxel_features, coors,
self.sparse_shape, batch_size)
x = self.conv_input(input_sp_tensor)
encode_features = []
for encoder_layer in self.encoder_layers:
x = encoder_layer(x)
encode_features.append(x)
# for detection head
# [200, 176, 5] -> [200, 176, 2]
out = self.conv_out(encode_features[-1])
spatial_features = out.dense()
N, C, H, W, D = spatial_features.shape
spatial_features = spatial_features.permute(0, 1, 4, 2, 3).contiguous()
spatial_features = spatial_features.view(N, C * D, H, W)
if self.return_middle_feats:
return spatial_features, encode_features
else:
return spatial_features
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.models import DetrTransformerDecoderLayer
from torch import Tensor, nn
from mmdet3d.registry import MODELS
class PositionEncodingLearned(nn.Module):
"""Absolute pos embedding, learned."""
def __init__(self, input_channel, num_pos_feats=288):
super().__init__()
self.position_embedding_head = nn.Sequential(
nn.Conv1d(input_channel, num_pos_feats, kernel_size=1),
nn.BatchNorm1d(num_pos_feats), nn.ReLU(inplace=True),
nn.Conv1d(num_pos_feats, num_pos_feats, kernel_size=1))
def forward(self, xyz):
xyz = xyz.transpose(1, 2).contiguous()
position_embedding = self.position_embedding_head(xyz)
return position_embedding
@MODELS.register_module()
class TransformerDecoderLayer(DetrTransformerDecoderLayer):
def __init__(self,
pos_encoding_cfg=dict(input_channel=2, num_pos_feats=128),
**kwargs):
super().__init__(**kwargs)
self.self_posembed = PositionEncodingLearned(**pos_encoding_cfg)
self.cross_posembed = PositionEncodingLearned(**pos_encoding_cfg)
def forward(self,
query: Tensor,
key: Tensor = None,
value: Tensor = None,
query_pos: Tensor = None,
key_pos: Tensor = None,
self_attn_mask: Tensor = None,
cross_attn_mask: Tensor = None,
key_padding_mask: Tensor = None,
**kwargs) -> Tensor:
"""
Args:
query (Tensor): The input query, has shape (bs, num_queries, dim).
key (Tensor, optional): The input key, has shape (bs, num_keys,
dim). If `None`, the `query` will be used. Defaults to `None`.
value (Tensor, optional): The input value, has the same shape as
`key`, as in `nn.MultiheadAttention.forward`. If `None`, the
`key` will be used. Defaults to `None`.
query_pos (Tensor, optional): The positional encoding for `query`,
has the same shape as `query`. If not `None`, it will be added
to `query` before forward function. Defaults to `None`.
key_pos (Tensor, optional): The positional encoding for `key`, has
the same shape as `key`. If not `None`, it will be added to
`key` before forward function. If None, and `query_pos` has the
same shape as `key`, then `query_pos` will be used for
`key_pos`. Defaults to None.
self_attn_mask (Tensor, optional): ByteTensor mask, has shape
(num_queries, num_keys), as in `nn.MultiheadAttention.forward`.
Defaults to None.
cross_attn_mask (Tensor, optional): ByteTensor mask, has shape
(num_queries, num_keys), as in `nn.MultiheadAttention.forward`.
Defaults to None.
key_padding_mask (Tensor, optional): The `key_padding_mask` of
`self_attn` input. ByteTensor, has shape (bs, num_value).
Defaults to None.
Returns:
Tensor: forwarded results, has shape (bs, num_queries, dim).
"""
if self.self_posembed is not None and query_pos is not None:
query_pos = self.self_posembed(query_pos).transpose(1, 2)
else:
query_pos = None
if self.cross_posembed is not None and key_pos is not None:
key_pos = self.cross_posembed(key_pos).transpose(1, 2)
else:
key_pos = None
query = query.transpose(1, 2)
key = key.transpose(1, 2)
# Note that the `value` (equal to `query`) is encoded with `query_pos`.
# This is different from the standard DETR Decoder Layer.
query = self.self_attn(
query=query,
key=query,
value=query + query_pos,
query_pos=query_pos,
key_pos=query_pos,
attn_mask=self_attn_mask,
**kwargs)
query = self.norms[0](query)
# Note that the `value` (equal to `key`) is encoded with `key_pos`.
# This is different from the standard DETR Decoder Layer.
query = self.cross_attn(
query=query,
key=key,
value=key + key_pos,
query_pos=query_pos,
key_pos=key_pos,
attn_mask=cross_attn_mask,
key_padding_mask=key_padding_mask,
**kwargs)
query = self.norms[1](query)
query = self.ffn(query)
query = self.norms[2](query)
query = query.transpose(1, 2)
return query
# modify from https://github.com/mit-han-lab/bevfusion
from typing import Any, Dict
import numpy as np
import torch
from mmcv.transforms import BaseTransform
from PIL import Image
from mmdet3d.registry import TRANSFORMS
@TRANSFORMS.register_module()
class ImageAug3D(BaseTransform):
def __init__(self, final_dim, resize_lim, bot_pct_lim, rot_lim, rand_flip,
is_train):
self.final_dim = final_dim
self.resize_lim = resize_lim
self.bot_pct_lim = bot_pct_lim
self.rand_flip = rand_flip
self.rot_lim = rot_lim
self.is_train = is_train
def sample_augmentation(self, results):
H, W = results['ori_shape']
fH, fW = self.final_dim
if self.is_train:
resize = np.random.uniform(*self.resize_lim)
resize_dims = (int(W * resize), int(H * resize))
newW, newH = resize_dims
crop_h = int(
(1 - np.random.uniform(*self.bot_pct_lim)) * newH) - fH
crop_w = int(np.random.uniform(0, max(0, newW - fW)))
crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
flip = False
if self.rand_flip and np.random.choice([0, 1]):
flip = True
rotate = np.random.uniform(*self.rot_lim)
else:
resize = np.mean(self.resize_lim)
resize_dims = (int(W * resize), int(H * resize))
newW, newH = resize_dims
crop_h = int((1 - np.mean(self.bot_pct_lim)) * newH) - fH
crop_w = int(max(0, newW - fW) / 2)
crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
flip = False
rotate = 0
return resize, resize_dims, crop, flip, rotate
def img_transform(self, img, rotation, translation, resize, resize_dims,
crop, flip, rotate):
# adjust image
img = Image.fromarray(img.astype('uint8'), mode='RGB')
img = img.resize(resize_dims)
img = img.crop(crop)
if flip:
img = img.transpose(method=Image.FLIP_LEFT_RIGHT)
img = img.rotate(rotate)
# post-homography transformation
rotation *= resize
translation -= torch.Tensor(crop[:2])
if flip:
A = torch.Tensor([[-1, 0], [0, 1]])
b = torch.Tensor([crop[2] - crop[0], 0])
rotation = A.matmul(rotation)
translation = A.matmul(translation) + b
theta = rotate / 180 * np.pi
A = torch.Tensor([
[np.cos(theta), np.sin(theta)],
[-np.sin(theta), np.cos(theta)],
])
b = torch.Tensor([crop[2] - crop[0], crop[3] - crop[1]]) / 2
b = A.matmul(-b) + b
rotation = A.matmul(rotation)
translation = A.matmul(translation) + b
return img, rotation, translation
def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
imgs = data['img']
new_imgs = []
transforms = []
for img in imgs:
resize, resize_dims, crop, flip, rotate = self.sample_augmentation(
data)
post_rot = torch.eye(2)
post_tran = torch.zeros(2)
new_img, rotation, translation = self.img_transform(
img,
post_rot,
post_tran,
resize=resize,
resize_dims=resize_dims,
crop=crop,
flip=flip,
rotate=rotate,
)
transform = torch.eye(4)
transform[:2, :2] = rotation
transform[:2, 3] = translation
new_imgs.append(np.array(new_img).astype(np.float32))
transforms.append(transform.numpy())
data['img'] = new_imgs
# update the calibration matrices
data['img_aug_matrix'] = transforms
return data
@TRANSFORMS.register_module()
class GridMask(BaseTransform):
def __init__(
self,
use_h,
use_w,
max_epoch,
rotate=1,
offset=False,
ratio=0.5,
mode=0,
prob=1.0,
fixed_prob=False,
):
self.use_h = use_h
self.use_w = use_w
self.rotate = rotate
self.offset = offset
self.ratio = ratio
self.mode = mode
self.st_prob = prob
self.prob = prob
self.epoch = None
self.max_epoch = max_epoch
self.fixed_prob = fixed_prob
def set_epoch(self, epoch):
self.epoch = epoch
if not self.fixed_prob:
self.set_prob(self.epoch, self.max_epoch)
def set_prob(self, epoch, max_epoch):
self.prob = self.st_prob * self.epoch / self.max_epoch
def transform(self, results):
if np.random.rand() > self.prob:
return results
imgs = results['img']
h = imgs[0].shape[0]
w = imgs[0].shape[1]
self.d1 = 2
self.d2 = min(h, w)
hh = int(1.5 * h)
ww = int(1.5 * w)
d = np.random.randint(self.d1, self.d2)
if self.ratio == 1:
self.length = np.random.randint(1, d)
else:
self.length = min(max(int(d * self.ratio + 0.5), 1), d - 1)
mask = np.ones((hh, ww), np.float32)
st_h = np.random.randint(d)
st_w = np.random.randint(d)
if self.use_h:
for i in range(hh // d):
s = d * i + st_h
t = min(s + self.length, hh)
mask[s:t, :] *= 0
if self.use_w:
for i in range(ww // d):
s = d * i + st_w
t = min(s + self.length, ww)
mask[:, s:t] *= 0
r = np.random.randint(self.rotate)
mask = Image.fromarray(np.uint8(mask))
mask = mask.rotate(r)
mask = np.asarray(mask)
mask = mask[(hh - h) // 2:(hh - h) // 2 + h,
(ww - w) // 2:(ww - w) // 2 + w]
mask = mask.astype(np.float32)
mask = mask[:, :, None]
if self.mode == 1:
mask = 1 - mask
# mask = mask.expand_as(imgs[0])
if self.offset:
offset = torch.from_numpy(2 * (np.random.rand(h, w) - 0.5)).float()
offset = (1 - mask) * offset
imgs = [x * mask + offset for x in imgs]
else:
imgs = [x * mask for x in imgs]
results.update(img=imgs)
return results
# modify from https://github.com/mit-han-lab/bevfusion
import copy
from typing import List
import numpy as np
import torch
import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_conv_layer
from mmdet.models.task_modules import (AssignResult, PseudoSampler,
build_assigner, build_bbox_coder,
build_sampler)
from mmdet.models.utils import multi_apply
from mmengine.structures import InstanceData
from torch import nn
from mmdet3d.models import circle_nms, draw_heatmap_gaussian, gaussian_radius
from mmdet3d.models.dense_heads.centerpoint_head import SeparateHead
from mmdet3d.models.layers import nms_bev
from mmdet3d.registry import MODELS
from mmdet3d.structures import xywhr2xyxyr
def clip_sigmoid(x, eps=1e-4):
y = torch.clamp(x.sigmoid_(), min=eps, max=1 - eps)
return y
@MODELS.register_module()
class ConvFuser(nn.Sequential):
def __init__(self, in_channels: int, out_channels: int) -> None:
self.in_channels = in_channels
self.out_channels = out_channels
super().__init__(
nn.Conv2d(
sum(in_channels), out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
)
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
return super().forward(torch.cat(inputs, dim=1))
@MODELS.register_module()
class TransFusionHead(nn.Module):
def __init__(
self,
num_proposals=128,
auxiliary=True,
in_channels=128 * 3,
hidden_channel=128,
num_classes=4,
# config for Transformer
num_decoder_layers=3,
decoder_layer=dict(),
num_heads=8,
nms_kernel_size=1,
bn_momentum=0.1,
# config for FFN
common_heads=dict(),
num_heatmap_convs=2,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
bias='auto',
# loss
loss_cls=dict(type='mmdet.GaussianFocalLoss', reduction='mean'),
loss_bbox=dict(type='mmdet.L1Loss', reduction='mean'),
loss_heatmap=dict(type='mmdet.GaussianFocalLoss', reduction='mean'),
# others
train_cfg=None,
test_cfg=None,
bbox_coder=None,
):
super(TransFusionHead, self).__init__()
self.fp16_enabled = False
self.num_classes = num_classes
self.num_proposals = num_proposals
self.auxiliary = auxiliary
self.in_channels = in_channels
self.num_heads = num_heads
self.num_decoder_layers = num_decoder_layers
self.bn_momentum = bn_momentum
self.nms_kernel_size = nms_kernel_size
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
if not self.use_sigmoid_cls:
self.num_classes += 1
self.loss_cls = MODELS.build(loss_cls)
self.loss_bbox = MODELS.build(loss_bbox)
self.loss_heatmap = MODELS.build(loss_heatmap)
self.bbox_coder = build_bbox_coder(bbox_coder)
self.sampling = False
# a shared convolution
self.shared_conv = build_conv_layer(
dict(type='Conv2d'),
in_channels,
hidden_channel,
kernel_size=3,
padding=1,
bias=bias,
)
layers = []
layers.append(
ConvModule(
hidden_channel,
hidden_channel,
kernel_size=3,
padding=1,
bias=bias,
conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'),
))
layers.append(
build_conv_layer(
dict(type='Conv2d'),
hidden_channel,
num_classes,
kernel_size=3,
padding=1,
bias=bias,
))
self.heatmap_head = nn.Sequential(*layers)
self.class_encoding = nn.Conv1d(num_classes, hidden_channel, 1)
# transformer decoder layers for object query with LiDAR feature
self.decoder = nn.ModuleList()
for i in range(self.num_decoder_layers):
self.decoder.append(MODELS.build(decoder_layer))
# Prediction Head
self.prediction_heads = nn.ModuleList()
for i in range(self.num_decoder_layers):
heads = copy.deepcopy(common_heads)
heads.update(dict(heatmap=(self.num_classes, num_heatmap_convs)))
self.prediction_heads.append(
SeparateHead(
hidden_channel,
heads,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=bias,
))
self.init_weights()
self._init_assigner_sampler()
# Position Embedding for Cross-Attention, which is re-used during training # noqa: E501
x_size = self.test_cfg['grid_size'][0] // self.test_cfg[
'out_size_factor']
y_size = self.test_cfg['grid_size'][1] // self.test_cfg[
'out_size_factor']
self.bev_pos = self.create_2D_grid(x_size, y_size)
self.img_feat_pos = None
self.img_feat_collapsed_pos = None
def create_2D_grid(self, x_size, y_size):
meshgrid = [[0, x_size - 1, x_size], [0, y_size - 1, y_size]]
# NOTE: modified
batch_x, batch_y = torch.meshgrid(
*[torch.linspace(it[0], it[1], it[2]) for it in meshgrid])
batch_x = batch_x + 0.5
batch_y = batch_y + 0.5
coord_base = torch.cat([batch_x[None], batch_y[None]], dim=0)[None]
coord_base = coord_base.view(1, 2, -1).permute(0, 2, 1)
return coord_base
def init_weights(self):
# initialize transformer
for m in self.decoder.parameters():
if m.dim() > 1:
nn.init.xavier_uniform_(m)
if hasattr(self, 'query'):
nn.init.xavier_normal_(self.query)
self.init_bn_momentum()
def init_bn_momentum(self):
for m in self.modules():
if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
m.momentum = self.bn_momentum
def _init_assigner_sampler(self):
"""Initialize the target assigner and sampler of the head."""
if self.train_cfg is None:
return
if self.sampling:
self.bbox_sampler = build_sampler(self.train_cfg.sampler)
else:
self.bbox_sampler = PseudoSampler()
if isinstance(self.train_cfg.assigner, dict):
self.bbox_assigner = build_assigner(self.train_cfg.assigner)
elif isinstance(self.train_cfg.assigner, list):
self.bbox_assigner = [
build_assigner(res) for res in self.train_cfg.assigner
]
def forward_single(self, inputs, metas):
"""Forward function for CenterPoint.
Args:
inputs (torch.Tensor): Input feature map with the shape of
[B, 512, 128(H), 128(W)]. (consistent with L748)
Returns:
list[dict]: Output results for tasks.
"""
batch_size = inputs.shape[0]
fusion_feat = self.shared_conv(inputs)
#################################
# image to BEV
#################################
fusion_feat_flatten = fusion_feat.view(batch_size,
fusion_feat.shape[1],
-1) # [BS, C, H*W]
bev_pos = self.bev_pos.repeat(batch_size, 1, 1).to(fusion_feat.device)
#################################
# query initialization
#################################
dense_heatmap = self.heatmap_head(fusion_feat)
heatmap = dense_heatmap.detach().sigmoid()
padding = self.nms_kernel_size // 2
local_max = torch.zeros_like(heatmap)
# equals to nms radius = voxel_size * out_size_factor * kenel_size
local_max_inner = F.max_pool2d(
heatmap, kernel_size=self.nms_kernel_size, stride=1, padding=0)
local_max[:, :, padding:(-padding),
padding:(-padding)] = local_max_inner
# for Pedestrian & Traffic_cone in nuScenes
if self.test_cfg['dataset'] == 'nuScenes':
local_max[:, 8, ] = F.max_pool2d(
heatmap[:, 8], kernel_size=1, stride=1, padding=0)
local_max[:, 9, ] = F.max_pool2d(
heatmap[:, 9], kernel_size=1, stride=1, padding=0)
elif self.test_cfg[
'dataset'] == 'Waymo': # for Pedestrian & Cyclist in Waymo
local_max[:, 1, ] = F.max_pool2d(
heatmap[:, 1], kernel_size=1, stride=1, padding=0)
local_max[:, 2, ] = F.max_pool2d(
heatmap[:, 2], kernel_size=1, stride=1, padding=0)
heatmap = heatmap * (heatmap == local_max)
heatmap = heatmap.view(batch_size, heatmap.shape[1], -1)
# top num_proposals among all classes
top_proposals = heatmap.view(batch_size, -1).argsort(
dim=-1, descending=True)[..., :self.num_proposals]
top_proposals_class = top_proposals // heatmap.shape[-1]
top_proposals_index = top_proposals % heatmap.shape[-1]
query_feat = fusion_feat_flatten.gather(
index=top_proposals_index[:, None, :].expand(
-1, fusion_feat_flatten.shape[1], -1),
dim=-1,
)
self.query_labels = top_proposals_class
# add category embedding
one_hot = F.one_hot(
top_proposals_class,
num_classes=self.num_classes).permute(0, 2, 1)
query_cat_encoding = self.class_encoding(one_hot.float())
query_feat += query_cat_encoding
query_pos = bev_pos.gather(
index=top_proposals_index[:, None, :].permute(0, 2, 1).expand(
-1, -1, bev_pos.shape[-1]),
dim=1,
)
#################################
# transformer decoder layer (Fusion feature as K,V)
#################################
ret_dicts = []
for i in range(self.num_decoder_layers):
# Transformer Decoder Layer
# :param query: B C Pq :param query_pos: B Pq 3/6
query_feat = self.decoder[i](
query_feat,
key=fusion_feat_flatten,
query_pos=query_pos,
key_pos=bev_pos)
# Prediction
res_layer = self.prediction_heads[i](query_feat)
res_layer['center'] = res_layer['center'] + query_pos.permute(
0, 2, 1)
ret_dicts.append(res_layer)
# for next level positional embedding
query_pos = res_layer['center'].detach().clone().permute(0, 2, 1)
ret_dicts[0]['query_heatmap_score'] = heatmap.gather(
index=top_proposals_index[:,
None, :].expand(-1, self.num_classes,
-1),
dim=-1,
) # [bs, num_classes, num_proposals]
ret_dicts[0]['dense_heatmap'] = dense_heatmap
if self.auxiliary is False:
# only return the results of last decoder layer
return [ret_dicts[-1]]
# return all the layer's results for auxiliary superivison
new_res = {}
for key in ret_dicts[0].keys():
if key not in [
'dense_heatmap', 'dense_heatmap_old', 'query_heatmap_score'
]:
new_res[key] = torch.cat(
[ret_dict[key] for ret_dict in ret_dicts], dim=-1)
else:
new_res[key] = ret_dicts[0][key]
return [new_res]
def forward(self, feats, metas):
"""Forward pass.
Args:
feats (list[torch.Tensor]): Multi-level features, e.g.,
features produced by FPN.
Returns:
tuple(list[dict]): Output results. first index by level, second
index by layer
"""
if isinstance(feats, torch.Tensor):
feats = [feats]
res = multi_apply(self.forward_single, feats, [metas])
assert len(res) == 1, 'only support one level features.'
return res
def predict(self, batch_feats, batch_input_metas):
preds_dicts = self(batch_feats, batch_input_metas)
res = self.predict_by_feat(preds_dicts, batch_input_metas)
return res
def predict_by_feat(self,
preds_dicts,
metas,
img=None,
rescale=False,
for_roi=False):
"""Generate bboxes from bbox head predictions.
Args:
preds_dicts (tuple[list[dict]]): Prediction results.
Returns:
list[list[dict]]: Decoded bbox, scores and labels for each layer
& each batch.
"""
rets = []
for layer_id, preds_dict in enumerate(preds_dicts):
batch_size = preds_dict[0]['heatmap'].shape[0]
batch_score = preds_dict[0]['heatmap'][
..., -self.num_proposals:].sigmoid()
# if self.loss_iou.loss_weight != 0:
# batch_score = torch.sqrt(batch_score * preds_dict[0]['iou'][..., -self.num_proposals:].sigmoid()) # noqa: E501
one_hot = F.one_hot(
self.query_labels,
num_classes=self.num_classes).permute(0, 2, 1)
batch_score = batch_score * preds_dict[0][
'query_heatmap_score'] * one_hot
batch_center = preds_dict[0]['center'][..., -self.num_proposals:]
batch_height = preds_dict[0]['height'][..., -self.num_proposals:]
batch_dim = preds_dict[0]['dim'][..., -self.num_proposals:]
batch_rot = preds_dict[0]['rot'][..., -self.num_proposals:]
batch_vel = None
if 'vel' in preds_dict[0]:
batch_vel = preds_dict[0]['vel'][..., -self.num_proposals:]
temp = self.bbox_coder.decode(
batch_score,
batch_rot,
batch_dim,
batch_center,
batch_height,
batch_vel,
filter=True,
)
if self.test_cfg['dataset'] == 'nuScenes':
self.tasks = [
dict(
num_class=8,
class_names=[],
indices=[0, 1, 2, 3, 4, 5, 6, 7],
radius=-1,
),
dict(
num_class=1,
class_names=['pedestrian'],
indices=[8],
radius=0.175,
),
dict(
num_class=1,
class_names=['traffic_cone'],
indices=[9],
radius=0.175,
),
]
elif self.test_cfg['dataset'] == 'Waymo':
self.tasks = [
dict(
num_class=1,
class_names=['Car'],
indices=[0],
radius=0.7),
dict(
num_class=1,
class_names=['Pedestrian'],
indices=[1],
radius=0.7),
dict(
num_class=1,
class_names=['Cyclist'],
indices=[2],
radius=0.7),
]
ret_layer = []
for i in range(batch_size):
boxes3d = temp[i]['bboxes']
scores = temp[i]['scores']
labels = temp[i]['labels']
# adopt circle nms for different categories
if self.test_cfg['nms_type'] is not None:
keep_mask = torch.zeros_like(scores)
for task in self.tasks:
task_mask = torch.zeros_like(scores)
for cls_idx in task['indices']:
task_mask += labels == cls_idx
task_mask = task_mask.bool()
if task['radius'] > 0:
if self.test_cfg['nms_type'] == 'circle':
boxes_for_nms = torch.cat(
[
boxes3d[task_mask][:, :2],
scores[:, None][task_mask],
],
dim=1,
)
task_keep_indices = torch.tensor(
circle_nms(
boxes_for_nms.detach().cpu().numpy(),
task['radius'],
))
else:
boxes_for_nms = xywhr2xyxyr(
metas[i]['box_type_3d'](
boxes3d[task_mask][:, :7], 7).bev)
top_scores = scores[task_mask]
task_keep_indices = nms_bev(
boxes_for_nms,
top_scores,
thresh=task['radius'],
pre_maxsize=self.test_cfg['pre_maxsize'],
post_max_size=self.
test_cfg['post_maxsize'],
)
else:
task_keep_indices = torch.arange(task_mask.sum())
if task_keep_indices.shape[0] != 0:
keep_indices = torch.where(
task_mask != 0)[0][task_keep_indices]
keep_mask[keep_indices] = 1
keep_mask = keep_mask.bool()
ret = dict(
bboxes=boxes3d[keep_mask],
scores=scores[keep_mask],
labels=labels[keep_mask],
)
else: # no nms
ret = dict(bboxes=boxes3d, scores=scores, labels=labels)
temp_instances = InstanceData()
ret['bboxes'][:, 2] = ret[
'bboxes'][:, 2] - ret['bboxes'][:, 5] * 0.5 # noqa: E501
temp_instances.bboxes_3d = metas[0]['box_type_3d'](
ret['bboxes'], box_dim=ret['bboxes'].shape[-1])
temp_instances.scores_3d = ret['scores']
temp_instances.labels_3d = ret['labels'].int()
ret_layer.append(temp_instances)
rets.append(ret_layer)
assert len(
rets
) == 1, f'only support one layer now, but get {len(rets)} layers'
return rets[0]
def get_targets(self, gt_bboxes_3d, gt_labels_3d, preds_dict):
"""Generate training targets.
Args:
gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`): Ground truth gt boxes.
gt_labels_3d (torch.Tensor): Labels of boxes.
preds_dicts (tuple of dict): first index by layer (default 1)
Returns:
tuple[torch.Tensor]: Tuple of target including \
the following results in order.
- torch.Tensor: classification target. [BS, num_proposals]
- torch.Tensor: classification weights (mask)
[BS, num_proposals]
- torch.Tensor: regression target. [BS, num_proposals, 8]
- torch.Tensor: regression weights. [BS, num_proposals, 8]
"""
# change preds_dict into list of dict (index by batch_id)
# preds_dict[0]['center'].shape [bs, 3, num_proposal]
list_of_pred_dict = []
for batch_idx in range(len(gt_bboxes_3d)):
pred_dict = {}
for key in preds_dict[0].keys():
pred_dict[key] = preds_dict[0][key][batch_idx:batch_idx + 1]
list_of_pred_dict.append(pred_dict)
assert len(gt_bboxes_3d) == len(list_of_pred_dict)
res_tuple = multi_apply(
self.get_targets_single,
gt_bboxes_3d,
gt_labels_3d,
list_of_pred_dict,
np.arange(len(gt_labels_3d)),
)
labels = torch.cat(res_tuple[0], dim=0)
label_weights = torch.cat(res_tuple[1], dim=0)
bbox_targets = torch.cat(res_tuple[2], dim=0)
bbox_weights = torch.cat(res_tuple[3], dim=0)
ious = torch.cat(res_tuple[4], dim=0)
num_pos = np.sum(res_tuple[5])
matched_ious = np.mean(res_tuple[6])
heatmap = torch.cat(res_tuple[7], dim=0)
return (
labels,
label_weights,
bbox_targets,
bbox_weights,
ious,
num_pos,
matched_ious,
heatmap,
)
def get_targets_single(self, gt_bboxes_3d, gt_labels_3d, preds_dict,
batch_idx):
"""Generate training targets for a single sample.
Args:
gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`): Ground truth gt boxes.
gt_labels_3d (torch.Tensor): Labels of boxes.
preds_dict (dict): dict of prediction result for a single sample
Returns:
tuple[torch.Tensor]: Tuple of target including \
the following results in order.
- torch.Tensor: classification target. [1, num_proposals]
- torch.Tensor: classification weights (mask) [1, num_proposals] # noqa: E501
- torch.Tensor: regression target. [1, num_proposals, 8]
- torch.Tensor: regression weights. [1, num_proposals, 8]
- torch.Tensor: iou target. [1, num_proposals]
- int: number of positive proposals
"""
num_proposals = preds_dict['center'].shape[-1]
# get pred boxes, carefully ! don't change the network outputs
score = copy.deepcopy(preds_dict['heatmap'].detach())
center = copy.deepcopy(preds_dict['center'].detach())
height = copy.deepcopy(preds_dict['height'].detach())
dim = copy.deepcopy(preds_dict['dim'].detach())
rot = copy.deepcopy(preds_dict['rot'].detach())
if 'vel' in preds_dict.keys():
vel = copy.deepcopy(preds_dict['vel'].detach())
else:
vel = None
boxes_dict = self.bbox_coder.decode(
score, rot, dim, center, height,
vel) # decode the prediction to real world metric bbox
bboxes_tensor = boxes_dict[0]['bboxes']
gt_bboxes_tensor = gt_bboxes_3d.tensor.to(score.device)
# each layer should do label assign separately.
if self.auxiliary:
num_layer = self.num_decoder_layers
else:
num_layer = 1
assign_result_list = []
for idx_layer in range(num_layer):
bboxes_tensor_layer = bboxes_tensor[self.num_proposals *
idx_layer:self.num_proposals *
(idx_layer + 1), :]
score_layer = score[..., self.num_proposals *
idx_layer:self.num_proposals *
(idx_layer + 1), ]
if self.train_cfg.assigner.type == 'HungarianAssigner3D':
assign_result = self.bbox_assigner.assign(
bboxes_tensor_layer,
gt_bboxes_tensor,
gt_labels_3d,
score_layer,
self.train_cfg,
)
elif self.train_cfg.assigner.type == 'HeuristicAssigner':
assign_result = self.bbox_assigner.assign(
bboxes_tensor_layer,
gt_bboxes_tensor,
None,
gt_labels_3d,
self.query_labels[batch_idx],
)
else:
raise NotImplementedError
assign_result_list.append(assign_result)
# combine assign result of each layer
assign_result_ensemble = AssignResult(
num_gts=sum([res.num_gts for res in assign_result_list]),
gt_inds=torch.cat([res.gt_inds for res in assign_result_list]),
max_overlaps=torch.cat(
[res.max_overlaps for res in assign_result_list]),
labels=torch.cat([res.labels for res in assign_result_list]),
)
sampling_result = self.bbox_sampler.sample(assign_result_ensemble,
bboxes_tensor,
gt_bboxes_tensor)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
assert len(pos_inds) + len(neg_inds) == num_proposals
# create target for loss computation
bbox_targets = torch.zeros([num_proposals, self.bbox_coder.code_size
]).to(center.device)
bbox_weights = torch.zeros([num_proposals, self.bbox_coder.code_size
]).to(center.device)
ious = assign_result_ensemble.max_overlaps
ious = torch.clamp(ious, min=0.0, max=1.0)
labels = bboxes_tensor.new_zeros(num_proposals, dtype=torch.long)
label_weights = bboxes_tensor.new_zeros(
num_proposals, dtype=torch.long)
if gt_labels_3d is not None: # default label is -1
labels += self.num_classes
# both pos and neg have classification loss, only pos has regression
# and iou loss
if len(pos_inds) > 0:
pos_bbox_targets = self.bbox_coder.encode(
sampling_result.pos_gt_bboxes)
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0
if gt_labels_3d is None:
labels[pos_inds] = 1
else:
labels[pos_inds] = gt_labels_3d[
sampling_result.pos_assigned_gt_inds]
if self.train_cfg.pos_weight <= 0:
label_weights[pos_inds] = 1.0
else:
label_weights[pos_inds] = self.train_cfg.pos_weight
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0
# # compute dense heatmap targets
device = labels.device
gt_bboxes_3d = torch.cat(
[gt_bboxes_3d.gravity_center, gt_bboxes_3d.tensor[:, 3:]],
dim=1).to(device)
grid_size = torch.tensor(self.train_cfg['grid_size'])
pc_range = torch.tensor(self.train_cfg['point_cloud_range'])
voxel_size = torch.tensor(self.train_cfg['voxel_size'])
feature_map_size = (grid_size[:2] // self.train_cfg['out_size_factor']
) # [x_len, y_len]
heatmap = gt_bboxes_3d.new_zeros(self.num_classes, feature_map_size[1],
feature_map_size[0])
for idx in range(len(gt_bboxes_3d)):
width = gt_bboxes_3d[idx][3]
length = gt_bboxes_3d[idx][4]
width = width / voxel_size[0] / self.train_cfg['out_size_factor']
length = length / voxel_size[1] / self.train_cfg['out_size_factor']
if width > 0 and length > 0:
radius = gaussian_radius(
(length, width),
min_overlap=self.train_cfg['gaussian_overlap'])
radius = max(self.train_cfg['min_radius'], int(radius))
x, y = gt_bboxes_3d[idx][0], gt_bboxes_3d[idx][1]
coor_x = ((x - pc_range[0]) / voxel_size[0] /
self.train_cfg['out_size_factor'])
coor_y = ((y - pc_range[1]) / voxel_size[1] /
self.train_cfg['out_size_factor'])
center = torch.tensor([coor_x, coor_y],
dtype=torch.float32,
device=device)
center_int = center.to(torch.int32)
# original
# draw_heatmap_gaussian(heatmap[gt_labels_3d[idx]], center_int, radius) # noqa: E501
# NOTE: fix
draw_heatmap_gaussian(heatmap[gt_labels_3d[idx]],
center_int[[1, 0]], radius)
mean_iou = ious[pos_inds].sum() / max(len(pos_inds), 1)
return (
labels[None],
label_weights[None],
bbox_targets[None],
bbox_weights[None],
ious[None],
int(pos_inds.shape[0]),
float(mean_iou),
heatmap[None],
)
def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs):
"""Loss function for CenterHead.
Args:
gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground
truth gt boxes.
gt_labels_3d (list[torch.Tensor]): Labels of boxes.
preds_dicts (list[list[dict]]): Output of forward function.
Returns:
dict[str:torch.Tensor]: Loss of heatmap and bbox of each task.
"""
(
labels,
label_weights,
bbox_targets,
bbox_weights,
ious,
num_pos,
matched_ious,
heatmap,
) = self.get_targets(gt_bboxes_3d, gt_labels_3d, preds_dicts[0])
if hasattr(self, 'on_the_image_mask'):
label_weights = label_weights * self.on_the_image_mask
bbox_weights = bbox_weights * self.on_the_image_mask[:, :, None]
num_pos = bbox_weights.max(-1).values.sum()
preds_dict = preds_dicts[0][0]
loss_dict = dict()
# compute heatmap loss
loss_heatmap = self.loss_heatmap(
clip_sigmoid(preds_dict['dense_heatmap']),
heatmap,
avg_factor=max(heatmap.eq(1).float().sum().item(), 1),
)
loss_dict['loss_heatmap'] = loss_heatmap
# compute loss for each layer
for idx_layer in range(
self.num_decoder_layers if self.auxiliary else 1):
if idx_layer == self.num_decoder_layers - 1 or (
idx_layer == 0 and self.auxiliary is False):
prefix = 'layer_-1'
else:
prefix = f'layer_{idx_layer}'
layer_labels = labels[..., idx_layer *
self.num_proposals:(idx_layer + 1) *
self.num_proposals, ].reshape(-1)
layer_label_weights = label_weights[
..., idx_layer * self.num_proposals:(idx_layer + 1) *
self.num_proposals, ].reshape(-1)
layer_score = preds_dict['heatmap'][..., idx_layer *
self.num_proposals:(idx_layer +
1) *
self.num_proposals, ]
layer_cls_score = layer_score.permute(0, 2, 1).reshape(
-1, self.num_classes)
layer_loss_cls = self.loss_cls(
layer_cls_score,
layer_labels,
layer_label_weights,
avg_factor=max(num_pos, 1),
)
layer_center = preds_dict['center'][..., idx_layer *
self.num_proposals:(idx_layer +
1) *
self.num_proposals, ]
layer_height = preds_dict['height'][..., idx_layer *
self.num_proposals:(idx_layer +
1) *
self.num_proposals, ]
layer_rot = preds_dict['rot'][..., idx_layer *
self.num_proposals:(idx_layer + 1) *
self.num_proposals, ]
layer_dim = preds_dict['dim'][..., idx_layer *
self.num_proposals:(idx_layer + 1) *
self.num_proposals, ]
preds = torch.cat(
[layer_center, layer_height, layer_dim, layer_rot],
dim=1).permute(0, 2, 1) # [BS, num_proposals, code_size]
if 'vel' in preds_dict.keys():
layer_vel = preds_dict['vel'][..., idx_layer *
self.num_proposals:(idx_layer +
1) *
self.num_proposals, ]
preds = torch.cat([
layer_center, layer_height, layer_dim, layer_rot, layer_vel
],
dim=1).permute(
0, 2,
1) # [BS, num_proposals, code_size]
code_weights = self.train_cfg.get('code_weights', None)
layer_bbox_weights = bbox_weights[:, idx_layer *
self.num_proposals:(idx_layer +
1) *
self.num_proposals, :, ]
layer_reg_weights = layer_bbox_weights * layer_bbox_weights.new_tensor( # noqa: E501
code_weights)
layer_bbox_targets = bbox_targets[:, idx_layer *
self.num_proposals:(idx_layer +
1) *
self.num_proposals, :, ]
layer_loss_bbox = self.loss_bbox(
preds,
layer_bbox_targets,
layer_reg_weights,
avg_factor=max(num_pos, 1))
loss_dict[f'{prefix}_loss_cls'] = layer_loss_cls
loss_dict[f'{prefix}_loss_bbox'] = layer_loss_bbox
# loss_dict[f'{prefix}_loss_iou'] = layer_loss_iou
loss_dict['matched_ious'] = layer_loss_cls.new_tensor(matched_ious)
return loss_dict
# modify from https://github.com/mit-han-lab/bevfusion
import torch
from mmdet.models.task_modules import AssignResult, BaseAssigner, BaseBBoxCoder
try:
from scipy.optimize import linear_sum_assignment
except ImportError:
linear_sum_assignment = None
from mmdet3d.registry import TASK_UTILS
@TASK_UTILS.register_module()
class TransFusionBBoxCoder(BaseBBoxCoder):
def __init__(
self,
pc_range,
out_size_factor,
voxel_size,
post_center_range=None,
score_threshold=None,
code_size=8,
):
self.pc_range = pc_range
self.out_size_factor = out_size_factor
self.voxel_size = voxel_size
self.post_center_range = post_center_range
self.score_threshold = score_threshold
self.code_size = code_size
def encode(self, dst_boxes):
targets = torch.zeros([dst_boxes.shape[0],
self.code_size]).to(dst_boxes.device)
targets[:, 0] = (dst_boxes[:, 0] - self.pc_range[0]) / (
self.out_size_factor * self.voxel_size[0])
targets[:, 1] = (dst_boxes[:, 1] - self.pc_range[1]) / (
self.out_size_factor * self.voxel_size[1])
targets[:, 3] = dst_boxes[:, 3].log()
targets[:, 4] = dst_boxes[:, 4].log()
targets[:, 5] = dst_boxes[:, 5].log()
# bottom center to gravity center
targets[:, 2] = dst_boxes[:, 2] + dst_boxes[:, 5] * 0.5
targets[:, 6] = torch.sin(dst_boxes[:, 6])
targets[:, 7] = torch.cos(dst_boxes[:, 6])
if self.code_size == 10:
targets[:, 8:10] = dst_boxes[:, 7:]
return targets
def decode(self, heatmap, rot, dim, center, height, vel, filter=False):
"""Decode bboxes.
Args:
heat (torch.Tensor): Heatmap with the shape of
[B, num_cls, num_proposals].
rot (torch.Tensor): Rotation with the shape of
[B, 1, num_proposals].
dim (torch.Tensor): Dim of the boxes with the shape of
[B, 3, num_proposals].
center (torch.Tensor): bev center of the boxes with the shape of
[B, 2, num_proposals]. (in feature map metric)
height (torch.Tensor): height of the boxes with the shape of
[B, 2, num_proposals]. (in real world metric)
vel (torch.Tensor): Velocity with the shape of
[B, 2, num_proposals].
filter: if False, return all box without checking score and
center_range
Returns:
list[dict]: Decoded boxes.
"""
# class label
final_preds = heatmap.max(1, keepdims=False).indices
final_scores = heatmap.max(1, keepdims=False).values
# change size to real world metric
center[:,
0, :] = center[:,
0, :] * self.out_size_factor * self.voxel_size[
0] + self.pc_range[0]
center[:,
1, :] = center[:,
1, :] * self.out_size_factor * self.voxel_size[
1] + self.pc_range[1]
dim[:, 0, :] = dim[:, 0, :].exp()
dim[:, 1, :] = dim[:, 1, :].exp()
dim[:, 2, :] = dim[:, 2, :].exp()
height = height - dim[:,
2:3, :] * 0.5 # gravity center to bottom center
rots, rotc = rot[:, 0:1, :], rot[:, 1:2, :]
rot = torch.atan2(rots, rotc)
if vel is None:
final_box_preds = torch.cat([center, height, dim, rot],
dim=1).permute(0, 2, 1)
else:
final_box_preds = torch.cat([center, height, dim, rot, vel],
dim=1).permute(0, 2, 1)
predictions_dicts = []
for i in range(heatmap.shape[0]):
boxes3d = final_box_preds[i]
scores = final_scores[i]
labels = final_preds[i]
predictions_dict = {
'bboxes': boxes3d,
'scores': scores,
'labels': labels
}
predictions_dicts.append(predictions_dict)
if filter is False:
return predictions_dicts
# use score threshold
if self.score_threshold is not None:
thresh_mask = final_scores > self.score_threshold
if self.post_center_range is not None:
self.post_center_range = torch.tensor(
self.post_center_range, device=heatmap.device)
mask = (final_box_preds[..., :3] >=
self.post_center_range[:3]).all(2)
mask &= (final_box_preds[..., :3] <=
self.post_center_range[3:]).all(2)
predictions_dicts = []
for i in range(heatmap.shape[0]):
cmask = mask[i, :]
if self.score_threshold:
cmask &= thresh_mask[i]
boxes3d = final_box_preds[i, cmask]
scores = final_scores[i, cmask]
labels = final_preds[i, cmask]
predictions_dict = {
'bboxes': boxes3d,
'scores': scores,
'labels': labels
}
predictions_dicts.append(predictions_dict)
else:
raise NotImplementedError(
'Need to reorganize output as a batch, only '
'support post_center_range is not None for now!')
return predictions_dicts
@TASK_UTILS.register_module()
class BBoxBEVL1Cost(object):
def __init__(self, weight):
self.weight = weight
def __call__(self, bboxes, gt_bboxes, train_cfg):
pc_start = bboxes.new(train_cfg['point_cloud_range'][0:2])
pc_range = bboxes.new(
train_cfg['point_cloud_range'][3:5]) - bboxes.new(
train_cfg['point_cloud_range'][0:2])
# normalize the box center to [0, 1]
normalized_bboxes_xy = (bboxes[:, :2] - pc_start) / pc_range
normalized_gt_bboxes_xy = (gt_bboxes[:, :2] - pc_start) / pc_range
reg_cost = torch.cdist(
normalized_bboxes_xy, normalized_gt_bboxes_xy, p=1)
return reg_cost * self.weight
@TASK_UTILS.register_module()
class IoU3DCost(object):
def __init__(self, weight):
self.weight = weight
def __call__(self, iou):
iou_cost = -iou
return iou_cost * self.weight
@TASK_UTILS.register_module()
class HeuristicAssigner3D(BaseAssigner):
def __init__(self,
dist_thre=100,
iou_calculator=dict(type='BboxOverlaps3D')):
self.dist_thre = dist_thre # distance in meter
self.iou_calculator = TASK_UTILS.build(iou_calculator)
def assign(self,
bboxes,
gt_bboxes,
gt_bboxes_ignore=None,
gt_labels=None,
query_labels=None):
dist_thre = self.dist_thre
num_gts, num_bboxes = len(gt_bboxes), len(bboxes)
bev_dist = torch.norm(
bboxes[:, 0:2][None, :, :] - gt_bboxes[:, 0:2][:, None, :],
dim=-1) # [num_gts, num_bboxes]
if query_labels is not None:
# only match the gt box and query with same category
not_same_class = (query_labels[None] != gt_labels[:, None])
bev_dist += not_same_class * dist_thre
# for each gt box, assign it to the nearest pred box
nearest_values, nearest_indices = bev_dist.min(1) # [num_gts]
assigned_gt_inds = torch.ones([
num_bboxes,
]).to(bboxes) * 0
assigned_gt_vals = torch.ones([
num_bboxes,
]).to(bboxes) * 10000
assigned_gt_labels = torch.ones([
num_bboxes,
]).to(bboxes) * -1
for idx_gts in range(num_gts):
# for idx_pred in torch.where(bev_dist[idx_gts] < dist_thre)[0]:
# # each gt match to all the pred box within some radius
idx_pred = nearest_indices[
idx_gts] # each gt only match to the nearest pred box
if bev_dist[idx_gts, idx_pred] <= dist_thre:
# if this pred box is assigned, then compare
if bev_dist[idx_gts, idx_pred] < assigned_gt_vals[idx_pred]:
assigned_gt_vals[idx_pred] = bev_dist[idx_gts, idx_pred]
# for AssignResult, 0 is negative, -1 is ignore, 1-based
# indices are positive
assigned_gt_inds[idx_pred] = idx_gts + 1
assigned_gt_labels[idx_pred] = gt_labels[idx_gts]
max_overlaps = torch.zeros([
num_bboxes,
]).to(bboxes)
matched_indices = torch.where(assigned_gt_inds > 0)
matched_iou = self.iou_calculator(
gt_bboxes[assigned_gt_inds[matched_indices].long() - 1],
bboxes[matched_indices]).diag()
max_overlaps[matched_indices] = matched_iou
return AssignResult(
num_gts,
assigned_gt_inds.long(),
max_overlaps,
labels=assigned_gt_labels)
@TASK_UTILS.register_module()
class HungarianAssigner3D(BaseAssigner):
def __init__(self,
cls_cost=dict(type='ClassificationCost', weight=1.),
reg_cost=dict(type='BBoxBEVL1Cost', weight=1.0),
iou_cost=dict(type='IoU3DCost', weight=1.0),
iou_calculator=dict(type='BboxOverlaps3D')):
self.cls_cost = TASK_UTILS.build(cls_cost)
self.reg_cost = TASK_UTILS.build(reg_cost)
self.iou_cost = TASK_UTILS.build(iou_cost)
self.iou_calculator = TASK_UTILS.build(iou_calculator)
def assign(self, bboxes, gt_bboxes, gt_labels, cls_pred, train_cfg):
num_gts, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
# 1. assign -1 by default
assigned_gt_inds = bboxes.new_full((num_bboxes, ),
-1,
dtype=torch.long)
assigned_labels = bboxes.new_full((num_bboxes, ), -1, dtype=torch.long)
if num_gts == 0 or num_bboxes == 0:
# No ground truth or boxes, return empty assignment
if num_gts == 0:
# No ground truth, assign all to background
assigned_gt_inds[:] = 0
return AssignResult(
num_gts, assigned_gt_inds, None, labels=assigned_labels)
# 2. compute the weighted costs
# see mmdetection/mmdet/core/bbox/match_costs/match_cost.py
cls_cost = self.cls_cost(cls_pred[0].T, gt_labels)
reg_cost = self.reg_cost(bboxes, gt_bboxes, train_cfg)
iou = self.iou_calculator(bboxes, gt_bboxes)
iou_cost = self.iou_cost(iou)
# weighted sum of above three costs
cost = cls_cost + reg_cost + iou_cost
# 3. do Hungarian matching on CPU using linear_sum_assignment
cost = cost.detach().cpu()
if linear_sum_assignment is None:
raise ImportError('Please run "pip install scipy" '
'to install scipy first.')
matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
matched_row_inds = torch.from_numpy(matched_row_inds).to(bboxes.device)
matched_col_inds = torch.from_numpy(matched_col_inds).to(bboxes.device)
# 4. assign backgrounds and foregrounds
# assign all indices to backgrounds first
assigned_gt_inds[:] = 0
# assign foregrounds based on matching results
assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
max_overlaps = torch.zeros_like(iou.max(1).values)
max_overlaps[matched_row_inds] = iou[matched_row_inds,
matched_col_inds]
# max_overlaps = iou.max(1).values
return AssignResult(
num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)
_base_ = ['mmdet3d::_base_/default_runtime.py']
custom_imports = dict(
imports=['projects.BEVFusion.bevfusion'], allow_failed_imports=False)
# model settings
# Voxel size for voxel encoder
# Usually voxel size is changed consistently with the point cloud range
# If point cloud range is modified, do remember to change all related
# keys in the config.
voxel_size = [0.075, 0.075, 0.2]
point_cloud_range = [-54.0, -54.0, -5.0, 54.0, 54.0, 3.0]
class_names = [
'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
]
metainfo = dict(classes=class_names)
dataset_type = 'NuScenesDataset'
data_root = 'data/nuscenes/'
data_prefix = dict(
pts='samples/LIDAR_TOP',
CAM_FRONT='samples/CAM_FRONT',
CAM_FRONT_LEFT='samples/CAM_FRONT_LEFT',
CAM_FRONT_RIGHT='samples/CAM_FRONT_RIGHT',
CAM_BACK='samples/CAM_BACK',
CAM_BACK_RIGHT='samples/CAM_BACK_RIGHT',
CAM_BACK_LEFT='samples/CAM_BACK_LEFT',
sweeps='sweeps/LIDAR_TOP')
input_modality = dict(use_lidar=True, use_camera=True)
file_client_args = dict(backend='disk')
model = dict(
type='BEVFusion',
data_preprocessor=dict(
type='Det3DDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=False,
pad_size_divisor=32,
voxelize_cfg=dict(
max_num_points=10,
point_cloud_range=[-54.0, -54.0, -5.0, 54.0, 54.0, 3.0],
voxel_size=[0.075, 0.075, 0.2],
max_voxels=[120000, 160000],
voxelize_reduce=True)),
img_backbone=dict(
type='mmdet.SwinTransformer',
embed_dims=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.2,
patch_norm=True,
out_indices=[1, 2, 3],
with_cp=False,
convert_weights=True,
init_cfg=dict(
type='Pretrained',
checkpoint= # noqa: E251
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' # noqa: E501
)),
img_neck=dict(
type='GeneralizedLSSFPN',
in_channels=[192, 384, 768],
out_channels=256,
start_level=0,
num_outs=3,
norm_cfg=dict(type='BN2d', requires_grad=True),
act_cfg=dict(type='ReLU', inplace=True),
upsample_cfg=dict(mode='bilinear', align_corners=False)),
vtransform=dict(
type='DepthLSSTransform',
in_channels=256,
out_channels=80,
image_size=[256, 704],
feature_size=[32, 88],
xbound=[-54.0, 54.0, 0.3],
ybound=[-54.0, 54.0, 0.3],
zbound=[-10.0, 10.0, 20.0],
dbound=[1.0, 60.0, 0.5],
downsample=2),
pts_voxel_encoder=dict(type='HardSimpleVFE', num_features=5),
pts_middle_encoder=dict(
type='BEVFusionSparseEncoder',
in_channels=5,
sparse_shape=[1440, 1440, 41],
order=('conv', 'norm', 'act'),
norm_cfg=dict(type='SyncBN', eps=0.001, momentum=0.01),
encoder_channels=((16, 16, 32), (32, 32, 64), (64, 64, 128), (128,
128)),
encoder_paddings=((0, 0, 1), (0, 0, 1), (0, 0, (1, 1, 0)), (0, 0)),
block_type='basicblock'),
fusion_layer=dict(
type='ConvFuser', in_channels=[80, 256], out_channels=256),
pts_backbone=dict(
type='SECOND',
in_channels=256,
out_channels=[128, 256],
layer_nums=[5, 5],
layer_strides=[1, 2],
norm_cfg=dict(type='SyncBN', eps=0.001, momentum=0.01),
conv_cfg=dict(type='Conv2d', bias=False)),
pts_neck=dict(
type='SECONDFPN',
in_channels=[128, 256],
out_channels=[256, 256],
upsample_strides=[1, 2],
norm_cfg=dict(type='SyncBN', eps=0.001, momentum=0.01),
upsample_cfg=dict(type='deconv', bias=False),
use_conv_for_no_stride=True),
bbox_head=dict(
type='TransFusionHead',
num_proposals=200,
auxiliary=True,
in_channels=512,
hidden_channel=128,
num_classes=10,
nms_kernel_size=3,
bn_momentum=0.1,
num_decoder_layers=1,
decoder_layer=dict(
type='TransformerDecoderLayer',
self_attn_cfg=dict(embed_dims=128, num_heads=8, dropout=0.1),
cross_attn_cfg=dict(embed_dims=128, num_heads=8, dropout=0.1),
ffn_cfg=dict(
embed_dims=128,
feedforward_channels=256,
num_fcs=2,
ffn_drop=0.1,
act_cfg=dict(type='ReLU', inplace=True),
),
norm_cfg=dict(type='LN'),
pos_encoding_cfg=dict(input_channel=2, num_pos_feats=128)),
train_cfg=dict(
dataset='nuScenes',
point_cloud_range=[-54.0, -54.0, -5.0, 54.0, 54.0, 3.0],
grid_size=[1440, 1440, 41],
voxel_size=[0.075, 0.075, 0.2],
out_size_factor=8,
gaussian_overlap=0.1,
min_radius=2,
pos_weight=-1,
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2],
assigner=dict(
type='HungarianAssigner3D',
iou_calculator=dict(type='BboxOverlaps3D', coordinate='lidar'),
cls_cost=dict(
type='mmdet.FocalLossCost',
gamma=2.0,
alpha=0.25,
weight=0.15),
reg_cost=dict(type='BBoxBEVL1Cost', weight=0.25),
iou_cost=dict(type='IoU3DCost', weight=0.25))),
test_cfg=dict(
dataset='nuScenes',
grid_size=[1440, 1440, 41],
out_size_factor=8,
voxel_size=[0.075, 0.075],
pc_range=[-54.0, -54.0],
nms_type=None),
common_heads=dict(
center=[2, 2], height=[1, 2], dim=[3, 2], rot=[2, 2], vel=[2, 2]),
bbox_coder=dict(
type='TransFusionBBoxCoder',
pc_range=[-54.0, -54.0],
post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
score_threshold=0.0,
out_size_factor=8,
voxel_size=[0.075, 0.075],
code_size=10),
loss_cls=dict(
type='mmdet.FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
reduction='mean',
loss_weight=1.0),
loss_heatmap=dict(
type='mmdet.GaussianFocalLoss', reduction='mean', loss_weight=1.0),
loss_bbox=dict(
type='mmdet.L1Loss', reduction='mean', loss_weight=0.25)))
db_sampler = dict(
data_root=data_root,
info_path=data_root + 'nuscenes_dbinfos_train.pkl',
rate=1.0,
prepare=dict(
filter_by_difficulty=[-1],
filter_by_min_points=dict(Car=5, Pedestrian=5, Cyclist=5)),
classes=class_names,
sample_groups=dict(
car=5,
truck=5,
bus=5,
trailer=5,
construction_vehicle=5,
traffic_cone=5,
barrier=5,
motorcycle=5,
bicycle=5,
pedestrian=5),
points_loader=dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=5,
use_dim=[0, 1, 2, 3, 4],
reduce_beams=32))
train_pipeline = [
dict(
type='BEVLoadMultiViewImageFromFiles',
to_float32=True,
color_type='color'),
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=5,
use_dim=5,
reduce_beams=32,
load_augmented=None),
dict(
type='LoadPointsFromMultiSweeps',
sweeps_num=9,
load_dim=5,
use_dim=5,
reduce_beams=32,
pad_empty_sweeps=True,
remove_close=True,
load_augmented=None),
dict(
type='LoadAnnotations3D',
with_bbox_3d=True,
with_label_3d=True,
with_attr_label=False),
# dict(type='ObjectSampling', db_sampler=db_sampler),
dict(
type='ImageAug3D',
final_dim=[256, 704],
resize_lim=[0.38, 0.55],
bot_pct_lim=[0.0, 0.0],
rot_lim=[-5.4, 5.4],
rand_flip=True,
is_train=True),
dict(
type='GlobalRotScaleTrans',
resize_lim=[0.9, 1.1],
rot_lim=[-0.78539816, 0.78539816],
trans_lim=0.5,
is_train=True),
dict(type='RandomFlip3D'),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(
type='ObjectNameFilter',
classes=[
'car', 'truck', 'construction_vehicle', 'bus', 'trailer',
'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
]),
dict(
type='GridMask',
use_h=True,
use_w=True,
max_epoch=6,
rotate=1,
offset=False,
ratio=0.5,
mode=1,
prob=0.0,
fixed_prob=True),
dict(type='PointShuffle'),
dict(
type='Pack3DDetInputs',
keys=[
'points', 'img', 'gt_bboxes_3d', 'gt_labels_3d', 'gt_bboxes',
'gt_labels'
])
]
test_pipeline = [
dict(
type='BEVLoadMultiViewImageFromFiles',
to_float32=True,
color_type='color'),
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=5, use_dim=5),
dict(
type='LoadPointsFromMultiSweeps',
sweeps_num=9,
load_dim=5,
use_dim=5,
pad_empty_sweeps=True,
remove_close=True),
dict(
type='ImageAug3D',
final_dim=[256, 704],
resize_lim=[0.48, 0.48],
bot_pct_lim=[0.0, 0.0],
rot_lim=[0.0, 0.0],
rand_flip=False,
is_train=False),
dict(
type='PointsRangeFilter',
point_cloud_range=[-54.0, -54.0, -5.0, 54.0, 54.0, 3.0]),
dict(
type='Pack3DDetInputs',
keys=['img', 'points', 'gt_bboxes_3d', 'gt_labels_3d'],
meta_keys=[
'cam2img', 'ori_cam2img', 'lidar2cam', 'lidar2img', 'cam2lidar',
'ori_lidar2img', 'img_aug_matrix', 'box_type_3d', 'sample_idx',
'lidar_path', 'img_path'
])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='nuscenes_infos_train.pkl',
pipeline=train_pipeline,
metainfo=metainfo,
modality=input_modality,
test_mode=False,
data_prefix=data_prefix,
# we use box_type_3d='LiDAR' in kitti and nuscenes dataset
# and box_type_3d='Depth' in sunrgbd and scannet dataset.
box_type_3d='LiDAR'))
val_dataloader = dict(
batch_size=1,
num_workers=0,
# persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='nuscenes_infos_val.pkl',
pipeline=test_pipeline,
metainfo=metainfo,
modality=input_modality,
data_prefix=data_prefix,
test_mode=True,
box_type_3d='LiDAR'))
test_dataloader = val_dataloader
val_evaluator = dict(
type='NuScenesMetric',
data_root=data_root,
ann_file=data_root + 'nuscenes_infos_val.pkl',
metric='bbox')
test_evaluator = val_evaluator
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')
param_scheduler = [
dict(
type='LinearLR',
start_factor=0.33333333,
by_epoch=False,
begin=0,
end=500),
dict(
type='CosineAnnealingLR',
begin=0,
T_max=6,
end=6,
by_epoch=True,
eta_min_ratio=1e-3),
# momentum scheduler
# During the first 8 epochs, momentum increases from 1 to 0.85 / 0.95
# during the next 12 epochs, momentum increases from 0.85 / 0.95 to 1
dict(
type='CosineAnnealingMomentum',
eta_min=0.85 / 0.95,
begin=0,
end=2.4,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingMomentum',
eta_min=1,
begin=2.4,
end=6,
by_epoch=True,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(by_epoch=True, max_epochs=6, val_interval=6)
val_cfg = dict()
test_cfg = dict()
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.01),
clip_grad=dict(max_norm=35, norm_type=2))
# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
# or not by default.
# - `base_batch_size` = (4 GPUs) x (4 samples per GPU).
auto_scale_lr = dict(enable=False, base_batch_size=16)
default_hooks = dict(
logger=dict(type='LoggerHook', interval=50),
checkpoint=dict(type='CheckpointHook', interval=5))
import os
from setuptools import setup
import torch
from torch.utils.cpp_extension import (BuildExtension, CppExtension,
CUDAExtension)
def make_cuda_ext(name,
module,
sources,
sources_cuda=[],
extra_args=[],
extra_include_path=[]):
define_macros = []
extra_compile_args = {'cxx': [] + extra_args}
if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
define_macros += [('WITH_CUDA', None)]
extension = CUDAExtension
extra_compile_args['nvcc'] = extra_args + [
'-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__',
'-gencode=arch=compute_70,code=sm_70',
'-gencode=arch=compute_75,code=sm_75',
'-gencode=arch=compute_80,code=sm_80',
'-gencode=arch=compute_86,code=sm_86',
]
sources += sources_cuda
else:
print('Compiling {} without CUDA'.format(name))
extension = CppExtension
return extension(
name='{}.{}'.format(module, name),
sources=[os.path.join(*module.split('.'), p) for p in sources],
include_dirs=extra_include_path,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
if __name__ == '__main__':
setup(
name='bev_pool',
ext_modules=[
make_cuda_ext(
name='bev_pool_ext',
module='projects.BEVFusion.bevfusion.ops.bev_pool',
sources=[
'src/bev_pool.cpp',
'src/bev_pool_cuda.cu',
],
),
make_cuda_ext(
name='voxel_layer',
module='projects.BEVFusion.bevfusion.ops.voxel',
sources=[
'src/voxelization.cpp',
'src/scatter_points_cpu.cpp',
'src/scatter_points_cuda.cu',
'src/voxelization_cpu.cpp',
'src/voxelization_cuda.cu',
],
),
],
cmdclass={'build_ext': BuildExtension},
zip_safe=False,
)
......@@ -32,24 +32,7 @@ is publicly available at https://github.com/TuSimple/centerformer
## Introduction
We implement CenterFormer and provide the result and checkpoints on Waymo dataset.
We follow the below style to name config files. Contributors are advised to follow the same style.
`{xxx}` is required field and `[yyy]` is optional.
`{model}`: model type like `centerpoint`.
`{model setting}`: voxel size and voxel type like `01voxel`, `02pillar`.
`{backbone}`: backbone type like `second`.
`{neck}`: neck type like `secfpn`.
`[batch_per_gpu x gpu]`: GPUs and samples per GPU, 4x8 is used by default.
`{schedule}`: training schedule, options are 1x, 2x, 20e, etc. 1x and 2x means 12 epochs and 24 epochs respectively. 20e is adopted in cascade models, which denotes 20 epochs. For 1x/2x, initial learning rate decays by a factor of 10 at the 8/16th and 11/22th epochs. For 20e, initial learning rate decays by a factor of 10 at the 16th and 19th epochs.
`{dataset}`: dataset like nus-3d, kitti-3d, lyft-3d, scannet-3d, sunrgbd-3d. We also indicate the number of classes we are using if there exist multiple settings, e.g., kitti-3d-3class and kitti-3d-car means training on KITTI dataset with 3 classes and single class, respectively.
We implement CenterFormer and provide the results and checkpoints on Waymo dataset.
## Usage
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment