Unverified Commit fe25f7a5 authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

Merge pull request #2867 from open-mmlab/dev-1.x

Bump version to 1.4.0
parents 5c0613be 0ef13b83
Pipeline #2710 failed with stages
in 0 seconds
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
from mmcv.transforms import BaseTransform, Compose
from PIL import Image
from mmdet3d.registry import TRANSFORMS
def get_dtu_raydir(pixelcoords, intrinsic, rot, dir_norm=None):
# rot is c2w
# pixelcoords: H x W x 2
x = (pixelcoords[..., 0] + 0.5 - intrinsic[0, 2]) / intrinsic[0, 0]
y = (pixelcoords[..., 1] + 0.5 - intrinsic[1, 2]) / intrinsic[1, 1]
z = np.ones_like(x)
dirs = np.stack([x, y, z], axis=-1)
# dirs = np.sum(dirs[...,None,:] * rot[:,:], axis=-1) # h*w*1*3 x 3*3
dirs = dirs @ rot[:, :].T #
if dir_norm:
dirs = dirs / (np.linalg.norm(dirs, axis=-1, keepdims=True) + 1e-5)
return dirs
@TRANSFORMS.register_module()
class MultiViewPipeline(BaseTransform):
"""MultiViewPipeline used in nerfdet.
Required Keys:
- depth_info
- img_prefix
- img_info
- lidar2img
- c2w
- cammrotc2w
- lightpos
- ray_info
Modified Keys:
- lidar2img
Added Keys:
- img
- denorm_images
- depth
- c2w
- camrotc2w
- lightpos
- pixels
- raydirs
- gt_images
- gt_depths
- nerf_sizes
- depth_range
Args:
transforms (list[dict]): The transform pipeline
used to process the imgs.
n_images (int): The number of sampled views.
mean (array): The mean values used in normalization.
std (array): The variance values used in normalization.
margin (int): The margin value. Defaults to 10.
depth_range (array): The range of the depth.
Defaults to [0.5, 5.5].
loading (str): The mode of loading. Defaults to 'random'.
nerf_target_views (int): The number of novel views.
sample_freq (int): The frequency of sampling.
"""
def __init__(self,
transforms: dict,
n_images: int,
mean: tuple = [123.675, 116.28, 103.53],
std: tuple = [58.395, 57.12, 57.375],
margin: int = 10,
depth_range: tuple = [0.5, 5.5],
loading: str = 'random',
nerf_target_views: int = 0,
sample_freq: int = 3):
self.transforms = Compose(transforms)
self.depth_transforms = Compose(transforms[1])
self.n_images = n_images
self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32)
self.margin = margin
self.depth_range = depth_range
self.loading = loading
self.sample_freq = sample_freq
self.nerf_target_views = nerf_target_views
def transform(self, results: dict) -> dict:
"""Nerfdet transform function.
Args:
results (dict): Result dict from loading pipeline
Returns:
dict: The result dict containing the processed results.
Updated key and value are described below.
- img (list): The loaded origin image.
- denorm_images (list): The denormalized image.
- depth (list): The origin depth image.
- c2w (list): The c2w matrixes.
- camrotc2w (list): The rotation matrixes.
- lightpos (list): The transform parameters of the camera.
- pixels (list): Some pixel information.
- raydirs (list): The ray-directions.
- gt_images (list): The groundtruth images.
- gt_depths (list): The groundtruth depth images.
- nerf_sizes (array): The size of the groundtruth images.
- depth_range (array): The range of the depth.
Here we give a detailed explanation of some keys mentioned above.
Let P_c be the coordinate of camera, P_w be the coordinate of world.
There is such a conversion relationship: P_c = R @ P_w + T.
The 'camrotc2w' mentioned above corresponds to the R matrix here.
The 'lightpos' corresponds to the T matrix here. And if you put
R and T together, you can get the camera extrinsics matrix. It
corresponds to the 'c2w' mentioned above.
"""
imgs = []
depths = []
extrinsics = []
c2ws = []
camrotc2ws = []
lightposes = []
pixels = []
raydirs = []
gt_images = []
gt_depths = []
denorm_imgs_list = []
nerf_sizes = []
if self.loading == 'random':
ids = np.arange(len(results['img_info']))
replace = True if self.n_images > len(ids) else False
ids = np.random.choice(ids, self.n_images, replace=replace)
if self.nerf_target_views != 0:
target_id = np.random.choice(
ids, self.nerf_target_views, replace=False)
ids = np.setdiff1d(ids, target_id)
ids = ids.tolist()
target_id = target_id.tolist()
else:
ids = np.arange(len(results['img_info']))
begin_id = 0
ids = np.arange(begin_id,
begin_id + self.n_images * self.sample_freq,
self.sample_freq)
if self.nerf_target_views != 0:
target_id = ids
ratio = 0
size = (240, 320)
for i in ids:
_results = dict()
_results['img_path'] = results['img_info'][i]['filename']
_results = self.transforms(_results)
imgs.append(_results['img'])
# normalize
for key in _results.get('img_fields', ['img']):
_results[key] = mmcv.imnormalize(_results[key], self.mean,
self.std, True)
_results['img_norm_cfg'] = dict(
mean=self.mean, std=self.std, to_rgb=True)
# pad
for key in _results.get('img_fields', ['img']):
padded_img = mmcv.impad(_results[key], shape=size, pad_val=0)
_results[key] = padded_img
_results['pad_shape'] = padded_img.shape
_results['pad_fixed_size'] = size
ori_shape = _results['ori_shape']
aft_shape = _results['img_shape']
ratio = ori_shape[0] / aft_shape[0]
# prepare the depth information
if 'depth_info' in results.keys():
if '.npy' in results['depth_info'][i]['filename']:
_results['depth'] = np.load(
results['depth_info'][i]['filename'])
else:
_results['depth'] = np.asarray((Image.open(
results['depth_info'][i]['filename']))) / 1000
_results['depth'] = mmcv.imresize(
_results['depth'], (aft_shape[1], aft_shape[0]))
depths.append(_results['depth'])
denorm_img = mmcv.imdenormalize(
_results['img'], self.mean, self.std, to_bgr=True).astype(
np.uint8) / 255.0
denorm_imgs_list.append(denorm_img)
height, width = padded_img.shape[:2]
extrinsics.append(results['lidar2img']['extrinsic'][i])
# prepare the nerf information
if 'ray_info' in results.keys():
intrinsics_nerf = results['lidar2img']['intrinsic'].copy()
intrinsics_nerf[:2] = intrinsics_nerf[:2] / ratio
assert self.nerf_target_views > 0
for i in target_id:
c2ws.append(results['c2w'][i])
camrotc2ws.append(results['camrotc2w'][i])
lightposes.append(results['lightpos'][i])
px, py = np.meshgrid(
np.arange(self.margin,
width - self.margin).astype(np.float32),
np.arange(self.margin,
height - self.margin).astype(np.float32))
pixelcoords = np.stack((px, py),
axis=-1).astype(np.float32) # H x W x 2
pixels.append(pixelcoords)
raydir = get_dtu_raydir(pixelcoords, intrinsics_nerf,
results['camrotc2w'][i])
raydirs.append(np.reshape(raydir.astype(np.float32), (-1, 3)))
# read target images
temp_results = dict()
temp_results['img_path'] = results['img_info'][i]['filename']
temp_results_ = self.transforms(temp_results)
# normalize
for key in temp_results.get('img_fields', ['img']):
temp_results[key] = mmcv.imnormalize(
temp_results[key], self.mean, self.std, True)
temp_results['img_norm_cfg'] = dict(
mean=self.mean, std=self.std, to_rgb=True)
# pad
for key in temp_results.get('img_fields', ['img']):
padded_img = mmcv.impad(
temp_results[key], shape=size, pad_val=0)
temp_results[key] = padded_img
temp_results['pad_shape'] = padded_img.shape
temp_results['pad_fixed_size'] = size
# denormalize target_images.
denorm_imgs = mmcv.imdenormalize(
temp_results_['img'], self.mean, self.std,
to_bgr=True).astype(np.uint8)
gt_rgb_shape = denorm_imgs.shape
gt_image = denorm_imgs[py.astype(np.int32),
px.astype(np.int32), :]
nerf_sizes.append(np.array(gt_image.shape))
gt_image = np.reshape(gt_image, (-1, 3))
gt_images.append(gt_image / 255.0)
if 'depth_info' in results.keys():
if '.npy' in results['depth_info'][i]['filename']:
_results['depth'] = np.load(
results['depth_info'][i]['filename'])
else:
depth_image = Image.open(
results['depth_info'][i]['filename'])
_results['depth'] = np.asarray(depth_image) / 1000
_results['depth'] = mmcv.imresize(
_results['depth'],
(gt_rgb_shape[1], gt_rgb_shape[0]))
_results['depth'] = _results['depth']
gt_depth = _results['depth'][py.astype(np.int32),
px.astype(np.int32)]
gt_depths.append(gt_depth)
for key in _results.keys():
if key not in ['img', 'img_info']:
results[key] = _results[key]
results['img'] = imgs
if 'ray_info' in results.keys():
results['c2w'] = c2ws
results['camrotc2w'] = camrotc2ws
results['lightpos'] = lightposes
results['pixels'] = pixels
results['raydirs'] = raydirs
results['gt_images'] = gt_images
results['gt_depths'] = gt_depths
results['nerf_sizes'] = nerf_sizes
results['denorm_images'] = denorm_imgs_list
results['depth_range'] = np.array([self.depth_range])
if len(depths) != 0:
results['depth'] = depths
results['lidar2img']['extrinsic'] = extrinsics
return results
@TRANSFORMS.register_module()
class RandomShiftOrigin(BaseTransform):
def __init__(self, std):
self.std = std
def transform(self, results):
shift = np.random.normal(.0, self.std, 3)
results['lidar2img']['origin'] += shift
return results
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union
import torch
from mmengine.structures import InstanceData
from mmdet3d.structures import Det3DDataSample
class NeRFDet3DDataSample(Det3DDataSample):
"""A data structure interface inheirted from Det3DDataSample. Some new
attributes are added to match the NeRF-Det project.
The attributes added in ``NeRFDet3DDataSample`` are divided into two parts:
- ``gt_nerf_images`` (InstanceData): Ground truth of the images which
will be used in the NeRF branch.
- ``gt_nerf_depths`` (InstanceData): Ground truth of the depth images
which will be used in the NeRF branch if needed.
For more details and examples, please refer to the 'Det3DDataSample' file.
"""
@property
def gt_nerf_images(self) -> InstanceData:
return self._gt_nerf_images
@gt_nerf_images.setter
def gt_nerf_images(self, value: InstanceData) -> None:
self.set_field(value, '_gt_nerf_images', dtype=InstanceData)
@gt_nerf_images.deleter
def gt_nerf_images(self) -> None:
del self._gt_nerf_images
@property
def gt_nerf_depths(self) -> InstanceData:
return self._gt_nerf_depths
@gt_nerf_depths.setter
def gt_nerf_depths(self, value: InstanceData) -> None:
self.set_field(value, '_gt_nerf_depths', dtype=InstanceData)
@gt_nerf_depths.deleter
def gt_nerf_depths(self) -> None:
del self._gt_nerf_depths
SampleList = List[NeRFDet3DDataSample]
OptSampleList = Optional[SampleList]
ForwardResults = Union[Dict[str, torch.Tensor], List[NeRFDet3DDataSample],
Tuple[torch.Tensor], torch.Tensor]
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Callable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
"""The MLP module used in NerfDet.
Args:
input_dim (int): The number of input tensor channels.
output_dim (int): The number of output tensor channels.
net_depth (int): The depth of the MLP. Defaults to 8.
net_width (int): The width of the MLP. Defaults to 256.
skip_layer (int): The layer to add skip layers to. Defaults to 4.
hidden_init (Callable): The initialize method of the hidden layers.
hidden_activation (Callable): The activation function of hidden
layers, defaults to ReLU.
output_enabled (bool): If true, the output layers will be used.
Defaults to True.
output_init (Optional): The initialize method of the output layer.
output_activation(Optional): The activation function of output layers.
bias_enabled (Bool): If true, the bias will be used.
bias_init (Callable): The initialize method of the bias.
Defaults to True.
"""
def __init__(
self,
input_dim: int,
output_dim: int = None,
net_depth: int = 8,
net_width: int = 256,
skip_layer: int = 4,
hidden_init: Callable = nn.init.xavier_uniform_,
hidden_activation: Callable = nn.ReLU(),
output_enabled: bool = True,
output_init: Optional[Callable] = nn.init.xavier_uniform_,
output_activation: Optional[Callable] = nn.Identity(),
bias_enabled: bool = True,
bias_init: Callable = nn.init.zeros_,
):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.net_depth = net_depth
self.net_width = net_width
self.skip_layer = skip_layer
self.hidden_init = hidden_init
self.hidden_activation = hidden_activation
self.output_enabled = output_enabled
self.output_init = output_init
self.output_activation = output_activation
self.bias_enabled = bias_enabled
self.bias_init = bias_init
self.hidden_layers = nn.ModuleList()
in_features = self.input_dim
for i in range(self.net_depth):
self.hidden_layers.append(
nn.Linear(in_features, self.net_width, bias=bias_enabled))
if (self.skip_layer is not None) and (i % self.skip_layer
== 0) and (i > 0):
in_features = self.net_width + self.input_dim
else:
in_features = self.net_width
if self.output_enabled:
self.output_layer = nn.Linear(
in_features, self.output_dim, bias=bias_enabled)
else:
self.output_dim = in_features
self.initialize()
def initialize(self):
def init_func_hidden(m):
if isinstance(m, nn.Linear):
if self.hidden_init is not None:
self.hidden_init(m.weight)
if self.bias_enabled and self.bias_init is not None:
self.bias_init(m.bias)
self.hidden_layers.apply(init_func_hidden)
if self.output_enabled:
def init_func_output(m):
if isinstance(m, nn.Linear):
if self.output_init is not None:
self.output_init(m.weight)
if self.bias_enabled and self.bias_init is not None:
self.bias_init(m.bias)
self.output_layer.apply(init_func_output)
def forward(self, x):
inputs = x
for i in range(self.net_depth):
x = self.hidden_layers[i](x)
x = self.hidden_activation(x)
if (self.skip_layer is not None) and (i % self.skip_layer
== 0) and (i > 0):
x = torch.cat([x, inputs], dim=-1)
if self.output_enabled:
x = self.output_layer(x)
x = self.output_activation(x)
return x
class DenseLayer(MLP):
def __init__(self, input_dim, output_dim, **kwargs):
super().__init__(
input_dim=input_dim,
output_dim=output_dim,
net_depth=0, # no hidden layers
**kwargs,
)
class NerfMLP(nn.Module):
"""The Nerf-MLP Module.
Args:
input_dim (int): The number of input tensor channels.
condition_dim (int): The number of condition tensor channels.
feature_dim (int): The number of feature channels. Defaults to 0.
net_depth (int): The depth of the MLP. Defaults to 8.
net_width (int): The width of the MLP. Defaults to 256.
skip_layer (int): The layer to add skip layers to. Defaults to 4.
net_depth_condition (int): The depth of the second part of MLP.
Defaults to 1.
net_width_condition (int): The width of the second part of MLP.
Defaults to 128.
"""
def __init__(
self,
input_dim: int,
condition_dim: int,
feature_dim: int = 0,
net_depth: int = 8,
net_width: int = 256,
skip_layer: int = 4,
net_depth_condition: int = 1,
net_width_condition: int = 128,
):
super().__init__()
self.base = MLP(
input_dim=input_dim + feature_dim,
net_depth=net_depth,
net_width=net_width,
skip_layer=skip_layer,
output_enabled=False,
)
hidden_features = self.base.output_dim
self.sigma_layer = DenseLayer(hidden_features, 1)
if condition_dim > 0:
self.bottleneck_layer = DenseLayer(hidden_features, net_width)
self.rgb_layer = MLP(
input_dim=net_width + condition_dim,
output_dim=3,
net_depth=net_depth_condition,
net_width=net_width_condition,
skip_layer=None,
)
else:
self.rgb_layer = DenseLayer(hidden_features, 3)
def query_density(self, x, features=None):
"""Calculate the raw sigma."""
if features is not None:
x = self.base(torch.cat([x, features], dim=-1))
else:
x = self.base(x)
raw_sigma = self.sigma_layer(x)
return raw_sigma
def forward(self, x, condition=None, features=None):
if features is not None:
x = self.base(torch.cat([x, features], dim=-1))
else:
x = self.base(x)
raw_sigma = self.sigma_layer(x)
if condition is not None:
if condition.shape[:-1] != x.shape[:-1]:
num_rays, n_dim = condition.shape
condition = condition.view(
[num_rays] + [1] * (x.dim() - condition.dim()) +
[n_dim]).expand(list(x.shape[:-1]) + [n_dim])
bottleneck = self.bottleneck_layer(x)
x = torch.cat([bottleneck, condition], dim=-1)
raw_rgb = self.rgb_layer(x)
return raw_rgb, raw_sigma
class SinusoidalEncoder(nn.Module):
"""Sinusodial Positional Encoder used in NeRF."""
def __init__(self, x_dim, min_deg, max_deg, use_identity: bool = True):
super().__init__()
self.x_dim = x_dim
self.min_deg = min_deg
self.max_deg = max_deg
self.use_identity = use_identity
self.register_buffer(
'scales', torch.tensor([2**i for i in range(min_deg, max_deg)]))
@property
def latent_dim(self) -> int:
return (int(self.use_identity) +
(self.max_deg - self.min_deg) * 2) * self.x_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.max_deg == self.min_deg:
return x
xb = torch.reshape(
(x[Ellipsis, None, :] * self.scales[:, None]),
list(x.shape[:-1]) + [(self.max_deg - self.min_deg) * self.x_dim],
)
latent = torch.sin(torch.cat([xb, xb + 0.5 * math.pi], dim=-1))
if self.use_identity:
latent = torch.cat([x] + [latent], dim=-1)
return latent
class VanillaNeRF(nn.Module):
"""The Nerf-MLP with the positional encoder.
Args:
net_depth (int): The depth of the MLP. Defaults to 8.
net_width (int): The width of the MLP. Defaults to 256.
skip_layer (int): The layer to add skip layers to. Defaults to 4.
feature_dim (int): The number of feature channels. Defaults to 0.
net_depth_condition (int): The depth of the second part of MLP.
Defaults to 1.
net_width_condition (int): The width of the second part of MLP.
Defaults to 128.
"""
def __init__(self,
net_depth: int = 8,
net_width: int = 256,
skip_layer: int = 4,
feature_dim: int = 0,
net_depth_condition: int = 1,
net_width_condition: int = 128):
super().__init__()
self.posi_encoder = SinusoidalEncoder(3, 0, 10, True)
self.view_encoder = SinusoidalEncoder(3, 0, 4, True)
self.mlp = NerfMLP(
input_dim=self.posi_encoder.latent_dim,
condition_dim=self.view_encoder.latent_dim,
feature_dim=feature_dim,
net_depth=net_depth,
net_width=net_width,
skip_layer=skip_layer,
net_depth_condition=net_depth_condition,
net_width_condition=net_width_condition,
)
def query_density(self, x, features=None):
x = self.posi_encoder(x)
sigma = self.mlp.query_density(x, features)
return F.relu(sigma)
def forward(self, x, condition=None, features=None):
x = self.posi_encoder(x)
if condition is not None:
condition = self.view_encoder(condition)
rgb, sigma = self.mlp(x, condition=condition, features=features)
return torch.sigmoid(rgb), F.relu(sigma)
# Copyright (c) OpenMMLab. All rights reserved.
# Attention: This file is mainly modified based on the file with the same
# name in the original project. For more details, please refer to the
# origin project.
import torch
import torch.nn.functional as F
class Projector():
def __init__(self, device='cuda'):
self.device = device
def inbound(self, pixel_locations, h, w):
"""check if the pixel locations are in valid range."""
return (pixel_locations[..., 0] <= w - 1.) & \
(pixel_locations[..., 0] >= 0) & \
(pixel_locations[..., 1] <= h - 1.) &\
(pixel_locations[..., 1] >= 0)
def normalize(self, pixel_locations, h, w):
resize_factor = torch.tensor([w - 1., h - 1.
]).to(pixel_locations.device)[None,
None, :]
normalized_pixel_locations = 2 * pixel_locations / resize_factor - 1.
return normalized_pixel_locations
def compute_projections(self, xyz, train_cameras):
"""project 3D points into cameras."""
original_shape = xyz.shape[:2]
xyz = xyz.reshape(-1, 3)
num_views = len(train_cameras)
train_intrinsics = train_cameras[:, 2:18].reshape(-1, 4, 4)
train_poses = train_cameras[:, -16:].reshape(-1, 4, 4)
xyz_h = torch.cat([xyz, torch.ones_like(xyz[..., :1])], dim=-1)
# projections = train_intrinsics.bmm(torch.inverse(train_poses))
# we have inverse the pose in dataloader so
# do not need to inverse here.
projections = train_intrinsics.bmm(train_poses) \
.bmm(xyz_h.t()[None, ...].repeat(num_views, 1, 1))
projections = projections.permute(0, 2, 1)
pixel_locations = projections[..., :2] / torch.clamp(
projections[..., 2:3], min=1e-8)
pixel_locations = torch.clamp(pixel_locations, min=-1e6, max=1e6)
mask = projections[..., 2] > 0
return pixel_locations.reshape((num_views, ) + original_shape + (2, )), \
mask.reshape((num_views, ) + original_shape) # noqa
def compute_angle(self, xyz, query_camera, train_cameras):
original_shape = xyz.shape[:2]
xyz = xyz.reshape(-1, 3)
train_poses = train_cameras[:, -16:].reshape(-1, 4, 4)
num_views = len(train_poses)
query_pose = query_camera[-16:].reshape(-1, 4,
4).repeat(num_views, 1, 1)
ray2tar_pose = (query_pose[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0))
ray2tar_pose /= (torch.norm(ray2tar_pose, dim=-1, keepdim=True) + 1e-6)
ray2train_pose = (
train_poses[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0))
ray2train_pose /= (
torch.norm(ray2train_pose, dim=-1, keepdim=True) + 1e-6)
ray_diff = ray2tar_pose - ray2train_pose
ray_diff_norm = torch.norm(ray_diff, dim=-1, keepdim=True)
ray_diff_dot = torch.sum(
ray2tar_pose * ray2train_pose, dim=-1, keepdim=True)
ray_diff_direction = ray_diff / torch.clamp(ray_diff_norm, min=1e-6)
ray_diff = torch.cat([ray_diff_direction, ray_diff_dot], dim=-1)
ray_diff = ray_diff.reshape((num_views, ) + original_shape + (4, ))
return ray_diff
def compute(self,
xyz,
train_imgs,
train_cameras,
featmaps=None,
grid_sample=True):
assert (train_imgs.shape[0] == 1) \
and (train_cameras.shape[0] == 1)
# only support batch_size=1 for now
train_imgs = train_imgs.squeeze(0)
train_cameras = train_cameras.squeeze(0)
train_imgs = train_imgs.permute(0, 3, 1, 2)
h, w = train_cameras[0][:2]
# compute the projection of the query points to each reference image
pixel_locations, mask_in_front = self.compute_projections(
xyz, train_cameras)
normalized_pixel_locations = self.normalize(pixel_locations, h, w)
# rgb sampling
rgbs_sampled = F.grid_sample(
train_imgs, normalized_pixel_locations, align_corners=True)
rgb_sampled = rgbs_sampled.permute(2, 3, 0, 1)
# deep feature sampling
if featmaps is not None:
if grid_sample:
feat_sampled = F.grid_sample(
featmaps, normalized_pixel_locations, align_corners=True)
feat_sampled = feat_sampled.permute(
2, 3, 0, 1) # [n_rays, n_samples, n_views, d]
rgb_feat_sampled = torch.cat(
[rgb_sampled, feat_sampled],
dim=-1) # [n_rays, n_samples, n_views, d+3]
# rgb_feat_sampled = feat_sampled
else:
n_images, n_channels, f_h, f_w = featmaps.shape
resize_factor = torch.tensor([f_w / w - 1., f_h / h - 1.]).to(
pixel_locations.device)[None, None, :]
sample_location = (pixel_locations *
resize_factor).round().long()
n_images, n_ray, n_sample, _ = sample_location.shape
sample_x = sample_location[..., 0].view(n_images, -1)
sample_y = sample_location[..., 1].view(n_images, -1)
valid = (sample_x >= 0) & (sample_y >=
0) & (sample_x < f_w) & (
sample_y < f_h)
valid = valid * mask_in_front.view(n_images, -1)
feat_sampled = torch.zeros(
(n_images, n_channels, sample_x.shape[-1]),
device=featmaps.device)
for i in range(n_images):
feat_sampled[i, :,
valid[i]] = featmaps[i, :, sample_y[i,
valid[i]],
sample_y[i, valid[i]]]
feat_sampled = feat_sampled.view(n_images, n_channels, n_ray,
n_sample)
rgb_feat_sampled = feat_sampled.permute(2, 3, 0, 1)
else:
rgb_feat_sampled = None
inbound = self.inbound(pixel_locations, h, w)
mask = (inbound * mask_in_front).float().permute(
1, 2, 0)[..., None] # [n_rays, n_samples, n_views, 1]
return rgb_feat_sampled, mask
# Copyright (c) OpenMMLab. All rights reserved.
# Attention: This file is mainly modified based on the file with the same
# name in the original project. For more details, please refer to the
# origin project.
from collections import OrderedDict
import numpy as np
import torch
import torch.nn.functional as F
rng = np.random.RandomState(234)
# helper functions for nerf ray rendering
def volume_sampling(sample_pts, features, aabb):
B, C, D, W, H = features.shape
assert B == 1
aabb = torch.Tensor(aabb).to(sample_pts.device)
N_rays, N_samples, coords = sample_pts.shape
sample_pts = sample_pts.view(1, N_rays * N_samples, 1, 1,
3).repeat(B, 1, 1, 1, 1)
aabbSize = aabb[1] - aabb[0]
invgridSize = 1.0 / aabbSize * 2
norm_pts = (sample_pts - aabb[0]) * invgridSize - 1
sample_features = F.grid_sample(
features, norm_pts, align_corners=True, padding_mode='border')
masks = ((norm_pts < 1) & (norm_pts > -1)).float().sum(dim=-1)
masks = (masks.view(N_rays, N_samples) == 3)
return sample_features.view(C, N_rays,
N_samples).permute(1, 2, 0).contiguous(), masks
def _compute_projection(img_meta):
views = len(img_meta['lidar2img']['extrinsic'])
intrinsic = torch.tensor(img_meta['lidar2img']['intrinsic'][:4, :4])
ratio = img_meta['ori_shape'][0] / img_meta['img_shape'][0]
intrinsic[:2] /= ratio
intrinsic = intrinsic.unsqueeze(0).view(1, 16).repeat(views, 1)
img_size = torch.Tensor(img_meta['img_shape'][:2]).to(intrinsic.device)
img_size = img_size.unsqueeze(0).repeat(views, 1)
extrinsics = []
for v in range(views):
extrinsics.append(
torch.Tensor(img_meta['lidar2img']['extrinsic'][v]).to(
intrinsic.device))
extrinsic = torch.stack(extrinsics).view(views, 16)
train_cameras = torch.cat([img_size, intrinsic, extrinsic], dim=-1)
return train_cameras.unsqueeze(0)
def compute_mask_points(feature, mask):
weight = mask / (torch.sum(mask, dim=2, keepdim=True) + 1e-8)
mean = torch.sum(feature * weight, dim=2, keepdim=True)
var = torch.sum((feature - mean)**2, dim=2, keepdim=True)
var = var / (torch.sum(mask, dim=2, keepdim=True) + 1e-8)
var = torch.exp(-var)
return mean, var
def sample_pdf(bins, weights, N_samples, det=False):
"""Helper function used for sampling.
Args:
bins (tensor):Tensor of shape [N_rays, M+1], M is the number of bins
weights (tensor):Tensor of shape [N_rays, M+1], M is the number of bins
N_samples (int):Number of samples along each ray
det (bool):If True, will perform deterministic sampling
Returns:
samples (tuple): [N_rays, N_samples]
"""
M = weights.shape[1]
weights += 1e-5
# Get pdf
pdf = weights / torch.sum(weights, dim=-1, keepdim=True)
cdf = torch.cumsum(pdf, dim=-1)
cdf = torch.cat([torch.zeros_like(cdf[:, 0:1]), cdf], dim=-1)
# Take uniform samples
if det:
u = torch.linspace(0., 1., N_samples, device=bins.device)
u = u.unsqueeze(0).repeat(bins.shape[0], 1)
else:
u = torch.rand(bins.shape[0], N_samples, device=bins.device)
# Invert CDF
above_inds = torch.zeros_like(u, dtype=torch.long)
for i in range(M):
above_inds += (u >= cdf[:, i:i + 1]).long()
# random sample inside each bin
below_inds = torch.clamp(above_inds - 1, min=0)
inds_g = torch.stack((below_inds, above_inds), dim=2)
cdf = cdf.unsqueeze(1).repeat(1, N_samples, 1)
cdf_g = torch.gather(input=cdf, dim=-1, index=inds_g)
bins = bins.unsqueeze(1).repeat(1, N_samples, 1)
bins_g = torch.gather(input=bins, dim=-1, index=inds_g)
denom = cdf_g[:, :, 1] - cdf_g[:, :, 0]
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
t = (u - cdf_g[:, :, 0]) / denom
samples = bins_g[:, :, 0] + t * (bins_g[:, :, 1] - bins_g[:, :, 0])
return samples
def sample_along_camera_ray(ray_o,
ray_d,
depth_range,
N_samples,
inv_uniform=False,
det=False):
"""Sampling along the camera ray.
Args:
ray_o (tensor): Origin of the ray in scene coordinate system;
tensor of shape [N_rays, 3]
ray_d (tensor): Homogeneous ray direction vectors in
scene coordinate system; tensor of shape [N_rays, 3]
depth_range (tuple): [near_depth, far_depth]
inv_uniform (bool): If True,uniformly sampling inverse depth.
det (bool): If True, will perform deterministic sampling.
Returns:
pts (tensor): Tensor of shape [N_rays, N_samples, 3]
z_vals (tensor): Tensor of shape [N_rays, N_samples]
"""
# will sample inside [near_depth, far_depth]
# assume the nearest possible depth is at least (min_ratio * depth)
near_depth_value = depth_range[0]
far_depth_value = depth_range[1]
assert near_depth_value > 0 and far_depth_value > 0 \
and far_depth_value > near_depth_value
near_depth = near_depth_value * torch.ones_like(ray_d[..., 0])
far_depth = far_depth_value * torch.ones_like(ray_d[..., 0])
if inv_uniform:
start = 1. / near_depth
step = (1. / far_depth - start) / (N_samples - 1)
inv_z_vals = torch.stack([start + i * step for i in range(N_samples)],
dim=1)
z_vals = 1. / inv_z_vals
else:
start = near_depth
step = (far_depth - near_depth) / (N_samples - 1)
z_vals = torch.stack([start + i * step for i in range(N_samples)],
dim=1)
if not det:
# get intervals between samples
mids = .5 * (z_vals[:, 1:] + z_vals[:, :-1])
upper = torch.cat([mids, z_vals[:, -1:]], dim=-1)
lower = torch.cat([z_vals[:, 0:1], mids], dim=-1)
# uniform samples in those intervals
t_rand = torch.rand_like(z_vals)
z_vals = lower + (upper - lower) * t_rand
ray_d = ray_d.unsqueeze(1).repeat(1, N_samples, 1)
ray_o = ray_o.unsqueeze(1).repeat(1, N_samples, 1)
pts = z_vals.unsqueeze(2) * ray_d + ray_o # [N_rays, N_samples, 3]
return pts, z_vals
# ray rendering of nerf
def raw2outputs(raw, z_vals, mask, white_bkgd=False):
"""Transform raw data to outputs:
Args:
raw(tensor):Raw network output.Tensor of shape [N_rays, N_samples, 4]
z_vals(tensor):Depth of point samples along rays.
Tensor of shape [N_rays, N_samples]
ray_d(tensor):[N_rays, 3]
Returns:
ret(dict):
-rgb(tensor):[N_rays, 3]
-depth(tensor):[N_rays,]
-weights(tensor):[N_rays,]
-depth_std(tensor):[N_rays,]
"""
rgb = raw[:, :, :3] # [N_rays, N_samples, 3]
sigma = raw[:, :, 3] # [N_rays, N_samples]
# note: we did not use the intervals here,
# because in practice different scenes from COLMAP can have
# very different scales, and using interval can affect
# the model's generalization ability.
# Therefore we don't use the intervals for both training and evaluation.
sigma2alpha = lambda sigma, dists: 1. - torch.exp(-sigma) # noqa
# point samples are ordered with increasing depth
# interval between samples
dists = z_vals[:, 1:] - z_vals[:, :-1]
dists = torch.cat((dists, dists[:, -1:]), dim=-1)
alpha = sigma2alpha(sigma, dists)
T = torch.cumprod(1. - alpha + 1e-10, dim=-1)[:, :-1]
T = torch.cat((torch.ones_like(T[:, 0:1]), T), dim=-1)
# maths show weights, and summation of weights along a ray,
# are always inside [0, 1]
weights = alpha * T
rgb_map = torch.sum(weights.unsqueeze(2) * rgb, dim=1)
if white_bkgd:
rgb_map = rgb_map + (1. - torch.sum(weights, dim=-1, keepdim=True))
if mask is not None:
mask = mask.float().sum(dim=1) > 8
depth_map = torch.sum(
weights * z_vals, dim=-1) / (
torch.sum(weights, dim=-1) + 1e-8)
depth_map = torch.clamp(depth_map, z_vals.min(), z_vals.max())
ret = OrderedDict([('rgb', rgb_map), ('depth', depth_map),
('weights', weights), ('mask', mask), ('alpha', alpha),
('z_vals', z_vals), ('transparency', T)])
return ret
def render_rays_func(
ray_o,
ray_d,
mean_volume,
cov_volume,
features_2D,
img,
aabb,
near_far_range,
N_samples,
N_rand=4096,
nerf_mlp=None,
img_meta=None,
projector=None,
mode='volume', # volume and image
nerf_sample_view=3,
inv_uniform=False,
N_importance=0,
det=False,
is_train=True,
white_bkgd=False,
gt_rgb=None,
gt_depth=None):
ret = {
'outputs_coarse': None,
'outputs_fine': None,
'gt_rgb': gt_rgb,
'gt_depth': gt_depth
}
# pts: [N_rays, N_samples, 3]
# z_vals: [N_rays, N_samples]
pts, z_vals = sample_along_camera_ray(
ray_o=ray_o,
ray_d=ray_d,
depth_range=near_far_range,
N_samples=N_samples,
inv_uniform=inv_uniform,
det=det)
N_rays, N_samples = pts.shape[:2]
if mode == 'image':
img = img.permute(0, 2, 3, 1).unsqueeze(0)
train_camera = _compute_projection(img_meta).to(img.device)
rgb_feat, mask = projector.compute(
pts, img, train_camera, features_2D, grid_sample=True)
pixel_mask = mask[..., 0].sum(dim=2) > 1
mean, var = compute_mask_points(rgb_feat, mask)
globalfeat = torch.cat([mean, var], dim=-1).squeeze(2)
rgb_pts, density_pts = nerf_mlp(pts, ray_d, globalfeat)
raw_coarse = torch.cat([rgb_pts, density_pts], dim=-1)
ret['sigma'] = density_pts
elif mode == 'volume':
mean_pts, inbound_masks = volume_sampling(pts, mean_volume, aabb)
cov_pts, inbound_masks = volume_sampling(pts, cov_volume, aabb)
# This masks is for indicating which points outside of aabb
img = img.permute(0, 2, 3, 1).unsqueeze(0)
train_camera = _compute_projection(img_meta).to(img.device)
_, view_mask = projector.compute(pts, img, train_camera, None)
pixel_mask = view_mask[..., 0].sum(dim=2) > 1
# plot_3D_vis(pts, aabb, img, train_camera)
# [N_rays, N_samples], should at least have 2 observations
# This mask is for indicating which points do not have projected point
globalpts = torch.cat([mean_pts, cov_pts], dim=-1)
rgb_pts, density_pts = nerf_mlp(pts, ray_d, globalpts)
density_pts = density_pts * inbound_masks.unsqueeze(dim=-1)
raw_coarse = torch.cat([rgb_pts, density_pts], dim=-1)
outputs_coarse = raw2outputs(
raw_coarse, z_vals, pixel_mask, white_bkgd=white_bkgd)
ret['outputs_coarse'] = outputs_coarse
return ret
def render_rays(
ray_batch,
mean_volume,
cov_volume,
features_2D,
img,
aabb,
near_far_range,
N_samples,
N_rand=4096,
nerf_mlp=None,
img_meta=None,
projector=None,
mode='volume', # volume and image
nerf_sample_view=3,
inv_uniform=False,
N_importance=0,
det=False,
is_train=True,
white_bkgd=False,
render_testing=False):
"""The function of the nerf rendering."""
ray_o = ray_batch['ray_o']
ray_d = ray_batch['ray_d']
gt_rgb = ray_batch['gt_rgb']
gt_depth = ray_batch['gt_depth']
nerf_sizes = ray_batch['nerf_sizes']
if is_train:
ray_o = ray_o.view(-1, 3)
ray_d = ray_d.view(-1, 3)
gt_rgb = gt_rgb.view(-1, 3)
if gt_depth.shape[1] != 0:
gt_depth = gt_depth.view(-1, 1)
non_zero_depth = (gt_depth > 0).squeeze(-1)
ray_o = ray_o[non_zero_depth]
ray_d = ray_d[non_zero_depth]
gt_rgb = gt_rgb[non_zero_depth]
gt_depth = gt_depth[non_zero_depth]
else:
gt_depth = None
total_rays = ray_d.shape[0]
select_inds = rng.choice(total_rays, size=(N_rand, ), replace=False)
ray_o = ray_o[select_inds]
ray_d = ray_d[select_inds]
gt_rgb = gt_rgb[select_inds]
if gt_depth is not None:
gt_depth = gt_depth[select_inds]
rets = render_rays_func(
ray_o,
ray_d,
mean_volume,
cov_volume,
features_2D,
img,
aabb,
near_far_range,
N_samples,
N_rand,
nerf_mlp,
img_meta,
projector,
mode, # volume and image
nerf_sample_view,
inv_uniform,
N_importance,
det,
is_train,
white_bkgd,
gt_rgb,
gt_depth)
elif render_testing:
nerf_size = nerf_sizes[0]
view_num = ray_o.shape[1]
H = nerf_size[0][0]
W = nerf_size[0][1]
ray_o = ray_o.view(-1, 3)
ray_d = ray_d.view(-1, 3)
gt_rgb = gt_rgb.view(-1, 3)
print(gt_rgb.shape)
if len(gt_depth) != 0:
gt_depth = gt_depth.view(-1, 1)
else:
gt_depth = None
assert view_num * H * W == ray_o.shape[0]
num_rays = ray_o.shape[0]
results = []
rgbs = []
for i in range(0, num_rays, N_rand):
ray_o_chunck = ray_o[i:i + N_rand, :]
ray_d_chunck = ray_d[i:i + N_rand, :]
ret = render_rays_func(ray_o_chunck, ray_d_chunck, mean_volume,
cov_volume, features_2D, img, aabb,
near_far_range, N_samples, N_rand, nerf_mlp,
img_meta, projector, mode, nerf_sample_view,
inv_uniform, N_importance, True, is_train,
white_bkgd, gt_rgb, gt_depth)
results.append(ret)
rgbs = []
depths = []
if results[0]['outputs_coarse'] is not None:
for i in range(len(results)):
rgb = results[i]['outputs_coarse']['rgb']
rgbs.append(rgb)
depth = results[i]['outputs_coarse']['depth']
depths.append(depth)
rets = {
'outputs_coarse': {
'rgb': torch.cat(rgbs, dim=0).view(view_num, H, W, 3),
'depth': torch.cat(depths, dim=0).view(view_num, H, W, 1),
},
'gt_rgb':
gt_rgb.view(view_num, H, W, 3),
'gt_depth':
gt_depth.view(view_num, H, W, 1) if gt_depth is not None else None,
}
else:
rets = None
return rets
# Copyright (c) OpenMMLab. All rights reserved.
import os
import cv2
import numpy as np
import torch
from skimage.metrics import structural_similarity
def compute_psnr_from_mse(mse):
return -10.0 * torch.log(mse) / np.log(10.0)
def compute_psnr(pred, target, mask=None):
"""Compute psnr value (we assume the maximum pixel value is 1)."""
if mask is not None:
pred, target = pred[mask], target[mask]
mse = ((pred - target)**2).mean()
return compute_psnr_from_mse(mse).cpu().numpy()
def compute_ssim(pred, target, mask=None):
"""Computes Masked SSIM following the neuralbody paper."""
assert pred.shape == target.shape and pred.shape[-1] == 3
if mask is not None:
x, y, w, h = cv2.boundingRect(mask.cpu().numpy().astype(np.uint8))
pred = pred[y:y + h, x:x + w]
target = target[y:y + h, x:x + w]
try:
ssim = structural_similarity(
pred.cpu().numpy(), target.cpu().numpy(), channel_axis=-1)
except ValueError:
ssim = structural_similarity(
pred.cpu().numpy(), target.cpu().numpy(), multichannel=True)
return ssim
def save_rendered_img(img_meta, rendered_results):
filename = img_meta[0]['filename']
scenes = filename.split('/')[-2]
for ret in rendered_results:
depth = ret['outputs_coarse']['depth']
rgb = ret['outputs_coarse']['rgb']
gt = ret['gt_rgb']
gt_depth = ret['gt_depth']
# save images
psnr_total = 0
ssim_total = 0
rsme = 0
for v in range(gt.shape[0]):
rsme += ((depth[v] - gt_depth[v])**2).cpu().numpy()
depth_ = ((depth[v] - depth[v].min()) /
(depth[v].max() - depth[v].min() + 1e-8)).repeat(1, 1, 3)
img_to_save = torch.cat([rgb[v], gt[v], depth_], dim=1)
image_path = os.path.join('nerf_vs_rebuttal', scenes)
if not os.path.exists(image_path):
os.makedirs(image_path)
save_dir = os.path.join(image_path, 'view_' + str(v) + '.png')
font = cv2.FONT_HERSHEY_SIMPLEX
org = (50, 50)
fontScale = 1
color = (255, 0, 0)
thickness = 2
image = np.uint8(img_to_save.cpu().numpy() * 255.0)
psnr = compute_psnr(rgb[v], gt[v], mask=None)
psnr_total += psnr
ssim = compute_ssim(rgb[v], gt[v], mask=None)
ssim_total += ssim
image = cv2.putText(
image, 'PSNR: ' + '%.2f' % compute_psnr(rgb[v], gt[v], mask=None),
org, font, fontScale, color, thickness, cv2.LINE_AA)
cv2.imwrite(save_dir, image)
return psnr_total / gt.shape[0], ssim_total / gt.shape[0], rsme / gt.shape[
0]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet3d.models.detectors import Base3DDetector
from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils import ConfigType, OptConfigType
from .nerf_utils.nerf_mlp import VanillaNeRF
from .nerf_utils.projection import Projector
from .nerf_utils.render_ray import render_rays
# from ..utils.nerf_utils.save_rendered_img import save_rendered_img
@MODELS.register_module()
class NerfDet(Base3DDetector):
r"""`ImVoxelNet <https://arxiv.org/abs/2307.14620>`_.
Args:
backbone (:obj:`ConfigDict` or dict): The backbone config.
neck (:obj:`ConfigDict` or dict): The neck config.
neck_3d(:obj:`ConfigDict` or dict): The 3D neck config.
bbox_head(:obj:`ConfigDict` or dict): The bbox head config.
prior_generator (:obj:`ConfigDict` or dict): The prior generator
config.
n_voxels (list): Number of voxels along x, y, z axis.
voxel_size (list): The size of voxels.Each voxel represents
a cube of `voxel_size[0]` meters, `voxel_size[1]` meters,
``
train_cfg (:obj:`ConfigDict` or dict, optional): Config dict of
training hyper-parameters. Defaults to None.
test_cfg (:obj:`ConfigDict` or dict, optional): Config dict of test
hyper-parameters. Defaults to None.
init_cfg (:obj:`ConfigDict` or dict, optional): The initialization
config. Defaults to None.
render_testing (bool): If you want to render novel view, please set
"render_testing = True" in config
The other args are the parameters of NeRF, you can just use the
default values.
"""
def __init__(
self,
backbone: ConfigType,
neck: ConfigType,
neck_3d: ConfigType,
bbox_head: ConfigType,
prior_generator: ConfigType,
n_voxels: List,
voxel_size: List,
head_2d: ConfigType = None,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
data_preprocessor: OptConfigType = None,
init_cfg: OptConfigType = None,
# pretrained,
aabb: Tuple = None,
near_far_range: List = None,
N_samples: int = 64,
N_rand: int = 2048,
depth_supervise: bool = False,
use_nerf_mask: bool = True,
nerf_sample_view: int = 3,
nerf_mode: str = 'volume',
squeeze_scale: int = 4,
rgb_supervision: bool = True,
nerf_density: bool = False,
render_testing: bool = False):
super().__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
self.backbone = MODELS.build(backbone)
self.neck = MODELS.build(neck)
self.neck_3d = MODELS.build(neck_3d)
bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg)
self.bbox_head = MODELS.build(bbox_head)
self.head_2d = MODELS.build(head_2d) if head_2d is not None else None
self.n_voxels = n_voxels
self.prior_generator = TASK_UTILS.build(prior_generator)
self.voxel_size = voxel_size
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.aabb = aabb
self.near_far_range = near_far_range
self.N_samples = N_samples
self.N_rand = N_rand
self.depth_supervise = depth_supervise
self.projector = Projector()
self.squeeze_scale = squeeze_scale
self.use_nerf_mask = use_nerf_mask
self.rgb_supervision = rgb_supervision
nerf_feature_dim = neck['out_channels'] // squeeze_scale
self.nerf_mlp = VanillaNeRF(
net_depth=4, # The depth of the MLP
net_width=256, # The width of the MLP
skip_layer=3, # The layer to add skip layers to.
feature_dim=nerf_feature_dim + 6, # + RGB original imgs
net_depth_condition=1, # The depth of the second part of MLP
net_width_condition=128)
self.nerf_mode = nerf_mode
self.nerf_density = nerf_density
self.nerf_sample_view = nerf_sample_view
self.render_testing = render_testing
# hard code here, will deal with batch issue later.
self.cov = nn.Sequential(
nn.Conv3d(
neck['out_channels'],
neck['out_channels'],
kernel_size=3,
padding=1), nn.ReLU(inplace=True),
nn.Conv3d(
neck['out_channels'],
neck['out_channels'],
kernel_size=3,
padding=1), nn.ReLU(inplace=True),
nn.Conv3d(neck['out_channels'], 1, kernel_size=1))
self.mean_mapping = nn.Sequential(
nn.Conv3d(
neck['out_channels'], nerf_feature_dim // 2, kernel_size=1))
self.cov_mapping = nn.Sequential(
nn.Conv3d(
neck['out_channels'], nerf_feature_dim // 2, kernel_size=1))
self.mapping = nn.Sequential(
nn.Linear(neck['out_channels'], nerf_feature_dim // 2))
self.mapping_2d = nn.Sequential(
nn.Conv2d(
neck['out_channels'], nerf_feature_dim // 2, kernel_size=1))
# self.overfit_nerfmlp = overfit_nerfmlp
# if self.overfit_nerfmlp:
# self. _finetuning_NeRF_MLP()
self.render_testing = render_testing
def extract_feat(self,
batch_inputs_dict: dict,
batch_data_samples: SampleList,
mode,
depth=None,
ray_batch=None):
"""Extract 3d features from the backbone -> fpn -> 3d projection.
-> 3d neck -> bbox_head.
Args:
batch_inputs_dict (dict): The model input dict which include
the 'imgs' key.
- imgs (torch.Tensor, optional): Image of each sample.
batch_data_samples (list[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instances` of `gt_panoptic_seg` or `gt_sem_seg`
Returns:
Tuple:
- torch.Tensor: Features of shape (N, C_out, N_x, N_y, N_z).
- torch.Tensor: Valid mask of shape (N, 1, N_x, N_y, N_z).
- torch.Tensor: 2D features if needed.
- dict: The nerf rendered information including the
'output_coarse', 'gt_rgb' and 'gt_depth' keys.
"""
img = batch_inputs_dict['imgs']
img = img.float()
batch_img_metas = [
data_samples.metainfo for data_samples in batch_data_samples
]
batch_size = img.shape[0]
if len(img.shape) > 4:
img = img.reshape([-1] + list(img.shape)[2:])
x = self.backbone(img)
x = self.neck(x)[0]
x = x.reshape([batch_size, -1] + list(x.shape[1:]))
else:
x = self.backbone(img)
x = self.neck(x)[0]
if depth is not None:
depth_bs = depth.shape[0]
assert depth_bs == batch_size
depth = batch_inputs_dict['depth']
depth = depth.reshape([-1] + list(depth.shape)[2:])
features_2d = self.head_2d.forward(x[-1], batch_img_metas) \
if self.head_2d is not None else None
stride = img.shape[-1] / x.shape[-1]
assert stride == 4
stride = int(stride)
volumes, valids = [], []
rgb_preds = []
for feature, img_meta in zip(x, batch_img_metas):
angles = features_2d[
0] if features_2d is not None and mode == 'test' else None
projection = self._compute_projection(img_meta, stride,
angles).to(x.device)
points = get_points(
n_voxels=torch.tensor(self.n_voxels),
voxel_size=torch.tensor(self.voxel_size),
origin=torch.tensor(img_meta['lidar2img']['origin'])).to(
x.device)
height = img_meta['img_shape'][0] // stride
width = img_meta['img_shape'][1] // stride
# Construct the volume space
# volume together with valid is the constructed scene
# volume represents V_i and valid represents M_p
volume, valid = backproject(feature[:, :, :height, :width], points,
projection, depth, self.voxel_size)
density = None
volume_sum = volume.sum(dim=0)
# cov_valid = valid.clone().detach()
valid = valid.sum(dim=0)
volume_mean = volume_sum / (valid + 1e-8)
volume_mean[:, valid[0] == 0] = .0
# volume_cov = (volume - volume_mean.unsqueeze(0)) ** 2 * cov_valid
# volume_cov = torch.sum(volume_cov, dim=0) / (valid + 1e-8)
volume_cov = torch.sum(
(volume - volume_mean.unsqueeze(0))**2, dim=0) / (
valid + 1e-8)
volume_cov[:, valid[0] == 0] = 1e6
volume_cov = torch.exp(-volume_cov) # default setting
# be careful here, the smaller the cov, the larger the weight.
n_channels, n_x_voxels, n_y_voxels, n_z_voxels = volume_mean.shape
if ray_batch is not None:
if self.nerf_mode == 'volume':
mean_volume = self.mean_mapping(volume_mean.unsqueeze(0))
cov_volume = self.cov_mapping(volume_cov.unsqueeze(0))
feature_2d = feature[:, :, :height, :width]
elif self.nerf_mode == 'image':
mean_volume = None
cov_volume = None
feature_2d = feature[:, :, :height, :width]
n_v, C, height, width = feature_2d.shape
feature_2d = feature_2d.view(n_v, C,
-1).permute(0, 2,
1).contiguous()
feature_2d = self.mapping(feature_2d).permute(
0, 2, 1).contiguous().view(n_v, -1, height, width)
denorm_images = ray_batch['denorm_images']
denorm_images = denorm_images.reshape(
[-1] + list(denorm_images.shape)[2:])
rgb_projection = self._compute_projection(
img_meta, stride=1, angles=None).to(x.device)
rgb_volume, _ = backproject(
denorm_images[:, :, :img_meta['img_shape'][0], :
img_meta['img_shape'][1]], points,
rgb_projection, depth, self.voxel_size)
ret = render_rays(
ray_batch,
mean_volume,
cov_volume,
feature_2d,
denorm_images,
self.aabb,
self.near_far_range,
self.N_samples,
self.N_rand,
self.nerf_mlp,
img_meta,
self.projector,
self.nerf_mode,
self.nerf_sample_view,
is_train=True if mode == 'train' else False,
render_testing=self.render_testing)
rgb_preds.append(ret)
if self.nerf_density:
# would have 0 bias issue for mean_mapping.
n_v, C, n_x_voxels, n_y_voxels, n_z_voxels = volume.shape
volume = volume.view(n_v, C, -1).permute(0, 2,
1).contiguous()
mapping_volume = self.mapping(volume).permute(
0, 2, 1).contiguous().view(n_v, -1, n_x_voxels,
n_y_voxels, n_z_voxels)
mapping_volume = torch.cat([rgb_volume, mapping_volume],
dim=1)
mapping_volume_sum = mapping_volume.sum(dim=0)
mapping_volume_mean = mapping_volume_sum / (valid + 1e-8)
# mapping_volume_cov = (
# mapping_volume - mapping_volume_mean.unsqueeze(0)
# ) ** 2 * cov_valid
mapping_volume_cov = (mapping_volume -
mapping_volume_mean.unsqueeze(0))**2
mapping_volume_cov = torch.sum(
mapping_volume_cov, dim=0) / (
valid + 1e-8)
mapping_volume_cov[:, valid[0] == 0] = 1e6
mapping_volume_cov = torch.exp(
-mapping_volume_cov) # default setting
global_volume = torch.cat(
[mapping_volume_mean, mapping_volume_cov], dim=1)
global_volume = global_volume.view(
-1, n_x_voxels * n_y_voxels * n_z_voxels).permute(
1, 0).contiguous()
points = points.view(3, -1).permute(1, 0).contiguous()
density = self.nerf_mlp.query_density(
points, global_volume)
alpha = 1 - torch.exp(-density)
# density -> alpha
# (1, n_x_voxels, n_y_voxels, n_z_voxels)
volume = alpha.view(1, n_x_voxels, n_y_voxels,
n_z_voxels) * volume_mean
volume[:, valid[0] == 0] = .0
volumes.append(volume)
valids.append(valid)
x = torch.stack(volumes)
x = self.neck_3d(x)
return x, torch.stack(valids).float(), features_2d, rgb_preds
def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
**kwargs) -> Union[dict, list]:
"""Calculate losses from a batch of inputs and data samples.
Args:
batch_inputs_dict (dict): The model input dict which include
the 'imgs' key.
- imgs (torch.Tensor, optional): Image of each sample.
batch_data_samples (list[:obj: `DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Returns:
dict: A dictionary of loss components.
"""
ray_batchs = {}
batch_images = []
batch_depths = []
if 'images' in batch_data_samples[0].gt_nerf_images:
for data_samples in batch_data_samples:
image = data_samples.gt_nerf_images['images']
batch_images.append(image)
batch_images = torch.stack(batch_images)
if 'depths' in batch_data_samples[0].gt_nerf_depths:
for data_samples in batch_data_samples:
depth = data_samples.gt_nerf_depths['depths']
batch_depths.append(depth)
batch_depths = torch.stack(batch_depths)
if 'raydirs' in batch_inputs_dict.keys():
ray_batchs['ray_o'] = batch_inputs_dict['lightpos']
ray_batchs['ray_d'] = batch_inputs_dict['raydirs']
ray_batchs['gt_rgb'] = batch_images
ray_batchs['gt_depth'] = batch_depths
ray_batchs['nerf_sizes'] = batch_inputs_dict['nerf_sizes']
ray_batchs['denorm_images'] = batch_inputs_dict['denorm_images']
x, valids, features_2d, rgb_preds = self.extract_feat(
batch_inputs_dict,
batch_data_samples,
'train',
depth=None,
ray_batch=ray_batchs)
else:
x, valids, features_2d, rgb_preds = self.extract_feat(
batch_inputs_dict, batch_data_samples, 'train')
x += (valids, )
losses = self.bbox_head.loss(x, batch_data_samples, **kwargs)
# if self.head_2d is not None:
# losses.update(
# self.head_2d.loss(*features_2d, batch_data_samples)
# )
if len(ray_batchs) != 0 and self.rgb_supervision:
losses.update(self.nvs_loss_func(rgb_preds))
if self.depth_supervise:
losses.update(self.depth_loss_func(rgb_preds))
return losses
def nvs_loss_func(self, rgb_pred):
loss = 0
for ret in rgb_pred:
rgb = ret['outputs_coarse']['rgb']
gt = ret['gt_rgb']
masks = ret['outputs_coarse']['mask']
if self.use_nerf_mask:
loss += torch.sum(masks.unsqueeze(-1) * (rgb - gt)**2) / (
masks.sum() + 1e-6)
else:
loss += torch.mean((rgb - gt)**2)
return dict(loss_nvs=loss)
def depth_loss_func(self, rgb_pred):
loss = 0
for ret in rgb_pred:
depth = ret['outputs_coarse']['depth']
gt = ret['gt_depth'].squeeze(-1)
masks = ret['outputs_coarse']['mask']
if self.use_nerf_mask:
loss += torch.sum(masks * torch.abs(depth - gt)) / (
masks.sum() + 1e-6)
else:
loss += torch.mean(torch.abs(depth - gt))
return dict(loss_depth=loss)
def predict(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
**kwargs) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
batch_inputs_dict (dict): The model input dict which include
the 'imgs' key.
- imgs (torch.Tensor, optional): Image of each sample.
batch_data_samples (List[:obj:`NeRFDet3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Returns:
list[:obj:`NeRFDet3DDataSample`]: Detection results of the
input images. Each NeRFDet3DDataSample usually contain
'pred_instances_3d'. And the ``pred_instances_3d`` usually
contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instance, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (Tensor): Contains a tensor with shape
(num_instances, C) where C = 6.
"""
ray_batchs = {}
batch_images = []
batch_depths = []
if 'images' in batch_data_samples[0].gt_nerf_images:
for data_samples in batch_data_samples:
image = data_samples.gt_nerf_images['images']
batch_images.append(image)
batch_images = torch.stack(batch_images)
if 'depths' in batch_data_samples[0].gt_nerf_depths:
for data_samples in batch_data_samples:
depth = data_samples.gt_nerf_depths['depths']
batch_depths.append(depth)
batch_depths = torch.stack(batch_depths)
if 'raydirs' in batch_inputs_dict.keys():
ray_batchs['ray_o'] = batch_inputs_dict['lightpos']
ray_batchs['ray_d'] = batch_inputs_dict['raydirs']
ray_batchs['gt_rgb'] = batch_images
ray_batchs['gt_depth'] = batch_depths
ray_batchs['nerf_sizes'] = batch_inputs_dict['nerf_sizes']
ray_batchs['denorm_images'] = batch_inputs_dict['denorm_images']
x, valids, features_2d, rgb_preds = self.extract_feat(
batch_inputs_dict,
batch_data_samples,
'test',
depth=None,
ray_batch=ray_batchs)
else:
x, valids, features_2d, rgb_preds = self.extract_feat(
batch_inputs_dict, batch_data_samples, 'test')
x += (valids, )
results_list = self.bbox_head.predict(x, batch_data_samples, **kwargs)
predictions = self.add_pred_to_datasample(batch_data_samples,
results_list)
return predictions
def _forward(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
*args, **kwargs) -> Tuple[List[torch.Tensor]]:
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
Args:
batch_inputs_dict (dict): The model input dict which include
the 'imgs' key.
- imgs (torch.Tensor, optional): Image of each sample.
batch_data_samples (List[:obj:`NeRFDet3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`
Returns:
tuple[list]: A tuple of features from ``bbox_head`` forward
"""
ray_batchs = {}
batch_images = []
batch_depths = []
if 'images' in batch_data_samples[0].gt_nerf_images:
for data_samples in batch_data_samples:
image = data_samples.gt_nerf_images['images']
batch_images.append(image)
batch_images = torch.stack(batch_images)
if 'depths' in batch_data_samples[0].gt_nerf_depths:
for data_samples in batch_data_samples:
depth = data_samples.gt_nerf_depths['depths']
batch_depths.append(depth)
batch_depths = torch.stack(batch_depths)
if 'raydirs' in batch_inputs_dict.keys():
ray_batchs['ray_o'] = batch_inputs_dict['lightpos']
ray_batchs['ray_d'] = batch_inputs_dict['raydirs']
ray_batchs['gt_rgb'] = batch_images
ray_batchs['gt_depth'] = batch_depths
ray_batchs['nerf_sizes'] = batch_inputs_dict['nerf_sizes']
ray_batchs['denorm_images'] = batch_inputs_dict['denorm_images']
x, valids, features_2d, rgb_preds = self.extract_feat(
batch_inputs_dict,
batch_data_samples,
'train',
depth=None,
ray_batch=ray_batchs)
else:
x, valids, features_2d, rgb_preds = self.extract_feat(
batch_inputs_dict, batch_data_samples, 'train')
x += (valids, )
results = self.bbox_head.forward(x)
return results
def aug_test(self, batch_inputs_dict, batch_data_samples):
pass
def show_results(self, *args, **kwargs):
pass
@staticmethod
def _compute_projection(img_meta, stride, angles):
projection = []
intrinsic = torch.tensor(img_meta['lidar2img']['intrinsic'][:3, :3])
ratio = img_meta['ori_shape'][0] / (img_meta['img_shape'][0] / stride)
intrinsic[:2] /= ratio
# use predict pitch and roll for SUNRGBDTotal test
if angles is not None:
extrinsics = []
for angle in angles:
extrinsics.append(get_extrinsics(angle).to(intrinsic.device))
else:
extrinsics = map(torch.tensor, img_meta['lidar2img']['extrinsic'])
for extrinsic in extrinsics:
projection.append(intrinsic @ extrinsic[:3])
return torch.stack(projection)
@torch.no_grad()
def get_points(n_voxels, voxel_size, origin):
# origin: point-cloud center.
points = torch.stack(
torch.meshgrid([
torch.arange(n_voxels[0]), # 40 W width, x
torch.arange(n_voxels[1]), # 40 D depth, y
torch.arange(n_voxels[2]) # 16 H Height, z
]))
new_origin = origin - n_voxels / 2. * voxel_size
points = points * voxel_size.view(3, 1, 1, 1) + new_origin.view(3, 1, 1, 1)
return points
# modify from https://github.com/magicleap/Atlas/blob/master/atlas/model.py
def backproject(features, points, projection, depth, voxel_size):
n_images, n_channels, height, width = features.shape
n_x_voxels, n_y_voxels, n_z_voxels = points.shape[-3:]
points = points.view(1, 3, -1).expand(n_images, 3, -1)
points = torch.cat((points, torch.ones_like(points[:, :1])), dim=1)
points_2d_3 = torch.bmm(projection, points)
x = (points_2d_3[:, 0] / points_2d_3[:, 2]).round().long()
y = (points_2d_3[:, 1] / points_2d_3[:, 2]).round().long()
z = points_2d_3[:, 2]
valid = (x >= 0) & (y >= 0) & (x < width) & (y < height) & (z > 0)
# below is using depth to sample feature
if depth is not None:
depth = F.interpolate(
depth.unsqueeze(1), size=(height, width),
mode='bilinear').squeeze(1)
for i in range(n_images):
z_mask = z.clone() > 0
z_mask[i, valid[i]] = \
(z[i, valid[i]] > depth[i, y[i, valid[i]], x[i, valid[i]]] - voxel_size[-1]) & \
(z[i, valid[i]] < depth[i, y[i, valid[i]], x[i, valid[i]]] + voxel_size[-1]) # noqa
valid = valid & z_mask
volume = torch.zeros((n_images, n_channels, points.shape[-1]),
device=features.device)
for i in range(n_images):
volume[i, :, valid[i]] = features[i, :, y[i, valid[i]], x[i, valid[i]]]
volume = volume.view(n_images, n_channels, n_x_voxels, n_y_voxels,
n_z_voxels)
valid = valid.view(n_images, 1, n_x_voxels, n_y_voxels, n_z_voxels)
return volume, valid
# for SUNRGBDTotal test
def get_extrinsics(angles):
yaw = angles.new_zeros(())
pitch, roll = angles
r = angles.new_zeros((3, 3))
r[0, 0] = torch.cos(yaw) * torch.cos(pitch)
r[0, 1] = torch.sin(yaw) * torch.sin(roll) - torch.cos(yaw) * torch.cos(
roll) * torch.sin(pitch)
r[0, 2] = torch.cos(roll) * torch.sin(yaw) + torch.cos(yaw) * torch.sin(
pitch) * torch.sin(roll)
r[1, 0] = torch.sin(pitch)
r[1, 1] = torch.cos(pitch) * torch.cos(roll)
r[1, 2] = -torch.cos(pitch) * torch.sin(roll)
r[2, 0] = -torch.cos(pitch) * torch.sin(yaw)
r[2, 1] = torch.cos(yaw) * torch.sin(roll) + torch.cos(roll) * torch.sin(
yaw) * torch.sin(pitch)
r[2, 2] = torch.cos(yaw) * torch.cos(roll) - torch.sin(yaw) * torch.sin(
pitch) * torch.sin(roll)
# follow Total3DUnderstanding
t = angles.new_tensor([[0., 0., 1.], [0., -1., 0.], [-1., 0., 0.]])
r = t @ r.T
# follow DepthInstance3DBoxes
r = r[:, [2, 0, 1]]
r[2] *= -1
extrinsic = angles.new_zeros((4, 4))
extrinsic[:3, :3] = r
extrinsic[3, 3] = 1.
return extrinsic
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch
from mmcv.cnn import Scale
# from mmcv.ops import nms3d, nms3d_normal
from mmdet.models.utils import multi_apply
from mmdet.utils import reduce_mean
# from mmengine.config import ConfigDict
from mmengine.model import BaseModule, bias_init_with_prob, normal_init
from mmengine.structures import InstanceData
from torch import Tensor, nn
from mmdet3d.registry import MODELS, TASK_UTILS
# from mmdet3d.structures.bbox_3d.utils import rotation_3d_in_axis
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils.typing_utils import (ConfigType, InstanceList,
OptConfigType, OptInstanceList)
@torch.no_grad()
def get_points(n_voxels, voxel_size, origin):
# origin: point-cloud center.
points = torch.stack(
torch.meshgrid([
torch.arange(n_voxels[0]), # 40 W width, x
torch.arange(n_voxels[1]), # 40 D depth, y
torch.arange(n_voxels[2]) # 16 H Height, z
]))
new_origin = origin - n_voxels / 2. * voxel_size
points = points * voxel_size.view(3, 1, 1, 1) + new_origin.view(3, 1, 1, 1)
return points
@MODELS.register_module()
class NerfDetHead(BaseModule):
r"""`ImVoxelNet<https://arxiv.org/abs/2106.01178>`_ head for indoor
datasets.
Args:
n_classes (int): Number of classes.
n_levels (int): Number of feature levels.
n_channels (int): Number of channels in input tensors.
n_reg_outs (int): Number of regression layer channels.
pts_assign_threshold (int): Min number of location per box to
be assigned with.
pts_center_threshold (int): Max number of locations per box to
be assigned with.
center_loss (dict, optional): Config of centerness loss.
Default: dict(type='CrossEntropyLoss', use_sigmoid=True).
bbox_loss (dict, optional): Config of bbox loss.
Default: dict(type='RotatedIoU3DLoss').
cls_loss (dict, optional): Config of classification loss.
Default: dict(type='FocalLoss').
train_cfg (dict, optional): Config for train stage. Defaults to None.
test_cfg (dict, optional): Config for test stage. Defaults to None.
init_cfg (dict, optional): Config for weight initialization.
Defaults to None.
"""
def __init__(self,
n_classes: int,
n_levels: int,
n_channels: int,
n_reg_outs: int,
pts_assign_threshold: int,
pts_center_threshold: int,
prior_generator: ConfigType,
center_loss: ConfigType = dict(
type='mmdet.CrossEntropyLoss', use_sigmoid=True),
bbox_loss: ConfigType = dict(type='RotatedIoU3DLoss'),
cls_loss: ConfigType = dict(type='mmdet.FocalLoss'),
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: OptConfigType = None):
super(NerfDetHead, self).__init__(init_cfg)
self.n_classes = n_classes
self.n_levels = n_levels
self.n_reg_outs = n_reg_outs
self.pts_assign_threshold = pts_assign_threshold
self.pts_center_threshold = pts_center_threshold
self.prior_generator = TASK_UTILS.build(prior_generator)
self.center_loss = MODELS.build(center_loss)
self.bbox_loss = MODELS.build(bbox_loss)
self.cls_loss = MODELS.build(cls_loss)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self._init_layers(n_channels, n_reg_outs, n_classes, n_levels)
def _init_layers(self, n_channels, n_reg_outs, n_classes, n_levels):
"""Initialize neural network layers of the head."""
self.conv_center = nn.Conv3d(n_channels, 1, 3, padding=1, bias=False)
self.conv_reg = nn.Conv3d(
n_channels, n_reg_outs, 3, padding=1, bias=False)
self.conv_cls = nn.Conv3d(n_channels, n_classes, 3, padding=1)
self.scales = nn.ModuleList([Scale(1.) for _ in range(n_levels)])
def init_weights(self):
"""Initialize all layer weights."""
normal_init(self.conv_center, std=.01)
normal_init(self.conv_reg, std=.01)
normal_init(self.conv_cls, std=.01, bias=bias_init_with_prob(.01))
def _forward_single(self, x: Tensor, scale: Scale):
"""Forward pass per level.
Args:
x (Tensor): Per level 3d neck output tensor.
scale (mmcv.cnn.Scale): Per level multiplication weight.
Returns:
tuple[Tensor]: Centerness, bbox and classification predictions.
"""
return (self.conv_center(x), torch.exp(scale(self.conv_reg(x))),
self.conv_cls(x))
def forward(self, x):
return multi_apply(self._forward_single, x, self.scales)
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList,
**kwargs) -> dict:
"""Perform forward propagation and loss calculation of the detection
head on the features of the upstream network.
Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
batch_data_samples (List[:obj:`NeRFDet3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
Returns:
dict: A dictionary of loss components.
"""
valid_pred = x[-1]
outs = self(x[:-1])
batch_gt_instances_3d = []
batch_gt_instances_ignore = []
batch_input_metas = []
for data_sample in batch_data_samples:
batch_input_metas.append(data_sample.metainfo)
batch_gt_instances_3d.append(data_sample.gt_instances_3d)
batch_gt_instances_ignore.append(
data_sample.get('ignored_instances', None))
loss_inputs = outs + (valid_pred, batch_gt_instances_3d,
batch_input_metas, batch_gt_instances_ignore)
losses = self.loss_by_feat(*loss_inputs)
return losses
def loss_by_feat(self,
center_preds: List[List[Tensor]],
bbox_preds: List[List[Tensor]],
cls_preds: List[List[Tensor]],
valid_pred: Tensor,
batch_gt_instances_3d: InstanceList,
batch_input_metas: List[dict],
batch_gt_instances_ignore: OptInstanceList = None,
**kwargs) -> dict:
"""Per scene loss function.
Args:
center_preds (list[list[Tensor]]): Centerness predictions for
all scenes. The first list contains predictions from different
levels. The second list contains predictions in a mini-batch.
bbox_preds (list[list[Tensor]]): Bbox predictions for all scenes.
The first list contains predictions from different
levels. The second list contains predictions in a mini-batch.
cls_preds (list[list[Tensor]]): Classification predictions for all
scenes. The first list contains predictions from different
levels. The second list contains predictions in a mini-batch.
valid_pred (Tensor): Valid mask prediction for all scenes.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instance_3d. It usually includes ``bboxes_3d``、`
`labels_3d``、``depths``、``centers_2d`` and attributes.
batch_input_metas (list[dict]): Meta information of each image,
e.g., image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict: Centerness, bbox, and classification loss values.
"""
valid_preds = self._upsample_valid_preds(valid_pred, center_preds)
center_losses, bbox_losses, cls_losses = [], [], []
for i in range(len(batch_input_metas)):
center_loss, bbox_loss, cls_loss = self._loss_by_feat_single(
center_preds=[x[i] for x in center_preds],
bbox_preds=[x[i] for x in bbox_preds],
cls_preds=[x[i] for x in cls_preds],
valid_preds=[x[i] for x in valid_preds],
input_meta=batch_input_metas[i],
gt_bboxes=batch_gt_instances_3d[i].bboxes_3d,
gt_labels=batch_gt_instances_3d[i].labels_3d)
center_losses.append(center_loss)
bbox_losses.append(bbox_loss)
cls_losses.append(cls_loss)
return dict(
center_loss=torch.mean(torch.stack(center_losses)),
bbox_loss=torch.mean(torch.stack(bbox_losses)),
cls_loss=torch.mean(torch.stack(cls_losses)))
def _loss_by_feat_single(self, center_preds, bbox_preds, cls_preds,
valid_preds, input_meta, gt_bboxes, gt_labels):
featmap_sizes = [featmap.size()[-3:] for featmap in center_preds]
points = self._get_points(
featmap_sizes=featmap_sizes,
origin=input_meta['lidar2img']['origin'],
device=gt_bboxes.device)
center_targets, bbox_targets, cls_targets = self._get_targets(
points, gt_bboxes, gt_labels)
center_preds = torch.cat(
[x.permute(1, 2, 3, 0).reshape(-1) for x in center_preds])
bbox_preds = torch.cat([
x.permute(1, 2, 3, 0).reshape(-1, x.shape[0]) for x in bbox_preds
])
cls_preds = torch.cat(
[x.permute(1, 2, 3, 0).reshape(-1, x.shape[0]) for x in cls_preds])
valid_preds = torch.cat(
[x.permute(1, 2, 3, 0).reshape(-1) for x in valid_preds])
points = torch.cat(points)
# cls loss
pos_inds = torch.nonzero(
torch.logical_and(cls_targets >= 0, valid_preds)).squeeze(1)
n_pos = points.new_tensor(len(pos_inds))
n_pos = max(reduce_mean(n_pos), 1.)
if torch.any(valid_preds):
cls_loss = self.cls_loss(
cls_preds[valid_preds],
cls_targets[valid_preds],
avg_factor=n_pos)
else:
cls_loss = cls_preds[valid_preds].sum()
# bbox and centerness losses
pos_center_preds = center_preds[pos_inds]
pos_bbox_preds = bbox_preds[pos_inds]
if len(pos_inds) > 0:
pos_center_targets = center_targets[pos_inds]
pos_bbox_targets = bbox_targets[pos_inds]
pos_points = points[pos_inds]
center_loss = self.center_loss(
pos_center_preds, pos_center_targets, avg_factor=n_pos)
bbox_loss = self.bbox_loss(
self._bbox_pred_to_bbox(pos_points, pos_bbox_preds),
pos_bbox_targets,
weight=pos_center_targets,
avg_factor=pos_center_targets.sum())
else:
center_loss = pos_center_preds.sum()
bbox_loss = pos_bbox_preds.sum()
return center_loss, bbox_loss, cls_loss
def predict(self,
x: Tuple[Tensor],
batch_data_samples: SampleList,
rescale: bool = False) -> InstanceList:
"""Perform forward propagation of the 3D detection head and predict
detection results on the features of the upstream network.
Args:
x (tuple[Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
batch_data_samples (List[:obj:`NeRFDet3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, `gt_pts_panoptic_seg` and
`gt_pts_sem_seg`.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[:obj:`InstanceData`]: Detection results of each sample
after the post process.
Each item usually contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instances, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes,
contains a tensor with shape (num_instances, C), where
C >= 6.
"""
batch_input_metas = [
data_samples.metainfo for data_samples in batch_data_samples
]
valid_pred = x[-1]
outs = self(x[:-1])
predictions = self.predict_by_feat(
*outs,
valid_pred=valid_pred,
batch_input_metas=batch_input_metas,
rescale=rescale)
return predictions
def predict_by_feat(self, center_preds: List[List[Tensor]],
bbox_preds: List[List[Tensor]],
cls_preds: List[List[Tensor]], valid_pred: Tensor,
batch_input_metas: List[dict],
**kwargs) -> List[InstanceData]:
"""Generate boxes for all scenes.
Args:
center_preds (list[list[Tensor]]): Centerness predictions for
all scenes.
bbox_preds (list[list[Tensor]]): Bbox predictions for all scenes.
cls_preds (list[list[Tensor]]): Classification predictions for all
scenes.
valid_pred (Tensor): Valid mask prediction for all scenes.
batch_input_metas (list[dict]): Meta infos for all scenes.
Returns:
list[tuple[Tensor]]: Predicted bboxes, scores, and labels for
all scenes.
"""
valid_preds = self._upsample_valid_preds(valid_pred, center_preds)
results = []
for i in range(len(batch_input_metas)):
results.append(
self._predict_by_feat_single(
center_preds=[x[i] for x in center_preds],
bbox_preds=[x[i] for x in bbox_preds],
cls_preds=[x[i] for x in cls_preds],
valid_preds=[x[i] for x in valid_preds],
input_meta=batch_input_metas[i]))
return results
def _predict_by_feat_single(self, center_preds: List[Tensor],
bbox_preds: List[Tensor],
cls_preds: List[Tensor],
valid_preds: List[Tensor],
input_meta: dict) -> InstanceData:
"""Generate boxes for single sample.
Args:
center_preds (list[Tensor]): Centerness predictions for all levels.
bbox_preds (list[Tensor]): Bbox predictions for all levels.
cls_preds (list[Tensor]): Classification predictions for all
levels.
valid_preds (tuple[Tensor]): Upsampled valid masks for all feature
levels.
input_meta (dict): Scene meta info.
Returns:
tuple[Tensor]: Predicted bounding boxes, scores and labels.
"""
featmap_sizes = [featmap.size()[-3:] for featmap in center_preds]
points = self._get_points(
featmap_sizes=featmap_sizes,
origin=input_meta['lidar2img']['origin'],
device=center_preds[0].device)
mlvl_bboxes, mlvl_scores = [], []
for center_pred, bbox_pred, cls_pred, valid_pred, point in zip(
center_preds, bbox_preds, cls_preds, valid_preds, points):
center_pred = center_pred.permute(1, 2, 3, 0).reshape(-1, 1)
bbox_pred = bbox_pred.permute(1, 2, 3,
0).reshape(-1, bbox_pred.shape[0])
cls_pred = cls_pred.permute(1, 2, 3,
0).reshape(-1, cls_pred.shape[0])
valid_pred = valid_pred.permute(1, 2, 3, 0).reshape(-1, 1)
scores = cls_pred.sigmoid() * center_pred.sigmoid() * valid_pred
max_scores, _ = scores.max(dim=1)
if len(scores) > self.test_cfg.nms_pre > 0:
_, ids = max_scores.topk(self.test_cfg.nms_pre)
bbox_pred = bbox_pred[ids]
scores = scores[ids]
point = point[ids]
bboxes = self._bbox_pred_to_bbox(point, bbox_pred)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
bboxes = torch.cat(mlvl_bboxes)
scores = torch.cat(mlvl_scores)
bboxes, scores, labels = self._nms(bboxes, scores, input_meta)
bboxes = input_meta['box_type_3d'](
bboxes, box_dim=6, with_yaw=False, origin=(.5, .5, .5))
results = InstanceData()
results.bboxes_3d = bboxes
results.scores_3d = scores
results.labels_3d = labels
return results
@staticmethod
def _upsample_valid_preds(valid_pred, features):
"""Upsample valid mask predictions.
Args:
valid_pred (Tensor): Valid mask prediction.
features (Tensor): Feature tensor.
Returns:
tuple[Tensor]: Upsampled valid masks for all feature levels.
"""
return [
nn.Upsample(size=x.shape[-3:],
mode='trilinear')(valid_pred).round().bool()
for x in features
]
@torch.no_grad()
def _get_points(self, featmap_sizes, origin, device):
mlvl_points = []
tmp_voxel_size = [.16, .16, .2]
for i, featmap_size in enumerate(featmap_sizes):
mlvl_points.append(
get_points(
n_voxels=torch.tensor(featmap_size),
voxel_size=torch.tensor(tmp_voxel_size) * (2**i),
origin=torch.tensor(origin)).reshape(3, -1).transpose(
0, 1).to(device))
return mlvl_points
def _bbox_pred_to_bbox(self, points, bbox_pred):
return torch.stack([
points[:, 0] - bbox_pred[:, 0], points[:, 1] - bbox_pred[:, 2],
points[:, 2] - bbox_pred[:, 4], points[:, 0] + bbox_pred[:, 1],
points[:, 1] + bbox_pred[:, 3], points[:, 2] + bbox_pred[:, 5]
], -1)
def _bbox_pred_to_loss(self, points, bbox_preds):
return self._bbox_pred_to_bbox(points, bbox_preds)
# The function is directly copied from FCAF3DHead.
@staticmethod
def _get_face_distances(points, boxes):
"""Calculate distances from point to box faces.
Args:
points (Tensor): Final locations of shape (N_points, N_boxes, 3).
boxes (Tensor): 3D boxes of shape (N_points, N_boxes, 7)
Returns:
Tensor: Face distances of shape (N_points, N_boxes, 6),
(dx_min, dx_max, dy_min, dy_max, dz_min, dz_max).
"""
dx_min = points[..., 0] - boxes[..., 0] + boxes[..., 3] / 2
dx_max = boxes[..., 0] + boxes[..., 3] / 2 - points[..., 0]
dy_min = points[..., 1] - boxes[..., 1] + boxes[..., 4] / 2
dy_max = boxes[..., 1] + boxes[..., 4] / 2 - points[..., 1]
dz_min = points[..., 2] - boxes[..., 2] + boxes[..., 5] / 2
dz_max = boxes[..., 2] + boxes[..., 5] / 2 - points[..., 2]
return torch.stack((dx_min, dx_max, dy_min, dy_max, dz_min, dz_max),
dim=-1)
@staticmethod
def _get_centerness(face_distances):
"""Compute point centerness w.r.t containing box.
Args:
face_distances (Tensor): Face distances of shape (B, N, 6),
(dx_min, dx_max, dy_min, dy_max, dz_min, dz_max).
Returns:
Tensor: Centerness of shape (B, N).
"""
x_dims = face_distances[..., [0, 1]]
y_dims = face_distances[..., [2, 3]]
z_dims = face_distances[..., [4, 5]]
centerness_targets = x_dims.min(dim=-1)[0] / x_dims.max(dim=-1)[0] * \
y_dims.min(dim=-1)[0] / y_dims.max(dim=-1)[0] * \
z_dims.min(dim=-1)[0] / z_dims.max(dim=-1)[0]
return torch.sqrt(centerness_targets)
@torch.no_grad()
def _get_targets(self, points, gt_bboxes, gt_labels):
"""Compute targets for final locations for a single scene.
Args:
points (list[Tensor]): Final locations for all levels.
gt_bboxes (BaseInstance3DBoxes): Ground truth boxes.
gt_labels (Tensor): Ground truth labels.
Returns:
tuple[Tensor]: Centerness, bbox and classification
targets for all locations.
"""
float_max = 1e8
expanded_scales = [
points[i].new_tensor(i).expand(len(points[i])).to(gt_labels.device)
for i in range(len(points))
]
points = torch.cat(points, dim=0).to(gt_labels.device)
scales = torch.cat(expanded_scales, dim=0)
# below is based on FCOSHead._get_target_single
n_points = len(points)
n_boxes = len(gt_bboxes)
volumes = gt_bboxes.volume.to(points.device)
volumes = volumes.expand(n_points, n_boxes).contiguous()
gt_bboxes = torch.cat(
(gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:6]), dim=1)
gt_bboxes = gt_bboxes.to(points.device).expand(n_points, n_boxes, 6)
expanded_points = points.unsqueeze(1).expand(n_points, n_boxes, 3)
bbox_targets = self._get_face_distances(expanded_points, gt_bboxes)
# condition1: inside a gt bbox
inside_gt_bbox_mask = bbox_targets[..., :6].min(
-1)[0] > 0 # skip angle
# condition2: positive points per scale >= limit
# calculate positive points per scale
n_pos_points_per_scale = []
for i in range(self.n_levels):
n_pos_points_per_scale.append(
torch.sum(inside_gt_bbox_mask[scales == i], dim=0))
# find best scale
n_pos_points_per_scale = torch.stack(n_pos_points_per_scale, dim=0)
lower_limit_mask = n_pos_points_per_scale < self.pts_assign_threshold
# fix nondeterministic argmax for torch<1.7
extra = torch.arange(self.n_levels, 0, -1).unsqueeze(1).expand(
self.n_levels, n_boxes).to(lower_limit_mask.device)
lower_index = torch.argmax(lower_limit_mask.int() * extra, dim=0) - 1
lower_index = torch.where(lower_index < 0,
torch.zeros_like(lower_index), lower_index)
all_upper_limit_mask = torch.all(
torch.logical_not(lower_limit_mask), dim=0)
best_scale = torch.where(
all_upper_limit_mask,
torch.ones_like(all_upper_limit_mask) * self.n_levels - 1,
lower_index)
# keep only points with best scale
best_scale = torch.unsqueeze(best_scale, 0).expand(n_points, n_boxes)
scales = torch.unsqueeze(scales, 1).expand(n_points, n_boxes)
inside_best_scale_mask = best_scale == scales
# condition3: limit topk locations per box by centerness
centerness = self._get_centerness(bbox_targets)
centerness = torch.where(inside_gt_bbox_mask, centerness,
torch.ones_like(centerness) * -1)
centerness = torch.where(inside_best_scale_mask, centerness,
torch.ones_like(centerness) * -1)
top_centerness = torch.topk(
centerness, self.pts_center_threshold + 1, dim=0).values[-1]
inside_top_centerness_mask = centerness > top_centerness.unsqueeze(0)
# if there are still more than one objects for a location,
# we choose the one with minimal area
volumes = torch.where(inside_gt_bbox_mask, volumes,
torch.ones_like(volumes) * float_max)
volumes = torch.where(inside_best_scale_mask, volumes,
torch.ones_like(volumes) * float_max)
volumes = torch.where(inside_top_centerness_mask, volumes,
torch.ones_like(volumes) * float_max)
min_area, min_area_inds = volumes.min(dim=1)
labels = gt_labels[min_area_inds]
labels = torch.where(min_area == float_max,
torch.ones_like(labels) * -1, labels)
bbox_targets = bbox_targets[range(n_points), min_area_inds]
centerness_targets = self._get_centerness(bbox_targets)
return centerness_targets, self._bbox_pred_to_bbox(
points, bbox_targets), labels
def _nms(self, bboxes, scores, img_meta):
scores, labels = scores.max(dim=1)
ids = scores > self.test_cfg.score_thr
bboxes = bboxes[ids]
scores = scores[ids]
labels = labels[ids]
ids = self.aligned_3d_nms(bboxes, scores, labels,
self.test_cfg.iou_thr)
bboxes = bboxes[ids]
bboxes = torch.stack(
((bboxes[:, 0] + bboxes[:, 3]) / 2.,
(bboxes[:, 1] + bboxes[:, 4]) / 2.,
(bboxes[:, 2] + bboxes[:, 5]) / 2., bboxes[:, 3] - bboxes[:, 0],
bboxes[:, 4] - bboxes[:, 1], bboxes[:, 5] - bboxes[:, 2]),
dim=1)
return bboxes, scores[ids], labels[ids]
@staticmethod
def aligned_3d_nms(boxes, scores, classes, thresh):
"""3d nms for aligned boxes.
Args:
boxes (torch.Tensor): Aligned box with shape [n, 6].
scores (torch.Tensor): Scores of each box.
classes (torch.Tensor): Class of each box.
thresh (float): Iou threshold for nms.
Returns:
torch.Tensor: Indices of selected boxes.
"""
x1 = boxes[:, 0]
y1 = boxes[:, 1]
z1 = boxes[:, 2]
x2 = boxes[:, 3]
y2 = boxes[:, 4]
z2 = boxes[:, 5]
area = (x2 - x1) * (y2 - y1) * (z2 - z1)
zero = boxes.new_zeros(1, )
score_sorted = torch.argsort(scores)
pick = []
while (score_sorted.shape[0] != 0):
last = score_sorted.shape[0]
i = score_sorted[-1]
pick.append(i)
xx1 = torch.max(x1[i], x1[score_sorted[:last - 1]])
yy1 = torch.max(y1[i], y1[score_sorted[:last - 1]])
zz1 = torch.max(z1[i], z1[score_sorted[:last - 1]])
xx2 = torch.min(x2[i], x2[score_sorted[:last - 1]])
yy2 = torch.min(y2[i], y2[score_sorted[:last - 1]])
zz2 = torch.min(z2[i], z2[score_sorted[:last - 1]])
classes1 = classes[i]
classes2 = classes[score_sorted[:last - 1]]
inter_l = torch.max(zero, xx2 - xx1)
inter_w = torch.max(zero, yy2 - yy1)
inter_h = torch.max(zero, zz2 - zz1)
inter = inter_l * inter_w * inter_h
iou = inter / (area[i] + area[score_sorted[:last - 1]] - inter)
iou = iou * (classes1 == classes2).float()
score_sorted = score_sorted[torch.nonzero(
iou <= thresh, as_tuple=False).flatten()]
indices = boxes.new_tensor(pick, dtype=torch.long)
return indices
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from os import path as osp
from typing import Callable, List, Optional, Union
import numpy as np
from mmdet3d.datasets import Det3DDataset
from mmdet3d.registry import DATASETS
from mmdet3d.structures import DepthInstance3DBoxes
@DATASETS.register_module()
class MultiViewScanNetDataset(Det3DDataset):
r"""Multi-View ScanNet Dataset for NeRF-detection Task
This class serves as the API for experiments on the ScanNet Dataset.
Please refer to the `github repo <https://github.com/ScanNet/ScanNet>`_
for data downloading.
Args:
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
pipeline (List[dict]): Pipeline used for data processing.
Defaults to [].
modality (dict): Modality to specify the sensor data used as input.
Defaults to dict(use_camera=True, use_lidar=False).
box_type_3d (str): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`.
Defaults to 'Depth' in this dataset. Available options includes:
- 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Camera': Box in camera coordinates.
filter_empty_gt (bool): Whether to filter the data with empty GT.
If it's set to be True, the example with empty annotations after
data pipeline will be dropped and a random example will be chosen
in `__getitem__`. Defaults to True.
test_mode (bool): Whether the dataset is in test mode.
Defaults to False.
"""
METAINFO = {
'classes':
('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator',
'showercurtrain', 'toilet', 'sink', 'bathtub', 'garbagebin')
}
def __init__(self,
data_root: str,
ann_file: str,
metainfo: Optional[dict] = None,
pipeline: List[Union[dict, Callable]] = [],
modality: dict = dict(use_camera=True, use_lidar=False),
box_type_3d: str = 'Depth',
filter_empty_gt: bool = True,
remove_dontcare: bool = False,
test_mode: bool = False,
**kwargs) -> None:
self.remove_dontcare = remove_dontcare
super().__init__(
data_root=data_root,
ann_file=ann_file,
metainfo=metainfo,
pipeline=pipeline,
modality=modality,
box_type_3d=box_type_3d,
filter_empty_gt=filter_empty_gt,
test_mode=test_mode,
**kwargs)
assert 'use_camera' in self.modality and \
'use_lidar' in self.modality
assert self.modality['use_camera'] or self.modality['use_lidar']
@staticmethod
def _get_axis_align_matrix(info: dict) -> np.ndarray:
"""Get axis_align_matrix from info. If not exist, return identity mat.
Args:
info (dict): Info of a single sample data.
Returns:
np.ndarray: 4x4 transformation matrix.
"""
if 'axis_align_matrix' in info:
return np.array(info['axis_align_matrix'])
else:
warnings.warn(
'axis_align_matrix is not found in ScanNet data info, please '
'use new pre-process scripts to re-generate ScanNet data')
return np.eye(4).astype(np.float32)
def parse_data_info(self, info: dict) -> dict:
"""Process the raw data info.
Convert all relative path of needed modality data file to
the absolute path.
Args:
info (dict): Raw info dict.
Returns:
dict: Has `ann_info` in training stage. And
all path has been converted to absolute path.
"""
if self.modality['use_depth']:
info['depth_info'] = []
if self.modality['use_neuralrecon_depth']:
info['depth_info'] = []
if self.modality['use_lidar']:
# implement lidar processing in the future
raise NotImplementedError(
'Please modified '
'`MultiViewPipeline` to support lidar processing')
info['axis_align_matrix'] = self._get_axis_align_matrix(info)
info['img_info'] = []
info['lidar2img'] = []
info['c2w'] = []
info['camrotc2w'] = []
info['lightpos'] = []
# load img and depth_img
for i in range(len(info['img_paths'])):
img_filename = osp.join(self.data_root, info['img_paths'][i])
info['img_info'].append(dict(filename=img_filename))
if 'depth_info' in info.keys():
if self.modality['use_neuralrecon_depth']:
info['depth_info'].append(
dict(filename=img_filename[:-4] + '.npy'))
else:
info['depth_info'].append(
dict(filename=img_filename[:-4] + '.png'))
# implement lidar_info in input.keys() in the future.
extrinsic = np.linalg.inv(
info['axis_align_matrix'] @ info['lidar2cam'][i])
info['lidar2img'].append(extrinsic.astype(np.float32))
if self.modality['use_ray']:
c2w = (
info['axis_align_matrix'] @ info['lidar2cam'][i]).astype(
np.float32) # noqa
info['c2w'].append(c2w)
info['camrotc2w'].append(c2w[0:3, 0:3])
info['lightpos'].append(c2w[0:3, 3])
origin = np.array([.0, .0, .5])
info['lidar2img'] = dict(
extrinsic=info['lidar2img'],
intrinsic=info['cam2img'].astype(np.float32),
origin=origin.astype(np.float32))
if self.modality['use_ray']:
info['ray_info'] = []
if not self.test_mode:
info['ann_info'] = self.parse_ann_info(info)
if self.test_mode and self.load_eval_anns:
info['ann_info'] = self.parse_ann_info(info)
info['eval_ann_info'] = self._remove_dontcare(info['ann_info'])
return info
def parse_ann_info(self, info: dict) -> dict:
"""Process the `instances` in data info to `ann_info`.
Args:
info (dict): Info dict.
Returns:
dict: Processed `ann_info`.
"""
ann_info = super().parse_ann_info(info)
if self.remove_dontcare:
ann_info = self._remove_dontcare(ann_info)
# empty gt
if ann_info is None:
ann_info = dict()
ann_info['gt_bboxes_3d'] = np.zeros((0, 6), dtype=np.float32)
ann_info['gt_labels_3d'] = np.zeros((0, ), dtype=np.int64)
ann_info['gt_bboxes_3d'] = DepthInstance3DBoxes(
ann_info['gt_bboxes_3d'],
box_dim=ann_info['gt_bboxes_3d'].shape[-1],
with_yaw=False,
origin=(0.5, 0.5, 0.5)).convert_to(self.box_mode_3d)
# count the numbers
for label in ann_info['gt_labels_3d']:
if label != -1:
cat_name = self.metainfo['classes'][label]
self.num_ins_per_cat[cat_name] += 1
return ann_info
# Copyright (c) OpenMMLab. All rights reserved.
"""Prepare the dataset for NeRF-Det.
Example:
python projects/NeRF-Det/prepare_infos.py
--root-path ./data/scannet
--out-dir ./data/scannet
"""
import argparse
import time
from os import path as osp
from pathlib import Path
import mmengine
from ...tools.dataset_converters import indoor_converter as indoor
from ...tools.dataset_converters.update_infos_to_v2 import (
clear_data_info_unused_keys, clear_instance_unused_keys,
get_empty_instance, get_empty_standard_data_info)
def update_scannet_infos_nerfdet(pkl_path, out_dir):
"""Update the origin pkl to the new format which will be used in nerf-det.
Args:
pkl_path (str): Path of the origin pkl.
out_dir (str): Output directory of the generated info file.
Returns:
The pkl will be overwritTen.
The new pkl is a dict containing two keys:
metainfo: Some base information of the pkl
data_list (list): A list containing all the information of the scenes.
"""
print('The new refactored process is running.')
print(f'{pkl_path} will be modified.')
if out_dir in pkl_path:
print(f'Warning, you may overwriting '
f'the original data {pkl_path}.')
time.sleep(5)
METAINFO = {
'classes':
('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator',
'showercurtrain', 'toilet', 'sink', 'bathtub', 'garbagebin')
}
print(f'Reading from input file: {pkl_path}.')
data_list = mmengine.load(pkl_path)
print('Start updating:')
converted_list = []
for ori_info_dict in mmengine.track_iter_progress(data_list):
temp_data_info = get_empty_standard_data_info()
# intrinsics, extrinsics and imgs
temp_data_info['cam2img'] = ori_info_dict['intrinsics']
temp_data_info['lidar2cam'] = ori_info_dict['extrinsics']
temp_data_info['img_paths'] = ori_info_dict['img_paths']
# annotation information
anns = ori_info_dict.get('annos', None)
ignore_class_name = set()
if anns is not None:
temp_data_info['axis_align_matrix'] = anns[
'axis_align_matrix'].tolist()
if anns['gt_num'] == 0:
instance_list = []
else:
num_instances = len(anns['name'])
instance_list = []
for instance_id in range(num_instances):
empty_instance = get_empty_instance()
empty_instance['bbox_3d'] = anns['gt_boxes_upright_depth'][
instance_id].tolist()
if anns['name'][instance_id] in METAINFO['classes']:
empty_instance['bbox_label_3d'] = METAINFO[
'classes'].index(anns['name'][instance_id])
else:
ignore_class_name.add(anns['name'][instance_id])
empty_instance['bbox_label_3d'] = -1
empty_instance = clear_instance_unused_keys(empty_instance)
instance_list.append(empty_instance)
temp_data_info['instances'] = instance_list
temp_data_info, _ = clear_data_info_unused_keys(temp_data_info)
converted_list.append(temp_data_info)
pkl_name = Path(pkl_path).name
out_path = osp.join(out_dir, pkl_name)
print(f'Writing to output file: {out_path}.')
print(f'ignore classes: {ignore_class_name}')
# dataset metainfo
metainfo = dict()
metainfo['categories'] = {k: i for i, k in enumerate(METAINFO['classes'])}
if ignore_class_name:
for ignore_class in ignore_class_name:
metainfo['categories'][ignore_class] = -1
metainfo['dataset'] = 'scannet'
metainfo['info_version'] = '1.1'
converted_data_info = dict(metainfo=metainfo, data_list=converted_list)
mmengine.dump(converted_data_info, out_path, 'pkl')
def scannet_data_prep(root_path, info_prefix, out_dir, workers):
"""Prepare the info file for scannet dataset.
Args:
root_path (str): Path of dataset root.
info_prefix (str): The prefix of info filenames.
out_dir (str): Output directory of the generated info file.
workers (int): Number of threads to be used.
version (str): Only used to generate the dataset of nerfdet now.
"""
indoor.create_indoor_info_file(
root_path, info_prefix, out_dir, workers=workers)
info_train_path = osp.join(out_dir, f'{info_prefix}_infos_train.pkl')
info_val_path = osp.join(out_dir, f'{info_prefix}_infos_val.pkl')
info_test_path = osp.join(out_dir, f'{info_prefix}_infos_test.pkl')
update_scannet_infos_nerfdet(out_dir=out_dir, pkl_path=info_train_path)
update_scannet_infos_nerfdet(out_dir=out_dir, pkl_path=info_val_path)
update_scannet_infos_nerfdet(out_dir=out_dir, pkl_path=info_test_path)
parser = argparse.ArgumentParser(description='Data converter arg parser')
parser.add_argument(
'--root-path',
type=str,
default='./data/scannet',
help='specify the root path of dataset')
parser.add_argument(
'--out-dir',
type=str,
default='./data/scannet',
required=False,
help='name of info pkl')
parser.add_argument('--extra-tag', type=str, default='scannet')
parser.add_argument(
'--workers', type=int, default=4, help='number of threads to be used')
args = parser.parse_args()
if __name__ == '__main__':
from mmdet3d.utils import register_all_modules
register_all_modules()
scannet_data_prep(
root_path=args.root_path,
info_prefix=args.extra_tag,
out_dir=args.out_dir,
workers=args.workers)
......@@ -16,7 +16,7 @@ This is an implementation of *PETR*.
In MMDet3D's root directory, run the following command to train the model:
```bash
python tools/train.py projects/PETR/config/petr/petr_vovnet_gridmask_p4_800x320.py
python tools/train.py projects/PETR/configs/petr_vovnet_gridmask_p4_800x320.py
```
### Testing commands
......@@ -24,7 +24,7 @@ python tools/train.py projects/PETR/config/petr/petr_vovnet_gridmask_p4_800x320.
In MMDet3D's root directory, run the following command to test the model:
```bash
python tools/test.py projects/PETR/config/petr/petr_vovnet_gridmask_p4_800x320.py ${CHECKPOINT_PATH}
python tools/test.py projects/PETR/configs/petr_vovnet_gridmask_p4_800x320.py ${CHECKPOINT_PATH}
```
## Results
......
......@@ -446,7 +446,7 @@ class PETRHead(AnchorFreeHead):
masks = x.new_ones((batch_size, num_cams, input_img_h, input_img_w))
for img_id in range(batch_size):
for cam_id in range(num_cams):
img_h, img_w, _ = img_metas[img_id]['img_shape'][cam_id]
img_h, img_w = img_metas[img_id]['img_shape'][cam_id]
masks[img_id, cam_id, :img_h, :img_w] = 0
x = self.input_proj(x.flatten(0, 1))
x = x.view(batch_size, num_cams, *x.shape[-3:])
......
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from mmcv.transforms.base import BaseTransform
from mmengine.registry import TRANSFORMS
from mmengine.structures import InstanceData
from mmdet3d.datasets import WaymoDataset
from mmdet3d.structures import Det3DDataSample, LiDARInstance3DBoxes
def _generate_waymo_dataset_config():
data_root = 'tests/data/waymo/kitti_format'
ann_file = 'waymo_infos_train.pkl'
classes = ['Car', 'Pedestrian', 'Cyclist']
# wait for pipline refactor
if 'Identity' not in TRANSFORMS:
@TRANSFORMS.register_module()
class Identity(BaseTransform):
def transform(self, info):
if 'ann_info' in info:
info['gt_labels_3d'] = info['ann_info']['gt_labels_3d']
data_sample = Det3DDataSample()
gt_instances_3d = InstanceData()
gt_instances_3d.labels_3d = info['gt_labels_3d']
data_sample.gt_instances_3d = gt_instances_3d
info['data_samples'] = data_sample
return info
pipeline = [
dict(type='Identity'),
]
modality = dict(use_lidar=True, use_camera=True)
data_prefix = data_prefix = dict(
pts='training/velodyne', CAM_FRONT='training/image_0')
return data_root, ann_file, classes, data_prefix, pipeline, modality
def test_getitem():
data_root, ann_file, classes, data_prefix, \
pipeline, modality, = _generate_waymo_dataset_config()
waymo_dataset = WaymoDataset(
data_root,
ann_file,
data_prefix=data_prefix,
pipeline=pipeline,
metainfo=dict(classes=classes),
modality=modality)
waymo_dataset.prepare_data(0)
input_dict = waymo_dataset.get_data_info(0)
waymo_dataset[0]
# assert the the path should contains data_prefix and data_root
assert data_prefix['pts'] in input_dict['lidar_points']['lidar_path']
assert data_root in input_dict['lidar_points']['lidar_path']
for cam_id, img_info in input_dict['images'].items():
if 'img_path' in img_info:
assert data_prefix['CAM_FRONT'] in img_info['img_path']
assert data_root in img_info['img_path']
ann_info = waymo_dataset.parse_ann_info(input_dict)
# only one instance
assert 'gt_labels_3d' in ann_info
assert ann_info['gt_labels_3d'].dtype == np.int64
assert 'gt_bboxes_3d' in ann_info
assert isinstance(ann_info['gt_bboxes_3d'], LiDARInstance3DBoxes)
assert torch.allclose(ann_info['gt_bboxes_3d'].tensor.sum(),
torch.tensor(43.3103))
assert 'centers_2d' in ann_info
assert ann_info['centers_2d'].dtype == np.float32
assert 'depths' in ann_info
assert ann_info['depths'].dtype == np.float32
......@@ -2,6 +2,8 @@
import argparse
from os import path as osp
from mmengine import print_log
from tools.dataset_converters import indoor_converter as indoor
from tools.dataset_converters import kitti_converter as kitti
from tools.dataset_converters import lyft_converter as lyft_converter
......@@ -171,8 +173,19 @@ def waymo_data_prep(root_path,
version,
out_dir,
workers,
max_sweeps=5):
"""Prepare the info file for waymo dataset.
max_sweeps=10,
only_gt_database=False,
save_senor_data=False,
skip_cam_instances_infos=False):
"""Prepare waymo dataset. There are 3 steps as follows:
Step 1. Extract camera images and lidar point clouds from waymo raw
data in '*.tfreord' and save as kitti format.
Step 2. Generate waymo train/val/test infos and save as pickle file.
Step 3. Generate waymo ground truth database (point clouds within
each 3D bounding box) for data augmentation in training.
Steps 1 and 2 will be done in Waymo2KITTI, and step 3 will be done in
GTDatabaseCreater.
Args:
root_path (str): Path of dataset root.
......@@ -180,44 +193,55 @@ def waymo_data_prep(root_path,
out_dir (str): Output directory of the generated info file.
workers (int): Number of threads to be used.
max_sweeps (int, optional): Number of input consecutive frames.
Default: 5. Here we store pose information of these frames
for later use.
Default to 10. Here we store ego2global information of these
frames for later use.
only_gt_database (bool, optional): Whether to only generate ground
truth database. Default to False.
save_senor_data (bool, optional): Whether to skip saving
image and lidar. Default to False.
skip_cam_instances_infos (bool, optional): Whether to skip
gathering cam_instances infos in Step 2. Default to False.
"""
from tools.dataset_converters import waymo_converter as waymo
splits = [
'training', 'validation', 'testing', 'testing_3d_camera_only_detection'
]
for i, split in enumerate(splits):
load_dir = osp.join(root_path, 'waymo_format', split)
if split == 'validation':
save_dir = osp.join(out_dir, 'kitti_format', 'training')
else:
save_dir = osp.join(out_dir, 'kitti_format', split)
converter = waymo.Waymo2KITTI(
load_dir,
save_dir,
prefix=str(i),
workers=workers,
test_mode=(split
in ['testing', 'testing_3d_camera_only_detection']))
converter.convert()
from tools.dataset_converters.waymo_converter import \
create_ImageSets_img_ids
create_ImageSets_img_ids(osp.join(out_dir, 'kitti_format'), splits)
# Generate waymo infos
if version == 'v1.4':
splits = [
'training', 'validation', 'testing',
'testing_3d_camera_only_detection'
]
elif version == 'v1.4-mini':
splits = ['training', 'validation']
else:
raise NotImplementedError(f'Unsupported Waymo version {version}!')
out_dir = osp.join(out_dir, 'kitti_format')
kitti.create_waymo_info_file(
out_dir, info_prefix, max_sweeps=max_sweeps, workers=workers)
info_train_path = osp.join(out_dir, f'{info_prefix}_infos_train.pkl')
info_val_path = osp.join(out_dir, f'{info_prefix}_infos_val.pkl')
info_trainval_path = osp.join(out_dir, f'{info_prefix}_infos_trainval.pkl')
info_test_path = osp.join(out_dir, f'{info_prefix}_infos_test.pkl')
update_pkl_infos('waymo', out_dir=out_dir, pkl_path=info_train_path)
update_pkl_infos('waymo', out_dir=out_dir, pkl_path=info_val_path)
update_pkl_infos('waymo', out_dir=out_dir, pkl_path=info_trainval_path)
update_pkl_infos('waymo', out_dir=out_dir, pkl_path=info_test_path)
if not only_gt_database:
for i, split in enumerate(splits):
load_dir = osp.join(root_path, 'waymo_format', split)
if split == 'validation':
save_dir = osp.join(out_dir, 'training')
else:
save_dir = osp.join(out_dir, split)
converter = waymo.Waymo2KITTI(
load_dir,
save_dir,
prefix=str(i),
workers=workers,
test_mode=(split
in ['testing', 'testing_3d_camera_only_detection']),
info_prefix=info_prefix,
max_sweeps=max_sweeps,
split=split,
save_senor_data=save_senor_data,
save_cam_instances=not skip_cam_instances_infos)
converter.convert()
if split == 'validation':
converter.merge_trainval_infos()
from tools.dataset_converters.waymo_converter import \
create_ImageSets_img_ids
create_ImageSets_img_ids(out_dir, splits)
GTDatabaseCreater(
'WaymoDataset',
out_dir,
......@@ -227,6 +251,8 @@ def waymo_data_prep(root_path,
with_mask=False,
num_worker=workers).create()
print_log('Successfully preparing Waymo Open Dataset')
def semantickitti_data_prep(info_prefix, out_dir):
"""Prepare the info file for SemanticKITTI dataset.
......@@ -274,12 +300,23 @@ parser.add_argument(
parser.add_argument(
'--only-gt-database',
action='store_true',
help='Whether to only generate ground truth database.')
help='''Whether to only generate ground truth database.
Only used when dataset is NuScenes or Waymo!''')
parser.add_argument(
'--skip-cam_instances-infos',
action='store_true',
help='''Whether to skip gathering cam_instances infos.
Only used when dataset is Waymo!''')
parser.add_argument(
'--skip-saving-sensor-data',
action='store_true',
help='''Whether to skip saving image and lidar.
Only used when dataset is Waymo!''')
args = parser.parse_args()
if __name__ == '__main__':
from mmdet3d.utils import register_all_modules
register_all_modules()
from mmengine.registry import init_default_scope
init_default_scope('mmdet3d')
if args.dataset == 'kitti':
if args.only_gt_database:
......@@ -334,6 +371,17 @@ if __name__ == '__main__':
dataset_name='NuScenesDataset',
out_dir=args.out_dir,
max_sweeps=args.max_sweeps)
elif args.dataset == 'waymo':
waymo_data_prep(
root_path=args.root_path,
info_prefix=args.extra_tag,
version=args.version,
out_dir=args.out_dir,
workers=args.workers,
max_sweeps=args.max_sweeps,
only_gt_database=args.only_gt_database,
save_senor_data=not args.skip_saving_sensor_data,
skip_cam_instances_infos=args.skip_cam_instances_infos)
elif args.dataset == 'lyft':
train_version = f'{args.version}-train'
lyft_data_prep(
......@@ -347,14 +395,6 @@ if __name__ == '__main__':
info_prefix=args.extra_tag,
version=test_version,
max_sweeps=args.max_sweeps)
elif args.dataset == 'waymo':
waymo_data_prep(
root_path=args.root_path,
info_prefix=args.extra_tag,
version=args.version,
out_dir=args.out_dir,
workers=args.workers,
max_sweeps=args.max_sweeps)
elif args.dataset == 'scannet':
scannet_data_prep(
root_path=args.root_path,
......
......@@ -6,10 +6,11 @@ export PYTHONPATH=`pwd`:$PYTHONPATH
PARTITION=$1
JOB_NAME=$2
DATASET=$3
WORKERS=$4
GPUS=${GPUS:-1}
GPUS_PER_NODE=${GPUS_PER_NODE:-1}
SRUN_ARGS=${SRUN_ARGS:-""}
JOB_NAME=create_data
PY_ARGS=${@:5}
srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
......@@ -21,4 +22,6 @@ srun -p ${PARTITION} \
python -u tools/create_data.py ${DATASET} \
--root-path ./data/${DATASET} \
--out-dir ./data/${DATASET} \
--extra-tag ${DATASET}
--workers ${WORKERS} \
--extra-tag ${DATASET} \
${PY_ARGS}
......@@ -7,7 +7,7 @@ import mmengine
import numpy as np
from mmcv.ops import roi_align
from mmdet.evaluation import bbox_overlaps
from mmengine import track_iter_progress
from mmengine import print_log, track_iter_progress
from pycocotools import mask as maskUtils
from pycocotools.coco import COCO
......@@ -504,7 +504,9 @@ class GTDatabaseCreater:
return single_db_infos
def create(self):
print(f'Create GT Database of {self.dataset_class_name}')
print_log(
f'Create GT Database of {self.dataset_class_name}',
logger='current')
dataset_cfg = dict(
type=self.dataset_class_name,
data_root=self.data_path,
......@@ -610,12 +612,19 @@ class GTDatabaseCreater:
input_dict['box_mode_3d'] = self.dataset.box_mode_3d
return input_dict
multi_db_infos = mmengine.track_parallel_progress(
self.create_single,
((loop_dataset(i)
for i in range(len(self.dataset))), len(self.dataset)),
self.num_worker)
print('Make global unique group id')
if self.num_worker == 0:
multi_db_infos = mmengine.track_progress(
self.create_single,
((loop_dataset(i)
for i in range(len(self.dataset))), len(self.dataset)))
else:
multi_db_infos = mmengine.track_parallel_progress(
self.create_single,
((loop_dataset(i)
for i in range(len(self.dataset))), len(self.dataset)),
self.num_worker,
chunksize=1000)
print_log('Make global unique group id', logger='current')
group_counter_offset = 0
all_db_infos = dict()
for single_db_infos in track_iter_progress(multi_db_infos):
......@@ -630,7 +639,8 @@ class GTDatabaseCreater:
group_counter_offset += (group_id + 1)
for k, v in all_db_infos.items():
print(f'load {len(v)} {k} database infos')
print_log(f'load {len(v)} {k} database infos', logger='current')
print_log(f'Saving GT database infos into {self.db_info_save_path}')
with open(self.db_info_save_path, 'wb') as f:
pickle.dump(all_db_infos, f)
......@@ -9,23 +9,33 @@ except ImportError:
raise ImportError('Please run "pip install waymo-open-dataset-tf-2-6-0" '
'>1.4.5 to install the official devkit first.')
import copy
import os
import os.path as osp
from glob import glob
from io import BytesIO
from os.path import exists, join
import mmengine
import numpy as np
import tensorflow as tf
from mmengine import print_log
from nuscenes.utils.geometry_utils import view_points
from PIL import Image
from waymo_open_dataset.utils import range_image_utils, transform_utils
from waymo_open_dataset.utils.frame_utils import \
parse_range_image_and_camera_projection
from mmdet3d.datasets.convert_utils import post_process_coords
from mmdet3d.structures import Box3DMode, LiDARInstance3DBoxes, points_cam2img
class Waymo2KITTI(object):
"""Waymo to KITTI converter.
"""Waymo to KITTI converter. There are 2 steps as follows:
This class serves as the converter to change the waymo raw data to KITTI
format.
Step 1. Extract camera images and lidar point clouds from waymo raw data in
'*.tfreord' and save as kitti format.
Step 2. Generate waymo train/val/test infos and save as pickle file.
Args:
load_dir (str): Directory to load waymo raw data.
......@@ -36,8 +46,16 @@ class Waymo2KITTI(object):
Defaults to 64.
test_mode (bool, optional): Whether in the test_mode.
Defaults to False.
save_cam_sync_labels (bool, optional): Whether to save cam sync labels.
Defaults to True.
save_senor_data (bool, optional): Whether to save image and lidar
data. Defaults to True.
save_cam_sync_instances (bool, optional): Whether to save cam sync
instances. Defaults to True.
save_cam_instances (bool, optional): Whether to save cam instances.
Defaults to False.
info_prefix (str, optional): Prefix of info filename.
Defaults to 'waymo'.
max_sweeps (int, optional): Max length of sweeps. Defaults to 10.
split (str, optional): Split of the data. Defaults to 'training'.
"""
def __init__(self,
......@@ -46,18 +64,12 @@ class Waymo2KITTI(object):
prefix,
workers=64,
test_mode=False,
save_cam_sync_labels=True):
self.filter_empty_3dboxes = True
self.filter_no_label_zone_points = True
self.selected_waymo_classes = ['VEHICLE', 'PEDESTRIAN', 'CYCLIST']
# Only data collected in specific locations will be converted
# If set None, this filter is disabled
# Available options: location_sf (main dataset)
self.selected_waymo_locations = None
self.save_track_id = False
save_senor_data=True,
save_cam_sync_instances=True,
save_cam_instances=True,
info_prefix='waymo',
max_sweeps=10,
split='training'):
# turn on eager execution for older tensorflow versions
if int(tf.__version__.split('.')[0]) < 2:
tf.enable_eager_execution()
......@@ -74,12 +86,21 @@ class Waymo2KITTI(object):
self.type_list = [
'UNKNOWN', 'VEHICLE', 'PEDESTRIAN', 'SIGN', 'CYCLIST'
]
self.waymo_to_kitti_class_map = {
'UNKNOWN': 'DontCare',
'PEDESTRIAN': 'Pedestrian',
'VEHICLE': 'Car',
'CYCLIST': 'Cyclist',
'SIGN': 'Sign' # not in kitti
# MMDetection3D unified camera keys & class names
self.camera_types = [
'CAM_FRONT',
'CAM_FRONT_LEFT',
'CAM_FRONT_RIGHT',
'CAM_SIDE_LEFT',
'CAM_SIDE_RIGHT',
]
self.selected_waymo_classes = ['VEHICLE', 'PEDESTRIAN', 'CYCLIST']
self.info_map = {
'training': '_infos_train.pkl',
'validation': '_infos_val.pkl',
'testing': '_infos_test.pkl',
'testing_3d_camera_only_detection': '_infos_test_cam_only.pkl'
}
self.load_dir = load_dir
......@@ -87,61 +108,87 @@ class Waymo2KITTI(object):
self.prefix = prefix
self.workers = int(workers)
self.test_mode = test_mode
self.save_cam_sync_labels = save_cam_sync_labels
self.save_senor_data = save_senor_data
self.save_cam_sync_instances = save_cam_sync_instances
self.save_cam_instances = save_cam_instances
self.info_prefix = info_prefix
self.max_sweeps = max_sweeps
self.split = split
# TODO: Discuss filter_empty_3dboxes and filter_no_label_zone_points
self.filter_empty_3dboxes = True
self.filter_no_label_zone_points = True
self.save_track_id = False
self.tfrecord_pathnames = sorted(
glob(join(self.load_dir, '*.tfrecord')))
self.label_save_dir = f'{self.save_dir}/label_'
self.label_all_save_dir = f'{self.save_dir}/label_all'
self.image_save_dir = f'{self.save_dir}/image_'
self.calib_save_dir = f'{self.save_dir}/calib'
self.point_cloud_save_dir = f'{self.save_dir}/velodyne'
self.pose_save_dir = f'{self.save_dir}/pose'
self.timestamp_save_dir = f'{self.save_dir}/timestamp'
if self.save_cam_sync_labels:
self.cam_sync_label_save_dir = f'{self.save_dir}/cam_sync_label_'
self.cam_sync_label_all_save_dir = \
f'{self.save_dir}/cam_sync_label_all'
self.create_folder()
# Create folder for saving KITTI format camera images and
# lidar point clouds.
if 'testing_3d_camera_only_detection' not in self.load_dir:
mmengine.mkdir_or_exist(self.point_cloud_save_dir)
for i in range(5):
mmengine.mkdir_or_exist(f'{self.image_save_dir}{str(i)}')
def convert(self):
"""Convert action."""
print('Start converting ...')
mmengine.track_parallel_progress(self.convert_one, range(len(self)),
self.workers)
print('\nFinished ...')
print_log(f'Start converting {self.split} dataset', logger='current')
if self.workers == 0:
data_infos = mmengine.track_progress(self.convert_one,
range(len(self)))
else:
data_infos = mmengine.track_parallel_progress(
self.convert_one, range(len(self)), self.workers)
data_list = []
for data_info in data_infos:
data_list.extend(data_info)
metainfo = dict()
metainfo['dataset'] = 'waymo'
metainfo['version'] = 'waymo_v1.4'
metainfo['info_version'] = 'mmdet3d_v1.4'
waymo_infos = dict(data_list=data_list, metainfo=metainfo)
filenames = osp.join(
osp.dirname(self.save_dir),
f'{self.info_prefix + self.info_map[self.split]}')
print_log(f'Saving {self.split} dataset infos into {filenames}')
mmengine.dump(waymo_infos, filenames)
def convert_one(self, file_idx):
"""Convert action for single file.
"""Convert one '*.tfrecord' file to kitti format. Each file stores all
the frames (about 200 frames) in current scene. We treat each frame as
a sample, save their images and point clouds in kitti format, and then
create info for all frames.
Args:
file_idx (int): Index of the file to be converted.
Returns:
List[dict]: Waymo infos for all frames in current file.
"""
pathname = self.tfrecord_pathnames[file_idx]
dataset = tf.data.TFRecordDataset(pathname, compression_type='')
# NOTE: file_infos is not shared between processes, only stores frame
# infos within the current file.
file_infos = []
for frame_idx, data in enumerate(dataset):
frame = dataset_pb2.Frame()
frame.ParseFromString(bytearray(data.numpy()))
if (self.selected_waymo_locations is not None
and frame.context.stats.location
not in self.selected_waymo_locations):
continue
self.save_image(frame, file_idx, frame_idx)
self.save_calib(frame, file_idx, frame_idx)
self.save_lidar(frame, file_idx, frame_idx)
self.save_pose(frame, file_idx, frame_idx)
self.save_timestamp(frame, file_idx, frame_idx)
# Step 1. Extract camera images and lidar point clouds from waymo
# raw data in '*.tfreord' and save as kitti format.
if self.save_senor_data:
self.save_image(frame, file_idx, frame_idx)
self.save_lidar(frame, file_idx, frame_idx)
if not self.test_mode:
# TODO save the depth image for waymo challenge solution.
self.save_label(frame, file_idx, frame_idx)
if self.save_cam_sync_labels:
self.save_label(frame, file_idx, frame_idx, cam_sync=True)
# Step 2. Generate waymo train/val/test infos and save as pkl file.
# TODO save the depth image for waymo challenge solution.
self.create_waymo_info_file(frame, file_idx, frame_idx, file_infos)
return file_infos
def __len__(self):
"""Length of the filename list."""
......@@ -162,62 +209,6 @@ class Waymo2KITTI(object):
with open(img_path, 'wb') as fp:
fp.write(img.image)
def save_calib(self, frame, file_idx, frame_idx):
"""Parse and save the calibration data.
Args:
frame (:obj:`Frame`): Open dataset frame proto.
file_idx (int): Current file index.
frame_idx (int): Current frame index.
"""
# waymo front camera to kitti reference camera
T_front_cam_to_ref = np.array([[0.0, -1.0, 0.0], [0.0, 0.0, -1.0],
[1.0, 0.0, 0.0]])
camera_calibs = []
R0_rect = [f'{i:e}' for i in np.eye(3).flatten()]
Tr_velo_to_cams = []
calib_context = ''
for camera in frame.context.camera_calibrations:
# extrinsic parameters
T_cam_to_vehicle = np.array(camera.extrinsic.transform).reshape(
4, 4)
T_vehicle_to_cam = np.linalg.inv(T_cam_to_vehicle)
Tr_velo_to_cam = \
self.cart_to_homo(T_front_cam_to_ref) @ T_vehicle_to_cam
if camera.name == 1: # FRONT = 1, see dataset.proto for details
self.T_velo_to_front_cam = Tr_velo_to_cam.copy()
Tr_velo_to_cam = Tr_velo_to_cam[:3, :].reshape((12, ))
Tr_velo_to_cams.append([f'{i:e}' for i in Tr_velo_to_cam])
# intrinsic parameters
camera_calib = np.zeros((3, 4))
camera_calib[0, 0] = camera.intrinsic[0]
camera_calib[1, 1] = camera.intrinsic[1]
camera_calib[0, 2] = camera.intrinsic[2]
camera_calib[1, 2] = camera.intrinsic[3]
camera_calib[2, 2] = 1
camera_calib = list(camera_calib.reshape(12))
camera_calib = [f'{i:e}' for i in camera_calib]
camera_calibs.append(camera_calib)
# all camera ids are saved as id-1 in the result because
# camera 0 is unknown in the proto
for i in range(5):
calib_context += 'P' + str(i) + ': ' + \
' '.join(camera_calibs[i]) + '\n'
calib_context += 'R0_rect' + ': ' + ' '.join(R0_rect) + '\n'
for i in range(5):
calib_context += 'Tr_velo_to_cam_' + str(i) + ': ' + \
' '.join(Tr_velo_to_cams[i]) + '\n'
with open(
f'{self.calib_save_dir}/{self.prefix}' +
f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt',
'w+') as fp_calib:
fp_calib.write(calib_context)
fp_calib.close()
def save_lidar(self, frame, file_idx, frame_idx):
"""Parse and save the lidar data in psd format.
......@@ -275,194 +266,6 @@ class Waymo2KITTI(object):
f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.bin'
point_cloud.astype(np.float32).tofile(pc_path)
def save_label(self, frame, file_idx, frame_idx, cam_sync=False):
"""Parse and save the label data in txt format.
The relation between waymo and kitti coordinates is noteworthy:
1. x, y, z correspond to l, w, h (waymo) -> l, h, w (kitti)
2. x-y-z: front-left-up (waymo) -> right-down-front(kitti)
3. bbox origin at volumetric center (waymo) -> bottom center (kitti)
4. rotation: +x around y-axis (kitti) -> +x around z-axis (waymo)
Args:
frame (:obj:`Frame`): Open dataset frame proto.
file_idx (int): Current file index.
frame_idx (int): Current frame index.
cam_sync (bool, optional): Whether to save the cam sync labels.
Defaults to False.
"""
label_all_path = f'{self.label_all_save_dir}/{self.prefix}' + \
f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt'
if cam_sync:
label_all_path = label_all_path.replace('label_',
'cam_sync_label_')
fp_label_all = open(label_all_path, 'w+')
id_to_bbox = dict()
id_to_name = dict()
for labels in frame.projected_lidar_labels:
name = labels.name
for label in labels.labels:
# TODO: need a workaround as bbox may not belong to front cam
bbox = [
label.box.center_x - label.box.length / 2,
label.box.center_y - label.box.width / 2,
label.box.center_x + label.box.length / 2,
label.box.center_y + label.box.width / 2
]
id_to_bbox[label.id] = bbox
id_to_name[label.id] = name - 1
for obj in frame.laser_labels:
bounding_box = None
name = None
id = obj.id
for proj_cam in self.cam_list:
if id + proj_cam in id_to_bbox:
bounding_box = id_to_bbox.get(id + proj_cam)
name = str(id_to_name.get(id + proj_cam))
break
# NOTE: the 2D labels do not have strict correspondence with
# the projected 2D lidar labels
# e.g.: the projected 2D labels can be in camera 2
# while the most_visible_camera can have id 4
if cam_sync:
if obj.most_visible_camera_name:
name = str(
self.cam_list.index(
f'_{obj.most_visible_camera_name}'))
box3d = obj.camera_synced_box
else:
continue
else:
box3d = obj.box
if bounding_box is None or name is None:
name = '0'
bounding_box = (0, 0, 0, 0)
my_type = self.type_list[obj.type]
if my_type not in self.selected_waymo_classes:
continue
if self.filter_empty_3dboxes and obj.num_lidar_points_in_box < 1:
continue
my_type = self.waymo_to_kitti_class_map[my_type]
height = box3d.height
width = box3d.width
length = box3d.length
x = box3d.center_x
y = box3d.center_y
z = box3d.center_z - height / 2
# project bounding box to the virtual reference frame
pt_ref = self.T_velo_to_front_cam @ \
np.array([x, y, z, 1]).reshape((4, 1))
x, y, z, _ = pt_ref.flatten().tolist()
rotation_y = -box3d.heading - np.pi / 2
track_id = obj.id
# not available
truncated = 0
occluded = 0
alpha = -10
line = my_type + \
' {} {} {} {} {} {} {} {} {} {} {} {} {} {}\n'.format(
round(truncated, 2), occluded, round(alpha, 2),
round(bounding_box[0], 2), round(bounding_box[1], 2),
round(bounding_box[2], 2), round(bounding_box[3], 2),
round(height, 2), round(width, 2), round(length, 2),
round(x, 2), round(y, 2), round(z, 2),
round(rotation_y, 2))
if self.save_track_id:
line_all = line[:-1] + ' ' + name + ' ' + track_id + '\n'
else:
line_all = line[:-1] + ' ' + name + '\n'
label_path = f'{self.label_save_dir}{name}/{self.prefix}' + \
f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt'
if cam_sync:
label_path = label_path.replace('label_', 'cam_sync_label_')
fp_label = open(label_path, 'a')
fp_label.write(line)
fp_label.close()
fp_label_all.write(line_all)
fp_label_all.close()
def save_pose(self, frame, file_idx, frame_idx):
"""Parse and save the pose data.
Note that SDC's own pose is not included in the regular training
of KITTI dataset. KITTI raw dataset contains ego motion files
but are not often used. Pose is important for algorithms that
take advantage of the temporal information.
Args:
frame (:obj:`Frame`): Open dataset frame proto.
file_idx (int): Current file index.
frame_idx (int): Current frame index.
"""
pose = np.array(frame.pose.transform).reshape(4, 4)
np.savetxt(
join(f'{self.pose_save_dir}/{self.prefix}' +
f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt'),
pose)
def save_timestamp(self, frame, file_idx, frame_idx):
"""Save the timestamp data in a separate file instead of the
pointcloud.
Note that SDC's own pose is not included in the regular training
of KITTI dataset. KITTI raw dataset contains ego motion files
but are not often used. Pose is important for algorithms that
take advantage of the temporal information.
Args:
frame (:obj:`Frame`): Open dataset frame proto.
file_idx (int): Current file index.
frame_idx (int): Current frame index.
"""
with open(
join(f'{self.timestamp_save_dir}/{self.prefix}' +
f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt'),
'w') as f:
f.write(str(frame.timestamp_micros))
def create_folder(self):
"""Create folder for data preprocessing."""
if not self.test_mode:
dir_list1 = [
self.label_all_save_dir,
self.calib_save_dir,
self.pose_save_dir,
self.timestamp_save_dir,
]
dir_list2 = [self.label_save_dir, self.image_save_dir]
if self.save_cam_sync_labels:
dir_list1.append(self.cam_sync_label_all_save_dir)
dir_list2.append(self.cam_sync_label_save_dir)
else:
dir_list1 = [
self.calib_save_dir, self.pose_save_dir,
self.timestamp_save_dir
]
dir_list2 = [self.image_save_dir]
if 'testing_3d_camera_only_detection' not in self.load_dir:
dir_list1.append(self.point_cloud_save_dir)
for d in dir_list1:
mmengine.mkdir_or_exist(d)
for d in dir_list2:
for i in range(5):
mmengine.mkdir_or_exist(f'{d}{str(i)}')
def convert_range_image_to_point_cloud(self,
frame,
range_images,
......@@ -604,29 +407,317 @@ class Waymo2KITTI(object):
raise ValueError(mat.shape)
return ret
def create_waymo_info_file(self, frame, file_idx, frame_idx, file_infos):
r"""Generate waymo train/val/test infos.
For more details about infos, please refer to:
https://mmdetection3d.readthedocs.io/en/latest/advanced_guides/datasets/waymo.html
""" # noqa: E501
frame_infos = dict()
# Gather frame infos
sample_idx = \
f'{self.prefix}{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}'
frame_infos['sample_idx'] = int(sample_idx)
frame_infos['timestamp'] = frame.timestamp_micros
frame_infos['ego2global'] = np.array(frame.pose.transform).reshape(
4, 4).astype(np.float32).tolist()
frame_infos['context_name'] = frame.context.name
# Gather camera infos
frame_infos['images'] = dict()
# waymo front camera to kitti reference camera
T_front_cam_to_ref = np.array([[0.0, -1.0, 0.0], [0.0, 0.0, -1.0],
[1.0, 0.0, 0.0]])
camera_calibs = []
Tr_velo_to_cams = []
for camera in frame.context.camera_calibrations:
# extrinsic parameters
T_cam_to_vehicle = np.array(camera.extrinsic.transform).reshape(
4, 4)
T_vehicle_to_cam = np.linalg.inv(T_cam_to_vehicle)
Tr_velo_to_cam = \
self.cart_to_homo(T_front_cam_to_ref) @ T_vehicle_to_cam
Tr_velo_to_cams.append(Tr_velo_to_cam)
# intrinsic parameters
camera_calib = np.zeros((3, 4))
camera_calib[0, 0] = camera.intrinsic[0]
camera_calib[1, 1] = camera.intrinsic[1]
camera_calib[0, 2] = camera.intrinsic[2]
camera_calib[1, 2] = camera.intrinsic[3]
camera_calib[2, 2] = 1
camera_calibs.append(camera_calib)
for i, (cam_key, camera_calib, Tr_velo_to_cam) in enumerate(
zip(self.camera_types, camera_calibs, Tr_velo_to_cams)):
cam_infos = dict()
cam_infos['img_path'] = str(sample_idx) + '.jpg'
# NOTE: frames.images order is different
for img in frame.images:
if img.name == i + 1:
width, height = Image.open(BytesIO(img.image)).size
cam_infos['height'] = height
cam_infos['width'] = width
cam_infos['lidar2cam'] = Tr_velo_to_cam.astype(np.float32).tolist()
cam_infos['cam2img'] = camera_calib.astype(np.float32).tolist()
cam_infos['lidar2img'] = (camera_calib @ Tr_velo_to_cam).astype(
np.float32).tolist()
frame_infos['images'][cam_key] = cam_infos
# Gather lidar infos
lidar_infos = dict()
lidar_infos['lidar_path'] = str(sample_idx) + '.bin'
lidar_infos['num_pts_feats'] = 6
frame_infos['lidar_points'] = lidar_infos
# Gather lidar sweeps and camera sweeps infos
# TODO: Add lidar2img in image sweeps infos when we need it.
# TODO: Consider merging lidar sweeps infos and image sweeps infos.
lidar_sweeps_infos, image_sweeps_infos = [], []
for prev_offset in range(-1, -self.max_sweeps - 1, -1):
prev_lidar_infos = dict()
prev_image_infos = dict()
if frame_idx + prev_offset >= 0:
prev_frame_infos = file_infos[prev_offset]
prev_lidar_infos['timestamp'] = prev_frame_infos['timestamp']
prev_lidar_infos['ego2global'] = prev_frame_infos['ego2global']
prev_lidar_infos['lidar_points'] = dict()
lidar_path = prev_frame_infos['lidar_points']['lidar_path']
prev_lidar_infos['lidar_points']['lidar_path'] = lidar_path
lidar_sweeps_infos.append(prev_lidar_infos)
prev_image_infos['timestamp'] = prev_frame_infos['timestamp']
prev_image_infos['ego2global'] = prev_frame_infos['ego2global']
prev_image_infos['images'] = dict()
for cam_key in self.camera_types:
prev_image_infos['images'][cam_key] = dict()
img_path = prev_frame_infos['images'][cam_key]['img_path']
prev_image_infos['images'][cam_key]['img_path'] = img_path
image_sweeps_infos.append(prev_image_infos)
if lidar_sweeps_infos:
frame_infos['lidar_sweeps'] = lidar_sweeps_infos
if image_sweeps_infos:
frame_infos['image_sweeps'] = image_sweeps_infos
if not self.test_mode:
# Gather instances infos which is used for lidar-based 3D detection
frame_infos['instances'] = self.gather_instance_info(frame)
# Gather cam_sync_instances infos which is used for image-based
# (multi-view) 3D detection.
if self.save_cam_sync_instances:
frame_infos['cam_sync_instances'] = self.gather_instance_info(
frame, cam_sync=True)
# Gather cam_instances infos which is used for image-based
# (monocular) 3D detection (optional).
# TODO: Should we use cam_sync_instances to generate cam_instances?
if self.save_cam_instances:
frame_infos['cam_instances'] = self.gather_cam_instance_info(
copy.deepcopy(frame_infos['instances']),
frame_infos['images'])
file_infos.append(frame_infos)
def gather_instance_info(self, frame, cam_sync=False):
"""Generate instances and cam_sync_instances infos.
For more details about infos, please refer to:
https://mmdetection3d.readthedocs.io/en/latest/advanced_guides/datasets/waymo.html
""" # noqa: E501
id_to_bbox = dict()
id_to_name = dict()
for labels in frame.projected_lidar_labels:
name = labels.name
for label in labels.labels:
# TODO: need a workaround as bbox may not belong to front cam
bbox = [
label.box.center_x - label.box.length / 2,
label.box.center_y - label.box.width / 2,
label.box.center_x + label.box.length / 2,
label.box.center_y + label.box.width / 2
]
id_to_bbox[label.id] = bbox
id_to_name[label.id] = name - 1
group_id = 0
instance_infos = []
for obj in frame.laser_labels:
instance_info = dict()
bounding_box = None
name = None
id = obj.id
for proj_cam in self.cam_list:
if id + proj_cam in id_to_bbox:
bounding_box = id_to_bbox.get(id + proj_cam)
name = id_to_name.get(id + proj_cam)
break
# NOTE: the 2D labels do not have strict correspondence with
# the projected 2D lidar labels
# e.g.: the projected 2D labels can be in camera 2
# while the most_visible_camera can have id 4
if cam_sync:
if obj.most_visible_camera_name:
name = self.cam_list.index(
f'_{obj.most_visible_camera_name}')
box3d = obj.camera_synced_box
else:
continue
else:
box3d = obj.box
if bounding_box is None or name is None:
name = 0
bounding_box = [0.0, 0.0, 0.0, 0.0]
my_type = self.type_list[obj.type]
if my_type not in self.selected_waymo_classes:
continue
else:
label = self.selected_waymo_classes.index(my_type)
if self.filter_empty_3dboxes and obj.num_lidar_points_in_box < 1:
continue
group_id += 1
instance_info['group_id'] = group_id
instance_info['camera_id'] = name
instance_info['bbox'] = bounding_box
instance_info['bbox_label'] = label
height = box3d.height
width = box3d.width
length = box3d.length
# NOTE: We save the bottom center of 3D bboxes.
x = box3d.center_x
y = box3d.center_y
z = box3d.center_z - height / 2
rotation_y = box3d.heading
instance_info['bbox_3d'] = np.array(
[x, y, z, length, width, height,
rotation_y]).astype(np.float32).tolist()
instance_info['bbox_label_3d'] = label
instance_info['num_lidar_pts'] = obj.num_lidar_points_in_box
if self.save_track_id:
instance_info['track_id'] = obj.id
instance_infos.append(instance_info)
return instance_infos
def gather_cam_instance_info(self, instances: dict, images: dict):
"""Generate cam_instances infos.
For more details about infos, please refer to:
https://mmdetection3d.readthedocs.io/en/latest/advanced_guides/datasets/waymo.html
""" # noqa: E501
cam_instances = dict()
for cam_type in self.camera_types:
lidar2cam = np.array(images[cam_type]['lidar2cam'])
cam2img = np.array(images[cam_type]['cam2img'])
cam_instances[cam_type] = []
for instance in instances:
cam_instance = dict()
gt_bboxes_3d = np.array(instance['bbox_3d'])
# Convert lidar coordinates to camera coordinates
gt_bboxes_3d = LiDARInstance3DBoxes(
gt_bboxes_3d[None, :]).convert_to(
Box3DMode.CAM, lidar2cam, correct_yaw=True)
corners_3d = gt_bboxes_3d.corners.numpy()
corners_3d = corners_3d[0].T # (1, 8, 3) -> (3, 8)
in_camera = np.argwhere(corners_3d[2, :] > 0).flatten()
corners_3d = corners_3d[:, in_camera]
# Project 3d box to 2d.
corner_coords = view_points(corners_3d, cam2img,
True).T[:, :2].tolist()
# Keep only corners that fall within the image.
# TODO: imsize should be determined by the current image size
# CAM_FRONT: (1920, 1280)
# CAM_FRONT_LEFT: (1920, 1280)
# CAM_SIDE_LEFT: (1920, 886)
final_coords = post_process_coords(
corner_coords,
imsize=(images['CAM_FRONT']['width'],
images['CAM_FRONT']['height']))
# Skip if the convex hull of the re-projected corners
# does not intersect the image canvas.
if final_coords is None:
continue
else:
min_x, min_y, max_x, max_y = final_coords
cam_instance['bbox'] = [min_x, min_y, max_x, max_y]
cam_instance['bbox_label'] = instance['bbox_label']
cam_instance['bbox_3d'] = gt_bboxes_3d.numpy().squeeze(
).astype(np.float32).tolist()
cam_instance['bbox_label_3d'] = instance['bbox_label_3d']
center_3d = gt_bboxes_3d.gravity_center.numpy()
center_2d_with_depth = points_cam2img(
center_3d, cam2img, with_depth=True)
center_2d_with_depth = center_2d_with_depth.squeeze().tolist()
# normalized center2D + depth
# if samples with depth < 0 will be removed
if center_2d_with_depth[2] <= 0:
continue
cam_instance['center_2d'] = center_2d_with_depth[:2]
cam_instance['depth'] = center_2d_with_depth[2]
# TODO: Discuss whether following info is necessary
cam_instance['bbox_3d_isvalid'] = True
cam_instance['velocity'] = -1
cam_instances[cam_type].append(cam_instance)
return cam_instances
def merge_trainval_infos(self):
"""Merge training and validation infos into a single file."""
train_infos_path = osp.join(
osp.dirname(self.save_dir), f'{self.info_prefix}_infos_train.pkl')
val_infos_path = osp.join(
osp.dirname(self.save_dir), f'{self.info_prefix}_infos_val.pkl')
train_infos = mmengine.load(train_infos_path)
val_infos = mmengine.load(val_infos_path)
trainval_infos = dict(
metainfo=train_infos['metainfo'],
data_list=train_infos['data_list'] + val_infos['data_list'])
mmengine.dump(
trainval_infos,
osp.join(
osp.dirname(self.save_dir),
f'{self.info_prefix}_infos_trainval.pkl'))
def create_ImageSets_img_ids(root_dir, splits):
"""Create txt files indicating what to collect in each split."""
save_dir = join(root_dir, 'ImageSets/')
if not exists(save_dir):
os.mkdir(save_dir)
idx_all = [[] for i in splits]
idx_all = [[] for _ in splits]
for i, split in enumerate(splits):
path = join(root_dir, splits[i], 'calib')
path = join(root_dir, split, 'image_0')
if not exists(path):
RawNames = []
else:
RawNames = os.listdir(path)
for name in RawNames:
if name.endswith('.txt'):
idx = name.replace('.txt', '\n')
if name.endswith('.jpg'):
idx = name.replace('.jpg', '\n')
idx_all[int(idx[0])].append(idx)
idx_all[i].sort()
open(save_dir + 'train.txt', 'w').writelines(idx_all[0])
open(save_dir + 'val.txt', 'w').writelines(idx_all[1])
open(save_dir + 'trainval.txt', 'w').writelines(idx_all[0] + idx_all[1])
open(save_dir + 'test.txt', 'w').writelines(idx_all[2])
# open(save_dir+'test_cam_only.txt','w').writelines(idx_all[3])
if len(idx_all) >= 3:
open(save_dir + 'test.txt', 'w').writelines(idx_all[2])
if len(idx_all) >= 4:
open(save_dir + 'test_cam_only.txt', 'w').writelines(idx_all[3])
print('created txt files indicating what to collect in ', splits)
......@@ -21,6 +21,12 @@ def parse_args():
action='store_true',
default=False,
help='enable automatic-mixed-precision training')
parser.add_argument(
'--sync_bn',
choices=['none', 'torch', 'mmcv'],
default='none',
help='convert all BatchNorm layers in the model to SyncBatchNorm '
'(SyncBN) or mmcv.ops.sync_bn.SyncBatchNorm (MMSyncBN) layers.')
parser.add_argument(
'--auto-scale-lr',
action='store_true',
......@@ -98,6 +104,10 @@ def main():
cfg.optim_wrapper.type = 'AmpOptimWrapper'
cfg.optim_wrapper.loss_scale = 'dynamic'
# convert BatchNorm layers
if args.sync_bn != 'none':
cfg.sync_bn = args.sync_bn
# enable automatically scaling LR
if args.auto_scale_lr:
if 'auto_scale_lr' in cfg and \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment