Unverified Commit 2dad86c2 authored by YirongYan's avatar YirongYan Committed by GitHub
Browse files

[Feature]Support NeRF-Det (#2732)

parent 8deeb6e2
# NeRF-Det: Learning Geometry-Aware Volumetric Representation for Multi-View 3D Object Detection
> [NeRF-Det: Learning Geometry-Aware Volumetric Representation for Multi-View 3D Object Detection](https://arxiv.org/abs/2307.14620)
<!-- [ALGORITHM] -->
## Abstract
NeRF-Det is a novel method for indoor 3D detection with posed RGB images as input. Unlike existing indoor 3D detection methods that struggle to model scene geometry, NeRF-Det makes novel use of NeRF in an end-to-end manner to explicitly estimate 3D geometry, thereby improving 3D detection performance. Specifically, to avoid the significant extra latency associated with per-scene optimization of NeRF, NeRF-Det introduce sufficient geometry priors to enhance the generalizability of NeRF-MLP. Furthermore, it subtly connect the detection and NeRF branches through a shared MLP, enabling an efficient adaptation of NeRF to detection and yielding geometry-aware volumetric representations for 3D detection. NeRF-Det outperforms state-of-the-arts by 3.9 mAP and 3.1 mAP on the ScanNet and ARKITScenes benchmarks, respectively. The author provide extensive analysis to shed light on how NeRF-Det works. As a result of joint-training design, NeRF-Det is able to generalize well to unseen scenes for object detection, view synthesis, and depth estimation tasks without requiring per-scene optimization. Code will be available at https://github.com/facebookresearch/NeRF-Det
<div align=center>
<img src="https://chenfengxu714.github.io/nerfdet/static/images/method-cropped_1.png" width="800"/>
</div>
## Introduction
This directory contains the implementations of NeRF-Det (https://arxiv.org/abs/2307.14620). Our implementations are built on top of MMdetection3D. We have updated NeRF-Det to be compatible with latest mmdet3d version. The codebase and config files have all changed to adapt to the new mmdet3d version. All previous pretrained models are verified with the result listed below. However, newly trained models are yet to be uploaded.
<!-- Share any information you would like others to know. For example:
Author: @xxx.
This is an implementation of \[XXX\]. -->
## Dataset
The format of the scannet dataset in the latest version of mmdet3d only supports the lidar tasks. For NeRF-Det, we need to create the new format of ScanNet Dataset.
Please following the files in mmdet3d to prepare the raw data of ScanNet. After that, please use this command to generate the pkls used in nerfdet.
```bash
python projects/NeRF-Det/prepare_infos.py --root-path ./data/scannet --out-dir ./data/scannet
```
The new format of the pkl is organized as below:
- scannet_infos_train.pkl: The train data infos, the detailed info of each scan is as follows:
- info\['instances'\]:A list of dict contains all annotations, each dict contains all annotation information of single instance.For the i-th instance:
- info\['instances'\]\[i\]\['bbox_3d'\]: List of 6 numbers representing the axis_aligned in depth coordinate system, in (x,y,z,l,w,h) order.
- info\['instances'\]\[i\]\['bbox_label_3d'\]: The label of each 3d bounding boxes.
- info\['cam2img'\]: The intrinsic matrix.Every scene has one matrix.
- info\['lidar2cam'\]: The extrinsic matrixes.Every scene has 300 matrixes.
- info\['img_paths'\]: The paths of the 300 rgb pictures.
- info\['axis_align_matrix'\]: The align matrix.Every scene has one matrix.
After preparing your scannet dataset pkls,please change the paths in configs to fit your project.
## Train
In MMDet3D's root directory, run the following command to train the model:
```bash
python tools/train.py projects/NeRF-Det/configs/nerfdet_res50_2x_low_res.py ${WORK_DIR}
```
## Results and Models
### NeRF-Det
| Backbone | mAP@25 | mAP@50 | Log |
| :-------------------------------------------------------------: | :----: | :----: | :-------: |
| [NeRF-Det-R50](./configs/nerfdet_res50_2x_low_res.py) | 53.0 | 26.8 | [log](<>) |
| [NeRF-Det-R50\*](./configs/nerfdet_res50_2x_low_res_depth.py) | 52.2 | 28.5 | [log](<>) |
| [NeRF-Det-R101\*](./configs/nerfdet_res101_2x_low_res_depth.py) | 52.3 | 28.5 | [log](<>) |
(Here NeRF-Det-R50\* means this model uses depth information in the training step)
### Notes
- The values showed in the chart all represents the best mAP in the training.
- Since there is a lot of randomness in the behavior of the model, we conducted three experiments on each config and took the average. The mAP showed on the above chart are all average values.
- We also conducted the same experiments in the original code, the results are showed below.
| Backbone | mAP@25 | mAP@50 |
| :-------------: | :----: | :----: |
| NeRF-Det-R50 | 52.8 | 26.8 |
| NeRF-Det-R50\* | 52.4 | 27.5 |
| NeRF-Det-R101\* | 52.8 | 28.6 |
- Attention: Because of the randomness in the construction of the ScanNet dataset itself and the behavior of the model, the training results will fluctuate considerably. According to experimental results and experience, the experimental results will fluctuate by plus or minus 1.5 points.
## Evaluation using pretrained models
1. Download the pretrained checkpoints through the linkings in the above chart.
2. Testing
To test, use:
```bash
python tools/test.py projects/NeRF-Det/configs/nerfdet_res50_2x_low_res.py ${CHECKPOINT_PATH}
```
## Citation
<!-- You may remove this section if not applicable. -->
```latex
@inproceedings{
xu2023nerfdet,
title={NeRF-Det: Learning Geometry-Aware Volumetric Representation for Multi-View 3D Object Detection},
author={Xu, Chenfeng and Wu, Bichen and Hou, Ji and Tsai, Sam and Li, Ruilong and Wang, Jialiang and Zhan, Wei and He, Zijian and Vajda, Peter and Keutzer, Kurt and Tomizuka, Masayoshi},
booktitle={ICCV},
year={2023},
}
@inproceedings{
park2023time,
title={Time Will Tell: New Outlooks and A Baseline for Temporal Multi-View 3D Object Detection},
author={Jinhyung Park and Chenfeng Xu and Shijia Yang and Kurt Keutzer and Kris M. Kitani and Masayoshi Tomizuka and Wei Zhan},
booktitle={The Eleventh International Conference on Learning Representations },
year={2023},
url={https://openreview.net/forum?id=H3HcEJA2Um}
}
```
_base_ = ['../../../configs/_base_/default_runtime.py']
custom_imports = dict(imports=['projects.NeRF-Det.nerfdet'])
prior_generator = dict(
type='AlignedAnchor3DRangeGenerator',
ranges=[[-3.2, -3.2, -1.28, 3.2, 3.2, 1.28]],
rotations=[.0])
model = dict(
type='NerfDet',
data_preprocessor=dict(
type='NeRFDetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=10),
backbone=dict(
type='mmdet.ResNet',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet101'),
style='pytorch'),
neck=dict(
type='mmdet.FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=4),
neck_3d=dict(
type='IndoorImVoxelNeck',
in_channels=256,
out_channels=128,
n_blocks=[1, 1, 1]),
bbox_head=dict(
type='NerfDetHead',
bbox_loss=dict(type='AxisAlignedIoULoss', loss_weight=1.0),
n_classes=18,
n_levels=3,
n_channels=128,
n_reg_outs=6,
pts_assign_threshold=27,
pts_center_threshold=18,
prior_generator=prior_generator),
prior_generator=prior_generator,
voxel_size=[.16, .16, .2],
n_voxels=[40, 40, 16],
aabb=([-2.7, -2.7, -0.78], [3.7, 3.7, 1.78]),
near_far_range=[0.2, 8.0],
N_samples=64,
N_rand=2048,
nerf_mode='image',
depth_supervise=True,
use_nerf_mask=True,
nerf_sample_view=20,
squeeze_scale=4,
nerf_density=True,
train_cfg=dict(),
test_cfg=dict(nms_pre=1000, iou_thr=.25, score_thr=.01))
dataset_type = 'MultiViewScanNetDataset'
data_root = 'data/scannet/'
class_names = [
'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf',
'picture', 'counter', 'desk', 'curtain', 'refrigerator', 'showercurtrain',
'toilet', 'sink', 'bathtub', 'garbagebin'
]
metainfo = dict(CLASSES=class_names)
file_client_args = dict(backend='disk')
input_modality = dict(
use_camera=True,
use_depth=True,
use_lidar=False,
use_neuralrecon_depth=False,
use_ray=True)
backend_args = None
train_collect_keys = [
'img', 'gt_bboxes_3d', 'gt_labels_3d', 'depth', 'lightpos', 'nerf_sizes',
'raydirs', 'gt_images', 'gt_depths', 'denorm_images'
]
test_collect_keys = [
'img',
'depth',
'lightpos',
'nerf_sizes',
'raydirs',
'gt_images',
'gt_depths',
'denorm_images',
]
train_pipeline = [
dict(type='LoadAnnotations3D'),
dict(
type='MultiViewPipeline',
n_images=48,
transforms=[
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='Resize', scale=(320, 240), keep_ratio=True),
],
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
margin=10,
depth_range=[0.5, 5.5],
loading='random',
nerf_target_views=10),
dict(type='RandomShiftOrigin', std=(.7, .7, .0)),
dict(type='PackNeRFDetInputs', keys=train_collect_keys)
]
test_pipeline = [
dict(type='LoadAnnotations3D'),
dict(
type='MultiViewPipeline',
n_images=101,
transforms=[
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='Resize', scale=(320, 240), keep_ratio=True),
],
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
margin=10,
depth_range=[0.5, 5.5],
loading='random',
nerf_target_views=1),
dict(type='PackNeRFDetInputs', keys=test_collect_keys)
]
train_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='RepeatDataset',
times=6,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='scannet_infos_train_new.pkl',
pipeline=train_pipeline,
modality=input_modality,
test_mode=False,
filter_empty_gt=True,
box_type_3d='Depth',
metainfo=metainfo)))
val_dataloader = dict(
batch_size=1,
num_workers=5,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='scannet_infos_val_new.pkl',
pipeline=test_pipeline,
modality=input_modality,
test_mode=True,
filter_empty_gt=True,
box_type_3d='Depth',
metainfo=metainfo))
test_dataloader = val_dataloader
val_evaluator = dict(type='IndoorMetric')
test_evaluator = val_evaluator
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=1)
test_cfg = dict()
val_cfg = dict()
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.0001),
paramwise_cfg=dict(
custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}),
clip_grad=dict(max_norm=35., norm_type=2))
param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=12,
by_epoch=True,
milestones=[8, 11],
gamma=0.1)
]
# hooks
default_hooks = dict(
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=12))
# runtime
find_unused_parameters = True # only 1 of 4 FPN outputs is used
_base_ = ['./nerfdet_res50_2x_low_res_depth.py']
model = dict(depth_supervise=False)
dataset_type = 'MultiViewScanNetDataset'
data_root = 'data/scannet/'
class_names = [
'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf',
'picture', 'counter', 'desk', 'curtain', 'refrigerator', 'showercurtrain',
'toilet', 'sink', 'bathtub', 'garbagebin'
]
metainfo = dict(CLASSES=class_names)
file_client_args = dict(backend='disk')
input_modality = dict(use_depth=False)
backend_args = None
train_collect_keys = [
'img', 'gt_bboxes_3d', 'gt_labels_3d', 'lightpos', 'nerf_sizes', 'raydirs',
'gt_images', 'gt_depths', 'denorm_images'
]
test_collect_keys = [
'img',
'lightpos',
'nerf_sizes',
'raydirs',
'gt_images',
'gt_depths',
'denorm_images',
]
train_pipeline = [
dict(type='LoadAnnotations3D'),
dict(
type='MultiViewPipeline',
n_images=50,
transforms=[
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='Resize', scale=(320, 240), keep_ratio=True),
],
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
margin=10,
depth_range=[0.5, 5.5],
loading='random',
nerf_target_views=10),
dict(type='RandomShiftOrigin', std=(.7, .7, .0)),
dict(type='PackNeRFDetInputs', keys=train_collect_keys)
]
test_pipeline = [
dict(type='LoadAnnotations3D'),
dict(
type='MultiViewPipeline',
n_images=101,
transforms=[
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='Resize', scale=(320, 240), keep_ratio=True),
],
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
margin=10,
depth_range=[0.5, 5.5],
loading='random',
nerf_target_views=1),
dict(type='PackNeRFDetInputs', keys=test_collect_keys)
]
train_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='RepeatDataset',
times=6,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='scannet_infos_train_new.pkl',
pipeline=train_pipeline,
modality=input_modality,
test_mode=False,
filter_empty_gt=True,
box_type_3d='Depth',
metainfo=metainfo)))
val_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='scannet_infos_val_new.pkl',
pipeline=test_pipeline,
modality=input_modality,
test_mode=True,
filter_empty_gt=True,
box_type_3d='Depth',
metainfo=metainfo))
test_dataloader = val_dataloader
_base_ = ['../../../configs/_base_/default_runtime.py']
custom_imports = dict(imports=['projects.NeRF-Det.nerfdet'])
prior_generator = dict(
type='AlignedAnchor3DRangeGenerator',
ranges=[[-3.2, -3.2, -1.28, 3.2, 3.2, 1.28]],
rotations=[.0])
model = dict(
type='NerfDet',
data_preprocessor=dict(
type='NeRFDetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=10),
backbone=dict(
type='mmdet.ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'),
style='pytorch'),
neck=dict(
type='mmdet.FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=4),
neck_3d=dict(
type='IndoorImVoxelNeck',
in_channels=256,
out_channels=128,
n_blocks=[1, 1, 1]),
bbox_head=dict(
type='NerfDetHead',
bbox_loss=dict(type='AxisAlignedIoULoss', loss_weight=1.0),
n_classes=18,
n_levels=3,
n_channels=128,
n_reg_outs=6,
pts_assign_threshold=27,
pts_center_threshold=18,
prior_generator=prior_generator),
prior_generator=prior_generator,
voxel_size=[.16, .16, .2],
n_voxels=[40, 40, 16],
aabb=([-2.7, -2.7, -0.78], [3.7, 3.7, 1.78]),
near_far_range=[0.2, 8.0],
N_samples=64,
N_rand=2048,
nerf_mode='image',
depth_supervise=True,
use_nerf_mask=True,
nerf_sample_view=20,
squeeze_scale=4,
nerf_density=True,
train_cfg=dict(),
test_cfg=dict(nms_pre=1000, iou_thr=.25, score_thr=.01))
dataset_type = 'MultiViewScanNetDataset'
data_root = 'data/scannet/'
class_names = [
'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf',
'picture', 'counter', 'desk', 'curtain', 'refrigerator', 'showercurtrain',
'toilet', 'sink', 'bathtub', 'garbagebin'
]
metainfo = dict(CLASSES=class_names)
file_client_args = dict(backend='disk')
input_modality = dict(
use_camera=True,
use_depth=True,
use_lidar=False,
use_neuralrecon_depth=False,
use_ray=True)
backend_args = None
train_collect_keys = [
'img', 'gt_bboxes_3d', 'gt_labels_3d', 'depth', 'lightpos', 'nerf_sizes',
'raydirs', 'gt_images', 'gt_depths', 'denorm_images'
]
test_collect_keys = [
'img',
'depth',
'lightpos',
'nerf_sizes',
'raydirs',
'gt_images',
'gt_depths',
'denorm_images',
]
train_pipeline = [
dict(type='LoadAnnotations3D'),
dict(
type='MultiViewPipeline',
n_images=50,
transforms=[
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='Resize', scale=(320, 240), keep_ratio=True),
],
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
margin=10,
depth_range=[0.5, 5.5],
loading='random',
nerf_target_views=10),
dict(type='RandomShiftOrigin', std=(.7, .7, .0)),
dict(type='PackNeRFDetInputs', keys=train_collect_keys)
]
test_pipeline = [
dict(type='LoadAnnotations3D'),
dict(
type='MultiViewPipeline',
n_images=101,
transforms=[
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='Resize', scale=(320, 240), keep_ratio=True),
],
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
margin=10,
depth_range=[0.5, 5.5],
loading='random',
nerf_target_views=1),
dict(type='PackNeRFDetInputs', keys=test_collect_keys)
]
train_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='RepeatDataset',
times=6,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='scannet_infos_train_new.pkl',
pipeline=train_pipeline,
modality=input_modality,
test_mode=False,
filter_empty_gt=True,
box_type_3d='Depth',
metainfo=metainfo)))
val_dataloader = dict(
batch_size=1,
num_workers=5,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='scannet_infos_val_new.pkl',
pipeline=test_pipeline,
modality=input_modality,
test_mode=True,
filter_empty_gt=True,
box_type_3d='Depth',
metainfo=metainfo))
test_dataloader = val_dataloader
val_evaluator = dict(type='IndoorMetric')
test_evaluator = val_evaluator
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=1)
test_cfg = dict()
val_cfg = dict()
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.0001),
paramwise_cfg=dict(
custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}),
clip_grad=dict(max_norm=35., norm_type=2))
param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=12,
by_epoch=True,
milestones=[8, 11],
gamma=0.1)
]
# hooks
default_hooks = dict(
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=12))
# runtime
find_unused_parameters = True # only 1 of 4 FPN outputs is used
from .data_preprocessor import NeRFDetDataPreprocessor
from .formating import PackNeRFDetInputs
from .multiview_pipeline import MultiViewPipeline, RandomShiftOrigin
from .nerfdet import NerfDet
from .nerfdet_head import NerfDetHead
from .scannet_multiview_dataset import MultiViewScanNetDataset
__all__ = [
'MultiViewScanNetDataset', 'MultiViewPipeline', 'RandomShiftOrigin',
'PackNeRFDetInputs', 'NeRFDetDataPreprocessor', 'NerfDetHead', 'NerfDet'
]
# Copyright (c) OpenMMLab. All rights reserved.
import math
from numbers import Number
from typing import Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from mmdet.models import DetDataPreprocessor
from mmdet.models.utils.misc import samplelist_boxtype2tensor
from mmengine.model import stack_batch
from mmengine.utils import is_seq_of
from torch import Tensor
from torch.nn import functional as F
from mmdet3d.models.data_preprocessors.utils import multiview_img_stack_batch
from mmdet3d.models.data_preprocessors.voxelize import (
VoxelizationByGridShape, dynamic_scatter_3d)
from mmdet3d.registry import MODELS
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils import OptConfigType
@MODELS.register_module()
class NeRFDetDataPreprocessor(DetDataPreprocessor):
"""In NeRF-Det, some extra information is needed in NeRF branch. We put the
datapreprocessor operations of these new information such as stack and pack
operations in this class. You can find the stack operations in subfuction
'collate_data' and the pack operations in 'simple_process'. Other codes are
the same as the default class 'DetDataPreprocessor'.
Points / Image pre-processor for point clouds / vision-only / multi-
modality 3D detection tasks.
It provides the data pre-processing as follows
- Collate and move image and point cloud data to the target device.
- 1) For image data:
- Pad images in inputs to the maximum size of current batch with defined
``pad_value``. The padding size can be divisible by a defined
``pad_size_divisor``.
- Stack images in inputs to batch_imgs.
- Convert images in inputs from bgr to rgb if the shape of input is
(3, H, W).
- Normalize images in inputs with defined std and mean.
- Do batch augmentations during training.
- 2) For point cloud data:
- If no voxelization, directly return list of point cloud data.
- If voxelization is applied, voxelize point cloud according to
``voxel_type`` and obtain ``voxels``.
Args:
voxel (bool): Whether to apply voxelization to point cloud.
Defaults to False.
voxel_type (str): Voxelization type. Two voxelization types are
provided: 'hard' and 'dynamic', respectively for hard voxelization
and dynamic voxelization. Defaults to 'hard'.
voxel_layer (dict or :obj:`ConfigDict`, optional): Voxelization layer
config. Defaults to None.
batch_first (bool): Whether to put the batch dimension to the first
dimension when getting voxel coordinates. Defaults to True.
max_voxels (int, optional): Maximum number of voxels in each voxel
grid. Defaults to None.
mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
Defaults to None.
std (Sequence[Number], optional): The pixel standard deviation of
R, G, B channels. Defaults to None.
pad_size_divisor (int): The size of padded image should be divisible by
``pad_size_divisor``. Defaults to 1.
pad_value (float or int): The padded pixel value. Defaults to 0.
pad_mask (bool): Whether to pad instance masks. Defaults to False.
mask_pad_value (int): The padded pixel value for instance masks.
Defaults to 0.
pad_seg (bool): Whether to pad semantic segmentation maps.
Defaults to False.
seg_pad_value (int): The padded pixel value for semantic segmentation
maps. Defaults to 255.
bgr_to_rgb (bool): Whether to convert image from BGR to RGB.
Defaults to False.
rgb_to_bgr (bool): Whether to convert image from RGB to BGR.
Defaults to False.
boxtype2tensor (bool): Whether to convert the ``BaseBoxes`` type of
bboxes data to ``Tensor`` type. Defaults to True.
non_blocking (bool): Whether to block current process when transferring
data to device. Defaults to False.
batch_augments (List[dict], optional): Batch-level augmentations.
Defaults to None.
"""
def __init__(self,
voxel: bool = False,
voxel_type: str = 'hard',
voxel_layer: OptConfigType = None,
batch_first: bool = True,
max_voxels: Optional[int] = None,
mean: Sequence[Number] = None,
std: Sequence[Number] = None,
pad_size_divisor: int = 1,
pad_value: Union[float, int] = 0,
pad_mask: bool = False,
mask_pad_value: int = 0,
pad_seg: bool = False,
seg_pad_value: int = 255,
bgr_to_rgb: bool = False,
rgb_to_bgr: bool = False,
boxtype2tensor: bool = True,
non_blocking: bool = False,
batch_augments: Optional[List[dict]] = None) -> None:
super(NeRFDetDataPreprocessor, self).__init__(
mean=mean,
std=std,
pad_size_divisor=pad_size_divisor,
pad_value=pad_value,
pad_mask=pad_mask,
mask_pad_value=mask_pad_value,
pad_seg=pad_seg,
seg_pad_value=seg_pad_value,
bgr_to_rgb=bgr_to_rgb,
rgb_to_bgr=rgb_to_bgr,
boxtype2tensor=boxtype2tensor,
non_blocking=non_blocking,
batch_augments=batch_augments)
self.voxel = voxel
self.voxel_type = voxel_type
self.batch_first = batch_first
self.max_voxels = max_voxels
if voxel:
self.voxel_layer = VoxelizationByGridShape(**voxel_layer)
def forward(self,
data: Union[dict, List[dict]],
training: bool = False) -> Union[dict, List[dict]]:
"""Perform normalization, padding and bgr2rgb conversion based on
``BaseDataPreprocessor``.
Args:
data (dict or List[dict]): Data from dataloader. The dict contains
the whole batch data, when it is a list[dict], the list
indicates test time augmentation.
training (bool): Whether to enable training time augmentation.
Defaults to False.
Returns:
dict or List[dict]: Data in the same format as the model input.
"""
if isinstance(data, list):
num_augs = len(data)
aug_batch_data = []
for aug_id in range(num_augs):
single_aug_batch_data = self.simple_process(
data[aug_id], training)
aug_batch_data.append(single_aug_batch_data)
return aug_batch_data
else:
return self.simple_process(data, training)
def simple_process(self, data: dict, training: bool = False) -> dict:
"""Perform normalization, padding and bgr2rgb conversion for img data
based on ``BaseDataPreprocessor``, and voxelize point cloud if `voxel`
is set to be True.
Args:
data (dict): Data sampled from dataloader.
training (bool): Whether to enable training time augmentation.
Defaults to False.
Returns:
dict: Data in the same format as the model input.
"""
if 'img' in data['inputs']:
batch_pad_shape = self._get_pad_shape(data)
data = self.collate_data(data)
inputs, data_samples = data['inputs'], data['data_samples']
batch_inputs = dict()
if 'points' in inputs:
batch_inputs['points'] = inputs['points']
if self.voxel:
voxel_dict = self.voxelize(inputs['points'], data_samples)
batch_inputs['voxels'] = voxel_dict
if 'imgs' in inputs:
imgs = inputs['imgs']
if data_samples is not None:
# NOTE the batched image size information may be useful, e.g.
# in DETR, this is needed for the construction of masks, which
# is then used for the transformer_head.
batch_input_shape = tuple(imgs[0].size()[-2:])
for data_sample, pad_shape in zip(data_samples,
batch_pad_shape):
data_sample.set_metainfo({
'batch_input_shape': batch_input_shape,
'pad_shape': pad_shape
})
if self.boxtype2tensor:
samplelist_boxtype2tensor(data_samples)
if self.pad_mask:
self.pad_gt_masks(data_samples)
if self.pad_seg:
self.pad_gt_sem_seg(data_samples)
if training and self.batch_augments is not None:
for batch_aug in self.batch_augments:
imgs, data_samples = batch_aug(imgs, data_samples)
batch_inputs['imgs'] = imgs
# Hard code here, will be changed later.
# if len(inputs['depth']) != 0:
if 'depth' in inputs.keys():
batch_inputs['depth'] = inputs['depth']
batch_inputs['lightpos'] = inputs['lightpos']
batch_inputs['nerf_sizes'] = inputs['nerf_sizes']
batch_inputs['denorm_images'] = inputs['denorm_images']
batch_inputs['raydirs'] = inputs['raydirs']
return {'inputs': batch_inputs, 'data_samples': data_samples}
def preprocess_img(self, _batch_img: Tensor) -> Tensor:
# channel transform
if self._channel_conversion:
_batch_img = _batch_img[[2, 1, 0], ...]
# Convert to float after channel conversion to ensure
# efficiency
_batch_img = _batch_img.float()
# Normalization.
if self._enable_normalize:
if self.mean.shape[0] == 3:
assert _batch_img.dim() == 3 and _batch_img.shape[0] == 3, (
'If the mean has 3 values, the input tensor '
'should in shape of (3, H, W), but got the '
f'tensor with shape {_batch_img.shape}')
_batch_img = (_batch_img - self.mean) / self.std
return _batch_img
def collate_data(self, data: dict) -> dict:
"""Copy data to the target device and perform normalization, padding
and bgr2rgb conversion and stack based on ``BaseDataPreprocessor``.
Collates the data sampled from dataloader into a list of dict and list
of labels, and then copies tensor to the target device.
Args:
data (dict): Data sampled from dataloader.
Returns:
dict: Data in the same format as the model input.
"""
data = self.cast_data(data) # type: ignore
if 'img' in data['inputs']:
_batch_imgs = data['inputs']['img']
# Process data with `pseudo_collate`.
if is_seq_of(_batch_imgs, torch.Tensor):
batch_imgs = []
img_dim = _batch_imgs[0].dim()
for _batch_img in _batch_imgs:
if img_dim == 3: # standard img
_batch_img = self.preprocess_img(_batch_img)
elif img_dim == 4:
_batch_img = [
self.preprocess_img(_img) for _img in _batch_img
]
_batch_img = torch.stack(_batch_img, dim=0)
batch_imgs.append(_batch_img)
# Pad and stack Tensor.
if img_dim == 3:
batch_imgs = stack_batch(batch_imgs, self.pad_size_divisor,
self.pad_value)
elif img_dim == 4:
batch_imgs = multiview_img_stack_batch(
batch_imgs, self.pad_size_divisor, self.pad_value)
# Process data with `default_collate`.
elif isinstance(_batch_imgs, torch.Tensor):
assert _batch_imgs.dim() == 4, (
'The input of `ImgDataPreprocessor` should be a NCHW '
'tensor or a list of tensor, but got a tensor with '
f'shape: {_batch_imgs.shape}')
if self._channel_conversion:
_batch_imgs = _batch_imgs[:, [2, 1, 0], ...]
# Convert to float after channel conversion to ensure
# efficiency
_batch_imgs = _batch_imgs.float()
if self._enable_normalize:
_batch_imgs = (_batch_imgs - self.mean) / self.std
h, w = _batch_imgs.shape[2:]
target_h = math.ceil(
h / self.pad_size_divisor) * self.pad_size_divisor
target_w = math.ceil(
w / self.pad_size_divisor) * self.pad_size_divisor
pad_h = target_h - h
pad_w = target_w - w
batch_imgs = F.pad(_batch_imgs, (0, pad_w, 0, pad_h),
'constant', self.pad_value)
else:
raise TypeError(
'Output of `cast_data` should be a list of dict '
'or a tuple with inputs and data_samples, but got '
f'{type(data)}: {data}')
data['inputs']['imgs'] = batch_imgs
if 'raydirs' in data['inputs']:
_batch_dirs = data['inputs']['raydirs']
batch_dirs = stack_batch(_batch_dirs)
data['inputs']['raydirs'] = batch_dirs
if 'lightpos' in data['inputs']:
_batch_poses = data['inputs']['lightpos']
batch_poses = stack_batch(_batch_poses)
data['inputs']['lightpos'] = batch_poses
if 'denorm_images' in data['inputs']:
_batch_denorm_imgs = data['inputs']['denorm_images']
# Process data with `pseudo_collate`.
if is_seq_of(_batch_denorm_imgs, torch.Tensor):
denorm_img_dim = _batch_denorm_imgs[0].dim()
# Pad and stack Tensor.
if denorm_img_dim == 3:
batch_denorm_imgs = stack_batch(_batch_denorm_imgs,
self.pad_size_divisor,
self.pad_value)
elif denorm_img_dim == 4:
batch_denorm_imgs = multiview_img_stack_batch(
_batch_denorm_imgs, self.pad_size_divisor,
self.pad_value)
data['inputs']['denorm_images'] = batch_denorm_imgs
data.setdefault('data_samples', None)
return data
def _get_pad_shape(self, data: dict) -> List[Tuple[int, int]]:
"""Get the pad_shape of each image based on data and
pad_size_divisor."""
# rewrite `_get_pad_shape` for obtaining image inputs.
_batch_inputs = data['inputs']['img']
# Process data with `pseudo_collate`.
if is_seq_of(_batch_inputs, torch.Tensor):
batch_pad_shape = []
for ori_input in _batch_inputs:
if ori_input.dim() == 4:
# mean multiview input, select one of the
# image to calculate the pad shape
ori_input = ori_input[0]
pad_h = int(
np.ceil(ori_input.shape[1] /
self.pad_size_divisor)) * self.pad_size_divisor
pad_w = int(
np.ceil(ori_input.shape[2] /
self.pad_size_divisor)) * self.pad_size_divisor
batch_pad_shape.append((pad_h, pad_w))
# Process data with `default_collate`.
elif isinstance(_batch_inputs, torch.Tensor):
assert _batch_inputs.dim() == 4, (
'The input of `ImgDataPreprocessor` should be a NCHW tensor '
'or a list of tensor, but got a tensor with shape: '
f'{_batch_inputs.shape}')
pad_h = int(
np.ceil(_batch_inputs.shape[1] /
self.pad_size_divisor)) * self.pad_size_divisor
pad_w = int(
np.ceil(_batch_inputs.shape[2] /
self.pad_size_divisor)) * self.pad_size_divisor
batch_pad_shape = [(pad_h, pad_w)] * _batch_inputs.shape[0]
else:
raise TypeError('Output of `cast_data` should be a list of dict '
'or a tuple with inputs and data_samples, but got '
f'{type(data)}: {data}')
return batch_pad_shape
@torch.no_grad()
def voxelize(self, points: List[Tensor],
data_samples: SampleList) -> Dict[str, Tensor]:
"""Apply voxelization to point cloud.
Args:
points (List[Tensor]): Point cloud in one data batch.
data_samples: (list[:obj:`NeRFDet3DDataSample`]): The annotation
data of every samples. Add voxel-wise annotation for
segmentation.
Returns:
Dict[str, Tensor]: Voxelization information.
- voxels (Tensor): Features of voxels, shape is MxNxC for hard
voxelization, NxC for dynamic voxelization.
- coors (Tensor): Coordinates of voxels, shape is Nx(1+NDim),
where 1 represents the batch index.
- num_points (Tensor, optional): Number of points in each voxel.
- voxel_centers (Tensor, optional): Centers of voxels.
"""
voxel_dict = dict()
if self.voxel_type == 'hard':
voxels, coors, num_points, voxel_centers = [], [], [], []
for i, res in enumerate(points):
res_voxels, res_coors, res_num_points = self.voxel_layer(res)
res_voxel_centers = (
res_coors[:, [2, 1, 0]] + 0.5) * res_voxels.new_tensor(
self.voxel_layer.voxel_size) + res_voxels.new_tensor(
self.voxel_layer.point_cloud_range[0:3])
res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
voxels.append(res_voxels)
coors.append(res_coors)
num_points.append(res_num_points)
voxel_centers.append(res_voxel_centers)
voxels = torch.cat(voxels, dim=0)
coors = torch.cat(coors, dim=0)
num_points = torch.cat(num_points, dim=0)
voxel_centers = torch.cat(voxel_centers, dim=0)
voxel_dict['num_points'] = num_points
voxel_dict['voxel_centers'] = voxel_centers
elif self.voxel_type == 'dynamic':
coors = []
# dynamic voxelization only provide a coors mapping
for i, res in enumerate(points):
res_coors = self.voxel_layer(res)
res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
coors.append(res_coors)
voxels = torch.cat(points, dim=0)
coors = torch.cat(coors, dim=0)
elif self.voxel_type == 'cylindrical':
voxels, coors = [], []
for i, (res, data_sample) in enumerate(zip(points, data_samples)):
rho = torch.sqrt(res[:, 0]**2 + res[:, 1]**2)
phi = torch.atan2(res[:, 1], res[:, 0])
polar_res = torch.stack((rho, phi, res[:, 2]), dim=-1)
min_bound = polar_res.new_tensor(
self.voxel_layer.point_cloud_range[:3])
max_bound = polar_res.new_tensor(
self.voxel_layer.point_cloud_range[3:])
try: # only support PyTorch >= 1.9.0
polar_res_clamp = torch.clamp(polar_res, min_bound,
max_bound)
except TypeError:
polar_res_clamp = polar_res.clone()
for coor_idx in range(3):
polar_res_clamp[:, coor_idx][
polar_res[:, coor_idx] >
max_bound[coor_idx]] = max_bound[coor_idx]
polar_res_clamp[:, coor_idx][
polar_res[:, coor_idx] <
min_bound[coor_idx]] = min_bound[coor_idx]
res_coors = torch.floor(
(polar_res_clamp - min_bound) / polar_res_clamp.new_tensor(
self.voxel_layer.voxel_size)).int()
self.get_voxel_seg(res_coors, data_sample)
res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
res_voxels = torch.cat((polar_res, res[:, :2], res[:, 3:]),
dim=-1)
voxels.append(res_voxels)
coors.append(res_coors)
voxels = torch.cat(voxels, dim=0)
coors = torch.cat(coors, dim=0)
elif self.voxel_type == 'minkunet':
voxels, coors = [], []
voxel_size = points[0].new_tensor(self.voxel_layer.voxel_size)
for i, (res, data_sample) in enumerate(zip(points, data_samples)):
res_coors = torch.round(res[:, :3] / voxel_size).int()
res_coors -= res_coors.min(0)[0]
res_coors_numpy = res_coors.cpu().numpy()
inds, point2voxel_map = self.sparse_quantize(
res_coors_numpy, return_index=True, return_inverse=True)
point2voxel_map = torch.from_numpy(point2voxel_map).cuda()
if self.training and self.max_voxels is not None:
if len(inds) > self.max_voxels:
inds = np.random.choice(
inds, self.max_voxels, replace=False)
inds = torch.from_numpy(inds).cuda()
if hasattr(data_sample.gt_pts_seg, 'pts_semantic_mask'):
data_sample.gt_pts_seg.voxel_semantic_mask \
= data_sample.gt_pts_seg.pts_semantic_mask[inds]
res_voxel_coors = res_coors[inds]
res_voxels = res[inds]
if self.batch_first:
res_voxel_coors = F.pad(
res_voxel_coors, (1, 0), mode='constant', value=i)
data_sample.batch_idx = res_voxel_coors[:, 0]
else:
res_voxel_coors = F.pad(
res_voxel_coors, (0, 1), mode='constant', value=i)
data_sample.batch_idx = res_voxel_coors[:, -1]
data_sample.point2voxel_map = point2voxel_map.long()
voxels.append(res_voxels)
coors.append(res_voxel_coors)
voxels = torch.cat(voxels, dim=0)
coors = torch.cat(coors, dim=0)
else:
raise ValueError(f'Invalid voxelization type {self.voxel_type}')
voxel_dict['voxels'] = voxels
voxel_dict['coors'] = coors
return voxel_dict
def get_voxel_seg(self, res_coors: Tensor,
data_sample: SampleList) -> None:
"""Get voxel-wise segmentation label and point2voxel map.
Args:
res_coors (Tensor): The voxel coordinates of points, Nx3.
data_sample: (:obj:`NeRFDet3DDataSample`): The annotation data of
every samples. Add voxel-wise annotation forsegmentation.
"""
if self.training:
pts_semantic_mask = data_sample.gt_pts_seg.pts_semantic_mask
voxel_semantic_mask, _, point2voxel_map = dynamic_scatter_3d(
F.one_hot(pts_semantic_mask.long()).float(), res_coors, 'mean',
True)
voxel_semantic_mask = torch.argmax(voxel_semantic_mask, dim=-1)
data_sample.gt_pts_seg.voxel_semantic_mask = voxel_semantic_mask
data_sample.point2voxel_map = point2voxel_map
else:
pseudo_tensor = res_coors.new_ones([res_coors.shape[0], 1]).float()
_, _, point2voxel_map = dynamic_scatter_3d(pseudo_tensor,
res_coors, 'mean', True)
data_sample.point2voxel_map = point2voxel_map
def ravel_hash(self, x: np.ndarray) -> np.ndarray:
"""Get voxel coordinates hash for np.unique.
Args:
x (np.ndarray): The voxel coordinates of points, Nx3.
Returns:
np.ndarray: Voxels coordinates hash.
"""
assert x.ndim == 2, x.shape
x = x - np.min(x, axis=0)
x = x.astype(np.uint64, copy=False)
xmax = np.max(x, axis=0).astype(np.uint64) + 1
h = np.zeros(x.shape[0], dtype=np.uint64)
for k in range(x.shape[1] - 1):
h += x[:, k]
h *= xmax[k + 1]
h += x[:, -1]
return h
def sparse_quantize(self,
coords: np.ndarray,
return_index: bool = False,
return_inverse: bool = False) -> List[np.ndarray]:
"""Sparse Quantization for voxel coordinates used in Minkunet.
Args:
coords (np.ndarray): The voxel coordinates of points, Nx3.
return_index (bool): Whether to return the indices of the unique
coords, shape (M,).
return_inverse (bool): Whether to return the indices of the
original coords, shape (N,).
Returns:
List[np.ndarray]: Return index and inverse map if return_index and
return_inverse is True.
"""
_, indices, inverse_indices = np.unique(
self.ravel_hash(coords), return_index=True, return_inverse=True)
coords = coords[indices]
outputs = []
if return_index:
outputs += [indices]
if return_inverse:
outputs += [inverse_indices]
return outputs
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Sequence, Union
import mmengine
import numpy as np
import torch
from mmcv import BaseTransform
from mmengine.structures import InstanceData
from numpy import dtype
from mmdet3d.registry import TRANSFORMS
from mmdet3d.structures import BaseInstance3DBoxes, PointData
from mmdet3d.structures.points import BasePoints
# from .det3d_data_sample import Det3DDataSample
from .nerf_det3d_data_sample import NeRFDet3DDataSample
def to_tensor(
data: Union[torch.Tensor, np.ndarray, Sequence, int,
float]) -> torch.Tensor:
"""Convert objects of various python types to :obj:`torch.Tensor`.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
Args:
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
be converted.
Returns:
torch.Tensor: the converted data.
"""
if isinstance(data, torch.Tensor):
return data
elif isinstance(data, np.ndarray):
if data.dtype is dtype('float64'):
data = data.astype(np.float32)
return torch.from_numpy(data)
elif isinstance(data, Sequence) and not mmengine.is_str(data):
return torch.tensor(data)
elif isinstance(data, int):
return torch.LongTensor([data])
elif isinstance(data, float):
return torch.FloatTensor([data])
else:
raise TypeError(f'type {type(data)} cannot be converted to tensor.')
@TRANSFORMS.register_module()
class PackNeRFDetInputs(BaseTransform):
INPUTS_KEYS = ['points', 'img']
NERF_INPUT_KEYS = [
'img', 'denorm_images', 'depth', 'lightpos', 'nerf_sizes', 'raydirs'
]
INSTANCEDATA_3D_KEYS = [
'gt_bboxes_3d', 'gt_labels_3d', 'attr_labels', 'depths', 'centers_2d'
]
INSTANCEDATA_2D_KEYS = [
'gt_bboxes',
'gt_bboxes_labels',
]
NERF_3D_KEYS = ['gt_images', 'gt_depths']
SEG_KEYS = [
'gt_seg_map', 'pts_instance_mask', 'pts_semantic_mask',
'gt_semantic_seg'
]
def __init__(
self,
keys: tuple,
meta_keys: tuple = ('img_path', 'ori_shape', 'img_shape', 'lidar2img',
'depth2img', 'cam2img', 'pad_shape',
'scale_factor', 'flip', 'pcd_horizontal_flip',
'pcd_vertical_flip', 'box_mode_3d', 'box_type_3d',
'img_norm_cfg', 'num_pts_feats', 'pcd_trans',
'sample_idx', 'pcd_scale_factor', 'pcd_rotation',
'pcd_rotation_angle', 'lidar_path',
'transformation_3d_flow', 'trans_mat',
'affine_aug', 'sweep_img_metas', 'ori_cam2img',
'cam2global', 'crop_offset', 'img_crop_offset',
'resize_img_shape', 'lidar2cam', 'ori_lidar2img',
'num_ref_frames', 'num_views', 'ego2global',
'axis_align_matrix')
) -> None:
self.keys = keys
self.meta_keys = meta_keys
def _remove_prefix(self, key: str) -> str:
if key.startswith('gt_'):
key = key[3:]
return key
def transform(self, results: Union[dict,
List[dict]]) -> Union[dict, List[dict]]:
"""Method to pack the input data. when the value in this dict is a
list, it usually is in Augmentations Testing.
Args:
results (dict | list[dict]): Result dict from the data pipeline.
Returns:
dict | List[dict]:
- 'inputs' (dict): The forward data of models. It usually contains
following keys:
- points
- img
- 'data_samples' (:obj:`NeRFDet3DDataSample`): The annotation info
of the sample.
"""
# augtest
if isinstance(results, list):
if len(results) == 1:
# simple test
return self.pack_single_results(results[0])
pack_results = []
for single_result in results:
pack_results.append(self.pack_single_results(single_result))
return pack_results
# norm training and simple testing
elif isinstance(results, dict):
return self.pack_single_results(results)
else:
raise NotImplementedError
def pack_single_results(self, results: dict) -> dict:
"""Method to pack the single input data. when the value in this dict is
a list, it usually is in Augmentations Testing.
Args:
results (dict): Result dict from the data pipeline.
Returns:
dict: A dict contains
- 'inputs' (dict): The forward data of models. It usually contains
following keys:
- points
- img
- 'data_samples' (:obj:`NeRFDet3DDataSample`): The annotation info
of the sample.
"""
# Format 3D data
if 'points' in results:
if isinstance(results['points'], BasePoints):
results['points'] = results['points'].tensor
if 'img' in results:
if isinstance(results['img'], list):
# process multiple imgs in single frame
imgs = np.stack(results['img'], axis=0)
if imgs.flags.c_contiguous:
imgs = to_tensor(imgs).permute(0, 3, 1, 2).contiguous()
else:
imgs = to_tensor(
np.ascontiguousarray(imgs.transpose(0, 3, 1, 2)))
results['img'] = imgs
else:
img = results['img']
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
# To improve the computational speed by by 3-5 times, apply:
# `torch.permute()` rather than `np.transpose()`.
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533
# for more details
if img.flags.c_contiguous:
img = to_tensor(img).permute(2, 0, 1).contiguous()
else:
img = to_tensor(
np.ascontiguousarray(img.transpose(2, 0, 1)))
results['img'] = img
if 'depth' in results:
if isinstance(results['depth'], list):
# process multiple depth imgs in single frame
depth_imgs = np.stack(results['depth'], axis=0)
if depth_imgs.flags.c_contiguous:
depth_imgs = to_tensor(depth_imgs).contiguous()
else:
depth_imgs = to_tensor(np.ascontiguousarray(depth_imgs))
results['depth'] = depth_imgs
else:
depth_img = results['depth']
if len(depth_img.shape) < 3:
depth_img = np.expand_dims(depth_img, -1)
if depth_img.flags.c_contiguous:
depth_img = to_tensor(depth_img).contiguous()
else:
depth_img = to_tensor(np.ascontiguousarray(depth_img))
results['depth'] = depth_img
if 'ray_info' in results:
if isinstance(results['raydirs'], list):
raydirs = np.stack(results['raydirs'], axis=0)
if raydirs.flags.c_contiguous:
raydirs = to_tensor(raydirs).contiguous()
else:
raydirs = to_tensor(np.ascontiguousarray(raydirs))
results['raydirs'] = raydirs
if isinstance(results['lightpos'], list):
lightposes = np.stack(results['lightpos'], axis=0)
if lightposes.flags.c_contiguous:
lightposes = to_tensor(lightposes).contiguous()
else:
lightposes = to_tensor(np.ascontiguousarray(lightposes))
lightposes = lightposes.unsqueeze(1).repeat(
1, raydirs.shape[1], 1)
results['lightpos'] = lightposes
if isinstance(results['gt_images'], list):
gt_images = np.stack(results['gt_images'], axis=0)
if gt_images.flags.c_contiguous:
gt_images = to_tensor(gt_images).contiguous()
else:
gt_images = to_tensor(np.ascontiguousarray(gt_images))
results['gt_images'] = gt_images
if isinstance(results['gt_depths'],
list) and len(results['gt_depths']) != 0:
gt_depths = np.stack(results['gt_depths'], axis=0)
if gt_depths.flags.c_contiguous:
gt_depths = to_tensor(gt_depths).contiguous()
else:
gt_depths = to_tensor(np.ascontiguousarray(gt_depths))
results['gt_depths'] = gt_depths
if isinstance(results['denorm_images'], list):
denorm_imgs = np.stack(results['denorm_images'], axis=0)
if denorm_imgs.flags.c_contiguous:
denorm_imgs = to_tensor(denorm_imgs).permute(
0, 3, 1, 2).contiguous()
else:
denorm_imgs = to_tensor(
np.ascontiguousarray(
denorm_imgs.transpose(0, 3, 1, 2)))
results['denorm_images'] = denorm_imgs
for key in [
'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels',
'gt_bboxes_labels', 'attr_labels', 'pts_instance_mask',
'pts_semantic_mask', 'centers_2d', 'depths', 'gt_labels_3d'
]:
if key not in results:
continue
if isinstance(results[key], list):
results[key] = [to_tensor(res) for res in results[key]]
else:
results[key] = to_tensor(results[key])
if 'gt_bboxes_3d' in results:
if not isinstance(results['gt_bboxes_3d'], BaseInstance3DBoxes):
results['gt_bboxes_3d'] = to_tensor(results['gt_bboxes_3d'])
if 'gt_semantic_seg' in results:
results['gt_semantic_seg'] = to_tensor(
results['gt_semantic_seg'][None])
if 'gt_seg_map' in results:
results['gt_seg_map'] = results['gt_seg_map'][None, ...]
if 'gt_images' in results:
results['gt_images'] = to_tensor(results['gt_images'])
if 'gt_depths' in results:
results['gt_depths'] = to_tensor(results['gt_depths'])
data_sample = NeRFDet3DDataSample()
gt_instances_3d = InstanceData()
gt_instances = InstanceData()
gt_pts_seg = PointData()
gt_nerf_images = InstanceData()
gt_nerf_depths = InstanceData()
data_metas = {}
for key in self.meta_keys:
if key in results:
data_metas[key] = results[key]
elif 'images' in results:
if len(results['images'].keys()) == 1:
cam_type = list(results['images'].keys())[0]
# single-view image
if key in results['images'][cam_type]:
data_metas[key] = results['images'][cam_type][key]
else:
# multi-view image
img_metas = []
cam_types = list(results['images'].keys())
for cam_type in cam_types:
if key in results['images'][cam_type]:
img_metas.append(results['images'][cam_type][key])
if len(img_metas) > 0:
data_metas[key] = img_metas
elif 'lidar_points' in results:
if key in results['lidar_points']:
data_metas[key] = results['lidar_points'][key]
data_sample.set_metainfo(data_metas)
inputs = {}
for key in self.keys:
if key in results:
# if key in self.INPUTS_KEYS:
if key in self.NERF_INPUT_KEYS:
inputs[key] = results[key]
elif key in self.NERF_3D_KEYS:
if key == 'gt_images':
gt_nerf_images[self._remove_prefix(key)] = results[key]
else:
gt_nerf_depths[self._remove_prefix(key)] = results[key]
elif key in self.INSTANCEDATA_3D_KEYS:
gt_instances_3d[self._remove_prefix(key)] = results[key]
elif key in self.INSTANCEDATA_2D_KEYS:
if key == 'gt_bboxes_labels':
gt_instances['labels'] = results[key]
else:
gt_instances[self._remove_prefix(key)] = results[key]
elif key in self.SEG_KEYS:
gt_pts_seg[self._remove_prefix(key)] = results[key]
else:
raise NotImplementedError(f'Please modified '
f'`Pack3DDetInputs` '
f'to put {key} to '
f'corresponding field')
data_sample.gt_instances_3d = gt_instances_3d
data_sample.gt_instances = gt_instances
data_sample.gt_pts_seg = gt_pts_seg
data_sample.gt_nerf_images = gt_nerf_images
data_sample.gt_nerf_depths = gt_nerf_depths
if 'eval_ann_info' in results:
data_sample.eval_ann_info = results['eval_ann_info']
else:
data_sample.eval_ann_info = None
packed_results = dict()
packed_results['data_samples'] = data_sample
packed_results['inputs'] = inputs
return packed_results
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(keys={self.keys})'
repr_str += f'(meta_keys={self.meta_keys})'
return repr_str
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
from mmcv.transforms import BaseTransform, Compose
from PIL import Image
from mmdet3d.registry import TRANSFORMS
def get_dtu_raydir(pixelcoords, intrinsic, rot, dir_norm=None):
# rot is c2w
# pixelcoords: H x W x 2
x = (pixelcoords[..., 0] + 0.5 - intrinsic[0, 2]) / intrinsic[0, 0]
y = (pixelcoords[..., 1] + 0.5 - intrinsic[1, 2]) / intrinsic[1, 1]
z = np.ones_like(x)
dirs = np.stack([x, y, z], axis=-1)
# dirs = np.sum(dirs[...,None,:] * rot[:,:], axis=-1) # h*w*1*3 x 3*3
dirs = dirs @ rot[:, :].T #
if dir_norm:
dirs = dirs / (np.linalg.norm(dirs, axis=-1, keepdims=True) + 1e-5)
return dirs
@TRANSFORMS.register_module()
class MultiViewPipeline(BaseTransform):
"""MultiViewPipeline used in nerfdet.
Required Keys:
- depth_info
- img_prefix
- img_info
- lidar2img
- c2w
- cammrotc2w
- lightpos
- ray_info
Modified Keys:
- lidar2img
Added Keys:
- img
- denorm_images
- depth
- c2w
- camrotc2w
- lightpos
- pixels
- raydirs
- gt_images
- gt_depths
- nerf_sizes
- depth_range
Args:
transforms (list[dict]): The transform pipeline
used to process the imgs.
n_images (int): The number of sampled views.
mean (array): The mean values used in normalization.
std (array): The variance values used in normalization.
margin (int): The margin value. Defaults to 10.
depth_range (array): The range of the depth.
Defaults to [0.5, 5.5].
loading (str): The mode of loading. Defaults to 'random'.
nerf_target_views (int): The number of novel views.
sample_freq (int): The frequency of sampling.
"""
def __init__(self,
transforms: dict,
n_images: int,
mean: tuple = [123.675, 116.28, 103.53],
std: tuple = [58.395, 57.12, 57.375],
margin: int = 10,
depth_range: tuple = [0.5, 5.5],
loading: str = 'random',
nerf_target_views: int = 0,
sample_freq: int = 3):
self.transforms = Compose(transforms)
self.depth_transforms = Compose(transforms[1])
self.n_images = n_images
self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32)
self.margin = margin
self.depth_range = depth_range
self.loading = loading
self.sample_freq = sample_freq
self.nerf_target_views = nerf_target_views
def transform(self, results: dict) -> dict:
"""Nerfdet transform function.
Args:
results (dict): Result dict from loading pipeline
Returns:
dict: The result dict containing the processed results.
Updated key and value are described below.
- img (list): The loaded origin image.
- denorm_images (list): The denormalized image.
- depth (list): The origin depth image.
- c2w (list): The c2w matrixes.
- camrotc2w (list): The rotation matrixes.
- lightpos (list): The transform parameters of the camera.
- pixels (list): Some pixel information.
- raydirs (list): The ray-directions.
- gt_images (list): The groundtruth images.
- gt_depths (list): The groundtruth depth images.
- nerf_sizes (array): The size of the groundtruth images.
- depth_range (array): The range of the depth.
Here we give a detailed explanation of some keys mentioned above.
Let P_c be the coordinate of camera, P_w be the coordinate of world.
There is such a conversion relationship: P_c = R @ P_w + T.
The 'camrotc2w' mentioned above corresponds to the R matrix here.
The 'lightpos' corresponds to the T matrix here. And if you put
R and T together, you can get the camera extrinsics matrix. It
corresponds to the 'c2w' mentioned above.
"""
imgs = []
depths = []
extrinsics = []
c2ws = []
camrotc2ws = []
lightposes = []
pixels = []
raydirs = []
gt_images = []
gt_depths = []
denorm_imgs_list = []
nerf_sizes = []
if self.loading == 'random':
ids = np.arange(len(results['img_info']))
replace = True if self.n_images > len(ids) else False
ids = np.random.choice(ids, self.n_images, replace=replace)
if self.nerf_target_views != 0:
target_id = np.random.choice(
ids, self.nerf_target_views, replace=False)
ids = np.setdiff1d(ids, target_id)
ids = ids.tolist()
target_id = target_id.tolist()
else:
ids = np.arange(len(results['img_info']))
begin_id = 0
ids = np.arange(begin_id,
begin_id + self.n_images * self.sample_freq,
self.sample_freq)
if self.nerf_target_views != 0:
target_id = ids
ratio = 0
size = (240, 320)
for i in ids:
_results = dict()
_results['img_path'] = results['img_info'][i]['filename']
_results = self.transforms(_results)
imgs.append(_results['img'])
# normalize
for key in _results.get('img_fields', ['img']):
_results[key] = mmcv.imnormalize(_results[key], self.mean,
self.std, True)
_results['img_norm_cfg'] = dict(
mean=self.mean, std=self.std, to_rgb=True)
# pad
for key in _results.get('img_fields', ['img']):
padded_img = mmcv.impad(_results[key], shape=size, pad_val=0)
_results[key] = padded_img
_results['pad_shape'] = padded_img.shape
_results['pad_fixed_size'] = size
ori_shape = _results['ori_shape']
aft_shape = _results['img_shape']
ratio = ori_shape[0] / aft_shape[0]
# prepare the depth information
if 'depth_info' in results.keys():
if '.npy' in results['depth_info'][i]['filename']:
_results['depth'] = np.load(
results['depth_info'][i]['filename'])
else:
_results['depth'] = np.asarray((Image.open(
results['depth_info'][i]['filename']))) / 1000
_results['depth'] = mmcv.imresize(
_results['depth'], (aft_shape[1], aft_shape[0]))
depths.append(_results['depth'])
denorm_img = mmcv.imdenormalize(
_results['img'], self.mean, self.std, to_bgr=True).astype(
np.uint8) / 255.0
denorm_imgs_list.append(denorm_img)
height, width = padded_img.shape[:2]
extrinsics.append(results['lidar2img']['extrinsic'][i])
# prepare the nerf information
if 'ray_info' in results.keys():
intrinsics_nerf = results['lidar2img']['intrinsic'].copy()
intrinsics_nerf[:2] = intrinsics_nerf[:2] / ratio
assert self.nerf_target_views > 0
for i in target_id:
c2ws.append(results['c2w'][i])
camrotc2ws.append(results['camrotc2w'][i])
lightposes.append(results['lightpos'][i])
px, py = np.meshgrid(
np.arange(self.margin,
width - self.margin).astype(np.float32),
np.arange(self.margin,
height - self.margin).astype(np.float32))
pixelcoords = np.stack((px, py),
axis=-1).astype(np.float32) # H x W x 2
pixels.append(pixelcoords)
raydir = get_dtu_raydir(pixelcoords, intrinsics_nerf,
results['camrotc2w'][i])
raydirs.append(np.reshape(raydir.astype(np.float32), (-1, 3)))
# read target images
temp_results = dict()
temp_results['img_path'] = results['img_info'][i]['filename']
temp_results_ = self.transforms(temp_results)
# normalize
for key in temp_results.get('img_fields', ['img']):
temp_results[key] = mmcv.imnormalize(
temp_results[key], self.mean, self.std, True)
temp_results['img_norm_cfg'] = dict(
mean=self.mean, std=self.std, to_rgb=True)
# pad
for key in temp_results.get('img_fields', ['img']):
padded_img = mmcv.impad(
temp_results[key], shape=size, pad_val=0)
temp_results[key] = padded_img
temp_results['pad_shape'] = padded_img.shape
temp_results['pad_fixed_size'] = size
# denormalize target_images.
denorm_imgs = mmcv.imdenormalize(
temp_results_['img'], self.mean, self.std,
to_bgr=True).astype(np.uint8)
gt_rgb_shape = denorm_imgs.shape
gt_image = denorm_imgs[py.astype(np.int32),
px.astype(np.int32), :]
nerf_sizes.append(np.array(gt_image.shape))
gt_image = np.reshape(gt_image, (-1, 3))
gt_images.append(gt_image / 255.0)
if 'depth_info' in results.keys():
if '.npy' in results['depth_info'][i]['filename']:
_results['depth'] = np.load(
results['depth_info'][i]['filename'])
else:
depth_image = Image.open(
results['depth_info'][i]['filename'])
_results['depth'] = np.asarray(depth_image) / 1000
_results['depth'] = mmcv.imresize(
_results['depth'],
(gt_rgb_shape[1], gt_rgb_shape[0]))
_results['depth'] = _results['depth']
gt_depth = _results['depth'][py.astype(np.int32),
px.astype(np.int32)]
gt_depths.append(gt_depth)
for key in _results.keys():
if key not in ['img', 'img_info']:
results[key] = _results[key]
results['img'] = imgs
if 'ray_info' in results.keys():
results['c2w'] = c2ws
results['camrotc2w'] = camrotc2ws
results['lightpos'] = lightposes
results['pixels'] = pixels
results['raydirs'] = raydirs
results['gt_images'] = gt_images
results['gt_depths'] = gt_depths
results['nerf_sizes'] = nerf_sizes
results['denorm_images'] = denorm_imgs_list
results['depth_range'] = np.array([self.depth_range])
if len(depths) != 0:
results['depth'] = depths
results['lidar2img']['extrinsic'] = extrinsics
return results
@TRANSFORMS.register_module()
class RandomShiftOrigin(BaseTransform):
def __init__(self, std):
self.std = std
def transform(self, results):
shift = np.random.normal(.0, self.std, 3)
results['lidar2img']['origin'] += shift
return results
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union
import torch
from mmengine.structures import InstanceData
from mmdet3d.structures import Det3DDataSample
class NeRFDet3DDataSample(Det3DDataSample):
"""A data structure interface inheirted from Det3DDataSample. Some new
attributes are added to match the NeRF-Det project.
The attributes added in ``NeRFDet3DDataSample`` are divided into two parts:
- ``gt_nerf_images`` (InstanceData): Ground truth of the images which
will be used in the NeRF branch.
- ``gt_nerf_depths`` (InstanceData): Ground truth of the depth images
which will be used in the NeRF branch if needed.
For more details and examples, please refer to the 'Det3DDataSample' file.
"""
@property
def gt_nerf_images(self) -> InstanceData:
return self._gt_nerf_images
@gt_nerf_images.setter
def gt_nerf_images(self, value: InstanceData) -> None:
self.set_field(value, '_gt_nerf_images', dtype=InstanceData)
@gt_nerf_images.deleter
def gt_nerf_images(self) -> None:
del self._gt_nerf_images
@property
def gt_nerf_depths(self) -> InstanceData:
return self._gt_nerf_depths
@gt_nerf_depths.setter
def gt_nerf_depths(self, value: InstanceData) -> None:
self.set_field(value, '_gt_nerf_depths', dtype=InstanceData)
@gt_nerf_depths.deleter
def gt_nerf_depths(self) -> None:
del self._gt_nerf_depths
SampleList = List[NeRFDet3DDataSample]
OptSampleList = Optional[SampleList]
ForwardResults = Union[Dict[str, torch.Tensor], List[NeRFDet3DDataSample],
Tuple[torch.Tensor], torch.Tensor]
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Callable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
"""The MLP module used in NerfDet.
Args:
input_dim (int): The number of input tensor channels.
output_dim (int): The number of output tensor channels.
net_depth (int): The depth of the MLP. Defaults to 8.
net_width (int): The width of the MLP. Defaults to 256.
skip_layer (int): The layer to add skip layers to. Defaults to 4.
hidden_init (Callable): The initialize method of the hidden layers.
hidden_activation (Callable): The activation function of hidden
layers, defaults to ReLU.
output_enabled (bool): If true, the output layers will be used.
Defaults to True.
output_init (Optional): The initialize method of the output layer.
output_activation(Optional): The activation function of output layers.
bias_enabled (Bool): If true, the bias will be used.
bias_init (Callable): The initialize method of the bias.
Defaults to True.
"""
def __init__(
self,
input_dim: int,
output_dim: int = None,
net_depth: int = 8,
net_width: int = 256,
skip_layer: int = 4,
hidden_init: Callable = nn.init.xavier_uniform_,
hidden_activation: Callable = nn.ReLU(),
output_enabled: bool = True,
output_init: Optional[Callable] = nn.init.xavier_uniform_,
output_activation: Optional[Callable] = nn.Identity(),
bias_enabled: bool = True,
bias_init: Callable = nn.init.zeros_,
):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.net_depth = net_depth
self.net_width = net_width
self.skip_layer = skip_layer
self.hidden_init = hidden_init
self.hidden_activation = hidden_activation
self.output_enabled = output_enabled
self.output_init = output_init
self.output_activation = output_activation
self.bias_enabled = bias_enabled
self.bias_init = bias_init
self.hidden_layers = nn.ModuleList()
in_features = self.input_dim
for i in range(self.net_depth):
self.hidden_layers.append(
nn.Linear(in_features, self.net_width, bias=bias_enabled))
if (self.skip_layer is not None) and (i % self.skip_layer
== 0) and (i > 0):
in_features = self.net_width + self.input_dim
else:
in_features = self.net_width
if self.output_enabled:
self.output_layer = nn.Linear(
in_features, self.output_dim, bias=bias_enabled)
else:
self.output_dim = in_features
self.initialize()
def initialize(self):
def init_func_hidden(m):
if isinstance(m, nn.Linear):
if self.hidden_init is not None:
self.hidden_init(m.weight)
if self.bias_enabled and self.bias_init is not None:
self.bias_init(m.bias)
self.hidden_layers.apply(init_func_hidden)
if self.output_enabled:
def init_func_output(m):
if isinstance(m, nn.Linear):
if self.output_init is not None:
self.output_init(m.weight)
if self.bias_enabled and self.bias_init is not None:
self.bias_init(m.bias)
self.output_layer.apply(init_func_output)
def forward(self, x):
inputs = x
for i in range(self.net_depth):
x = self.hidden_layers[i](x)
x = self.hidden_activation(x)
if (self.skip_layer is not None) and (i % self.skip_layer
== 0) and (i > 0):
x = torch.cat([x, inputs], dim=-1)
if self.output_enabled:
x = self.output_layer(x)
x = self.output_activation(x)
return x
class DenseLayer(MLP):
def __init__(self, input_dim, output_dim, **kwargs):
super().__init__(
input_dim=input_dim,
output_dim=output_dim,
net_depth=0, # no hidden layers
**kwargs,
)
class NerfMLP(nn.Module):
"""The Nerf-MLP Module.
Args:
input_dim (int): The number of input tensor channels.
condition_dim (int): The number of condition tensor channels.
feature_dim (int): The number of feature channels. Defaults to 0.
net_depth (int): The depth of the MLP. Defaults to 8.
net_width (int): The width of the MLP. Defaults to 256.
skip_layer (int): The layer to add skip layers to. Defaults to 4.
net_depth_condition (int): The depth of the second part of MLP.
Defaults to 1.
net_width_condition (int): The width of the second part of MLP.
Defaults to 128.
"""
def __init__(
self,
input_dim: int,
condition_dim: int,
feature_dim: int = 0,
net_depth: int = 8,
net_width: int = 256,
skip_layer: int = 4,
net_depth_condition: int = 1,
net_width_condition: int = 128,
):
super().__init__()
self.base = MLP(
input_dim=input_dim + feature_dim,
net_depth=net_depth,
net_width=net_width,
skip_layer=skip_layer,
output_enabled=False,
)
hidden_features = self.base.output_dim
self.sigma_layer = DenseLayer(hidden_features, 1)
if condition_dim > 0:
self.bottleneck_layer = DenseLayer(hidden_features, net_width)
self.rgb_layer = MLP(
input_dim=net_width + condition_dim,
output_dim=3,
net_depth=net_depth_condition,
net_width=net_width_condition,
skip_layer=None,
)
else:
self.rgb_layer = DenseLayer(hidden_features, 3)
def query_density(self, x, features=None):
"""Calculate the raw sigma."""
if features is not None:
x = self.base(torch.cat([x, features], dim=-1))
else:
x = self.base(x)
raw_sigma = self.sigma_layer(x)
return raw_sigma
def forward(self, x, condition=None, features=None):
if features is not None:
x = self.base(torch.cat([x, features], dim=-1))
else:
x = self.base(x)
raw_sigma = self.sigma_layer(x)
if condition is not None:
if condition.shape[:-1] != x.shape[:-1]:
num_rays, n_dim = condition.shape
condition = condition.view(
[num_rays] + [1] * (x.dim() - condition.dim()) +
[n_dim]).expand(list(x.shape[:-1]) + [n_dim])
bottleneck = self.bottleneck_layer(x)
x = torch.cat([bottleneck, condition], dim=-1)
raw_rgb = self.rgb_layer(x)
return raw_rgb, raw_sigma
class SinusoidalEncoder(nn.Module):
"""Sinusodial Positional Encoder used in NeRF."""
def __init__(self, x_dim, min_deg, max_deg, use_identity: bool = True):
super().__init__()
self.x_dim = x_dim
self.min_deg = min_deg
self.max_deg = max_deg
self.use_identity = use_identity
self.register_buffer(
'scales', torch.tensor([2**i for i in range(min_deg, max_deg)]))
@property
def latent_dim(self) -> int:
return (int(self.use_identity) +
(self.max_deg - self.min_deg) * 2) * self.x_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.max_deg == self.min_deg:
return x
xb = torch.reshape(
(x[Ellipsis, None, :] * self.scales[:, None]),
list(x.shape[:-1]) + [(self.max_deg - self.min_deg) * self.x_dim],
)
latent = torch.sin(torch.cat([xb, xb + 0.5 * math.pi], dim=-1))
if self.use_identity:
latent = torch.cat([x] + [latent], dim=-1)
return latent
class VanillaNeRF(nn.Module):
"""The Nerf-MLP with the positional encoder.
Args:
net_depth (int): The depth of the MLP. Defaults to 8.
net_width (int): The width of the MLP. Defaults to 256.
skip_layer (int): The layer to add skip layers to. Defaults to 4.
feature_dim (int): The number of feature channels. Defaults to 0.
net_depth_condition (int): The depth of the second part of MLP.
Defaults to 1.
net_width_condition (int): The width of the second part of MLP.
Defaults to 128.
"""
def __init__(self,
net_depth: int = 8,
net_width: int = 256,
skip_layer: int = 4,
feature_dim: int = 0,
net_depth_condition: int = 1,
net_width_condition: int = 128):
super().__init__()
self.posi_encoder = SinusoidalEncoder(3, 0, 10, True)
self.view_encoder = SinusoidalEncoder(3, 0, 4, True)
self.mlp = NerfMLP(
input_dim=self.posi_encoder.latent_dim,
condition_dim=self.view_encoder.latent_dim,
feature_dim=feature_dim,
net_depth=net_depth,
net_width=net_width,
skip_layer=skip_layer,
net_depth_condition=net_depth_condition,
net_width_condition=net_width_condition,
)
def query_density(self, x, features=None):
x = self.posi_encoder(x)
sigma = self.mlp.query_density(x, features)
return F.relu(sigma)
def forward(self, x, condition=None, features=None):
x = self.posi_encoder(x)
if condition is not None:
condition = self.view_encoder(condition)
rgb, sigma = self.mlp(x, condition=condition, features=features)
return torch.sigmoid(rgb), F.relu(sigma)
# Copyright (c) OpenMMLab. All rights reserved.
# Attention: This file is mainly modified based on the file with the same
# name in the original project. For more details, please refer to the
# origin project.
import torch
import torch.nn.functional as F
class Projector():
def __init__(self, device='cuda'):
self.device = device
def inbound(self, pixel_locations, h, w):
"""check if the pixel locations are in valid range."""
return (pixel_locations[..., 0] <= w - 1.) & \
(pixel_locations[..., 0] >= 0) & \
(pixel_locations[..., 1] <= h - 1.) &\
(pixel_locations[..., 1] >= 0)
def normalize(self, pixel_locations, h, w):
resize_factor = torch.tensor([w - 1., h - 1.
]).to(pixel_locations.device)[None,
None, :]
normalized_pixel_locations = 2 * pixel_locations / resize_factor - 1.
return normalized_pixel_locations
def compute_projections(self, xyz, train_cameras):
"""project 3D points into cameras."""
original_shape = xyz.shape[:2]
xyz = xyz.reshape(-1, 3)
num_views = len(train_cameras)
train_intrinsics = train_cameras[:, 2:18].reshape(-1, 4, 4)
train_poses = train_cameras[:, -16:].reshape(-1, 4, 4)
xyz_h = torch.cat([xyz, torch.ones_like(xyz[..., :1])], dim=-1)
# projections = train_intrinsics.bmm(torch.inverse(train_poses))
# we have inverse the pose in dataloader so
# do not need to inverse here.
projections = train_intrinsics.bmm(train_poses) \
.bmm(xyz_h.t()[None, ...].repeat(num_views, 1, 1))
projections = projections.permute(0, 2, 1)
pixel_locations = projections[..., :2] / torch.clamp(
projections[..., 2:3], min=1e-8)
pixel_locations = torch.clamp(pixel_locations, min=-1e6, max=1e6)
mask = projections[..., 2] > 0
return pixel_locations.reshape((num_views, ) + original_shape + (2, )), \
mask.reshape((num_views, ) + original_shape) # noqa
def compute_angle(self, xyz, query_camera, train_cameras):
original_shape = xyz.shape[:2]
xyz = xyz.reshape(-1, 3)
train_poses = train_cameras[:, -16:].reshape(-1, 4, 4)
num_views = len(train_poses)
query_pose = query_camera[-16:].reshape(-1, 4,
4).repeat(num_views, 1, 1)
ray2tar_pose = (query_pose[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0))
ray2tar_pose /= (torch.norm(ray2tar_pose, dim=-1, keepdim=True) + 1e-6)
ray2train_pose = (
train_poses[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0))
ray2train_pose /= (
torch.norm(ray2train_pose, dim=-1, keepdim=True) + 1e-6)
ray_diff = ray2tar_pose - ray2train_pose
ray_diff_norm = torch.norm(ray_diff, dim=-1, keepdim=True)
ray_diff_dot = torch.sum(
ray2tar_pose * ray2train_pose, dim=-1, keepdim=True)
ray_diff_direction = ray_diff / torch.clamp(ray_diff_norm, min=1e-6)
ray_diff = torch.cat([ray_diff_direction, ray_diff_dot], dim=-1)
ray_diff = ray_diff.reshape((num_views, ) + original_shape + (4, ))
return ray_diff
def compute(self,
xyz,
train_imgs,
train_cameras,
featmaps=None,
grid_sample=True):
assert (train_imgs.shape[0] == 1) \
and (train_cameras.shape[0] == 1)
# only support batch_size=1 for now
train_imgs = train_imgs.squeeze(0)
train_cameras = train_cameras.squeeze(0)
train_imgs = train_imgs.permute(0, 3, 1, 2)
h, w = train_cameras[0][:2]
# compute the projection of the query points to each reference image
pixel_locations, mask_in_front = self.compute_projections(
xyz, train_cameras)
normalized_pixel_locations = self.normalize(pixel_locations, h, w)
# rgb sampling
rgbs_sampled = F.grid_sample(
train_imgs, normalized_pixel_locations, align_corners=True)
rgb_sampled = rgbs_sampled.permute(2, 3, 0, 1)
# deep feature sampling
if featmaps is not None:
if grid_sample:
feat_sampled = F.grid_sample(
featmaps, normalized_pixel_locations, align_corners=True)
feat_sampled = feat_sampled.permute(
2, 3, 0, 1) # [n_rays, n_samples, n_views, d]
rgb_feat_sampled = torch.cat(
[rgb_sampled, feat_sampled],
dim=-1) # [n_rays, n_samples, n_views, d+3]
# rgb_feat_sampled = feat_sampled
else:
n_images, n_channels, f_h, f_w = featmaps.shape
resize_factor = torch.tensor([f_w / w - 1., f_h / h - 1.]).to(
pixel_locations.device)[None, None, :]
sample_location = (pixel_locations *
resize_factor).round().long()
n_images, n_ray, n_sample, _ = sample_location.shape
sample_x = sample_location[..., 0].view(n_images, -1)
sample_y = sample_location[..., 1].view(n_images, -1)
valid = (sample_x >= 0) & (sample_y >=
0) & (sample_x < f_w) & (
sample_y < f_h)
valid = valid * mask_in_front.view(n_images, -1)
feat_sampled = torch.zeros(
(n_images, n_channels, sample_x.shape[-1]),
device=featmaps.device)
for i in range(n_images):
feat_sampled[i, :,
valid[i]] = featmaps[i, :, sample_y[i,
valid[i]],
sample_y[i, valid[i]]]
feat_sampled = feat_sampled.view(n_images, n_channels, n_ray,
n_sample)
rgb_feat_sampled = feat_sampled.permute(2, 3, 0, 1)
else:
rgb_feat_sampled = None
inbound = self.inbound(pixel_locations, h, w)
mask = (inbound * mask_in_front).float().permute(
1, 2, 0)[..., None] # [n_rays, n_samples, n_views, 1]
return rgb_feat_sampled, mask
# Copyright (c) OpenMMLab. All rights reserved.
# Attention: This file is mainly modified based on the file with the same
# name in the original project. For more details, please refer to the
# origin project.
from collections import OrderedDict
import numpy as np
import torch
import torch.nn.functional as F
rng = np.random.RandomState(234)
# helper functions for nerf ray rendering
def volume_sampling(sample_pts, features, aabb):
B, C, D, W, H = features.shape
assert B == 1
aabb = torch.Tensor(aabb).to(sample_pts.device)
N_rays, N_samples, coords = sample_pts.shape
sample_pts = sample_pts.view(1, N_rays * N_samples, 1, 1,
3).repeat(B, 1, 1, 1, 1)
aabbSize = aabb[1] - aabb[0]
invgridSize = 1.0 / aabbSize * 2
norm_pts = (sample_pts - aabb[0]) * invgridSize - 1
sample_features = F.grid_sample(
features, norm_pts, align_corners=True, padding_mode='border')
masks = ((norm_pts < 1) & (norm_pts > -1)).float().sum(dim=-1)
masks = (masks.view(N_rays, N_samples) == 3)
return sample_features.view(C, N_rays,
N_samples).permute(1, 2, 0).contiguous(), masks
def _compute_projection(img_meta):
views = len(img_meta['lidar2img']['extrinsic'])
intrinsic = torch.tensor(img_meta['lidar2img']['intrinsic'][:4, :4])
ratio = img_meta['ori_shape'][0] / img_meta['img_shape'][0]
intrinsic[:2] /= ratio
intrinsic = intrinsic.unsqueeze(0).view(1, 16).repeat(views, 1)
img_size = torch.Tensor(img_meta['img_shape'][:2]).to(intrinsic.device)
img_size = img_size.unsqueeze(0).repeat(views, 1)
extrinsics = []
for v in range(views):
extrinsics.append(
torch.Tensor(img_meta['lidar2img']['extrinsic'][v]).to(
intrinsic.device))
extrinsic = torch.stack(extrinsics).view(views, 16)
train_cameras = torch.cat([img_size, intrinsic, extrinsic], dim=-1)
return train_cameras.unsqueeze(0)
def compute_mask_points(feature, mask):
weight = mask / (torch.sum(mask, dim=2, keepdim=True) + 1e-8)
mean = torch.sum(feature * weight, dim=2, keepdim=True)
var = torch.sum((feature - mean)**2, dim=2, keepdim=True)
var = var / (torch.sum(mask, dim=2, keepdim=True) + 1e-8)
var = torch.exp(-var)
return mean, var
def sample_pdf(bins, weights, N_samples, det=False):
"""Helper function used for sampling.
Args:
bins (tensor):Tensor of shape [N_rays, M+1], M is the number of bins
weights (tensor):Tensor of shape [N_rays, M+1], M is the number of bins
N_samples (int):Number of samples along each ray
det (bool):If True, will perform deterministic sampling
Returns:
samples (tuple): [N_rays, N_samples]
"""
M = weights.shape[1]
weights += 1e-5
# Get pdf
pdf = weights / torch.sum(weights, dim=-1, keepdim=True)
cdf = torch.cumsum(pdf, dim=-1)
cdf = torch.cat([torch.zeros_like(cdf[:, 0:1]), cdf], dim=-1)
# Take uniform samples
if det:
u = torch.linspace(0., 1., N_samples, device=bins.device)
u = u.unsqueeze(0).repeat(bins.shape[0], 1)
else:
u = torch.rand(bins.shape[0], N_samples, device=bins.device)
# Invert CDF
above_inds = torch.zeros_like(u, dtype=torch.long)
for i in range(M):
above_inds += (u >= cdf[:, i:i + 1]).long()
# random sample inside each bin
below_inds = torch.clamp(above_inds - 1, min=0)
inds_g = torch.stack((below_inds, above_inds), dim=2)
cdf = cdf.unsqueeze(1).repeat(1, N_samples, 1)
cdf_g = torch.gather(input=cdf, dim=-1, index=inds_g)
bins = bins.unsqueeze(1).repeat(1, N_samples, 1)
bins_g = torch.gather(input=bins, dim=-1, index=inds_g)
denom = cdf_g[:, :, 1] - cdf_g[:, :, 0]
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
t = (u - cdf_g[:, :, 0]) / denom
samples = bins_g[:, :, 0] + t * (bins_g[:, :, 1] - bins_g[:, :, 0])
return samples
def sample_along_camera_ray(ray_o,
ray_d,
depth_range,
N_samples,
inv_uniform=False,
det=False):
"""Sampling along the camera ray.
Args:
ray_o (tensor): Origin of the ray in scene coordinate system;
tensor of shape [N_rays, 3]
ray_d (tensor): Homogeneous ray direction vectors in
scene coordinate system; tensor of shape [N_rays, 3]
depth_range (tuple): [near_depth, far_depth]
inv_uniform (bool): If True,uniformly sampling inverse depth.
det (bool): If True, will perform deterministic sampling.
Returns:
pts (tensor): Tensor of shape [N_rays, N_samples, 3]
z_vals (tensor): Tensor of shape [N_rays, N_samples]
"""
# will sample inside [near_depth, far_depth]
# assume the nearest possible depth is at least (min_ratio * depth)
near_depth_value = depth_range[0]
far_depth_value = depth_range[1]
assert near_depth_value > 0 and far_depth_value > 0 \
and far_depth_value > near_depth_value
near_depth = near_depth_value * torch.ones_like(ray_d[..., 0])
far_depth = far_depth_value * torch.ones_like(ray_d[..., 0])
if inv_uniform:
start = 1. / near_depth
step = (1. / far_depth - start) / (N_samples - 1)
inv_z_vals = torch.stack([start + i * step for i in range(N_samples)],
dim=1)
z_vals = 1. / inv_z_vals
else:
start = near_depth
step = (far_depth - near_depth) / (N_samples - 1)
z_vals = torch.stack([start + i * step for i in range(N_samples)],
dim=1)
if not det:
# get intervals between samples
mids = .5 * (z_vals[:, 1:] + z_vals[:, :-1])
upper = torch.cat([mids, z_vals[:, -1:]], dim=-1)
lower = torch.cat([z_vals[:, 0:1], mids], dim=-1)
# uniform samples in those intervals
t_rand = torch.rand_like(z_vals)
z_vals = lower + (upper - lower) * t_rand
ray_d = ray_d.unsqueeze(1).repeat(1, N_samples, 1)
ray_o = ray_o.unsqueeze(1).repeat(1, N_samples, 1)
pts = z_vals.unsqueeze(2) * ray_d + ray_o # [N_rays, N_samples, 3]
return pts, z_vals
# ray rendering of nerf
def raw2outputs(raw, z_vals, mask, white_bkgd=False):
"""Transform raw data to outputs:
Args:
raw(tensor):Raw network output.Tensor of shape [N_rays, N_samples, 4]
z_vals(tensor):Depth of point samples along rays.
Tensor of shape [N_rays, N_samples]
ray_d(tensor):[N_rays, 3]
Returns:
ret(dict):
-rgb(tensor):[N_rays, 3]
-depth(tensor):[N_rays,]
-weights(tensor):[N_rays,]
-depth_std(tensor):[N_rays,]
"""
rgb = raw[:, :, :3] # [N_rays, N_samples, 3]
sigma = raw[:, :, 3] # [N_rays, N_samples]
# note: we did not use the intervals here,
# because in practice different scenes from COLMAP can have
# very different scales, and using interval can affect
# the model's generalization ability.
# Therefore we don't use the intervals for both training and evaluation.
sigma2alpha = lambda sigma, dists: 1. - torch.exp(-sigma) # noqa
# point samples are ordered with increasing depth
# interval between samples
dists = z_vals[:, 1:] - z_vals[:, :-1]
dists = torch.cat((dists, dists[:, -1:]), dim=-1)
alpha = sigma2alpha(sigma, dists)
T = torch.cumprod(1. - alpha + 1e-10, dim=-1)[:, :-1]
T = torch.cat((torch.ones_like(T[:, 0:1]), T), dim=-1)
# maths show weights, and summation of weights along a ray,
# are always inside [0, 1]
weights = alpha * T
rgb_map = torch.sum(weights.unsqueeze(2) * rgb, dim=1)
if white_bkgd:
rgb_map = rgb_map + (1. - torch.sum(weights, dim=-1, keepdim=True))
if mask is not None:
mask = mask.float().sum(dim=1) > 8
depth_map = torch.sum(
weights * z_vals, dim=-1) / (
torch.sum(weights, dim=-1) + 1e-8)
depth_map = torch.clamp(depth_map, z_vals.min(), z_vals.max())
ret = OrderedDict([('rgb', rgb_map), ('depth', depth_map),
('weights', weights), ('mask', mask), ('alpha', alpha),
('z_vals', z_vals), ('transparency', T)])
return ret
def render_rays_func(
ray_o,
ray_d,
mean_volume,
cov_volume,
features_2D,
img,
aabb,
near_far_range,
N_samples,
N_rand=4096,
nerf_mlp=None,
img_meta=None,
projector=None,
mode='volume', # volume and image
nerf_sample_view=3,
inv_uniform=False,
N_importance=0,
det=False,
is_train=True,
white_bkgd=False,
gt_rgb=None,
gt_depth=None):
ret = {
'outputs_coarse': None,
'outputs_fine': None,
'gt_rgb': gt_rgb,
'gt_depth': gt_depth
}
# pts: [N_rays, N_samples, 3]
# z_vals: [N_rays, N_samples]
pts, z_vals = sample_along_camera_ray(
ray_o=ray_o,
ray_d=ray_d,
depth_range=near_far_range,
N_samples=N_samples,
inv_uniform=inv_uniform,
det=det)
N_rays, N_samples = pts.shape[:2]
if mode == 'image':
img = img.permute(0, 2, 3, 1).unsqueeze(0)
train_camera = _compute_projection(img_meta).to(img.device)
rgb_feat, mask = projector.compute(
pts, img, train_camera, features_2D, grid_sample=True)
pixel_mask = mask[..., 0].sum(dim=2) > 1
mean, var = compute_mask_points(rgb_feat, mask)
globalfeat = torch.cat([mean, var], dim=-1).squeeze(2)
rgb_pts, density_pts = nerf_mlp(pts, ray_d, globalfeat)
raw_coarse = torch.cat([rgb_pts, density_pts], dim=-1)
ret['sigma'] = density_pts
elif mode == 'volume':
mean_pts, inbound_masks = volume_sampling(pts, mean_volume, aabb)
cov_pts, inbound_masks = volume_sampling(pts, cov_volume, aabb)
# This masks is for indicating which points outside of aabb
img = img.permute(0, 2, 3, 1).unsqueeze(0)
train_camera = _compute_projection(img_meta).to(img.device)
_, view_mask = projector.compute(pts, img, train_camera, None)
pixel_mask = view_mask[..., 0].sum(dim=2) > 1
# plot_3D_vis(pts, aabb, img, train_camera)
# [N_rays, N_samples], should at least have 2 observations
# This mask is for indicating which points do not have projected point
globalpts = torch.cat([mean_pts, cov_pts], dim=-1)
rgb_pts, density_pts = nerf_mlp(pts, ray_d, globalpts)
density_pts = density_pts * inbound_masks.unsqueeze(dim=-1)
raw_coarse = torch.cat([rgb_pts, density_pts], dim=-1)
outputs_coarse = raw2outputs(
raw_coarse, z_vals, pixel_mask, white_bkgd=white_bkgd)
ret['outputs_coarse'] = outputs_coarse
return ret
def render_rays(
ray_batch,
mean_volume,
cov_volume,
features_2D,
img,
aabb,
near_far_range,
N_samples,
N_rand=4096,
nerf_mlp=None,
img_meta=None,
projector=None,
mode='volume', # volume and image
nerf_sample_view=3,
inv_uniform=False,
N_importance=0,
det=False,
is_train=True,
white_bkgd=False,
render_testing=False):
"""The function of the nerf rendering."""
ray_o = ray_batch['ray_o']
ray_d = ray_batch['ray_d']
gt_rgb = ray_batch['gt_rgb']
gt_depth = ray_batch['gt_depth']
nerf_sizes = ray_batch['nerf_sizes']
if is_train:
ray_o = ray_o.view(-1, 3)
ray_d = ray_d.view(-1, 3)
gt_rgb = gt_rgb.view(-1, 3)
if gt_depth.shape[1] != 0:
gt_depth = gt_depth.view(-1, 1)
non_zero_depth = (gt_depth > 0).squeeze(-1)
ray_o = ray_o[non_zero_depth]
ray_d = ray_d[non_zero_depth]
gt_rgb = gt_rgb[non_zero_depth]
gt_depth = gt_depth[non_zero_depth]
else:
gt_depth = None
total_rays = ray_d.shape[0]
select_inds = rng.choice(total_rays, size=(N_rand, ), replace=False)
ray_o = ray_o[select_inds]
ray_d = ray_d[select_inds]
gt_rgb = gt_rgb[select_inds]
if gt_depth is not None:
gt_depth = gt_depth[select_inds]
rets = render_rays_func(
ray_o,
ray_d,
mean_volume,
cov_volume,
features_2D,
img,
aabb,
near_far_range,
N_samples,
N_rand,
nerf_mlp,
img_meta,
projector,
mode, # volume and image
nerf_sample_view,
inv_uniform,
N_importance,
det,
is_train,
white_bkgd,
gt_rgb,
gt_depth)
elif render_testing:
nerf_size = nerf_sizes[0]
view_num = ray_o.shape[1]
H = nerf_size[0][0]
W = nerf_size[0][1]
ray_o = ray_o.view(-1, 3)
ray_d = ray_d.view(-1, 3)
gt_rgb = gt_rgb.view(-1, 3)
print(gt_rgb.shape)
if len(gt_depth) != 0:
gt_depth = gt_depth.view(-1, 1)
else:
gt_depth = None
assert view_num * H * W == ray_o.shape[0]
num_rays = ray_o.shape[0]
results = []
rgbs = []
for i in range(0, num_rays, N_rand):
ray_o_chunck = ray_o[i:i + N_rand, :]
ray_d_chunck = ray_d[i:i + N_rand, :]
ret = render_rays_func(ray_o_chunck, ray_d_chunck, mean_volume,
cov_volume, features_2D, img, aabb,
near_far_range, N_samples, N_rand, nerf_mlp,
img_meta, projector, mode, nerf_sample_view,
inv_uniform, N_importance, True, is_train,
white_bkgd, gt_rgb, gt_depth)
results.append(ret)
rgbs = []
depths = []
if results[0]['outputs_coarse'] is not None:
for i in range(len(results)):
rgb = results[i]['outputs_coarse']['rgb']
rgbs.append(rgb)
depth = results[i]['outputs_coarse']['depth']
depths.append(depth)
rets = {
'outputs_coarse': {
'rgb': torch.cat(rgbs, dim=0).view(view_num, H, W, 3),
'depth': torch.cat(depths, dim=0).view(view_num, H, W, 1),
},
'gt_rgb':
gt_rgb.view(view_num, H, W, 3),
'gt_depth':
gt_depth.view(view_num, H, W, 1) if gt_depth is not None else None,
}
else:
rets = None
return rets
# Copyright (c) OpenMMLab. All rights reserved.
import os
import cv2
import numpy as np
import torch
from skimage.metrics import structural_similarity
def compute_psnr_from_mse(mse):
return -10.0 * torch.log(mse) / np.log(10.0)
def compute_psnr(pred, target, mask=None):
"""Compute psnr value (we assume the maximum pixel value is 1)."""
if mask is not None:
pred, target = pred[mask], target[mask]
mse = ((pred - target)**2).mean()
return compute_psnr_from_mse(mse).cpu().numpy()
def compute_ssim(pred, target, mask=None):
"""Computes Masked SSIM following the neuralbody paper."""
assert pred.shape == target.shape and pred.shape[-1] == 3
if mask is not None:
x, y, w, h = cv2.boundingRect(mask.cpu().numpy().astype(np.uint8))
pred = pred[y:y + h, x:x + w]
target = target[y:y + h, x:x + w]
try:
ssim = structural_similarity(
pred.cpu().numpy(), target.cpu().numpy(), channel_axis=-1)
except ValueError:
ssim = structural_similarity(
pred.cpu().numpy(), target.cpu().numpy(), multichannel=True)
return ssim
def save_rendered_img(img_meta, rendered_results):
filename = img_meta[0]['filename']
scenes = filename.split('/')[-2]
for ret in rendered_results:
depth = ret['outputs_coarse']['depth']
rgb = ret['outputs_coarse']['rgb']
gt = ret['gt_rgb']
gt_depth = ret['gt_depth']
# save images
psnr_total = 0
ssim_total = 0
rsme = 0
for v in range(gt.shape[0]):
rsme += ((depth[v] - gt_depth[v])**2).cpu().numpy()
depth_ = ((depth[v] - depth[v].min()) /
(depth[v].max() - depth[v].min() + 1e-8)).repeat(1, 1, 3)
img_to_save = torch.cat([rgb[v], gt[v], depth_], dim=1)
image_path = os.path.join('nerf_vs_rebuttal', scenes)
if not os.path.exists(image_path):
os.makedirs(image_path)
save_dir = os.path.join(image_path, 'view_' + str(v) + '.png')
font = cv2.FONT_HERSHEY_SIMPLEX
org = (50, 50)
fontScale = 1
color = (255, 0, 0)
thickness = 2
image = np.uint8(img_to_save.cpu().numpy() * 255.0)
psnr = compute_psnr(rgb[v], gt[v], mask=None)
psnr_total += psnr
ssim = compute_ssim(rgb[v], gt[v], mask=None)
ssim_total += ssim
image = cv2.putText(
image, 'PSNR: ' + '%.2f' % compute_psnr(rgb[v], gt[v], mask=None),
org, font, fontScale, color, thickness, cv2.LINE_AA)
cv2.imwrite(save_dir, image)
return psnr_total / gt.shape[0], ssim_total / gt.shape[0], rsme / gt.shape[
0]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet3d.models.detectors import Base3DDetector
from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils import ConfigType, OptConfigType
from .nerf_utils.nerf_mlp import VanillaNeRF
from .nerf_utils.projection import Projector
from .nerf_utils.render_ray import render_rays
# from ..utils.nerf_utils.save_rendered_img import save_rendered_img
@MODELS.register_module()
class NerfDet(Base3DDetector):
r"""`ImVoxelNet <https://arxiv.org/abs/2307.14620>`_.
Args:
backbone (:obj:`ConfigDict` or dict): The backbone config.
neck (:obj:`ConfigDict` or dict): The neck config.
neck_3d(:obj:`ConfigDict` or dict): The 3D neck config.
bbox_head(:obj:`ConfigDict` or dict): The bbox head config.
prior_generator (:obj:`ConfigDict` or dict): The prior generator
config.
n_voxels (list): Number of voxels along x, y, z axis.
voxel_size (list): The size of voxels.Each voxel represents
a cube of `voxel_size[0]` meters, `voxel_size[1]` meters,
``
train_cfg (:obj:`ConfigDict` or dict, optional): Config dict of
training hyper-parameters. Defaults to None.
test_cfg (:obj:`ConfigDict` or dict, optional): Config dict of test
hyper-parameters. Defaults to None.
init_cfg (:obj:`ConfigDict` or dict, optional): The initialization
config. Defaults to None.
render_testing (bool): If you want to render novel view, please set
"render_testing = True" in config
The other args are the parameters of NeRF, you can just use the
default values.
"""
def __init__(
self,
backbone: ConfigType,
neck: ConfigType,
neck_3d: ConfigType,
bbox_head: ConfigType,
prior_generator: ConfigType,
n_voxels: List,
voxel_size: List,
head_2d: ConfigType = None,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
data_preprocessor: OptConfigType = None,
init_cfg: OptConfigType = None,
# pretrained,
aabb: Tuple = None,
near_far_range: List = None,
N_samples: int = 64,
N_rand: int = 2048,
depth_supervise: bool = False,
use_nerf_mask: bool = True,
nerf_sample_view: int = 3,
nerf_mode: str = 'volume',
squeeze_scale: int = 4,
rgb_supervision: bool = True,
nerf_density: bool = False,
render_testing: bool = False):
super().__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
self.backbone = MODELS.build(backbone)
self.neck = MODELS.build(neck)
self.neck_3d = MODELS.build(neck_3d)
bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg)
self.bbox_head = MODELS.build(bbox_head)
self.head_2d = MODELS.build(head_2d) if head_2d is not None else None
self.n_voxels = n_voxels
self.prior_generator = TASK_UTILS.build(prior_generator)
self.voxel_size = voxel_size
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.aabb = aabb
self.near_far_range = near_far_range
self.N_samples = N_samples
self.N_rand = N_rand
self.depth_supervise = depth_supervise
self.projector = Projector()
self.squeeze_scale = squeeze_scale
self.use_nerf_mask = use_nerf_mask
self.rgb_supervision = rgb_supervision
nerf_feature_dim = neck['out_channels'] // squeeze_scale
self.nerf_mlp = VanillaNeRF(
net_depth=4, # The depth of the MLP
net_width=256, # The width of the MLP
skip_layer=3, # The layer to add skip layers to.
feature_dim=nerf_feature_dim + 6, # + RGB original imgs
net_depth_condition=1, # The depth of the second part of MLP
net_width_condition=128)
self.nerf_mode = nerf_mode
self.nerf_density = nerf_density
self.nerf_sample_view = nerf_sample_view
self.render_testing = render_testing
# hard code here, will deal with batch issue later.
self.cov = nn.Sequential(
nn.Conv3d(
neck['out_channels'],
neck['out_channels'],
kernel_size=3,
padding=1), nn.ReLU(inplace=True),
nn.Conv3d(
neck['out_channels'],
neck['out_channels'],
kernel_size=3,
padding=1), nn.ReLU(inplace=True),
nn.Conv3d(neck['out_channels'], 1, kernel_size=1))
self.mean_mapping = nn.Sequential(
nn.Conv3d(
neck['out_channels'], nerf_feature_dim // 2, kernel_size=1))
self.cov_mapping = nn.Sequential(
nn.Conv3d(
neck['out_channels'], nerf_feature_dim // 2, kernel_size=1))
self.mapping = nn.Sequential(
nn.Linear(neck['out_channels'], nerf_feature_dim // 2))
self.mapping_2d = nn.Sequential(
nn.Conv2d(
neck['out_channels'], nerf_feature_dim // 2, kernel_size=1))
# self.overfit_nerfmlp = overfit_nerfmlp
# if self.overfit_nerfmlp:
# self. _finetuning_NeRF_MLP()
self.render_testing = render_testing
def extract_feat(self,
batch_inputs_dict: dict,
batch_data_samples: SampleList,
mode,
depth=None,
ray_batch=None):
"""Extract 3d features from the backbone -> fpn -> 3d projection.
-> 3d neck -> bbox_head.
Args:
batch_inputs_dict (dict): The model input dict which include
the 'imgs' key.
- imgs (torch.Tensor, optional): Image of each sample.
batch_data_samples (list[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instances` of `gt_panoptic_seg` or `gt_sem_seg`
Returns:
Tuple:
- torch.Tensor: Features of shape (N, C_out, N_x, N_y, N_z).
- torch.Tensor: Valid mask of shape (N, 1, N_x, N_y, N_z).
- torch.Tensor: 2D features if needed.
- dict: The nerf rendered information including the
'output_coarse', 'gt_rgb' and 'gt_depth' keys.
"""
img = batch_inputs_dict['imgs']
img = img.float()
batch_img_metas = [
data_samples.metainfo for data_samples in batch_data_samples
]
batch_size = img.shape[0]
if len(img.shape) > 4:
img = img.reshape([-1] + list(img.shape)[2:])
x = self.backbone(img)
x = self.neck(x)[0]
x = x.reshape([batch_size, -1] + list(x.shape[1:]))
else:
x = self.backbone(img)
x = self.neck(x)[0]
if depth is not None:
depth_bs = depth.shape[0]
assert depth_bs == batch_size
depth = batch_inputs_dict['depth']
depth = depth.reshape([-1] + list(depth.shape)[2:])
features_2d = self.head_2d.forward(x[-1], batch_img_metas) \
if self.head_2d is not None else None
stride = img.shape[-1] / x.shape[-1]
assert stride == 4
stride = int(stride)
volumes, valids = [], []
rgb_preds = []
for feature, img_meta in zip(x, batch_img_metas):
angles = features_2d[
0] if features_2d is not None and mode == 'test' else None
projection = self._compute_projection(img_meta, stride,
angles).to(x.device)
points = get_points(
n_voxels=torch.tensor(self.n_voxels),
voxel_size=torch.tensor(self.voxel_size),
origin=torch.tensor(img_meta['lidar2img']['origin'])).to(
x.device)
height = img_meta['img_shape'][0] // stride
width = img_meta['img_shape'][1] // stride
# Construct the volume space
# volume together with valid is the constructed scene
# volume represents V_i and valid represents M_p
volume, valid = backproject(feature[:, :, :height, :width], points,
projection, depth, self.voxel_size)
density = None
volume_sum = volume.sum(dim=0)
# cov_valid = valid.clone().detach()
valid = valid.sum(dim=0)
volume_mean = volume_sum / (valid + 1e-8)
volume_mean[:, valid[0] == 0] = .0
# volume_cov = (volume - volume_mean.unsqueeze(0)) ** 2 * cov_valid
# volume_cov = torch.sum(volume_cov, dim=0) / (valid + 1e-8)
volume_cov = torch.sum(
(volume - volume_mean.unsqueeze(0))**2, dim=0) / (
valid + 1e-8)
volume_cov[:, valid[0] == 0] = 1e6
volume_cov = torch.exp(-volume_cov) # default setting
# be careful here, the smaller the cov, the larger the weight.
n_channels, n_x_voxels, n_y_voxels, n_z_voxels = volume_mean.shape
if ray_batch is not None:
if self.nerf_mode == 'volume':
mean_volume = self.mean_mapping(volume_mean.unsqueeze(0))
cov_volume = self.cov_mapping(volume_cov.unsqueeze(0))
feature_2d = feature[:, :, :height, :width]
elif self.nerf_mode == 'image':
mean_volume = None
cov_volume = None
feature_2d = feature[:, :, :height, :width]
n_v, C, height, width = feature_2d.shape
feature_2d = feature_2d.view(n_v, C,
-1).permute(0, 2,
1).contiguous()
feature_2d = self.mapping(feature_2d).permute(
0, 2, 1).contiguous().view(n_v, -1, height, width)
denorm_images = ray_batch['denorm_images']
denorm_images = denorm_images.reshape(
[-1] + list(denorm_images.shape)[2:])
rgb_projection = self._compute_projection(
img_meta, stride=1, angles=None).to(x.device)
rgb_volume, _ = backproject(
denorm_images[:, :, :img_meta['img_shape'][0], :
img_meta['img_shape'][1]], points,
rgb_projection, depth, self.voxel_size)
ret = render_rays(
ray_batch,
mean_volume,
cov_volume,
feature_2d,
denorm_images,
self.aabb,
self.near_far_range,
self.N_samples,
self.N_rand,
self.nerf_mlp,
img_meta,
self.projector,
self.nerf_mode,
self.nerf_sample_view,
is_train=True if mode == 'train' else False,
render_testing=self.render_testing)
rgb_preds.append(ret)
if self.nerf_density:
# would have 0 bias issue for mean_mapping.
n_v, C, n_x_voxels, n_y_voxels, n_z_voxels = volume.shape
volume = volume.view(n_v, C, -1).permute(0, 2,
1).contiguous()
mapping_volume = self.mapping(volume).permute(
0, 2, 1).contiguous().view(n_v, -1, n_x_voxels,
n_y_voxels, n_z_voxels)
mapping_volume = torch.cat([rgb_volume, mapping_volume],
dim=1)
mapping_volume_sum = mapping_volume.sum(dim=0)
mapping_volume_mean = mapping_volume_sum / (valid + 1e-8)
# mapping_volume_cov = (
# mapping_volume - mapping_volume_mean.unsqueeze(0)
# ) ** 2 * cov_valid
mapping_volume_cov = (mapping_volume -
mapping_volume_mean.unsqueeze(0))**2
mapping_volume_cov = torch.sum(
mapping_volume_cov, dim=0) / (
valid + 1e-8)
mapping_volume_cov[:, valid[0] == 0] = 1e6
mapping_volume_cov = torch.exp(
-mapping_volume_cov) # default setting
global_volume = torch.cat(
[mapping_volume_mean, mapping_volume_cov], dim=1)
global_volume = global_volume.view(
-1, n_x_voxels * n_y_voxels * n_z_voxels).permute(
1, 0).contiguous()
points = points.view(3, -1).permute(1, 0).contiguous()
density = self.nerf_mlp.query_density(
points, global_volume)
alpha = 1 - torch.exp(-density)
# density -> alpha
# (1, n_x_voxels, n_y_voxels, n_z_voxels)
volume = alpha.view(1, n_x_voxels, n_y_voxels,
n_z_voxels) * volume_mean
volume[:, valid[0] == 0] = .0
volumes.append(volume)
valids.append(valid)
x = torch.stack(volumes)
x = self.neck_3d(x)
return x, torch.stack(valids).float(), features_2d, rgb_preds
def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
**kwargs) -> Union[dict, list]:
"""Calculate losses from a batch of inputs and data samples.
Args:
batch_inputs_dict (dict): The model input dict which include
the 'imgs' key.
- imgs (torch.Tensor, optional): Image of each sample.
batch_data_samples (list[:obj: `DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Returns:
dict: A dictionary of loss components.
"""
ray_batchs = {}
batch_images = []
batch_depths = []
if 'images' in batch_data_samples[0].gt_nerf_images:
for data_samples in batch_data_samples:
image = data_samples.gt_nerf_images['images']
batch_images.append(image)
batch_images = torch.stack(batch_images)
if 'depths' in batch_data_samples[0].gt_nerf_depths:
for data_samples in batch_data_samples:
depth = data_samples.gt_nerf_depths['depths']
batch_depths.append(depth)
batch_depths = torch.stack(batch_depths)
if 'raydirs' in batch_inputs_dict.keys():
ray_batchs['ray_o'] = batch_inputs_dict['lightpos']
ray_batchs['ray_d'] = batch_inputs_dict['raydirs']
ray_batchs['gt_rgb'] = batch_images
ray_batchs['gt_depth'] = batch_depths
ray_batchs['nerf_sizes'] = batch_inputs_dict['nerf_sizes']
ray_batchs['denorm_images'] = batch_inputs_dict['denorm_images']
x, valids, features_2d, rgb_preds = self.extract_feat(
batch_inputs_dict,
batch_data_samples,
'train',
depth=None,
ray_batch=ray_batchs)
else:
x, valids, features_2d, rgb_preds = self.extract_feat(
batch_inputs_dict, batch_data_samples, 'train')
x += (valids, )
losses = self.bbox_head.loss(x, batch_data_samples, **kwargs)
# if self.head_2d is not None:
# losses.update(
# self.head_2d.loss(*features_2d, batch_data_samples)
# )
if len(ray_batchs) != 0 and self.rgb_supervision:
losses.update(self.nvs_loss_func(rgb_preds))
if self.depth_supervise:
losses.update(self.depth_loss_func(rgb_preds))
return losses
def nvs_loss_func(self, rgb_pred):
loss = 0
for ret in rgb_pred:
rgb = ret['outputs_coarse']['rgb']
gt = ret['gt_rgb']
masks = ret['outputs_coarse']['mask']
if self.use_nerf_mask:
loss += torch.sum(masks.unsqueeze(-1) * (rgb - gt)**2) / (
masks.sum() + 1e-6)
else:
loss += torch.mean((rgb - gt)**2)
return dict(loss_nvs=loss)
def depth_loss_func(self, rgb_pred):
loss = 0
for ret in rgb_pred:
depth = ret['outputs_coarse']['depth']
gt = ret['gt_depth'].squeeze(-1)
masks = ret['outputs_coarse']['mask']
if self.use_nerf_mask:
loss += torch.sum(masks * torch.abs(depth - gt)) / (
masks.sum() + 1e-6)
else:
loss += torch.mean(torch.abs(depth - gt))
return dict(loss_depth=loss)
def predict(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
**kwargs) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
batch_inputs_dict (dict): The model input dict which include
the 'imgs' key.
- imgs (torch.Tensor, optional): Image of each sample.
batch_data_samples (List[:obj:`NeRFDet3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Returns:
list[:obj:`NeRFDet3DDataSample`]: Detection results of the
input images. Each NeRFDet3DDataSample usually contain
'pred_instances_3d'. And the ``pred_instances_3d`` usually
contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instance, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (Tensor): Contains a tensor with shape
(num_instances, C) where C = 6.
"""
ray_batchs = {}
batch_images = []
batch_depths = []
if 'images' in batch_data_samples[0].gt_nerf_images:
for data_samples in batch_data_samples:
image = data_samples.gt_nerf_images['images']
batch_images.append(image)
batch_images = torch.stack(batch_images)
if 'depths' in batch_data_samples[0].gt_nerf_depths:
for data_samples in batch_data_samples:
depth = data_samples.gt_nerf_depths['depths']
batch_depths.append(depth)
batch_depths = torch.stack(batch_depths)
if 'raydirs' in batch_inputs_dict.keys():
ray_batchs['ray_o'] = batch_inputs_dict['lightpos']
ray_batchs['ray_d'] = batch_inputs_dict['raydirs']
ray_batchs['gt_rgb'] = batch_images
ray_batchs['gt_depth'] = batch_depths
ray_batchs['nerf_sizes'] = batch_inputs_dict['nerf_sizes']
ray_batchs['denorm_images'] = batch_inputs_dict['denorm_images']
x, valids, features_2d, rgb_preds = self.extract_feat(
batch_inputs_dict,
batch_data_samples,
'test',
depth=None,
ray_batch=ray_batchs)
else:
x, valids, features_2d, rgb_preds = self.extract_feat(
batch_inputs_dict, batch_data_samples, 'test')
x += (valids, )
results_list = self.bbox_head.predict(x, batch_data_samples, **kwargs)
predictions = self.add_pred_to_datasample(batch_data_samples,
results_list)
return predictions
def _forward(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
*args, **kwargs) -> Tuple[List[torch.Tensor]]:
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
Args:
batch_inputs_dict (dict): The model input dict which include
the 'imgs' key.
- imgs (torch.Tensor, optional): Image of each sample.
batch_data_samples (List[:obj:`NeRFDet3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`
Returns:
tuple[list]: A tuple of features from ``bbox_head`` forward
"""
ray_batchs = {}
batch_images = []
batch_depths = []
if 'images' in batch_data_samples[0].gt_nerf_images:
for data_samples in batch_data_samples:
image = data_samples.gt_nerf_images['images']
batch_images.append(image)
batch_images = torch.stack(batch_images)
if 'depths' in batch_data_samples[0].gt_nerf_depths:
for data_samples in batch_data_samples:
depth = data_samples.gt_nerf_depths['depths']
batch_depths.append(depth)
batch_depths = torch.stack(batch_depths)
if 'raydirs' in batch_inputs_dict.keys():
ray_batchs['ray_o'] = batch_inputs_dict['lightpos']
ray_batchs['ray_d'] = batch_inputs_dict['raydirs']
ray_batchs['gt_rgb'] = batch_images
ray_batchs['gt_depth'] = batch_depths
ray_batchs['nerf_sizes'] = batch_inputs_dict['nerf_sizes']
ray_batchs['denorm_images'] = batch_inputs_dict['denorm_images']
x, valids, features_2d, rgb_preds = self.extract_feat(
batch_inputs_dict,
batch_data_samples,
'train',
depth=None,
ray_batch=ray_batchs)
else:
x, valids, features_2d, rgb_preds = self.extract_feat(
batch_inputs_dict, batch_data_samples, 'train')
x += (valids, )
results = self.bbox_head.forward(x)
return results
def aug_test(self, batch_inputs_dict, batch_data_samples):
pass
def show_results(self, *args, **kwargs):
pass
@staticmethod
def _compute_projection(img_meta, stride, angles):
projection = []
intrinsic = torch.tensor(img_meta['lidar2img']['intrinsic'][:3, :3])
ratio = img_meta['ori_shape'][0] / (img_meta['img_shape'][0] / stride)
intrinsic[:2] /= ratio
# use predict pitch and roll for SUNRGBDTotal test
if angles is not None:
extrinsics = []
for angle in angles:
extrinsics.append(get_extrinsics(angle).to(intrinsic.device))
else:
extrinsics = map(torch.tensor, img_meta['lidar2img']['extrinsic'])
for extrinsic in extrinsics:
projection.append(intrinsic @ extrinsic[:3])
return torch.stack(projection)
@torch.no_grad()
def get_points(n_voxels, voxel_size, origin):
# origin: point-cloud center.
points = torch.stack(
torch.meshgrid([
torch.arange(n_voxels[0]), # 40 W width, x
torch.arange(n_voxels[1]), # 40 D depth, y
torch.arange(n_voxels[2]) # 16 H Height, z
]))
new_origin = origin - n_voxels / 2. * voxel_size
points = points * voxel_size.view(3, 1, 1, 1) + new_origin.view(3, 1, 1, 1)
return points
# modify from https://github.com/magicleap/Atlas/blob/master/atlas/model.py
def backproject(features, points, projection, depth, voxel_size):
n_images, n_channels, height, width = features.shape
n_x_voxels, n_y_voxels, n_z_voxels = points.shape[-3:]
points = points.view(1, 3, -1).expand(n_images, 3, -1)
points = torch.cat((points, torch.ones_like(points[:, :1])), dim=1)
points_2d_3 = torch.bmm(projection, points)
x = (points_2d_3[:, 0] / points_2d_3[:, 2]).round().long()
y = (points_2d_3[:, 1] / points_2d_3[:, 2]).round().long()
z = points_2d_3[:, 2]
valid = (x >= 0) & (y >= 0) & (x < width) & (y < height) & (z > 0)
# below is using depth to sample feature
if depth is not None:
depth = F.interpolate(
depth.unsqueeze(1), size=(height, width),
mode='bilinear').squeeze(1)
for i in range(n_images):
z_mask = z.clone() > 0
z_mask[i, valid[i]] = \
(z[i, valid[i]] > depth[i, y[i, valid[i]], x[i, valid[i]]] - voxel_size[-1]) & \
(z[i, valid[i]] < depth[i, y[i, valid[i]], x[i, valid[i]]] + voxel_size[-1]) # noqa
valid = valid & z_mask
volume = torch.zeros((n_images, n_channels, points.shape[-1]),
device=features.device)
for i in range(n_images):
volume[i, :, valid[i]] = features[i, :, y[i, valid[i]], x[i, valid[i]]]
volume = volume.view(n_images, n_channels, n_x_voxels, n_y_voxels,
n_z_voxels)
valid = valid.view(n_images, 1, n_x_voxels, n_y_voxels, n_z_voxels)
return volume, valid
# for SUNRGBDTotal test
def get_extrinsics(angles):
yaw = angles.new_zeros(())
pitch, roll = angles
r = angles.new_zeros((3, 3))
r[0, 0] = torch.cos(yaw) * torch.cos(pitch)
r[0, 1] = torch.sin(yaw) * torch.sin(roll) - torch.cos(yaw) * torch.cos(
roll) * torch.sin(pitch)
r[0, 2] = torch.cos(roll) * torch.sin(yaw) + torch.cos(yaw) * torch.sin(
pitch) * torch.sin(roll)
r[1, 0] = torch.sin(pitch)
r[1, 1] = torch.cos(pitch) * torch.cos(roll)
r[1, 2] = -torch.cos(pitch) * torch.sin(roll)
r[2, 0] = -torch.cos(pitch) * torch.sin(yaw)
r[2, 1] = torch.cos(yaw) * torch.sin(roll) + torch.cos(roll) * torch.sin(
yaw) * torch.sin(pitch)
r[2, 2] = torch.cos(yaw) * torch.cos(roll) - torch.sin(yaw) * torch.sin(
pitch) * torch.sin(roll)
# follow Total3DUnderstanding
t = angles.new_tensor([[0., 0., 1.], [0., -1., 0.], [-1., 0., 0.]])
r = t @ r.T
# follow DepthInstance3DBoxes
r = r[:, [2, 0, 1]]
r[2] *= -1
extrinsic = angles.new_zeros((4, 4))
extrinsic[:3, :3] = r
extrinsic[3, 3] = 1.
return extrinsic
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch
from mmcv.cnn import Scale
# from mmcv.ops import nms3d, nms3d_normal
from mmdet.models.utils import multi_apply
from mmdet.utils import reduce_mean
# from mmengine.config import ConfigDict
from mmengine.model import BaseModule, bias_init_with_prob, normal_init
from mmengine.structures import InstanceData
from torch import Tensor, nn
from mmdet3d.registry import MODELS, TASK_UTILS
# from mmdet3d.structures.bbox_3d.utils import rotation_3d_in_axis
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils.typing_utils import (ConfigType, InstanceList,
OptConfigType, OptInstanceList)
@torch.no_grad()
def get_points(n_voxels, voxel_size, origin):
# origin: point-cloud center.
points = torch.stack(
torch.meshgrid([
torch.arange(n_voxels[0]), # 40 W width, x
torch.arange(n_voxels[1]), # 40 D depth, y
torch.arange(n_voxels[2]) # 16 H Height, z
]))
new_origin = origin - n_voxels / 2. * voxel_size
points = points * voxel_size.view(3, 1, 1, 1) + new_origin.view(3, 1, 1, 1)
return points
@MODELS.register_module()
class NerfDetHead(BaseModule):
r"""`ImVoxelNet<https://arxiv.org/abs/2106.01178>`_ head for indoor
datasets.
Args:
n_classes (int): Number of classes.
n_levels (int): Number of feature levels.
n_channels (int): Number of channels in input tensors.
n_reg_outs (int): Number of regression layer channels.
pts_assign_threshold (int): Min number of location per box to
be assigned with.
pts_center_threshold (int): Max number of locations per box to
be assigned with.
center_loss (dict, optional): Config of centerness loss.
Default: dict(type='CrossEntropyLoss', use_sigmoid=True).
bbox_loss (dict, optional): Config of bbox loss.
Default: dict(type='RotatedIoU3DLoss').
cls_loss (dict, optional): Config of classification loss.
Default: dict(type='FocalLoss').
train_cfg (dict, optional): Config for train stage. Defaults to None.
test_cfg (dict, optional): Config for test stage. Defaults to None.
init_cfg (dict, optional): Config for weight initialization.
Defaults to None.
"""
def __init__(self,
n_classes: int,
n_levels: int,
n_channels: int,
n_reg_outs: int,
pts_assign_threshold: int,
pts_center_threshold: int,
prior_generator: ConfigType,
center_loss: ConfigType = dict(
type='mmdet.CrossEntropyLoss', use_sigmoid=True),
bbox_loss: ConfigType = dict(type='RotatedIoU3DLoss'),
cls_loss: ConfigType = dict(type='mmdet.FocalLoss'),
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: OptConfigType = None):
super(NerfDetHead, self).__init__(init_cfg)
self.n_classes = n_classes
self.n_levels = n_levels
self.n_reg_outs = n_reg_outs
self.pts_assign_threshold = pts_assign_threshold
self.pts_center_threshold = pts_center_threshold
self.prior_generator = TASK_UTILS.build(prior_generator)
self.center_loss = MODELS.build(center_loss)
self.bbox_loss = MODELS.build(bbox_loss)
self.cls_loss = MODELS.build(cls_loss)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self._init_layers(n_channels, n_reg_outs, n_classes, n_levels)
def _init_layers(self, n_channels, n_reg_outs, n_classes, n_levels):
"""Initialize neural network layers of the head."""
self.conv_center = nn.Conv3d(n_channels, 1, 3, padding=1, bias=False)
self.conv_reg = nn.Conv3d(
n_channels, n_reg_outs, 3, padding=1, bias=False)
self.conv_cls = nn.Conv3d(n_channels, n_classes, 3, padding=1)
self.scales = nn.ModuleList([Scale(1.) for _ in range(n_levels)])
def init_weights(self):
"""Initialize all layer weights."""
normal_init(self.conv_center, std=.01)
normal_init(self.conv_reg, std=.01)
normal_init(self.conv_cls, std=.01, bias=bias_init_with_prob(.01))
def _forward_single(self, x: Tensor, scale: Scale):
"""Forward pass per level.
Args:
x (Tensor): Per level 3d neck output tensor.
scale (mmcv.cnn.Scale): Per level multiplication weight.
Returns:
tuple[Tensor]: Centerness, bbox and classification predictions.
"""
return (self.conv_center(x), torch.exp(scale(self.conv_reg(x))),
self.conv_cls(x))
def forward(self, x):
return multi_apply(self._forward_single, x, self.scales)
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList,
**kwargs) -> dict:
"""Perform forward propagation and loss calculation of the detection
head on the features of the upstream network.
Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
batch_data_samples (List[:obj:`NeRFDet3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
Returns:
dict: A dictionary of loss components.
"""
valid_pred = x[-1]
outs = self(x[:-1])
batch_gt_instances_3d = []
batch_gt_instances_ignore = []
batch_input_metas = []
for data_sample in batch_data_samples:
batch_input_metas.append(data_sample.metainfo)
batch_gt_instances_3d.append(data_sample.gt_instances_3d)
batch_gt_instances_ignore.append(
data_sample.get('ignored_instances', None))
loss_inputs = outs + (valid_pred, batch_gt_instances_3d,
batch_input_metas, batch_gt_instances_ignore)
losses = self.loss_by_feat(*loss_inputs)
return losses
def loss_by_feat(self,
center_preds: List[List[Tensor]],
bbox_preds: List[List[Tensor]],
cls_preds: List[List[Tensor]],
valid_pred: Tensor,
batch_gt_instances_3d: InstanceList,
batch_input_metas: List[dict],
batch_gt_instances_ignore: OptInstanceList = None,
**kwargs) -> dict:
"""Per scene loss function.
Args:
center_preds (list[list[Tensor]]): Centerness predictions for
all scenes. The first list contains predictions from different
levels. The second list contains predictions in a mini-batch.
bbox_preds (list[list[Tensor]]): Bbox predictions for all scenes.
The first list contains predictions from different
levels. The second list contains predictions in a mini-batch.
cls_preds (list[list[Tensor]]): Classification predictions for all
scenes. The first list contains predictions from different
levels. The second list contains predictions in a mini-batch.
valid_pred (Tensor): Valid mask prediction for all scenes.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instance_3d. It usually includes ``bboxes_3d``、`
`labels_3d``、``depths``、``centers_2d`` and attributes.
batch_input_metas (list[dict]): Meta information of each image,
e.g., image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict: Centerness, bbox, and classification loss values.
"""
valid_preds = self._upsample_valid_preds(valid_pred, center_preds)
center_losses, bbox_losses, cls_losses = [], [], []
for i in range(len(batch_input_metas)):
center_loss, bbox_loss, cls_loss = self._loss_by_feat_single(
center_preds=[x[i] for x in center_preds],
bbox_preds=[x[i] for x in bbox_preds],
cls_preds=[x[i] for x in cls_preds],
valid_preds=[x[i] for x in valid_preds],
input_meta=batch_input_metas[i],
gt_bboxes=batch_gt_instances_3d[i].bboxes_3d,
gt_labels=batch_gt_instances_3d[i].labels_3d)
center_losses.append(center_loss)
bbox_losses.append(bbox_loss)
cls_losses.append(cls_loss)
return dict(
center_loss=torch.mean(torch.stack(center_losses)),
bbox_loss=torch.mean(torch.stack(bbox_losses)),
cls_loss=torch.mean(torch.stack(cls_losses)))
def _loss_by_feat_single(self, center_preds, bbox_preds, cls_preds,
valid_preds, input_meta, gt_bboxes, gt_labels):
featmap_sizes = [featmap.size()[-3:] for featmap in center_preds]
points = self._get_points(
featmap_sizes=featmap_sizes,
origin=input_meta['lidar2img']['origin'],
device=gt_bboxes.device)
center_targets, bbox_targets, cls_targets = self._get_targets(
points, gt_bboxes, gt_labels)
center_preds = torch.cat(
[x.permute(1, 2, 3, 0).reshape(-1) for x in center_preds])
bbox_preds = torch.cat([
x.permute(1, 2, 3, 0).reshape(-1, x.shape[0]) for x in bbox_preds
])
cls_preds = torch.cat(
[x.permute(1, 2, 3, 0).reshape(-1, x.shape[0]) for x in cls_preds])
valid_preds = torch.cat(
[x.permute(1, 2, 3, 0).reshape(-1) for x in valid_preds])
points = torch.cat(points)
# cls loss
pos_inds = torch.nonzero(
torch.logical_and(cls_targets >= 0, valid_preds)).squeeze(1)
n_pos = points.new_tensor(len(pos_inds))
n_pos = max(reduce_mean(n_pos), 1.)
if torch.any(valid_preds):
cls_loss = self.cls_loss(
cls_preds[valid_preds],
cls_targets[valid_preds],
avg_factor=n_pos)
else:
cls_loss = cls_preds[valid_preds].sum()
# bbox and centerness losses
pos_center_preds = center_preds[pos_inds]
pos_bbox_preds = bbox_preds[pos_inds]
if len(pos_inds) > 0:
pos_center_targets = center_targets[pos_inds]
pos_bbox_targets = bbox_targets[pos_inds]
pos_points = points[pos_inds]
center_loss = self.center_loss(
pos_center_preds, pos_center_targets, avg_factor=n_pos)
bbox_loss = self.bbox_loss(
self._bbox_pred_to_bbox(pos_points, pos_bbox_preds),
pos_bbox_targets,
weight=pos_center_targets,
avg_factor=pos_center_targets.sum())
else:
center_loss = pos_center_preds.sum()
bbox_loss = pos_bbox_preds.sum()
return center_loss, bbox_loss, cls_loss
def predict(self,
x: Tuple[Tensor],
batch_data_samples: SampleList,
rescale: bool = False) -> InstanceList:
"""Perform forward propagation of the 3D detection head and predict
detection results on the features of the upstream network.
Args:
x (tuple[Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
batch_data_samples (List[:obj:`NeRFDet3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, `gt_pts_panoptic_seg` and
`gt_pts_sem_seg`.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[:obj:`InstanceData`]: Detection results of each sample
after the post process.
Each item usually contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instances, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes,
contains a tensor with shape (num_instances, C), where
C >= 6.
"""
batch_input_metas = [
data_samples.metainfo for data_samples in batch_data_samples
]
valid_pred = x[-1]
outs = self(x[:-1])
predictions = self.predict_by_feat(
*outs,
valid_pred=valid_pred,
batch_input_metas=batch_input_metas,
rescale=rescale)
return predictions
def predict_by_feat(self, center_preds: List[List[Tensor]],
bbox_preds: List[List[Tensor]],
cls_preds: List[List[Tensor]], valid_pred: Tensor,
batch_input_metas: List[dict],
**kwargs) -> List[InstanceData]:
"""Generate boxes for all scenes.
Args:
center_preds (list[list[Tensor]]): Centerness predictions for
all scenes.
bbox_preds (list[list[Tensor]]): Bbox predictions for all scenes.
cls_preds (list[list[Tensor]]): Classification predictions for all
scenes.
valid_pred (Tensor): Valid mask prediction for all scenes.
batch_input_metas (list[dict]): Meta infos for all scenes.
Returns:
list[tuple[Tensor]]: Predicted bboxes, scores, and labels for
all scenes.
"""
valid_preds = self._upsample_valid_preds(valid_pred, center_preds)
results = []
for i in range(len(batch_input_metas)):
results.append(
self._predict_by_feat_single(
center_preds=[x[i] for x in center_preds],
bbox_preds=[x[i] for x in bbox_preds],
cls_preds=[x[i] for x in cls_preds],
valid_preds=[x[i] for x in valid_preds],
input_meta=batch_input_metas[i]))
return results
def _predict_by_feat_single(self, center_preds: List[Tensor],
bbox_preds: List[Tensor],
cls_preds: List[Tensor],
valid_preds: List[Tensor],
input_meta: dict) -> InstanceData:
"""Generate boxes for single sample.
Args:
center_preds (list[Tensor]): Centerness predictions for all levels.
bbox_preds (list[Tensor]): Bbox predictions for all levels.
cls_preds (list[Tensor]): Classification predictions for all
levels.
valid_preds (tuple[Tensor]): Upsampled valid masks for all feature
levels.
input_meta (dict): Scene meta info.
Returns:
tuple[Tensor]: Predicted bounding boxes, scores and labels.
"""
featmap_sizes = [featmap.size()[-3:] for featmap in center_preds]
points = self._get_points(
featmap_sizes=featmap_sizes,
origin=input_meta['lidar2img']['origin'],
device=center_preds[0].device)
mlvl_bboxes, mlvl_scores = [], []
for center_pred, bbox_pred, cls_pred, valid_pred, point in zip(
center_preds, bbox_preds, cls_preds, valid_preds, points):
center_pred = center_pred.permute(1, 2, 3, 0).reshape(-1, 1)
bbox_pred = bbox_pred.permute(1, 2, 3,
0).reshape(-1, bbox_pred.shape[0])
cls_pred = cls_pred.permute(1, 2, 3,
0).reshape(-1, cls_pred.shape[0])
valid_pred = valid_pred.permute(1, 2, 3, 0).reshape(-1, 1)
scores = cls_pred.sigmoid() * center_pred.sigmoid() * valid_pred
max_scores, _ = scores.max(dim=1)
if len(scores) > self.test_cfg.nms_pre > 0:
_, ids = max_scores.topk(self.test_cfg.nms_pre)
bbox_pred = bbox_pred[ids]
scores = scores[ids]
point = point[ids]
bboxes = self._bbox_pred_to_bbox(point, bbox_pred)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
bboxes = torch.cat(mlvl_bboxes)
scores = torch.cat(mlvl_scores)
bboxes, scores, labels = self._nms(bboxes, scores, input_meta)
bboxes = input_meta['box_type_3d'](
bboxes, box_dim=6, with_yaw=False, origin=(.5, .5, .5))
results = InstanceData()
results.bboxes_3d = bboxes
results.scores_3d = scores
results.labels_3d = labels
return results
@staticmethod
def _upsample_valid_preds(valid_pred, features):
"""Upsample valid mask predictions.
Args:
valid_pred (Tensor): Valid mask prediction.
features (Tensor): Feature tensor.
Returns:
tuple[Tensor]: Upsampled valid masks for all feature levels.
"""
return [
nn.Upsample(size=x.shape[-3:],
mode='trilinear')(valid_pred).round().bool()
for x in features
]
@torch.no_grad()
def _get_points(self, featmap_sizes, origin, device):
mlvl_points = []
tmp_voxel_size = [.16, .16, .2]
for i, featmap_size in enumerate(featmap_sizes):
mlvl_points.append(
get_points(
n_voxels=torch.tensor(featmap_size),
voxel_size=torch.tensor(tmp_voxel_size) * (2**i),
origin=torch.tensor(origin)).reshape(3, -1).transpose(
0, 1).to(device))
return mlvl_points
def _bbox_pred_to_bbox(self, points, bbox_pred):
return torch.stack([
points[:, 0] - bbox_pred[:, 0], points[:, 1] - bbox_pred[:, 2],
points[:, 2] - bbox_pred[:, 4], points[:, 0] + bbox_pred[:, 1],
points[:, 1] + bbox_pred[:, 3], points[:, 2] + bbox_pred[:, 5]
], -1)
def _bbox_pred_to_loss(self, points, bbox_preds):
return self._bbox_pred_to_bbox(points, bbox_preds)
# The function is directly copied from FCAF3DHead.
@staticmethod
def _get_face_distances(points, boxes):
"""Calculate distances from point to box faces.
Args:
points (Tensor): Final locations of shape (N_points, N_boxes, 3).
boxes (Tensor): 3D boxes of shape (N_points, N_boxes, 7)
Returns:
Tensor: Face distances of shape (N_points, N_boxes, 6),
(dx_min, dx_max, dy_min, dy_max, dz_min, dz_max).
"""
dx_min = points[..., 0] - boxes[..., 0] + boxes[..., 3] / 2
dx_max = boxes[..., 0] + boxes[..., 3] / 2 - points[..., 0]
dy_min = points[..., 1] - boxes[..., 1] + boxes[..., 4] / 2
dy_max = boxes[..., 1] + boxes[..., 4] / 2 - points[..., 1]
dz_min = points[..., 2] - boxes[..., 2] + boxes[..., 5] / 2
dz_max = boxes[..., 2] + boxes[..., 5] / 2 - points[..., 2]
return torch.stack((dx_min, dx_max, dy_min, dy_max, dz_min, dz_max),
dim=-1)
@staticmethod
def _get_centerness(face_distances):
"""Compute point centerness w.r.t containing box.
Args:
face_distances (Tensor): Face distances of shape (B, N, 6),
(dx_min, dx_max, dy_min, dy_max, dz_min, dz_max).
Returns:
Tensor: Centerness of shape (B, N).
"""
x_dims = face_distances[..., [0, 1]]
y_dims = face_distances[..., [2, 3]]
z_dims = face_distances[..., [4, 5]]
centerness_targets = x_dims.min(dim=-1)[0] / x_dims.max(dim=-1)[0] * \
y_dims.min(dim=-1)[0] / y_dims.max(dim=-1)[0] * \
z_dims.min(dim=-1)[0] / z_dims.max(dim=-1)[0]
return torch.sqrt(centerness_targets)
@torch.no_grad()
def _get_targets(self, points, gt_bboxes, gt_labels):
"""Compute targets for final locations for a single scene.
Args:
points (list[Tensor]): Final locations for all levels.
gt_bboxes (BaseInstance3DBoxes): Ground truth boxes.
gt_labels (Tensor): Ground truth labels.
Returns:
tuple[Tensor]: Centerness, bbox and classification
targets for all locations.
"""
float_max = 1e8
expanded_scales = [
points[i].new_tensor(i).expand(len(points[i])).to(gt_labels.device)
for i in range(len(points))
]
points = torch.cat(points, dim=0).to(gt_labels.device)
scales = torch.cat(expanded_scales, dim=0)
# below is based on FCOSHead._get_target_single
n_points = len(points)
n_boxes = len(gt_bboxes)
volumes = gt_bboxes.volume.to(points.device)
volumes = volumes.expand(n_points, n_boxes).contiguous()
gt_bboxes = torch.cat(
(gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:6]), dim=1)
gt_bboxes = gt_bboxes.to(points.device).expand(n_points, n_boxes, 6)
expanded_points = points.unsqueeze(1).expand(n_points, n_boxes, 3)
bbox_targets = self._get_face_distances(expanded_points, gt_bboxes)
# condition1: inside a gt bbox
inside_gt_bbox_mask = bbox_targets[..., :6].min(
-1)[0] > 0 # skip angle
# condition2: positive points per scale >= limit
# calculate positive points per scale
n_pos_points_per_scale = []
for i in range(self.n_levels):
n_pos_points_per_scale.append(
torch.sum(inside_gt_bbox_mask[scales == i], dim=0))
# find best scale
n_pos_points_per_scale = torch.stack(n_pos_points_per_scale, dim=0)
lower_limit_mask = n_pos_points_per_scale < self.pts_assign_threshold
# fix nondeterministic argmax for torch<1.7
extra = torch.arange(self.n_levels, 0, -1).unsqueeze(1).expand(
self.n_levels, n_boxes).to(lower_limit_mask.device)
lower_index = torch.argmax(lower_limit_mask.int() * extra, dim=0) - 1
lower_index = torch.where(lower_index < 0,
torch.zeros_like(lower_index), lower_index)
all_upper_limit_mask = torch.all(
torch.logical_not(lower_limit_mask), dim=0)
best_scale = torch.where(
all_upper_limit_mask,
torch.ones_like(all_upper_limit_mask) * self.n_levels - 1,
lower_index)
# keep only points with best scale
best_scale = torch.unsqueeze(best_scale, 0).expand(n_points, n_boxes)
scales = torch.unsqueeze(scales, 1).expand(n_points, n_boxes)
inside_best_scale_mask = best_scale == scales
# condition3: limit topk locations per box by centerness
centerness = self._get_centerness(bbox_targets)
centerness = torch.where(inside_gt_bbox_mask, centerness,
torch.ones_like(centerness) * -1)
centerness = torch.where(inside_best_scale_mask, centerness,
torch.ones_like(centerness) * -1)
top_centerness = torch.topk(
centerness, self.pts_center_threshold + 1, dim=0).values[-1]
inside_top_centerness_mask = centerness > top_centerness.unsqueeze(0)
# if there are still more than one objects for a location,
# we choose the one with minimal area
volumes = torch.where(inside_gt_bbox_mask, volumes,
torch.ones_like(volumes) * float_max)
volumes = torch.where(inside_best_scale_mask, volumes,
torch.ones_like(volumes) * float_max)
volumes = torch.where(inside_top_centerness_mask, volumes,
torch.ones_like(volumes) * float_max)
min_area, min_area_inds = volumes.min(dim=1)
labels = gt_labels[min_area_inds]
labels = torch.where(min_area == float_max,
torch.ones_like(labels) * -1, labels)
bbox_targets = bbox_targets[range(n_points), min_area_inds]
centerness_targets = self._get_centerness(bbox_targets)
return centerness_targets, self._bbox_pred_to_bbox(
points, bbox_targets), labels
def _nms(self, bboxes, scores, img_meta):
scores, labels = scores.max(dim=1)
ids = scores > self.test_cfg.score_thr
bboxes = bboxes[ids]
scores = scores[ids]
labels = labels[ids]
ids = self.aligned_3d_nms(bboxes, scores, labels,
self.test_cfg.iou_thr)
bboxes = bboxes[ids]
bboxes = torch.stack(
((bboxes[:, 0] + bboxes[:, 3]) / 2.,
(bboxes[:, 1] + bboxes[:, 4]) / 2.,
(bboxes[:, 2] + bboxes[:, 5]) / 2., bboxes[:, 3] - bboxes[:, 0],
bboxes[:, 4] - bboxes[:, 1], bboxes[:, 5] - bboxes[:, 2]),
dim=1)
return bboxes, scores[ids], labels[ids]
@staticmethod
def aligned_3d_nms(boxes, scores, classes, thresh):
"""3d nms for aligned boxes.
Args:
boxes (torch.Tensor): Aligned box with shape [n, 6].
scores (torch.Tensor): Scores of each box.
classes (torch.Tensor): Class of each box.
thresh (float): Iou threshold for nms.
Returns:
torch.Tensor: Indices of selected boxes.
"""
x1 = boxes[:, 0]
y1 = boxes[:, 1]
z1 = boxes[:, 2]
x2 = boxes[:, 3]
y2 = boxes[:, 4]
z2 = boxes[:, 5]
area = (x2 - x1) * (y2 - y1) * (z2 - z1)
zero = boxes.new_zeros(1, )
score_sorted = torch.argsort(scores)
pick = []
while (score_sorted.shape[0] != 0):
last = score_sorted.shape[0]
i = score_sorted[-1]
pick.append(i)
xx1 = torch.max(x1[i], x1[score_sorted[:last - 1]])
yy1 = torch.max(y1[i], y1[score_sorted[:last - 1]])
zz1 = torch.max(z1[i], z1[score_sorted[:last - 1]])
xx2 = torch.min(x2[i], x2[score_sorted[:last - 1]])
yy2 = torch.min(y2[i], y2[score_sorted[:last - 1]])
zz2 = torch.min(z2[i], z2[score_sorted[:last - 1]])
classes1 = classes[i]
classes2 = classes[score_sorted[:last - 1]]
inter_l = torch.max(zero, xx2 - xx1)
inter_w = torch.max(zero, yy2 - yy1)
inter_h = torch.max(zero, zz2 - zz1)
inter = inter_l * inter_w * inter_h
iou = inter / (area[i] + area[score_sorted[:last - 1]] - inter)
iou = iou * (classes1 == classes2).float()
score_sorted = score_sorted[torch.nonzero(
iou <= thresh, as_tuple=False).flatten()]
indices = boxes.new_tensor(pick, dtype=torch.long)
return indices
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from os import path as osp
from typing import Callable, List, Optional, Union
import numpy as np
from mmdet3d.datasets import Det3DDataset
from mmdet3d.registry import DATASETS
from mmdet3d.structures import DepthInstance3DBoxes
@DATASETS.register_module()
class MultiViewScanNetDataset(Det3DDataset):
r"""Multi-View ScanNet Dataset for NeRF-detection Task
This class serves as the API for experiments on the ScanNet Dataset.
Please refer to the `github repo <https://github.com/ScanNet/ScanNet>`_
for data downloading.
Args:
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
pipeline (List[dict]): Pipeline used for data processing.
Defaults to [].
modality (dict): Modality to specify the sensor data used as input.
Defaults to dict(use_camera=True, use_lidar=False).
box_type_3d (str): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`.
Defaults to 'Depth' in this dataset. Available options includes:
- 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Camera': Box in camera coordinates.
filter_empty_gt (bool): Whether to filter the data with empty GT.
If it's set to be True, the example with empty annotations after
data pipeline will be dropped and a random example will be chosen
in `__getitem__`. Defaults to True.
test_mode (bool): Whether the dataset is in test mode.
Defaults to False.
"""
METAINFO = {
'classes':
('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator',
'showercurtrain', 'toilet', 'sink', 'bathtub', 'garbagebin')
}
def __init__(self,
data_root: str,
ann_file: str,
metainfo: Optional[dict] = None,
pipeline: List[Union[dict, Callable]] = [],
modality: dict = dict(use_camera=True, use_lidar=False),
box_type_3d: str = 'Depth',
filter_empty_gt: bool = True,
remove_dontcare: bool = False,
test_mode: bool = False,
**kwargs) -> None:
self.remove_dontcare = remove_dontcare
super().__init__(
data_root=data_root,
ann_file=ann_file,
metainfo=metainfo,
pipeline=pipeline,
modality=modality,
box_type_3d=box_type_3d,
filter_empty_gt=filter_empty_gt,
test_mode=test_mode,
**kwargs)
assert 'use_camera' in self.modality and \
'use_lidar' in self.modality
assert self.modality['use_camera'] or self.modality['use_lidar']
@staticmethod
def _get_axis_align_matrix(info: dict) -> np.ndarray:
"""Get axis_align_matrix from info. If not exist, return identity mat.
Args:
info (dict): Info of a single sample data.
Returns:
np.ndarray: 4x4 transformation matrix.
"""
if 'axis_align_matrix' in info:
return np.array(info['axis_align_matrix'])
else:
warnings.warn(
'axis_align_matrix is not found in ScanNet data info, please '
'use new pre-process scripts to re-generate ScanNet data')
return np.eye(4).astype(np.float32)
def parse_data_info(self, info: dict) -> dict:
"""Process the raw data info.
Convert all relative path of needed modality data file to
the absolute path.
Args:
info (dict): Raw info dict.
Returns:
dict: Has `ann_info` in training stage. And
all path has been converted to absolute path.
"""
if self.modality['use_depth']:
info['depth_info'] = []
if self.modality['use_neuralrecon_depth']:
info['depth_info'] = []
if self.modality['use_lidar']:
# implement lidar processing in the future
raise NotImplementedError(
'Please modified '
'`MultiViewPipeline` to support lidar processing')
info['axis_align_matrix'] = self._get_axis_align_matrix(info)
info['img_info'] = []
info['lidar2img'] = []
info['c2w'] = []
info['camrotc2w'] = []
info['lightpos'] = []
# load img and depth_img
for i in range(len(info['img_paths'])):
img_filename = osp.join(self.data_root, info['img_paths'][i])
info['img_info'].append(dict(filename=img_filename))
if 'depth_info' in info.keys():
if self.modality['use_neuralrecon_depth']:
info['depth_info'].append(
dict(filename=img_filename[:-4] + '.npy'))
else:
info['depth_info'].append(
dict(filename=img_filename[:-4] + '.png'))
# implement lidar_info in input.keys() in the future.
extrinsic = np.linalg.inv(
info['axis_align_matrix'] @ info['lidar2cam'][i])
info['lidar2img'].append(extrinsic.astype(np.float32))
if self.modality['use_ray']:
c2w = (
info['axis_align_matrix'] @ info['lidar2cam'][i]).astype(
np.float32) # noqa
info['c2w'].append(c2w)
info['camrotc2w'].append(c2w[0:3, 0:3])
info['lightpos'].append(c2w[0:3, 3])
origin = np.array([.0, .0, .5])
info['lidar2img'] = dict(
extrinsic=info['lidar2img'],
intrinsic=info['cam2img'].astype(np.float32),
origin=origin.astype(np.float32))
if self.modality['use_ray']:
info['ray_info'] = []
if not self.test_mode:
info['ann_info'] = self.parse_ann_info(info)
if self.test_mode and self.load_eval_anns:
info['ann_info'] = self.parse_ann_info(info)
info['eval_ann_info'] = self._remove_dontcare(info['ann_info'])
return info
def parse_ann_info(self, info: dict) -> dict:
"""Process the `instances` in data info to `ann_info`.
Args:
info (dict): Info dict.
Returns:
dict: Processed `ann_info`.
"""
ann_info = super().parse_ann_info(info)
if self.remove_dontcare:
ann_info = self._remove_dontcare(ann_info)
# empty gt
if ann_info is None:
ann_info = dict()
ann_info['gt_bboxes_3d'] = np.zeros((0, 6), dtype=np.float32)
ann_info['gt_labels_3d'] = np.zeros((0, ), dtype=np.int64)
ann_info['gt_bboxes_3d'] = DepthInstance3DBoxes(
ann_info['gt_bboxes_3d'],
box_dim=ann_info['gt_bboxes_3d'].shape[-1],
with_yaw=False,
origin=(0.5, 0.5, 0.5)).convert_to(self.box_mode_3d)
# count the numbers
for label in ann_info['gt_labels_3d']:
if label != -1:
cat_name = self.metainfo['classes'][label]
self.num_ins_per_cat[cat_name] += 1
return ann_info
# Copyright (c) OpenMMLab. All rights reserved.
"""Prepare the dataset for NeRF-Det.
Example:
python projects/NeRF-Det/prepare_infos.py
--root-path ./data/scannet
--out-dir ./data/scannet
"""
import argparse
import time
from os import path as osp
from pathlib import Path
import mmengine
from ...tools.dataset_converters import indoor_converter as indoor
from ...tools.dataset_converters.update_infos_to_v2 import (
clear_data_info_unused_keys, clear_instance_unused_keys,
get_empty_instance, get_empty_standard_data_info)
def update_scannet_infos_nerfdet(pkl_path, out_dir):
"""Update the origin pkl to the new format which will be used in nerf-det.
Args:
pkl_path (str): Path of the origin pkl.
out_dir (str): Output directory of the generated info file.
Returns:
The pkl will be overwritTen.
The new pkl is a dict containing two keys:
metainfo: Some base information of the pkl
data_list (list): A list containing all the information of the scenes.
"""
print('The new refactored process is running.')
print(f'{pkl_path} will be modified.')
if out_dir in pkl_path:
print(f'Warning, you may overwriting '
f'the original data {pkl_path}.')
time.sleep(5)
METAINFO = {
'classes':
('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator',
'showercurtrain', 'toilet', 'sink', 'bathtub', 'garbagebin')
}
print(f'Reading from input file: {pkl_path}.')
data_list = mmengine.load(pkl_path)
print('Start updating:')
converted_list = []
for ori_info_dict in mmengine.track_iter_progress(data_list):
temp_data_info = get_empty_standard_data_info()
# intrinsics, extrinsics and imgs
temp_data_info['cam2img'] = ori_info_dict['intrinsics']
temp_data_info['lidar2cam'] = ori_info_dict['extrinsics']
temp_data_info['img_paths'] = ori_info_dict['img_paths']
# annotation information
anns = ori_info_dict.get('annos', None)
ignore_class_name = set()
if anns is not None:
temp_data_info['axis_align_matrix'] = anns[
'axis_align_matrix'].tolist()
if anns['gt_num'] == 0:
instance_list = []
else:
num_instances = len(anns['name'])
instance_list = []
for instance_id in range(num_instances):
empty_instance = get_empty_instance()
empty_instance['bbox_3d'] = anns['gt_boxes_upright_depth'][
instance_id].tolist()
if anns['name'][instance_id] in METAINFO['classes']:
empty_instance['bbox_label_3d'] = METAINFO[
'classes'].index(anns['name'][instance_id])
else:
ignore_class_name.add(anns['name'][instance_id])
empty_instance['bbox_label_3d'] = -1
empty_instance = clear_instance_unused_keys(empty_instance)
instance_list.append(empty_instance)
temp_data_info['instances'] = instance_list
temp_data_info, _ = clear_data_info_unused_keys(temp_data_info)
converted_list.append(temp_data_info)
pkl_name = Path(pkl_path).name
out_path = osp.join(out_dir, pkl_name)
print(f'Writing to output file: {out_path}.')
print(f'ignore classes: {ignore_class_name}')
# dataset metainfo
metainfo = dict()
metainfo['categories'] = {k: i for i, k in enumerate(METAINFO['classes'])}
if ignore_class_name:
for ignore_class in ignore_class_name:
metainfo['categories'][ignore_class] = -1
metainfo['dataset'] = 'scannet'
metainfo['info_version'] = '1.1'
converted_data_info = dict(metainfo=metainfo, data_list=converted_list)
mmengine.dump(converted_data_info, out_path, 'pkl')
def scannet_data_prep(root_path, info_prefix, out_dir, workers):
"""Prepare the info file for scannet dataset.
Args:
root_path (str): Path of dataset root.
info_prefix (str): The prefix of info filenames.
out_dir (str): Output directory of the generated info file.
workers (int): Number of threads to be used.
version (str): Only used to generate the dataset of nerfdet now.
"""
indoor.create_indoor_info_file(
root_path, info_prefix, out_dir, workers=workers)
info_train_path = osp.join(out_dir, f'{info_prefix}_infos_train.pkl')
info_val_path = osp.join(out_dir, f'{info_prefix}_infos_val.pkl')
info_test_path = osp.join(out_dir, f'{info_prefix}_infos_test.pkl')
update_scannet_infos_nerfdet(out_dir=out_dir, pkl_path=info_train_path)
update_scannet_infos_nerfdet(out_dir=out_dir, pkl_path=info_val_path)
update_scannet_infos_nerfdet(out_dir=out_dir, pkl_path=info_test_path)
parser = argparse.ArgumentParser(description='Data converter arg parser')
parser.add_argument(
'--root-path',
type=str,
default='./data/scannet',
help='specify the root path of dataset')
parser.add_argument(
'--out-dir',
type=str,
default='./data/scannet',
required=False,
help='name of info pkl')
parser.add_argument('--extra-tag', type=str, default='scannet')
parser.add_argument(
'--workers', type=int, default=4, help='number of threads to be used')
args = parser.parse_args()
if __name__ == '__main__':
from mmdet3d.utils import register_all_modules
register_all_modules()
scannet_data_prep(
root_path=args.root_path,
info_prefix=args.extra_tag,
out_dir=args.out_dir,
workers=args.workers)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment