Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
fe25f7a5
Unverified
Commit
fe25f7a5
authored
Jan 08, 2024
by
Wenwei Zhang
Committed by
GitHub
Jan 08, 2024
Browse files
Merge pull request #2867 from open-mmlab/dev-1.x
Bump version to 1.4.0
parents
5c0613be
0ef13b83
Pipeline
#2710
failed with stages
in 0 seconds
Changes
80
Pipelines
3
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3488 additions
and
364 deletions
+3488
-364
projects/NeRF-Det/nerfdet/multiview_pipeline.py
projects/NeRF-Det/nerfdet/multiview_pipeline.py
+297
-0
projects/NeRF-Det/nerfdet/nerf_det3d_data_sample.py
projects/NeRF-Det/nerfdet/nerf_det3d_data_sample.py
+52
-0
projects/NeRF-Det/nerfdet/nerf_utils/nerf_mlp.py
projects/NeRF-Det/nerfdet/nerf_utils/nerf_mlp.py
+277
-0
projects/NeRF-Det/nerfdet/nerf_utils/projection.py
projects/NeRF-Det/nerfdet/nerf_utils/projection.py
+140
-0
projects/NeRF-Det/nerfdet/nerf_utils/render_ray.py
projects/NeRF-Det/nerfdet/nerf_utils/render_ray.py
+431
-0
projects/NeRF-Det/nerfdet/nerf_utils/save_rendered_img.py
projects/NeRF-Det/nerfdet/nerf_utils/save_rendered_img.py
+79
-0
projects/NeRF-Det/nerfdet/nerfdet.py
projects/NeRF-Det/nerfdet/nerfdet.py
+632
-0
projects/NeRF-Det/nerfdet/nerfdet_head.py
projects/NeRF-Det/nerfdet/nerfdet_head.py
+629
-0
projects/NeRF-Det/nerfdet/scannet_multiview_dataset.py
projects/NeRF-Det/nerfdet/scannet_multiview_dataset.py
+202
-0
projects/NeRF-Det/prepare_infos.py
projects/NeRF-Det/prepare_infos.py
+151
-0
projects/PETR/README.md
projects/PETR/README.md
+2
-2
projects/PETR/petr/petr_head.py
projects/PETR/petr/petr_head.py
+1
-1
tests/data/waymo/kitti_format/waymo_infos_train.pkl
tests/data/waymo/kitti_format/waymo_infos_train.pkl
+0
-0
tests/data/waymo/kitti_format/waymo_infos_val.pkl
tests/data/waymo/kitti_format/waymo_infos_val.pkl
+0
-0
tests/test_datasets/test_waymo_dataset.py
tests/test_datasets/test_waymo_dataset.py
+80
-0
tools/create_data.py
tools/create_data.py
+87
-47
tools/create_data.sh
tools/create_data.sh
+5
-2
tools/dataset_converters/create_gt_database.py
tools/dataset_converters/create_gt_database.py
+19
-9
tools/dataset_converters/waymo_converter.py
tools/dataset_converters/waymo_converter.py
+394
-303
tools/train.py
tools/train.py
+10
-0
No files found.
projects/NeRF-Det/nerfdet/multiview_pipeline.py
0 → 100644
View file @
fe25f7a5
# Copyright (c) OpenMMLab. All rights reserved.
import
mmcv
import
numpy
as
np
from
mmcv.transforms
import
BaseTransform
,
Compose
from
PIL
import
Image
from
mmdet3d.registry
import
TRANSFORMS
def
get_dtu_raydir
(
pixelcoords
,
intrinsic
,
rot
,
dir_norm
=
None
):
# rot is c2w
# pixelcoords: H x W x 2
x
=
(
pixelcoords
[...,
0
]
+
0.5
-
intrinsic
[
0
,
2
])
/
intrinsic
[
0
,
0
]
y
=
(
pixelcoords
[...,
1
]
+
0.5
-
intrinsic
[
1
,
2
])
/
intrinsic
[
1
,
1
]
z
=
np
.
ones_like
(
x
)
dirs
=
np
.
stack
([
x
,
y
,
z
],
axis
=-
1
)
# dirs = np.sum(dirs[...,None,:] * rot[:,:], axis=-1) # h*w*1*3 x 3*3
dirs
=
dirs
@
rot
[:,
:].
T
#
if
dir_norm
:
dirs
=
dirs
/
(
np
.
linalg
.
norm
(
dirs
,
axis
=-
1
,
keepdims
=
True
)
+
1e-5
)
return
dirs
@
TRANSFORMS
.
register_module
()
class
MultiViewPipeline
(
BaseTransform
):
"""MultiViewPipeline used in nerfdet.
Required Keys:
- depth_info
- img_prefix
- img_info
- lidar2img
- c2w
- cammrotc2w
- lightpos
- ray_info
Modified Keys:
- lidar2img
Added Keys:
- img
- denorm_images
- depth
- c2w
- camrotc2w
- lightpos
- pixels
- raydirs
- gt_images
- gt_depths
- nerf_sizes
- depth_range
Args:
transforms (list[dict]): The transform pipeline
used to process the imgs.
n_images (int): The number of sampled views.
mean (array): The mean values used in normalization.
std (array): The variance values used in normalization.
margin (int): The margin value. Defaults to 10.
depth_range (array): The range of the depth.
Defaults to [0.5, 5.5].
loading (str): The mode of loading. Defaults to 'random'.
nerf_target_views (int): The number of novel views.
sample_freq (int): The frequency of sampling.
"""
def
__init__
(
self
,
transforms
:
dict
,
n_images
:
int
,
mean
:
tuple
=
[
123.675
,
116.28
,
103.53
],
std
:
tuple
=
[
58.395
,
57.12
,
57.375
],
margin
:
int
=
10
,
depth_range
:
tuple
=
[
0.5
,
5.5
],
loading
:
str
=
'random'
,
nerf_target_views
:
int
=
0
,
sample_freq
:
int
=
3
):
self
.
transforms
=
Compose
(
transforms
)
self
.
depth_transforms
=
Compose
(
transforms
[
1
])
self
.
n_images
=
n_images
self
.
mean
=
np
.
array
(
mean
,
dtype
=
np
.
float32
)
self
.
std
=
np
.
array
(
std
,
dtype
=
np
.
float32
)
self
.
margin
=
margin
self
.
depth_range
=
depth_range
self
.
loading
=
loading
self
.
sample_freq
=
sample_freq
self
.
nerf_target_views
=
nerf_target_views
def
transform
(
self
,
results
:
dict
)
->
dict
:
"""Nerfdet transform function.
Args:
results (dict): Result dict from loading pipeline
Returns:
dict: The result dict containing the processed results.
Updated key and value are described below.
- img (list): The loaded origin image.
- denorm_images (list): The denormalized image.
- depth (list): The origin depth image.
- c2w (list): The c2w matrixes.
- camrotc2w (list): The rotation matrixes.
- lightpos (list): The transform parameters of the camera.
- pixels (list): Some pixel information.
- raydirs (list): The ray-directions.
- gt_images (list): The groundtruth images.
- gt_depths (list): The groundtruth depth images.
- nerf_sizes (array): The size of the groundtruth images.
- depth_range (array): The range of the depth.
Here we give a detailed explanation of some keys mentioned above.
Let P_c be the coordinate of camera, P_w be the coordinate of world.
There is such a conversion relationship: P_c = R @ P_w + T.
The 'camrotc2w' mentioned above corresponds to the R matrix here.
The 'lightpos' corresponds to the T matrix here. And if you put
R and T together, you can get the camera extrinsics matrix. It
corresponds to the 'c2w' mentioned above.
"""
imgs
=
[]
depths
=
[]
extrinsics
=
[]
c2ws
=
[]
camrotc2ws
=
[]
lightposes
=
[]
pixels
=
[]
raydirs
=
[]
gt_images
=
[]
gt_depths
=
[]
denorm_imgs_list
=
[]
nerf_sizes
=
[]
if
self
.
loading
==
'random'
:
ids
=
np
.
arange
(
len
(
results
[
'img_info'
]))
replace
=
True
if
self
.
n_images
>
len
(
ids
)
else
False
ids
=
np
.
random
.
choice
(
ids
,
self
.
n_images
,
replace
=
replace
)
if
self
.
nerf_target_views
!=
0
:
target_id
=
np
.
random
.
choice
(
ids
,
self
.
nerf_target_views
,
replace
=
False
)
ids
=
np
.
setdiff1d
(
ids
,
target_id
)
ids
=
ids
.
tolist
()
target_id
=
target_id
.
tolist
()
else
:
ids
=
np
.
arange
(
len
(
results
[
'img_info'
]))
begin_id
=
0
ids
=
np
.
arange
(
begin_id
,
begin_id
+
self
.
n_images
*
self
.
sample_freq
,
self
.
sample_freq
)
if
self
.
nerf_target_views
!=
0
:
target_id
=
ids
ratio
=
0
size
=
(
240
,
320
)
for
i
in
ids
:
_results
=
dict
()
_results
[
'img_path'
]
=
results
[
'img_info'
][
i
][
'filename'
]
_results
=
self
.
transforms
(
_results
)
imgs
.
append
(
_results
[
'img'
])
# normalize
for
key
in
_results
.
get
(
'img_fields'
,
[
'img'
]):
_results
[
key
]
=
mmcv
.
imnormalize
(
_results
[
key
],
self
.
mean
,
self
.
std
,
True
)
_results
[
'img_norm_cfg'
]
=
dict
(
mean
=
self
.
mean
,
std
=
self
.
std
,
to_rgb
=
True
)
# pad
for
key
in
_results
.
get
(
'img_fields'
,
[
'img'
]):
padded_img
=
mmcv
.
impad
(
_results
[
key
],
shape
=
size
,
pad_val
=
0
)
_results
[
key
]
=
padded_img
_results
[
'pad_shape'
]
=
padded_img
.
shape
_results
[
'pad_fixed_size'
]
=
size
ori_shape
=
_results
[
'ori_shape'
]
aft_shape
=
_results
[
'img_shape'
]
ratio
=
ori_shape
[
0
]
/
aft_shape
[
0
]
# prepare the depth information
if
'depth_info'
in
results
.
keys
():
if
'.npy'
in
results
[
'depth_info'
][
i
][
'filename'
]:
_results
[
'depth'
]
=
np
.
load
(
results
[
'depth_info'
][
i
][
'filename'
])
else
:
_results
[
'depth'
]
=
np
.
asarray
((
Image
.
open
(
results
[
'depth_info'
][
i
][
'filename'
])))
/
1000
_results
[
'depth'
]
=
mmcv
.
imresize
(
_results
[
'depth'
],
(
aft_shape
[
1
],
aft_shape
[
0
]))
depths
.
append
(
_results
[
'depth'
])
denorm_img
=
mmcv
.
imdenormalize
(
_results
[
'img'
],
self
.
mean
,
self
.
std
,
to_bgr
=
True
).
astype
(
np
.
uint8
)
/
255.0
denorm_imgs_list
.
append
(
denorm_img
)
height
,
width
=
padded_img
.
shape
[:
2
]
extrinsics
.
append
(
results
[
'lidar2img'
][
'extrinsic'
][
i
])
# prepare the nerf information
if
'ray_info'
in
results
.
keys
():
intrinsics_nerf
=
results
[
'lidar2img'
][
'intrinsic'
].
copy
()
intrinsics_nerf
[:
2
]
=
intrinsics_nerf
[:
2
]
/
ratio
assert
self
.
nerf_target_views
>
0
for
i
in
target_id
:
c2ws
.
append
(
results
[
'c2w'
][
i
])
camrotc2ws
.
append
(
results
[
'camrotc2w'
][
i
])
lightposes
.
append
(
results
[
'lightpos'
][
i
])
px
,
py
=
np
.
meshgrid
(
np
.
arange
(
self
.
margin
,
width
-
self
.
margin
).
astype
(
np
.
float32
),
np
.
arange
(
self
.
margin
,
height
-
self
.
margin
).
astype
(
np
.
float32
))
pixelcoords
=
np
.
stack
((
px
,
py
),
axis
=-
1
).
astype
(
np
.
float32
)
# H x W x 2
pixels
.
append
(
pixelcoords
)
raydir
=
get_dtu_raydir
(
pixelcoords
,
intrinsics_nerf
,
results
[
'camrotc2w'
][
i
])
raydirs
.
append
(
np
.
reshape
(
raydir
.
astype
(
np
.
float32
),
(
-
1
,
3
)))
# read target images
temp_results
=
dict
()
temp_results
[
'img_path'
]
=
results
[
'img_info'
][
i
][
'filename'
]
temp_results_
=
self
.
transforms
(
temp_results
)
# normalize
for
key
in
temp_results
.
get
(
'img_fields'
,
[
'img'
]):
temp_results
[
key
]
=
mmcv
.
imnormalize
(
temp_results
[
key
],
self
.
mean
,
self
.
std
,
True
)
temp_results
[
'img_norm_cfg'
]
=
dict
(
mean
=
self
.
mean
,
std
=
self
.
std
,
to_rgb
=
True
)
# pad
for
key
in
temp_results
.
get
(
'img_fields'
,
[
'img'
]):
padded_img
=
mmcv
.
impad
(
temp_results
[
key
],
shape
=
size
,
pad_val
=
0
)
temp_results
[
key
]
=
padded_img
temp_results
[
'pad_shape'
]
=
padded_img
.
shape
temp_results
[
'pad_fixed_size'
]
=
size
# denormalize target_images.
denorm_imgs
=
mmcv
.
imdenormalize
(
temp_results_
[
'img'
],
self
.
mean
,
self
.
std
,
to_bgr
=
True
).
astype
(
np
.
uint8
)
gt_rgb_shape
=
denorm_imgs
.
shape
gt_image
=
denorm_imgs
[
py
.
astype
(
np
.
int32
),
px
.
astype
(
np
.
int32
),
:]
nerf_sizes
.
append
(
np
.
array
(
gt_image
.
shape
))
gt_image
=
np
.
reshape
(
gt_image
,
(
-
1
,
3
))
gt_images
.
append
(
gt_image
/
255.0
)
if
'depth_info'
in
results
.
keys
():
if
'.npy'
in
results
[
'depth_info'
][
i
][
'filename'
]:
_results
[
'depth'
]
=
np
.
load
(
results
[
'depth_info'
][
i
][
'filename'
])
else
:
depth_image
=
Image
.
open
(
results
[
'depth_info'
][
i
][
'filename'
])
_results
[
'depth'
]
=
np
.
asarray
(
depth_image
)
/
1000
_results
[
'depth'
]
=
mmcv
.
imresize
(
_results
[
'depth'
],
(
gt_rgb_shape
[
1
],
gt_rgb_shape
[
0
]))
_results
[
'depth'
]
=
_results
[
'depth'
]
gt_depth
=
_results
[
'depth'
][
py
.
astype
(
np
.
int32
),
px
.
astype
(
np
.
int32
)]
gt_depths
.
append
(
gt_depth
)
for
key
in
_results
.
keys
():
if
key
not
in
[
'img'
,
'img_info'
]:
results
[
key
]
=
_results
[
key
]
results
[
'img'
]
=
imgs
if
'ray_info'
in
results
.
keys
():
results
[
'c2w'
]
=
c2ws
results
[
'camrotc2w'
]
=
camrotc2ws
results
[
'lightpos'
]
=
lightposes
results
[
'pixels'
]
=
pixels
results
[
'raydirs'
]
=
raydirs
results
[
'gt_images'
]
=
gt_images
results
[
'gt_depths'
]
=
gt_depths
results
[
'nerf_sizes'
]
=
nerf_sizes
results
[
'denorm_images'
]
=
denorm_imgs_list
results
[
'depth_range'
]
=
np
.
array
([
self
.
depth_range
])
if
len
(
depths
)
!=
0
:
results
[
'depth'
]
=
depths
results
[
'lidar2img'
][
'extrinsic'
]
=
extrinsics
return
results
@
TRANSFORMS
.
register_module
()
class
RandomShiftOrigin
(
BaseTransform
):
def
__init__
(
self
,
std
):
self
.
std
=
std
def
transform
(
self
,
results
):
shift
=
np
.
random
.
normal
(.
0
,
self
.
std
,
3
)
results
[
'lidar2img'
][
'origin'
]
+=
shift
return
results
projects/NeRF-Det/nerfdet/nerf_det3d_data_sample.py
0 → 100644
View file @
fe25f7a5
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
mmengine.structures
import
InstanceData
from
mmdet3d.structures
import
Det3DDataSample
class
NeRFDet3DDataSample
(
Det3DDataSample
):
"""A data structure interface inheirted from Det3DDataSample. Some new
attributes are added to match the NeRF-Det project.
The attributes added in ``NeRFDet3DDataSample`` are divided into two parts:
- ``gt_nerf_images`` (InstanceData): Ground truth of the images which
will be used in the NeRF branch.
- ``gt_nerf_depths`` (InstanceData): Ground truth of the depth images
which will be used in the NeRF branch if needed.
For more details and examples, please refer to the 'Det3DDataSample' file.
"""
@
property
def
gt_nerf_images
(
self
)
->
InstanceData
:
return
self
.
_gt_nerf_images
@
gt_nerf_images
.
setter
def
gt_nerf_images
(
self
,
value
:
InstanceData
)
->
None
:
self
.
set_field
(
value
,
'_gt_nerf_images'
,
dtype
=
InstanceData
)
@
gt_nerf_images
.
deleter
def
gt_nerf_images
(
self
)
->
None
:
del
self
.
_gt_nerf_images
@
property
def
gt_nerf_depths
(
self
)
->
InstanceData
:
return
self
.
_gt_nerf_depths
@
gt_nerf_depths
.
setter
def
gt_nerf_depths
(
self
,
value
:
InstanceData
)
->
None
:
self
.
set_field
(
value
,
'_gt_nerf_depths'
,
dtype
=
InstanceData
)
@
gt_nerf_depths
.
deleter
def
gt_nerf_depths
(
self
)
->
None
:
del
self
.
_gt_nerf_depths
SampleList
=
List
[
NeRFDet3DDataSample
]
OptSampleList
=
Optional
[
SampleList
]
ForwardResults
=
Union
[
Dict
[
str
,
torch
.
Tensor
],
List
[
NeRFDet3DDataSample
],
Tuple
[
torch
.
Tensor
],
torch
.
Tensor
]
projects/NeRF-Det/nerfdet/nerf_utils/nerf_mlp.py
0 → 100644
View file @
fe25f7a5
# Copyright (c) OpenMMLab. All rights reserved.
import
math
from
typing
import
Callable
,
Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
MLP
(
nn
.
Module
):
"""The MLP module used in NerfDet.
Args:
input_dim (int): The number of input tensor channels.
output_dim (int): The number of output tensor channels.
net_depth (int): The depth of the MLP. Defaults to 8.
net_width (int): The width of the MLP. Defaults to 256.
skip_layer (int): The layer to add skip layers to. Defaults to 4.
hidden_init (Callable): The initialize method of the hidden layers.
hidden_activation (Callable): The activation function of hidden
layers, defaults to ReLU.
output_enabled (bool): If true, the output layers will be used.
Defaults to True.
output_init (Optional): The initialize method of the output layer.
output_activation(Optional): The activation function of output layers.
bias_enabled (Bool): If true, the bias will be used.
bias_init (Callable): The initialize method of the bias.
Defaults to True.
"""
def
__init__
(
self
,
input_dim
:
int
,
output_dim
:
int
=
None
,
net_depth
:
int
=
8
,
net_width
:
int
=
256
,
skip_layer
:
int
=
4
,
hidden_init
:
Callable
=
nn
.
init
.
xavier_uniform_
,
hidden_activation
:
Callable
=
nn
.
ReLU
(),
output_enabled
:
bool
=
True
,
output_init
:
Optional
[
Callable
]
=
nn
.
init
.
xavier_uniform_
,
output_activation
:
Optional
[
Callable
]
=
nn
.
Identity
(),
bias_enabled
:
bool
=
True
,
bias_init
:
Callable
=
nn
.
init
.
zeros_
,
):
super
().
__init__
()
self
.
input_dim
=
input_dim
self
.
output_dim
=
output_dim
self
.
net_depth
=
net_depth
self
.
net_width
=
net_width
self
.
skip_layer
=
skip_layer
self
.
hidden_init
=
hidden_init
self
.
hidden_activation
=
hidden_activation
self
.
output_enabled
=
output_enabled
self
.
output_init
=
output_init
self
.
output_activation
=
output_activation
self
.
bias_enabled
=
bias_enabled
self
.
bias_init
=
bias_init
self
.
hidden_layers
=
nn
.
ModuleList
()
in_features
=
self
.
input_dim
for
i
in
range
(
self
.
net_depth
):
self
.
hidden_layers
.
append
(
nn
.
Linear
(
in_features
,
self
.
net_width
,
bias
=
bias_enabled
))
if
(
self
.
skip_layer
is
not
None
)
and
(
i
%
self
.
skip_layer
==
0
)
and
(
i
>
0
):
in_features
=
self
.
net_width
+
self
.
input_dim
else
:
in_features
=
self
.
net_width
if
self
.
output_enabled
:
self
.
output_layer
=
nn
.
Linear
(
in_features
,
self
.
output_dim
,
bias
=
bias_enabled
)
else
:
self
.
output_dim
=
in_features
self
.
initialize
()
def
initialize
(
self
):
def
init_func_hidden
(
m
):
if
isinstance
(
m
,
nn
.
Linear
):
if
self
.
hidden_init
is
not
None
:
self
.
hidden_init
(
m
.
weight
)
if
self
.
bias_enabled
and
self
.
bias_init
is
not
None
:
self
.
bias_init
(
m
.
bias
)
self
.
hidden_layers
.
apply
(
init_func_hidden
)
if
self
.
output_enabled
:
def
init_func_output
(
m
):
if
isinstance
(
m
,
nn
.
Linear
):
if
self
.
output_init
is
not
None
:
self
.
output_init
(
m
.
weight
)
if
self
.
bias_enabled
and
self
.
bias_init
is
not
None
:
self
.
bias_init
(
m
.
bias
)
self
.
output_layer
.
apply
(
init_func_output
)
def
forward
(
self
,
x
):
inputs
=
x
for
i
in
range
(
self
.
net_depth
):
x
=
self
.
hidden_layers
[
i
](
x
)
x
=
self
.
hidden_activation
(
x
)
if
(
self
.
skip_layer
is
not
None
)
and
(
i
%
self
.
skip_layer
==
0
)
and
(
i
>
0
):
x
=
torch
.
cat
([
x
,
inputs
],
dim
=-
1
)
if
self
.
output_enabled
:
x
=
self
.
output_layer
(
x
)
x
=
self
.
output_activation
(
x
)
return
x
class
DenseLayer
(
MLP
):
def
__init__
(
self
,
input_dim
,
output_dim
,
**
kwargs
):
super
().
__init__
(
input_dim
=
input_dim
,
output_dim
=
output_dim
,
net_depth
=
0
,
# no hidden layers
**
kwargs
,
)
class
NerfMLP
(
nn
.
Module
):
"""The Nerf-MLP Module.
Args:
input_dim (int): The number of input tensor channels.
condition_dim (int): The number of condition tensor channels.
feature_dim (int): The number of feature channels. Defaults to 0.
net_depth (int): The depth of the MLP. Defaults to 8.
net_width (int): The width of the MLP. Defaults to 256.
skip_layer (int): The layer to add skip layers to. Defaults to 4.
net_depth_condition (int): The depth of the second part of MLP.
Defaults to 1.
net_width_condition (int): The width of the second part of MLP.
Defaults to 128.
"""
def
__init__
(
self
,
input_dim
:
int
,
condition_dim
:
int
,
feature_dim
:
int
=
0
,
net_depth
:
int
=
8
,
net_width
:
int
=
256
,
skip_layer
:
int
=
4
,
net_depth_condition
:
int
=
1
,
net_width_condition
:
int
=
128
,
):
super
().
__init__
()
self
.
base
=
MLP
(
input_dim
=
input_dim
+
feature_dim
,
net_depth
=
net_depth
,
net_width
=
net_width
,
skip_layer
=
skip_layer
,
output_enabled
=
False
,
)
hidden_features
=
self
.
base
.
output_dim
self
.
sigma_layer
=
DenseLayer
(
hidden_features
,
1
)
if
condition_dim
>
0
:
self
.
bottleneck_layer
=
DenseLayer
(
hidden_features
,
net_width
)
self
.
rgb_layer
=
MLP
(
input_dim
=
net_width
+
condition_dim
,
output_dim
=
3
,
net_depth
=
net_depth_condition
,
net_width
=
net_width_condition
,
skip_layer
=
None
,
)
else
:
self
.
rgb_layer
=
DenseLayer
(
hidden_features
,
3
)
def
query_density
(
self
,
x
,
features
=
None
):
"""Calculate the raw sigma."""
if
features
is
not
None
:
x
=
self
.
base
(
torch
.
cat
([
x
,
features
],
dim
=-
1
))
else
:
x
=
self
.
base
(
x
)
raw_sigma
=
self
.
sigma_layer
(
x
)
return
raw_sigma
def
forward
(
self
,
x
,
condition
=
None
,
features
=
None
):
if
features
is
not
None
:
x
=
self
.
base
(
torch
.
cat
([
x
,
features
],
dim
=-
1
))
else
:
x
=
self
.
base
(
x
)
raw_sigma
=
self
.
sigma_layer
(
x
)
if
condition
is
not
None
:
if
condition
.
shape
[:
-
1
]
!=
x
.
shape
[:
-
1
]:
num_rays
,
n_dim
=
condition
.
shape
condition
=
condition
.
view
(
[
num_rays
]
+
[
1
]
*
(
x
.
dim
()
-
condition
.
dim
())
+
[
n_dim
]).
expand
(
list
(
x
.
shape
[:
-
1
])
+
[
n_dim
])
bottleneck
=
self
.
bottleneck_layer
(
x
)
x
=
torch
.
cat
([
bottleneck
,
condition
],
dim
=-
1
)
raw_rgb
=
self
.
rgb_layer
(
x
)
return
raw_rgb
,
raw_sigma
class
SinusoidalEncoder
(
nn
.
Module
):
"""Sinusodial Positional Encoder used in NeRF."""
def
__init__
(
self
,
x_dim
,
min_deg
,
max_deg
,
use_identity
:
bool
=
True
):
super
().
__init__
()
self
.
x_dim
=
x_dim
self
.
min_deg
=
min_deg
self
.
max_deg
=
max_deg
self
.
use_identity
=
use_identity
self
.
register_buffer
(
'scales'
,
torch
.
tensor
([
2
**
i
for
i
in
range
(
min_deg
,
max_deg
)]))
@
property
def
latent_dim
(
self
)
->
int
:
return
(
int
(
self
.
use_identity
)
+
(
self
.
max_deg
-
self
.
min_deg
)
*
2
)
*
self
.
x_dim
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
max_deg
==
self
.
min_deg
:
return
x
xb
=
torch
.
reshape
(
(
x
[
Ellipsis
,
None
,
:]
*
self
.
scales
[:,
None
]),
list
(
x
.
shape
[:
-
1
])
+
[(
self
.
max_deg
-
self
.
min_deg
)
*
self
.
x_dim
],
)
latent
=
torch
.
sin
(
torch
.
cat
([
xb
,
xb
+
0.5
*
math
.
pi
],
dim
=-
1
))
if
self
.
use_identity
:
latent
=
torch
.
cat
([
x
]
+
[
latent
],
dim
=-
1
)
return
latent
class
VanillaNeRF
(
nn
.
Module
):
"""The Nerf-MLP with the positional encoder.
Args:
net_depth (int): The depth of the MLP. Defaults to 8.
net_width (int): The width of the MLP. Defaults to 256.
skip_layer (int): The layer to add skip layers to. Defaults to 4.
feature_dim (int): The number of feature channels. Defaults to 0.
net_depth_condition (int): The depth of the second part of MLP.
Defaults to 1.
net_width_condition (int): The width of the second part of MLP.
Defaults to 128.
"""
def
__init__
(
self
,
net_depth
:
int
=
8
,
net_width
:
int
=
256
,
skip_layer
:
int
=
4
,
feature_dim
:
int
=
0
,
net_depth_condition
:
int
=
1
,
net_width_condition
:
int
=
128
):
super
().
__init__
()
self
.
posi_encoder
=
SinusoidalEncoder
(
3
,
0
,
10
,
True
)
self
.
view_encoder
=
SinusoidalEncoder
(
3
,
0
,
4
,
True
)
self
.
mlp
=
NerfMLP
(
input_dim
=
self
.
posi_encoder
.
latent_dim
,
condition_dim
=
self
.
view_encoder
.
latent_dim
,
feature_dim
=
feature_dim
,
net_depth
=
net_depth
,
net_width
=
net_width
,
skip_layer
=
skip_layer
,
net_depth_condition
=
net_depth_condition
,
net_width_condition
=
net_width_condition
,
)
def
query_density
(
self
,
x
,
features
=
None
):
x
=
self
.
posi_encoder
(
x
)
sigma
=
self
.
mlp
.
query_density
(
x
,
features
)
return
F
.
relu
(
sigma
)
def
forward
(
self
,
x
,
condition
=
None
,
features
=
None
):
x
=
self
.
posi_encoder
(
x
)
if
condition
is
not
None
:
condition
=
self
.
view_encoder
(
condition
)
rgb
,
sigma
=
self
.
mlp
(
x
,
condition
=
condition
,
features
=
features
)
return
torch
.
sigmoid
(
rgb
),
F
.
relu
(
sigma
)
projects/NeRF-Det/nerfdet/nerf_utils/projection.py
0 → 100644
View file @
fe25f7a5
# Copyright (c) OpenMMLab. All rights reserved.
# Attention: This file is mainly modified based on the file with the same
# name in the original project. For more details, please refer to the
# origin project.
import
torch
import
torch.nn.functional
as
F
class
Projector
():
def
__init__
(
self
,
device
=
'cuda'
):
self
.
device
=
device
def
inbound
(
self
,
pixel_locations
,
h
,
w
):
"""check if the pixel locations are in valid range."""
return
(
pixel_locations
[...,
0
]
<=
w
-
1.
)
&
\
(
pixel_locations
[...,
0
]
>=
0
)
&
\
(
pixel_locations
[...,
1
]
<=
h
-
1.
)
&
\
(
pixel_locations
[...,
1
]
>=
0
)
def
normalize
(
self
,
pixel_locations
,
h
,
w
):
resize_factor
=
torch
.
tensor
([
w
-
1.
,
h
-
1.
]).
to
(
pixel_locations
.
device
)[
None
,
None
,
:]
normalized_pixel_locations
=
2
*
pixel_locations
/
resize_factor
-
1.
return
normalized_pixel_locations
def
compute_projections
(
self
,
xyz
,
train_cameras
):
"""project 3D points into cameras."""
original_shape
=
xyz
.
shape
[:
2
]
xyz
=
xyz
.
reshape
(
-
1
,
3
)
num_views
=
len
(
train_cameras
)
train_intrinsics
=
train_cameras
[:,
2
:
18
].
reshape
(
-
1
,
4
,
4
)
train_poses
=
train_cameras
[:,
-
16
:].
reshape
(
-
1
,
4
,
4
)
xyz_h
=
torch
.
cat
([
xyz
,
torch
.
ones_like
(
xyz
[...,
:
1
])],
dim
=-
1
)
# projections = train_intrinsics.bmm(torch.inverse(train_poses))
# we have inverse the pose in dataloader so
# do not need to inverse here.
projections
=
train_intrinsics
.
bmm
(
train_poses
)
\
.
bmm
(
xyz_h
.
t
()[
None
,
...].
repeat
(
num_views
,
1
,
1
))
projections
=
projections
.
permute
(
0
,
2
,
1
)
pixel_locations
=
projections
[...,
:
2
]
/
torch
.
clamp
(
projections
[...,
2
:
3
],
min
=
1e-8
)
pixel_locations
=
torch
.
clamp
(
pixel_locations
,
min
=-
1e6
,
max
=
1e6
)
mask
=
projections
[...,
2
]
>
0
return
pixel_locations
.
reshape
((
num_views
,
)
+
original_shape
+
(
2
,
)),
\
mask
.
reshape
((
num_views
,
)
+
original_shape
)
# noqa
def
compute_angle
(
self
,
xyz
,
query_camera
,
train_cameras
):
original_shape
=
xyz
.
shape
[:
2
]
xyz
=
xyz
.
reshape
(
-
1
,
3
)
train_poses
=
train_cameras
[:,
-
16
:].
reshape
(
-
1
,
4
,
4
)
num_views
=
len
(
train_poses
)
query_pose
=
query_camera
[
-
16
:].
reshape
(
-
1
,
4
,
4
).
repeat
(
num_views
,
1
,
1
)
ray2tar_pose
=
(
query_pose
[:,
:
3
,
3
].
unsqueeze
(
1
)
-
xyz
.
unsqueeze
(
0
))
ray2tar_pose
/=
(
torch
.
norm
(
ray2tar_pose
,
dim
=-
1
,
keepdim
=
True
)
+
1e-6
)
ray2train_pose
=
(
train_poses
[:,
:
3
,
3
].
unsqueeze
(
1
)
-
xyz
.
unsqueeze
(
0
))
ray2train_pose
/=
(
torch
.
norm
(
ray2train_pose
,
dim
=-
1
,
keepdim
=
True
)
+
1e-6
)
ray_diff
=
ray2tar_pose
-
ray2train_pose
ray_diff_norm
=
torch
.
norm
(
ray_diff
,
dim
=-
1
,
keepdim
=
True
)
ray_diff_dot
=
torch
.
sum
(
ray2tar_pose
*
ray2train_pose
,
dim
=-
1
,
keepdim
=
True
)
ray_diff_direction
=
ray_diff
/
torch
.
clamp
(
ray_diff_norm
,
min
=
1e-6
)
ray_diff
=
torch
.
cat
([
ray_diff_direction
,
ray_diff_dot
],
dim
=-
1
)
ray_diff
=
ray_diff
.
reshape
((
num_views
,
)
+
original_shape
+
(
4
,
))
return
ray_diff
def
compute
(
self
,
xyz
,
train_imgs
,
train_cameras
,
featmaps
=
None
,
grid_sample
=
True
):
assert
(
train_imgs
.
shape
[
0
]
==
1
)
\
and
(
train_cameras
.
shape
[
0
]
==
1
)
# only support batch_size=1 for now
train_imgs
=
train_imgs
.
squeeze
(
0
)
train_cameras
=
train_cameras
.
squeeze
(
0
)
train_imgs
=
train_imgs
.
permute
(
0
,
3
,
1
,
2
)
h
,
w
=
train_cameras
[
0
][:
2
]
# compute the projection of the query points to each reference image
pixel_locations
,
mask_in_front
=
self
.
compute_projections
(
xyz
,
train_cameras
)
normalized_pixel_locations
=
self
.
normalize
(
pixel_locations
,
h
,
w
)
# rgb sampling
rgbs_sampled
=
F
.
grid_sample
(
train_imgs
,
normalized_pixel_locations
,
align_corners
=
True
)
rgb_sampled
=
rgbs_sampled
.
permute
(
2
,
3
,
0
,
1
)
# deep feature sampling
if
featmaps
is
not
None
:
if
grid_sample
:
feat_sampled
=
F
.
grid_sample
(
featmaps
,
normalized_pixel_locations
,
align_corners
=
True
)
feat_sampled
=
feat_sampled
.
permute
(
2
,
3
,
0
,
1
)
# [n_rays, n_samples, n_views, d]
rgb_feat_sampled
=
torch
.
cat
(
[
rgb_sampled
,
feat_sampled
],
dim
=-
1
)
# [n_rays, n_samples, n_views, d+3]
# rgb_feat_sampled = feat_sampled
else
:
n_images
,
n_channels
,
f_h
,
f_w
=
featmaps
.
shape
resize_factor
=
torch
.
tensor
([
f_w
/
w
-
1.
,
f_h
/
h
-
1.
]).
to
(
pixel_locations
.
device
)[
None
,
None
,
:]
sample_location
=
(
pixel_locations
*
resize_factor
).
round
().
long
()
n_images
,
n_ray
,
n_sample
,
_
=
sample_location
.
shape
sample_x
=
sample_location
[...,
0
].
view
(
n_images
,
-
1
)
sample_y
=
sample_location
[...,
1
].
view
(
n_images
,
-
1
)
valid
=
(
sample_x
>=
0
)
&
(
sample_y
>=
0
)
&
(
sample_x
<
f_w
)
&
(
sample_y
<
f_h
)
valid
=
valid
*
mask_in_front
.
view
(
n_images
,
-
1
)
feat_sampled
=
torch
.
zeros
(
(
n_images
,
n_channels
,
sample_x
.
shape
[
-
1
]),
device
=
featmaps
.
device
)
for
i
in
range
(
n_images
):
feat_sampled
[
i
,
:,
valid
[
i
]]
=
featmaps
[
i
,
:,
sample_y
[
i
,
valid
[
i
]],
sample_y
[
i
,
valid
[
i
]]]
feat_sampled
=
feat_sampled
.
view
(
n_images
,
n_channels
,
n_ray
,
n_sample
)
rgb_feat_sampled
=
feat_sampled
.
permute
(
2
,
3
,
0
,
1
)
else
:
rgb_feat_sampled
=
None
inbound
=
self
.
inbound
(
pixel_locations
,
h
,
w
)
mask
=
(
inbound
*
mask_in_front
).
float
().
permute
(
1
,
2
,
0
)[...,
None
]
# [n_rays, n_samples, n_views, 1]
return
rgb_feat_sampled
,
mask
projects/NeRF-Det/nerfdet/nerf_utils/render_ray.py
0 → 100644
View file @
fe25f7a5
# Copyright (c) OpenMMLab. All rights reserved.
# Attention: This file is mainly modified based on the file with the same
# name in the original project. For more details, please refer to the
# origin project.
from
collections
import
OrderedDict
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
rng
=
np
.
random
.
RandomState
(
234
)
# helper functions for nerf ray rendering
def
volume_sampling
(
sample_pts
,
features
,
aabb
):
B
,
C
,
D
,
W
,
H
=
features
.
shape
assert
B
==
1
aabb
=
torch
.
Tensor
(
aabb
).
to
(
sample_pts
.
device
)
N_rays
,
N_samples
,
coords
=
sample_pts
.
shape
sample_pts
=
sample_pts
.
view
(
1
,
N_rays
*
N_samples
,
1
,
1
,
3
).
repeat
(
B
,
1
,
1
,
1
,
1
)
aabbSize
=
aabb
[
1
]
-
aabb
[
0
]
invgridSize
=
1.0
/
aabbSize
*
2
norm_pts
=
(
sample_pts
-
aabb
[
0
])
*
invgridSize
-
1
sample_features
=
F
.
grid_sample
(
features
,
norm_pts
,
align_corners
=
True
,
padding_mode
=
'border'
)
masks
=
((
norm_pts
<
1
)
&
(
norm_pts
>
-
1
)).
float
().
sum
(
dim
=-
1
)
masks
=
(
masks
.
view
(
N_rays
,
N_samples
)
==
3
)
return
sample_features
.
view
(
C
,
N_rays
,
N_samples
).
permute
(
1
,
2
,
0
).
contiguous
(),
masks
def
_compute_projection
(
img_meta
):
views
=
len
(
img_meta
[
'lidar2img'
][
'extrinsic'
])
intrinsic
=
torch
.
tensor
(
img_meta
[
'lidar2img'
][
'intrinsic'
][:
4
,
:
4
])
ratio
=
img_meta
[
'ori_shape'
][
0
]
/
img_meta
[
'img_shape'
][
0
]
intrinsic
[:
2
]
/=
ratio
intrinsic
=
intrinsic
.
unsqueeze
(
0
).
view
(
1
,
16
).
repeat
(
views
,
1
)
img_size
=
torch
.
Tensor
(
img_meta
[
'img_shape'
][:
2
]).
to
(
intrinsic
.
device
)
img_size
=
img_size
.
unsqueeze
(
0
).
repeat
(
views
,
1
)
extrinsics
=
[]
for
v
in
range
(
views
):
extrinsics
.
append
(
torch
.
Tensor
(
img_meta
[
'lidar2img'
][
'extrinsic'
][
v
]).
to
(
intrinsic
.
device
))
extrinsic
=
torch
.
stack
(
extrinsics
).
view
(
views
,
16
)
train_cameras
=
torch
.
cat
([
img_size
,
intrinsic
,
extrinsic
],
dim
=-
1
)
return
train_cameras
.
unsqueeze
(
0
)
def
compute_mask_points
(
feature
,
mask
):
weight
=
mask
/
(
torch
.
sum
(
mask
,
dim
=
2
,
keepdim
=
True
)
+
1e-8
)
mean
=
torch
.
sum
(
feature
*
weight
,
dim
=
2
,
keepdim
=
True
)
var
=
torch
.
sum
((
feature
-
mean
)
**
2
,
dim
=
2
,
keepdim
=
True
)
var
=
var
/
(
torch
.
sum
(
mask
,
dim
=
2
,
keepdim
=
True
)
+
1e-8
)
var
=
torch
.
exp
(
-
var
)
return
mean
,
var
def
sample_pdf
(
bins
,
weights
,
N_samples
,
det
=
False
):
"""Helper function used for sampling.
Args:
bins (tensor):Tensor of shape [N_rays, M+1], M is the number of bins
weights (tensor):Tensor of shape [N_rays, M+1], M is the number of bins
N_samples (int):Number of samples along each ray
det (bool):If True, will perform deterministic sampling
Returns:
samples (tuple): [N_rays, N_samples]
"""
M
=
weights
.
shape
[
1
]
weights
+=
1e-5
# Get pdf
pdf
=
weights
/
torch
.
sum
(
weights
,
dim
=-
1
,
keepdim
=
True
)
cdf
=
torch
.
cumsum
(
pdf
,
dim
=-
1
)
cdf
=
torch
.
cat
([
torch
.
zeros_like
(
cdf
[:,
0
:
1
]),
cdf
],
dim
=-
1
)
# Take uniform samples
if
det
:
u
=
torch
.
linspace
(
0.
,
1.
,
N_samples
,
device
=
bins
.
device
)
u
=
u
.
unsqueeze
(
0
).
repeat
(
bins
.
shape
[
0
],
1
)
else
:
u
=
torch
.
rand
(
bins
.
shape
[
0
],
N_samples
,
device
=
bins
.
device
)
# Invert CDF
above_inds
=
torch
.
zeros_like
(
u
,
dtype
=
torch
.
long
)
for
i
in
range
(
M
):
above_inds
+=
(
u
>=
cdf
[:,
i
:
i
+
1
]).
long
()
# random sample inside each bin
below_inds
=
torch
.
clamp
(
above_inds
-
1
,
min
=
0
)
inds_g
=
torch
.
stack
((
below_inds
,
above_inds
),
dim
=
2
)
cdf
=
cdf
.
unsqueeze
(
1
).
repeat
(
1
,
N_samples
,
1
)
cdf_g
=
torch
.
gather
(
input
=
cdf
,
dim
=-
1
,
index
=
inds_g
)
bins
=
bins
.
unsqueeze
(
1
).
repeat
(
1
,
N_samples
,
1
)
bins_g
=
torch
.
gather
(
input
=
bins
,
dim
=-
1
,
index
=
inds_g
)
denom
=
cdf_g
[:,
:,
1
]
-
cdf_g
[:,
:,
0
]
denom
=
torch
.
where
(
denom
<
1e-5
,
torch
.
ones_like
(
denom
),
denom
)
t
=
(
u
-
cdf_g
[:,
:,
0
])
/
denom
samples
=
bins_g
[:,
:,
0
]
+
t
*
(
bins_g
[:,
:,
1
]
-
bins_g
[:,
:,
0
])
return
samples
def
sample_along_camera_ray
(
ray_o
,
ray_d
,
depth_range
,
N_samples
,
inv_uniform
=
False
,
det
=
False
):
"""Sampling along the camera ray.
Args:
ray_o (tensor): Origin of the ray in scene coordinate system;
tensor of shape [N_rays, 3]
ray_d (tensor): Homogeneous ray direction vectors in
scene coordinate system; tensor of shape [N_rays, 3]
depth_range (tuple): [near_depth, far_depth]
inv_uniform (bool): If True,uniformly sampling inverse depth.
det (bool): If True, will perform deterministic sampling.
Returns:
pts (tensor): Tensor of shape [N_rays, N_samples, 3]
z_vals (tensor): Tensor of shape [N_rays, N_samples]
"""
# will sample inside [near_depth, far_depth]
# assume the nearest possible depth is at least (min_ratio * depth)
near_depth_value
=
depth_range
[
0
]
far_depth_value
=
depth_range
[
1
]
assert
near_depth_value
>
0
and
far_depth_value
>
0
\
and
far_depth_value
>
near_depth_value
near_depth
=
near_depth_value
*
torch
.
ones_like
(
ray_d
[...,
0
])
far_depth
=
far_depth_value
*
torch
.
ones_like
(
ray_d
[...,
0
])
if
inv_uniform
:
start
=
1.
/
near_depth
step
=
(
1.
/
far_depth
-
start
)
/
(
N_samples
-
1
)
inv_z_vals
=
torch
.
stack
([
start
+
i
*
step
for
i
in
range
(
N_samples
)],
dim
=
1
)
z_vals
=
1.
/
inv_z_vals
else
:
start
=
near_depth
step
=
(
far_depth
-
near_depth
)
/
(
N_samples
-
1
)
z_vals
=
torch
.
stack
([
start
+
i
*
step
for
i
in
range
(
N_samples
)],
dim
=
1
)
if
not
det
:
# get intervals between samples
mids
=
.
5
*
(
z_vals
[:,
1
:]
+
z_vals
[:,
:
-
1
])
upper
=
torch
.
cat
([
mids
,
z_vals
[:,
-
1
:]],
dim
=-
1
)
lower
=
torch
.
cat
([
z_vals
[:,
0
:
1
],
mids
],
dim
=-
1
)
# uniform samples in those intervals
t_rand
=
torch
.
rand_like
(
z_vals
)
z_vals
=
lower
+
(
upper
-
lower
)
*
t_rand
ray_d
=
ray_d
.
unsqueeze
(
1
).
repeat
(
1
,
N_samples
,
1
)
ray_o
=
ray_o
.
unsqueeze
(
1
).
repeat
(
1
,
N_samples
,
1
)
pts
=
z_vals
.
unsqueeze
(
2
)
*
ray_d
+
ray_o
# [N_rays, N_samples, 3]
return
pts
,
z_vals
# ray rendering of nerf
def
raw2outputs
(
raw
,
z_vals
,
mask
,
white_bkgd
=
False
):
"""Transform raw data to outputs:
Args:
raw(tensor):Raw network output.Tensor of shape [N_rays, N_samples, 4]
z_vals(tensor):Depth of point samples along rays.
Tensor of shape [N_rays, N_samples]
ray_d(tensor):[N_rays, 3]
Returns:
ret(dict):
-rgb(tensor):[N_rays, 3]
-depth(tensor):[N_rays,]
-weights(tensor):[N_rays,]
-depth_std(tensor):[N_rays,]
"""
rgb
=
raw
[:,
:,
:
3
]
# [N_rays, N_samples, 3]
sigma
=
raw
[:,
:,
3
]
# [N_rays, N_samples]
# note: we did not use the intervals here,
# because in practice different scenes from COLMAP can have
# very different scales, and using interval can affect
# the model's generalization ability.
# Therefore we don't use the intervals for both training and evaluation.
sigma2alpha
=
lambda
sigma
,
dists
:
1.
-
torch
.
exp
(
-
sigma
)
# noqa
# point samples are ordered with increasing depth
# interval between samples
dists
=
z_vals
[:,
1
:]
-
z_vals
[:,
:
-
1
]
dists
=
torch
.
cat
((
dists
,
dists
[:,
-
1
:]),
dim
=-
1
)
alpha
=
sigma2alpha
(
sigma
,
dists
)
T
=
torch
.
cumprod
(
1.
-
alpha
+
1e-10
,
dim
=-
1
)[:,
:
-
1
]
T
=
torch
.
cat
((
torch
.
ones_like
(
T
[:,
0
:
1
]),
T
),
dim
=-
1
)
# maths show weights, and summation of weights along a ray,
# are always inside [0, 1]
weights
=
alpha
*
T
rgb_map
=
torch
.
sum
(
weights
.
unsqueeze
(
2
)
*
rgb
,
dim
=
1
)
if
white_bkgd
:
rgb_map
=
rgb_map
+
(
1.
-
torch
.
sum
(
weights
,
dim
=-
1
,
keepdim
=
True
))
if
mask
is
not
None
:
mask
=
mask
.
float
().
sum
(
dim
=
1
)
>
8
depth_map
=
torch
.
sum
(
weights
*
z_vals
,
dim
=-
1
)
/
(
torch
.
sum
(
weights
,
dim
=-
1
)
+
1e-8
)
depth_map
=
torch
.
clamp
(
depth_map
,
z_vals
.
min
(),
z_vals
.
max
())
ret
=
OrderedDict
([(
'rgb'
,
rgb_map
),
(
'depth'
,
depth_map
),
(
'weights'
,
weights
),
(
'mask'
,
mask
),
(
'alpha'
,
alpha
),
(
'z_vals'
,
z_vals
),
(
'transparency'
,
T
)])
return
ret
def
render_rays_func
(
ray_o
,
ray_d
,
mean_volume
,
cov_volume
,
features_2D
,
img
,
aabb
,
near_far_range
,
N_samples
,
N_rand
=
4096
,
nerf_mlp
=
None
,
img_meta
=
None
,
projector
=
None
,
mode
=
'volume'
,
# volume and image
nerf_sample_view
=
3
,
inv_uniform
=
False
,
N_importance
=
0
,
det
=
False
,
is_train
=
True
,
white_bkgd
=
False
,
gt_rgb
=
None
,
gt_depth
=
None
):
ret
=
{
'outputs_coarse'
:
None
,
'outputs_fine'
:
None
,
'gt_rgb'
:
gt_rgb
,
'gt_depth'
:
gt_depth
}
# pts: [N_rays, N_samples, 3]
# z_vals: [N_rays, N_samples]
pts
,
z_vals
=
sample_along_camera_ray
(
ray_o
=
ray_o
,
ray_d
=
ray_d
,
depth_range
=
near_far_range
,
N_samples
=
N_samples
,
inv_uniform
=
inv_uniform
,
det
=
det
)
N_rays
,
N_samples
=
pts
.
shape
[:
2
]
if
mode
==
'image'
:
img
=
img
.
permute
(
0
,
2
,
3
,
1
).
unsqueeze
(
0
)
train_camera
=
_compute_projection
(
img_meta
).
to
(
img
.
device
)
rgb_feat
,
mask
=
projector
.
compute
(
pts
,
img
,
train_camera
,
features_2D
,
grid_sample
=
True
)
pixel_mask
=
mask
[...,
0
].
sum
(
dim
=
2
)
>
1
mean
,
var
=
compute_mask_points
(
rgb_feat
,
mask
)
globalfeat
=
torch
.
cat
([
mean
,
var
],
dim
=-
1
).
squeeze
(
2
)
rgb_pts
,
density_pts
=
nerf_mlp
(
pts
,
ray_d
,
globalfeat
)
raw_coarse
=
torch
.
cat
([
rgb_pts
,
density_pts
],
dim
=-
1
)
ret
[
'sigma'
]
=
density_pts
elif
mode
==
'volume'
:
mean_pts
,
inbound_masks
=
volume_sampling
(
pts
,
mean_volume
,
aabb
)
cov_pts
,
inbound_masks
=
volume_sampling
(
pts
,
cov_volume
,
aabb
)
# This masks is for indicating which points outside of aabb
img
=
img
.
permute
(
0
,
2
,
3
,
1
).
unsqueeze
(
0
)
train_camera
=
_compute_projection
(
img_meta
).
to
(
img
.
device
)
_
,
view_mask
=
projector
.
compute
(
pts
,
img
,
train_camera
,
None
)
pixel_mask
=
view_mask
[...,
0
].
sum
(
dim
=
2
)
>
1
# plot_3D_vis(pts, aabb, img, train_camera)
# [N_rays, N_samples], should at least have 2 observations
# This mask is for indicating which points do not have projected point
globalpts
=
torch
.
cat
([
mean_pts
,
cov_pts
],
dim
=-
1
)
rgb_pts
,
density_pts
=
nerf_mlp
(
pts
,
ray_d
,
globalpts
)
density_pts
=
density_pts
*
inbound_masks
.
unsqueeze
(
dim
=-
1
)
raw_coarse
=
torch
.
cat
([
rgb_pts
,
density_pts
],
dim
=-
1
)
outputs_coarse
=
raw2outputs
(
raw_coarse
,
z_vals
,
pixel_mask
,
white_bkgd
=
white_bkgd
)
ret
[
'outputs_coarse'
]
=
outputs_coarse
return
ret
def
render_rays
(
ray_batch
,
mean_volume
,
cov_volume
,
features_2D
,
img
,
aabb
,
near_far_range
,
N_samples
,
N_rand
=
4096
,
nerf_mlp
=
None
,
img_meta
=
None
,
projector
=
None
,
mode
=
'volume'
,
# volume and image
nerf_sample_view
=
3
,
inv_uniform
=
False
,
N_importance
=
0
,
det
=
False
,
is_train
=
True
,
white_bkgd
=
False
,
render_testing
=
False
):
"""The function of the nerf rendering."""
ray_o
=
ray_batch
[
'ray_o'
]
ray_d
=
ray_batch
[
'ray_d'
]
gt_rgb
=
ray_batch
[
'gt_rgb'
]
gt_depth
=
ray_batch
[
'gt_depth'
]
nerf_sizes
=
ray_batch
[
'nerf_sizes'
]
if
is_train
:
ray_o
=
ray_o
.
view
(
-
1
,
3
)
ray_d
=
ray_d
.
view
(
-
1
,
3
)
gt_rgb
=
gt_rgb
.
view
(
-
1
,
3
)
if
gt_depth
.
shape
[
1
]
!=
0
:
gt_depth
=
gt_depth
.
view
(
-
1
,
1
)
non_zero_depth
=
(
gt_depth
>
0
).
squeeze
(
-
1
)
ray_o
=
ray_o
[
non_zero_depth
]
ray_d
=
ray_d
[
non_zero_depth
]
gt_rgb
=
gt_rgb
[
non_zero_depth
]
gt_depth
=
gt_depth
[
non_zero_depth
]
else
:
gt_depth
=
None
total_rays
=
ray_d
.
shape
[
0
]
select_inds
=
rng
.
choice
(
total_rays
,
size
=
(
N_rand
,
),
replace
=
False
)
ray_o
=
ray_o
[
select_inds
]
ray_d
=
ray_d
[
select_inds
]
gt_rgb
=
gt_rgb
[
select_inds
]
if
gt_depth
is
not
None
:
gt_depth
=
gt_depth
[
select_inds
]
rets
=
render_rays_func
(
ray_o
,
ray_d
,
mean_volume
,
cov_volume
,
features_2D
,
img
,
aabb
,
near_far_range
,
N_samples
,
N_rand
,
nerf_mlp
,
img_meta
,
projector
,
mode
,
# volume and image
nerf_sample_view
,
inv_uniform
,
N_importance
,
det
,
is_train
,
white_bkgd
,
gt_rgb
,
gt_depth
)
elif
render_testing
:
nerf_size
=
nerf_sizes
[
0
]
view_num
=
ray_o
.
shape
[
1
]
H
=
nerf_size
[
0
][
0
]
W
=
nerf_size
[
0
][
1
]
ray_o
=
ray_o
.
view
(
-
1
,
3
)
ray_d
=
ray_d
.
view
(
-
1
,
3
)
gt_rgb
=
gt_rgb
.
view
(
-
1
,
3
)
print
(
gt_rgb
.
shape
)
if
len
(
gt_depth
)
!=
0
:
gt_depth
=
gt_depth
.
view
(
-
1
,
1
)
else
:
gt_depth
=
None
assert
view_num
*
H
*
W
==
ray_o
.
shape
[
0
]
num_rays
=
ray_o
.
shape
[
0
]
results
=
[]
rgbs
=
[]
for
i
in
range
(
0
,
num_rays
,
N_rand
):
ray_o_chunck
=
ray_o
[
i
:
i
+
N_rand
,
:]
ray_d_chunck
=
ray_d
[
i
:
i
+
N_rand
,
:]
ret
=
render_rays_func
(
ray_o_chunck
,
ray_d_chunck
,
mean_volume
,
cov_volume
,
features_2D
,
img
,
aabb
,
near_far_range
,
N_samples
,
N_rand
,
nerf_mlp
,
img_meta
,
projector
,
mode
,
nerf_sample_view
,
inv_uniform
,
N_importance
,
True
,
is_train
,
white_bkgd
,
gt_rgb
,
gt_depth
)
results
.
append
(
ret
)
rgbs
=
[]
depths
=
[]
if
results
[
0
][
'outputs_coarse'
]
is
not
None
:
for
i
in
range
(
len
(
results
)):
rgb
=
results
[
i
][
'outputs_coarse'
][
'rgb'
]
rgbs
.
append
(
rgb
)
depth
=
results
[
i
][
'outputs_coarse'
][
'depth'
]
depths
.
append
(
depth
)
rets
=
{
'outputs_coarse'
:
{
'rgb'
:
torch
.
cat
(
rgbs
,
dim
=
0
).
view
(
view_num
,
H
,
W
,
3
),
'depth'
:
torch
.
cat
(
depths
,
dim
=
0
).
view
(
view_num
,
H
,
W
,
1
),
},
'gt_rgb'
:
gt_rgb
.
view
(
view_num
,
H
,
W
,
3
),
'gt_depth'
:
gt_depth
.
view
(
view_num
,
H
,
W
,
1
)
if
gt_depth
is
not
None
else
None
,
}
else
:
rets
=
None
return
rets
projects/NeRF-Det/nerfdet/nerf_utils/save_rendered_img.py
0 → 100644
View file @
fe25f7a5
# Copyright (c) OpenMMLab. All rights reserved.
import
os
import
cv2
import
numpy
as
np
import
torch
from
skimage.metrics
import
structural_similarity
def
compute_psnr_from_mse
(
mse
):
return
-
10.0
*
torch
.
log
(
mse
)
/
np
.
log
(
10.0
)
def
compute_psnr
(
pred
,
target
,
mask
=
None
):
"""Compute psnr value (we assume the maximum pixel value is 1)."""
if
mask
is
not
None
:
pred
,
target
=
pred
[
mask
],
target
[
mask
]
mse
=
((
pred
-
target
)
**
2
).
mean
()
return
compute_psnr_from_mse
(
mse
).
cpu
().
numpy
()
def
compute_ssim
(
pred
,
target
,
mask
=
None
):
"""Computes Masked SSIM following the neuralbody paper."""
assert
pred
.
shape
==
target
.
shape
and
pred
.
shape
[
-
1
]
==
3
if
mask
is
not
None
:
x
,
y
,
w
,
h
=
cv2
.
boundingRect
(
mask
.
cpu
().
numpy
().
astype
(
np
.
uint8
))
pred
=
pred
[
y
:
y
+
h
,
x
:
x
+
w
]
target
=
target
[
y
:
y
+
h
,
x
:
x
+
w
]
try
:
ssim
=
structural_similarity
(
pred
.
cpu
().
numpy
(),
target
.
cpu
().
numpy
(),
channel_axis
=-
1
)
except
ValueError
:
ssim
=
structural_similarity
(
pred
.
cpu
().
numpy
(),
target
.
cpu
().
numpy
(),
multichannel
=
True
)
return
ssim
def
save_rendered_img
(
img_meta
,
rendered_results
):
filename
=
img_meta
[
0
][
'filename'
]
scenes
=
filename
.
split
(
'/'
)[
-
2
]
for
ret
in
rendered_results
:
depth
=
ret
[
'outputs_coarse'
][
'depth'
]
rgb
=
ret
[
'outputs_coarse'
][
'rgb'
]
gt
=
ret
[
'gt_rgb'
]
gt_depth
=
ret
[
'gt_depth'
]
# save images
psnr_total
=
0
ssim_total
=
0
rsme
=
0
for
v
in
range
(
gt
.
shape
[
0
]):
rsme
+=
((
depth
[
v
]
-
gt_depth
[
v
])
**
2
).
cpu
().
numpy
()
depth_
=
((
depth
[
v
]
-
depth
[
v
].
min
())
/
(
depth
[
v
].
max
()
-
depth
[
v
].
min
()
+
1e-8
)).
repeat
(
1
,
1
,
3
)
img_to_save
=
torch
.
cat
([
rgb
[
v
],
gt
[
v
],
depth_
],
dim
=
1
)
image_path
=
os
.
path
.
join
(
'nerf_vs_rebuttal'
,
scenes
)
if
not
os
.
path
.
exists
(
image_path
):
os
.
makedirs
(
image_path
)
save_dir
=
os
.
path
.
join
(
image_path
,
'view_'
+
str
(
v
)
+
'.png'
)
font
=
cv2
.
FONT_HERSHEY_SIMPLEX
org
=
(
50
,
50
)
fontScale
=
1
color
=
(
255
,
0
,
0
)
thickness
=
2
image
=
np
.
uint8
(
img_to_save
.
cpu
().
numpy
()
*
255.0
)
psnr
=
compute_psnr
(
rgb
[
v
],
gt
[
v
],
mask
=
None
)
psnr_total
+=
psnr
ssim
=
compute_ssim
(
rgb
[
v
],
gt
[
v
],
mask
=
None
)
ssim_total
+=
ssim
image
=
cv2
.
putText
(
image
,
'PSNR: '
+
'%.2f'
%
compute_psnr
(
rgb
[
v
],
gt
[
v
],
mask
=
None
),
org
,
font
,
fontScale
,
color
,
thickness
,
cv2
.
LINE_AA
)
cv2
.
imwrite
(
save_dir
,
image
)
return
psnr_total
/
gt
.
shape
[
0
],
ssim_total
/
gt
.
shape
[
0
],
rsme
/
gt
.
shape
[
0
]
projects/NeRF-Det/nerfdet/nerfdet.py
0 → 100644
View file @
fe25f7a5
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
List
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmdet3d.models.detectors
import
Base3DDetector
from
mmdet3d.registry
import
MODELS
,
TASK_UTILS
from
mmdet3d.structures.det3d_data_sample
import
SampleList
from
mmdet3d.utils
import
ConfigType
,
OptConfigType
from
.nerf_utils.nerf_mlp
import
VanillaNeRF
from
.nerf_utils.projection
import
Projector
from
.nerf_utils.render_ray
import
render_rays
# from ..utils.nerf_utils.save_rendered_img import save_rendered_img
@
MODELS
.
register_module
()
class
NerfDet
(
Base3DDetector
):
r
"""`ImVoxelNet <https://arxiv.org/abs/2307.14620>`_.
Args:
backbone (:obj:`ConfigDict` or dict): The backbone config.
neck (:obj:`ConfigDict` or dict): The neck config.
neck_3d(:obj:`ConfigDict` or dict): The 3D neck config.
bbox_head(:obj:`ConfigDict` or dict): The bbox head config.
prior_generator (:obj:`ConfigDict` or dict): The prior generator
config.
n_voxels (list): Number of voxels along x, y, z axis.
voxel_size (list): The size of voxels.Each voxel represents
a cube of `voxel_size[0]` meters, `voxel_size[1]` meters,
``
train_cfg (:obj:`ConfigDict` or dict, optional): Config dict of
training hyper-parameters. Defaults to None.
test_cfg (:obj:`ConfigDict` or dict, optional): Config dict of test
hyper-parameters. Defaults to None.
init_cfg (:obj:`ConfigDict` or dict, optional): The initialization
config. Defaults to None.
render_testing (bool): If you want to render novel view, please set
"render_testing = True" in config
The other args are the parameters of NeRF, you can just use the
default values.
"""
def
__init__
(
self
,
backbone
:
ConfigType
,
neck
:
ConfigType
,
neck_3d
:
ConfigType
,
bbox_head
:
ConfigType
,
prior_generator
:
ConfigType
,
n_voxels
:
List
,
voxel_size
:
List
,
head_2d
:
ConfigType
=
None
,
train_cfg
:
OptConfigType
=
None
,
test_cfg
:
OptConfigType
=
None
,
data_preprocessor
:
OptConfigType
=
None
,
init_cfg
:
OptConfigType
=
None
,
# pretrained,
aabb
:
Tuple
=
None
,
near_far_range
:
List
=
None
,
N_samples
:
int
=
64
,
N_rand
:
int
=
2048
,
depth_supervise
:
bool
=
False
,
use_nerf_mask
:
bool
=
True
,
nerf_sample_view
:
int
=
3
,
nerf_mode
:
str
=
'volume'
,
squeeze_scale
:
int
=
4
,
rgb_supervision
:
bool
=
True
,
nerf_density
:
bool
=
False
,
render_testing
:
bool
=
False
):
super
().
__init__
(
data_preprocessor
=
data_preprocessor
,
init_cfg
=
init_cfg
)
self
.
backbone
=
MODELS
.
build
(
backbone
)
self
.
neck
=
MODELS
.
build
(
neck
)
self
.
neck_3d
=
MODELS
.
build
(
neck_3d
)
bbox_head
.
update
(
train_cfg
=
train_cfg
)
bbox_head
.
update
(
test_cfg
=
test_cfg
)
self
.
bbox_head
=
MODELS
.
build
(
bbox_head
)
self
.
head_2d
=
MODELS
.
build
(
head_2d
)
if
head_2d
is
not
None
else
None
self
.
n_voxels
=
n_voxels
self
.
prior_generator
=
TASK_UTILS
.
build
(
prior_generator
)
self
.
voxel_size
=
voxel_size
self
.
train_cfg
=
train_cfg
self
.
test_cfg
=
test_cfg
self
.
aabb
=
aabb
self
.
near_far_range
=
near_far_range
self
.
N_samples
=
N_samples
self
.
N_rand
=
N_rand
self
.
depth_supervise
=
depth_supervise
self
.
projector
=
Projector
()
self
.
squeeze_scale
=
squeeze_scale
self
.
use_nerf_mask
=
use_nerf_mask
self
.
rgb_supervision
=
rgb_supervision
nerf_feature_dim
=
neck
[
'out_channels'
]
//
squeeze_scale
self
.
nerf_mlp
=
VanillaNeRF
(
net_depth
=
4
,
# The depth of the MLP
net_width
=
256
,
# The width of the MLP
skip_layer
=
3
,
# The layer to add skip layers to.
feature_dim
=
nerf_feature_dim
+
6
,
# + RGB original imgs
net_depth_condition
=
1
,
# The depth of the second part of MLP
net_width_condition
=
128
)
self
.
nerf_mode
=
nerf_mode
self
.
nerf_density
=
nerf_density
self
.
nerf_sample_view
=
nerf_sample_view
self
.
render_testing
=
render_testing
# hard code here, will deal with batch issue later.
self
.
cov
=
nn
.
Sequential
(
nn
.
Conv3d
(
neck
[
'out_channels'
],
neck
[
'out_channels'
],
kernel_size
=
3
,
padding
=
1
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv3d
(
neck
[
'out_channels'
],
neck
[
'out_channels'
],
kernel_size
=
3
,
padding
=
1
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv3d
(
neck
[
'out_channels'
],
1
,
kernel_size
=
1
))
self
.
mean_mapping
=
nn
.
Sequential
(
nn
.
Conv3d
(
neck
[
'out_channels'
],
nerf_feature_dim
//
2
,
kernel_size
=
1
))
self
.
cov_mapping
=
nn
.
Sequential
(
nn
.
Conv3d
(
neck
[
'out_channels'
],
nerf_feature_dim
//
2
,
kernel_size
=
1
))
self
.
mapping
=
nn
.
Sequential
(
nn
.
Linear
(
neck
[
'out_channels'
],
nerf_feature_dim
//
2
))
self
.
mapping_2d
=
nn
.
Sequential
(
nn
.
Conv2d
(
neck
[
'out_channels'
],
nerf_feature_dim
//
2
,
kernel_size
=
1
))
# self.overfit_nerfmlp = overfit_nerfmlp
# if self.overfit_nerfmlp:
# self. _finetuning_NeRF_MLP()
self
.
render_testing
=
render_testing
def
extract_feat
(
self
,
batch_inputs_dict
:
dict
,
batch_data_samples
:
SampleList
,
mode
,
depth
=
None
,
ray_batch
=
None
):
"""Extract 3d features from the backbone -> fpn -> 3d projection.
-> 3d neck -> bbox_head.
Args:
batch_inputs_dict (dict): The model input dict which include
the 'imgs' key.
- imgs (torch.Tensor, optional): Image of each sample.
batch_data_samples (list[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instances` of `gt_panoptic_seg` or `gt_sem_seg`
Returns:
Tuple:
- torch.Tensor: Features of shape (N, C_out, N_x, N_y, N_z).
- torch.Tensor: Valid mask of shape (N, 1, N_x, N_y, N_z).
- torch.Tensor: 2D features if needed.
- dict: The nerf rendered information including the
'output_coarse', 'gt_rgb' and 'gt_depth' keys.
"""
img
=
batch_inputs_dict
[
'imgs'
]
img
=
img
.
float
()
batch_img_metas
=
[
data_samples
.
metainfo
for
data_samples
in
batch_data_samples
]
batch_size
=
img
.
shape
[
0
]
if
len
(
img
.
shape
)
>
4
:
img
=
img
.
reshape
([
-
1
]
+
list
(
img
.
shape
)[
2
:])
x
=
self
.
backbone
(
img
)
x
=
self
.
neck
(
x
)[
0
]
x
=
x
.
reshape
([
batch_size
,
-
1
]
+
list
(
x
.
shape
[
1
:]))
else
:
x
=
self
.
backbone
(
img
)
x
=
self
.
neck
(
x
)[
0
]
if
depth
is
not
None
:
depth_bs
=
depth
.
shape
[
0
]
assert
depth_bs
==
batch_size
depth
=
batch_inputs_dict
[
'depth'
]
depth
=
depth
.
reshape
([
-
1
]
+
list
(
depth
.
shape
)[
2
:])
features_2d
=
self
.
head_2d
.
forward
(
x
[
-
1
],
batch_img_metas
)
\
if
self
.
head_2d
is
not
None
else
None
stride
=
img
.
shape
[
-
1
]
/
x
.
shape
[
-
1
]
assert
stride
==
4
stride
=
int
(
stride
)
volumes
,
valids
=
[],
[]
rgb_preds
=
[]
for
feature
,
img_meta
in
zip
(
x
,
batch_img_metas
):
angles
=
features_2d
[
0
]
if
features_2d
is
not
None
and
mode
==
'test'
else
None
projection
=
self
.
_compute_projection
(
img_meta
,
stride
,
angles
).
to
(
x
.
device
)
points
=
get_points
(
n_voxels
=
torch
.
tensor
(
self
.
n_voxels
),
voxel_size
=
torch
.
tensor
(
self
.
voxel_size
),
origin
=
torch
.
tensor
(
img_meta
[
'lidar2img'
][
'origin'
])).
to
(
x
.
device
)
height
=
img_meta
[
'img_shape'
][
0
]
//
stride
width
=
img_meta
[
'img_shape'
][
1
]
//
stride
# Construct the volume space
# volume together with valid is the constructed scene
# volume represents V_i and valid represents M_p
volume
,
valid
=
backproject
(
feature
[:,
:,
:
height
,
:
width
],
points
,
projection
,
depth
,
self
.
voxel_size
)
density
=
None
volume_sum
=
volume
.
sum
(
dim
=
0
)
# cov_valid = valid.clone().detach()
valid
=
valid
.
sum
(
dim
=
0
)
volume_mean
=
volume_sum
/
(
valid
+
1e-8
)
volume_mean
[:,
valid
[
0
]
==
0
]
=
.
0
# volume_cov = (volume - volume_mean.unsqueeze(0)) ** 2 * cov_valid
# volume_cov = torch.sum(volume_cov, dim=0) / (valid + 1e-8)
volume_cov
=
torch
.
sum
(
(
volume
-
volume_mean
.
unsqueeze
(
0
))
**
2
,
dim
=
0
)
/
(
valid
+
1e-8
)
volume_cov
[:,
valid
[
0
]
==
0
]
=
1e6
volume_cov
=
torch
.
exp
(
-
volume_cov
)
# default setting
# be careful here, the smaller the cov, the larger the weight.
n_channels
,
n_x_voxels
,
n_y_voxels
,
n_z_voxels
=
volume_mean
.
shape
if
ray_batch
is
not
None
:
if
self
.
nerf_mode
==
'volume'
:
mean_volume
=
self
.
mean_mapping
(
volume_mean
.
unsqueeze
(
0
))
cov_volume
=
self
.
cov_mapping
(
volume_cov
.
unsqueeze
(
0
))
feature_2d
=
feature
[:,
:,
:
height
,
:
width
]
elif
self
.
nerf_mode
==
'image'
:
mean_volume
=
None
cov_volume
=
None
feature_2d
=
feature
[:,
:,
:
height
,
:
width
]
n_v
,
C
,
height
,
width
=
feature_2d
.
shape
feature_2d
=
feature_2d
.
view
(
n_v
,
C
,
-
1
).
permute
(
0
,
2
,
1
).
contiguous
()
feature_2d
=
self
.
mapping
(
feature_2d
).
permute
(
0
,
2
,
1
).
contiguous
().
view
(
n_v
,
-
1
,
height
,
width
)
denorm_images
=
ray_batch
[
'denorm_images'
]
denorm_images
=
denorm_images
.
reshape
(
[
-
1
]
+
list
(
denorm_images
.
shape
)[
2
:])
rgb_projection
=
self
.
_compute_projection
(
img_meta
,
stride
=
1
,
angles
=
None
).
to
(
x
.
device
)
rgb_volume
,
_
=
backproject
(
denorm_images
[:,
:,
:
img_meta
[
'img_shape'
][
0
],
:
img_meta
[
'img_shape'
][
1
]],
points
,
rgb_projection
,
depth
,
self
.
voxel_size
)
ret
=
render_rays
(
ray_batch
,
mean_volume
,
cov_volume
,
feature_2d
,
denorm_images
,
self
.
aabb
,
self
.
near_far_range
,
self
.
N_samples
,
self
.
N_rand
,
self
.
nerf_mlp
,
img_meta
,
self
.
projector
,
self
.
nerf_mode
,
self
.
nerf_sample_view
,
is_train
=
True
if
mode
==
'train'
else
False
,
render_testing
=
self
.
render_testing
)
rgb_preds
.
append
(
ret
)
if
self
.
nerf_density
:
# would have 0 bias issue for mean_mapping.
n_v
,
C
,
n_x_voxels
,
n_y_voxels
,
n_z_voxels
=
volume
.
shape
volume
=
volume
.
view
(
n_v
,
C
,
-
1
).
permute
(
0
,
2
,
1
).
contiguous
()
mapping_volume
=
self
.
mapping
(
volume
).
permute
(
0
,
2
,
1
).
contiguous
().
view
(
n_v
,
-
1
,
n_x_voxels
,
n_y_voxels
,
n_z_voxels
)
mapping_volume
=
torch
.
cat
([
rgb_volume
,
mapping_volume
],
dim
=
1
)
mapping_volume_sum
=
mapping_volume
.
sum
(
dim
=
0
)
mapping_volume_mean
=
mapping_volume_sum
/
(
valid
+
1e-8
)
# mapping_volume_cov = (
# mapping_volume - mapping_volume_mean.unsqueeze(0)
# ) ** 2 * cov_valid
mapping_volume_cov
=
(
mapping_volume
-
mapping_volume_mean
.
unsqueeze
(
0
))
**
2
mapping_volume_cov
=
torch
.
sum
(
mapping_volume_cov
,
dim
=
0
)
/
(
valid
+
1e-8
)
mapping_volume_cov
[:,
valid
[
0
]
==
0
]
=
1e6
mapping_volume_cov
=
torch
.
exp
(
-
mapping_volume_cov
)
# default setting
global_volume
=
torch
.
cat
(
[
mapping_volume_mean
,
mapping_volume_cov
],
dim
=
1
)
global_volume
=
global_volume
.
view
(
-
1
,
n_x_voxels
*
n_y_voxels
*
n_z_voxels
).
permute
(
1
,
0
).
contiguous
()
points
=
points
.
view
(
3
,
-
1
).
permute
(
1
,
0
).
contiguous
()
density
=
self
.
nerf_mlp
.
query_density
(
points
,
global_volume
)
alpha
=
1
-
torch
.
exp
(
-
density
)
# density -> alpha
# (1, n_x_voxels, n_y_voxels, n_z_voxels)
volume
=
alpha
.
view
(
1
,
n_x_voxels
,
n_y_voxels
,
n_z_voxels
)
*
volume_mean
volume
[:,
valid
[
0
]
==
0
]
=
.
0
volumes
.
append
(
volume
)
valids
.
append
(
valid
)
x
=
torch
.
stack
(
volumes
)
x
=
self
.
neck_3d
(
x
)
return
x
,
torch
.
stack
(
valids
).
float
(),
features_2d
,
rgb_preds
def
loss
(
self
,
batch_inputs_dict
:
dict
,
batch_data_samples
:
SampleList
,
**
kwargs
)
->
Union
[
dict
,
list
]:
"""Calculate losses from a batch of inputs and data samples.
Args:
batch_inputs_dict (dict): The model input dict which include
the 'imgs' key.
- imgs (torch.Tensor, optional): Image of each sample.
batch_data_samples (list[:obj: `DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Returns:
dict: A dictionary of loss components.
"""
ray_batchs
=
{}
batch_images
=
[]
batch_depths
=
[]
if
'images'
in
batch_data_samples
[
0
].
gt_nerf_images
:
for
data_samples
in
batch_data_samples
:
image
=
data_samples
.
gt_nerf_images
[
'images'
]
batch_images
.
append
(
image
)
batch_images
=
torch
.
stack
(
batch_images
)
if
'depths'
in
batch_data_samples
[
0
].
gt_nerf_depths
:
for
data_samples
in
batch_data_samples
:
depth
=
data_samples
.
gt_nerf_depths
[
'depths'
]
batch_depths
.
append
(
depth
)
batch_depths
=
torch
.
stack
(
batch_depths
)
if
'raydirs'
in
batch_inputs_dict
.
keys
():
ray_batchs
[
'ray_o'
]
=
batch_inputs_dict
[
'lightpos'
]
ray_batchs
[
'ray_d'
]
=
batch_inputs_dict
[
'raydirs'
]
ray_batchs
[
'gt_rgb'
]
=
batch_images
ray_batchs
[
'gt_depth'
]
=
batch_depths
ray_batchs
[
'nerf_sizes'
]
=
batch_inputs_dict
[
'nerf_sizes'
]
ray_batchs
[
'denorm_images'
]
=
batch_inputs_dict
[
'denorm_images'
]
x
,
valids
,
features_2d
,
rgb_preds
=
self
.
extract_feat
(
batch_inputs_dict
,
batch_data_samples
,
'train'
,
depth
=
None
,
ray_batch
=
ray_batchs
)
else
:
x
,
valids
,
features_2d
,
rgb_preds
=
self
.
extract_feat
(
batch_inputs_dict
,
batch_data_samples
,
'train'
)
x
+=
(
valids
,
)
losses
=
self
.
bbox_head
.
loss
(
x
,
batch_data_samples
,
**
kwargs
)
# if self.head_2d is not None:
# losses.update(
# self.head_2d.loss(*features_2d, batch_data_samples)
# )
if
len
(
ray_batchs
)
!=
0
and
self
.
rgb_supervision
:
losses
.
update
(
self
.
nvs_loss_func
(
rgb_preds
))
if
self
.
depth_supervise
:
losses
.
update
(
self
.
depth_loss_func
(
rgb_preds
))
return
losses
def
nvs_loss_func
(
self
,
rgb_pred
):
loss
=
0
for
ret
in
rgb_pred
:
rgb
=
ret
[
'outputs_coarse'
][
'rgb'
]
gt
=
ret
[
'gt_rgb'
]
masks
=
ret
[
'outputs_coarse'
][
'mask'
]
if
self
.
use_nerf_mask
:
loss
+=
torch
.
sum
(
masks
.
unsqueeze
(
-
1
)
*
(
rgb
-
gt
)
**
2
)
/
(
masks
.
sum
()
+
1e-6
)
else
:
loss
+=
torch
.
mean
((
rgb
-
gt
)
**
2
)
return
dict
(
loss_nvs
=
loss
)
def
depth_loss_func
(
self
,
rgb_pred
):
loss
=
0
for
ret
in
rgb_pred
:
depth
=
ret
[
'outputs_coarse'
][
'depth'
]
gt
=
ret
[
'gt_depth'
].
squeeze
(
-
1
)
masks
=
ret
[
'outputs_coarse'
][
'mask'
]
if
self
.
use_nerf_mask
:
loss
+=
torch
.
sum
(
masks
*
torch
.
abs
(
depth
-
gt
))
/
(
masks
.
sum
()
+
1e-6
)
else
:
loss
+=
torch
.
mean
(
torch
.
abs
(
depth
-
gt
))
return
dict
(
loss_depth
=
loss
)
def
predict
(
self
,
batch_inputs_dict
:
dict
,
batch_data_samples
:
SampleList
,
**
kwargs
)
->
SampleList
:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
batch_inputs_dict (dict): The model input dict which include
the 'imgs' key.
- imgs (torch.Tensor, optional): Image of each sample.
batch_data_samples (List[:obj:`NeRFDet3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Returns:
list[:obj:`NeRFDet3DDataSample`]: Detection results of the
input images. Each NeRFDet3DDataSample usually contain
'pred_instances_3d'. And the ``pred_instances_3d`` usually
contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instance, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (Tensor): Contains a tensor with shape
(num_instances, C) where C = 6.
"""
ray_batchs
=
{}
batch_images
=
[]
batch_depths
=
[]
if
'images'
in
batch_data_samples
[
0
].
gt_nerf_images
:
for
data_samples
in
batch_data_samples
:
image
=
data_samples
.
gt_nerf_images
[
'images'
]
batch_images
.
append
(
image
)
batch_images
=
torch
.
stack
(
batch_images
)
if
'depths'
in
batch_data_samples
[
0
].
gt_nerf_depths
:
for
data_samples
in
batch_data_samples
:
depth
=
data_samples
.
gt_nerf_depths
[
'depths'
]
batch_depths
.
append
(
depth
)
batch_depths
=
torch
.
stack
(
batch_depths
)
if
'raydirs'
in
batch_inputs_dict
.
keys
():
ray_batchs
[
'ray_o'
]
=
batch_inputs_dict
[
'lightpos'
]
ray_batchs
[
'ray_d'
]
=
batch_inputs_dict
[
'raydirs'
]
ray_batchs
[
'gt_rgb'
]
=
batch_images
ray_batchs
[
'gt_depth'
]
=
batch_depths
ray_batchs
[
'nerf_sizes'
]
=
batch_inputs_dict
[
'nerf_sizes'
]
ray_batchs
[
'denorm_images'
]
=
batch_inputs_dict
[
'denorm_images'
]
x
,
valids
,
features_2d
,
rgb_preds
=
self
.
extract_feat
(
batch_inputs_dict
,
batch_data_samples
,
'test'
,
depth
=
None
,
ray_batch
=
ray_batchs
)
else
:
x
,
valids
,
features_2d
,
rgb_preds
=
self
.
extract_feat
(
batch_inputs_dict
,
batch_data_samples
,
'test'
)
x
+=
(
valids
,
)
results_list
=
self
.
bbox_head
.
predict
(
x
,
batch_data_samples
,
**
kwargs
)
predictions
=
self
.
add_pred_to_datasample
(
batch_data_samples
,
results_list
)
return
predictions
def
_forward
(
self
,
batch_inputs_dict
:
dict
,
batch_data_samples
:
SampleList
,
*
args
,
**
kwargs
)
->
Tuple
[
List
[
torch
.
Tensor
]]:
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
Args:
batch_inputs_dict (dict): The model input dict which include
the 'imgs' key.
- imgs (torch.Tensor, optional): Image of each sample.
batch_data_samples (List[:obj:`NeRFDet3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`
Returns:
tuple[list]: A tuple of features from ``bbox_head`` forward
"""
ray_batchs
=
{}
batch_images
=
[]
batch_depths
=
[]
if
'images'
in
batch_data_samples
[
0
].
gt_nerf_images
:
for
data_samples
in
batch_data_samples
:
image
=
data_samples
.
gt_nerf_images
[
'images'
]
batch_images
.
append
(
image
)
batch_images
=
torch
.
stack
(
batch_images
)
if
'depths'
in
batch_data_samples
[
0
].
gt_nerf_depths
:
for
data_samples
in
batch_data_samples
:
depth
=
data_samples
.
gt_nerf_depths
[
'depths'
]
batch_depths
.
append
(
depth
)
batch_depths
=
torch
.
stack
(
batch_depths
)
if
'raydirs'
in
batch_inputs_dict
.
keys
():
ray_batchs
[
'ray_o'
]
=
batch_inputs_dict
[
'lightpos'
]
ray_batchs
[
'ray_d'
]
=
batch_inputs_dict
[
'raydirs'
]
ray_batchs
[
'gt_rgb'
]
=
batch_images
ray_batchs
[
'gt_depth'
]
=
batch_depths
ray_batchs
[
'nerf_sizes'
]
=
batch_inputs_dict
[
'nerf_sizes'
]
ray_batchs
[
'denorm_images'
]
=
batch_inputs_dict
[
'denorm_images'
]
x
,
valids
,
features_2d
,
rgb_preds
=
self
.
extract_feat
(
batch_inputs_dict
,
batch_data_samples
,
'train'
,
depth
=
None
,
ray_batch
=
ray_batchs
)
else
:
x
,
valids
,
features_2d
,
rgb_preds
=
self
.
extract_feat
(
batch_inputs_dict
,
batch_data_samples
,
'train'
)
x
+=
(
valids
,
)
results
=
self
.
bbox_head
.
forward
(
x
)
return
results
def
aug_test
(
self
,
batch_inputs_dict
,
batch_data_samples
):
pass
def
show_results
(
self
,
*
args
,
**
kwargs
):
pass
@
staticmethod
def
_compute_projection
(
img_meta
,
stride
,
angles
):
projection
=
[]
intrinsic
=
torch
.
tensor
(
img_meta
[
'lidar2img'
][
'intrinsic'
][:
3
,
:
3
])
ratio
=
img_meta
[
'ori_shape'
][
0
]
/
(
img_meta
[
'img_shape'
][
0
]
/
stride
)
intrinsic
[:
2
]
/=
ratio
# use predict pitch and roll for SUNRGBDTotal test
if
angles
is
not
None
:
extrinsics
=
[]
for
angle
in
angles
:
extrinsics
.
append
(
get_extrinsics
(
angle
).
to
(
intrinsic
.
device
))
else
:
extrinsics
=
map
(
torch
.
tensor
,
img_meta
[
'lidar2img'
][
'extrinsic'
])
for
extrinsic
in
extrinsics
:
projection
.
append
(
intrinsic
@
extrinsic
[:
3
])
return
torch
.
stack
(
projection
)
@
torch
.
no_grad
()
def
get_points
(
n_voxels
,
voxel_size
,
origin
):
# origin: point-cloud center.
points
=
torch
.
stack
(
torch
.
meshgrid
([
torch
.
arange
(
n_voxels
[
0
]),
# 40 W width, x
torch
.
arange
(
n_voxels
[
1
]),
# 40 D depth, y
torch
.
arange
(
n_voxels
[
2
])
# 16 H Height, z
]))
new_origin
=
origin
-
n_voxels
/
2.
*
voxel_size
points
=
points
*
voxel_size
.
view
(
3
,
1
,
1
,
1
)
+
new_origin
.
view
(
3
,
1
,
1
,
1
)
return
points
# modify from https://github.com/magicleap/Atlas/blob/master/atlas/model.py
def
backproject
(
features
,
points
,
projection
,
depth
,
voxel_size
):
n_images
,
n_channels
,
height
,
width
=
features
.
shape
n_x_voxels
,
n_y_voxels
,
n_z_voxels
=
points
.
shape
[
-
3
:]
points
=
points
.
view
(
1
,
3
,
-
1
).
expand
(
n_images
,
3
,
-
1
)
points
=
torch
.
cat
((
points
,
torch
.
ones_like
(
points
[:,
:
1
])),
dim
=
1
)
points_2d_3
=
torch
.
bmm
(
projection
,
points
)
x
=
(
points_2d_3
[:,
0
]
/
points_2d_3
[:,
2
]).
round
().
long
()
y
=
(
points_2d_3
[:,
1
]
/
points_2d_3
[:,
2
]).
round
().
long
()
z
=
points_2d_3
[:,
2
]
valid
=
(
x
>=
0
)
&
(
y
>=
0
)
&
(
x
<
width
)
&
(
y
<
height
)
&
(
z
>
0
)
# below is using depth to sample feature
if
depth
is
not
None
:
depth
=
F
.
interpolate
(
depth
.
unsqueeze
(
1
),
size
=
(
height
,
width
),
mode
=
'bilinear'
).
squeeze
(
1
)
for
i
in
range
(
n_images
):
z_mask
=
z
.
clone
()
>
0
z_mask
[
i
,
valid
[
i
]]
=
\
(
z
[
i
,
valid
[
i
]]
>
depth
[
i
,
y
[
i
,
valid
[
i
]],
x
[
i
,
valid
[
i
]]]
-
voxel_size
[
-
1
])
&
\
(
z
[
i
,
valid
[
i
]]
<
depth
[
i
,
y
[
i
,
valid
[
i
]],
x
[
i
,
valid
[
i
]]]
+
voxel_size
[
-
1
])
# noqa
valid
=
valid
&
z_mask
volume
=
torch
.
zeros
((
n_images
,
n_channels
,
points
.
shape
[
-
1
]),
device
=
features
.
device
)
for
i
in
range
(
n_images
):
volume
[
i
,
:,
valid
[
i
]]
=
features
[
i
,
:,
y
[
i
,
valid
[
i
]],
x
[
i
,
valid
[
i
]]]
volume
=
volume
.
view
(
n_images
,
n_channels
,
n_x_voxels
,
n_y_voxels
,
n_z_voxels
)
valid
=
valid
.
view
(
n_images
,
1
,
n_x_voxels
,
n_y_voxels
,
n_z_voxels
)
return
volume
,
valid
# for SUNRGBDTotal test
def
get_extrinsics
(
angles
):
yaw
=
angles
.
new_zeros
(())
pitch
,
roll
=
angles
r
=
angles
.
new_zeros
((
3
,
3
))
r
[
0
,
0
]
=
torch
.
cos
(
yaw
)
*
torch
.
cos
(
pitch
)
r
[
0
,
1
]
=
torch
.
sin
(
yaw
)
*
torch
.
sin
(
roll
)
-
torch
.
cos
(
yaw
)
*
torch
.
cos
(
roll
)
*
torch
.
sin
(
pitch
)
r
[
0
,
2
]
=
torch
.
cos
(
roll
)
*
torch
.
sin
(
yaw
)
+
torch
.
cos
(
yaw
)
*
torch
.
sin
(
pitch
)
*
torch
.
sin
(
roll
)
r
[
1
,
0
]
=
torch
.
sin
(
pitch
)
r
[
1
,
1
]
=
torch
.
cos
(
pitch
)
*
torch
.
cos
(
roll
)
r
[
1
,
2
]
=
-
torch
.
cos
(
pitch
)
*
torch
.
sin
(
roll
)
r
[
2
,
0
]
=
-
torch
.
cos
(
pitch
)
*
torch
.
sin
(
yaw
)
r
[
2
,
1
]
=
torch
.
cos
(
yaw
)
*
torch
.
sin
(
roll
)
+
torch
.
cos
(
roll
)
*
torch
.
sin
(
yaw
)
*
torch
.
sin
(
pitch
)
r
[
2
,
2
]
=
torch
.
cos
(
yaw
)
*
torch
.
cos
(
roll
)
-
torch
.
sin
(
yaw
)
*
torch
.
sin
(
pitch
)
*
torch
.
sin
(
roll
)
# follow Total3DUnderstanding
t
=
angles
.
new_tensor
([[
0.
,
0.
,
1.
],
[
0.
,
-
1.
,
0.
],
[
-
1.
,
0.
,
0.
]])
r
=
t
@
r
.
T
# follow DepthInstance3DBoxes
r
=
r
[:,
[
2
,
0
,
1
]]
r
[
2
]
*=
-
1
extrinsic
=
angles
.
new_zeros
((
4
,
4
))
extrinsic
[:
3
,
:
3
]
=
r
extrinsic
[
3
,
3
]
=
1.
return
extrinsic
projects/NeRF-Det/nerfdet/nerfdet_head.py
0 → 100644
View file @
fe25f7a5
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
List
,
Tuple
import
torch
from
mmcv.cnn
import
Scale
# from mmcv.ops import nms3d, nms3d_normal
from
mmdet.models.utils
import
multi_apply
from
mmdet.utils
import
reduce_mean
# from mmengine.config import ConfigDict
from
mmengine.model
import
BaseModule
,
bias_init_with_prob
,
normal_init
from
mmengine.structures
import
InstanceData
from
torch
import
Tensor
,
nn
from
mmdet3d.registry
import
MODELS
,
TASK_UTILS
# from mmdet3d.structures.bbox_3d.utils import rotation_3d_in_axis
from
mmdet3d.structures.det3d_data_sample
import
SampleList
from
mmdet3d.utils.typing_utils
import
(
ConfigType
,
InstanceList
,
OptConfigType
,
OptInstanceList
)
@
torch
.
no_grad
()
def
get_points
(
n_voxels
,
voxel_size
,
origin
):
# origin: point-cloud center.
points
=
torch
.
stack
(
torch
.
meshgrid
([
torch
.
arange
(
n_voxels
[
0
]),
# 40 W width, x
torch
.
arange
(
n_voxels
[
1
]),
# 40 D depth, y
torch
.
arange
(
n_voxels
[
2
])
# 16 H Height, z
]))
new_origin
=
origin
-
n_voxels
/
2.
*
voxel_size
points
=
points
*
voxel_size
.
view
(
3
,
1
,
1
,
1
)
+
new_origin
.
view
(
3
,
1
,
1
,
1
)
return
points
@
MODELS
.
register_module
()
class
NerfDetHead
(
BaseModule
):
r
"""`ImVoxelNet<https://arxiv.org/abs/2106.01178>`_ head for indoor
datasets.
Args:
n_classes (int): Number of classes.
n_levels (int): Number of feature levels.
n_channels (int): Number of channels in input tensors.
n_reg_outs (int): Number of regression layer channels.
pts_assign_threshold (int): Min number of location per box to
be assigned with.
pts_center_threshold (int): Max number of locations per box to
be assigned with.
center_loss (dict, optional): Config of centerness loss.
Default: dict(type='CrossEntropyLoss', use_sigmoid=True).
bbox_loss (dict, optional): Config of bbox loss.
Default: dict(type='RotatedIoU3DLoss').
cls_loss (dict, optional): Config of classification loss.
Default: dict(type='FocalLoss').
train_cfg (dict, optional): Config for train stage. Defaults to None.
test_cfg (dict, optional): Config for test stage. Defaults to None.
init_cfg (dict, optional): Config for weight initialization.
Defaults to None.
"""
def
__init__
(
self
,
n_classes
:
int
,
n_levels
:
int
,
n_channels
:
int
,
n_reg_outs
:
int
,
pts_assign_threshold
:
int
,
pts_center_threshold
:
int
,
prior_generator
:
ConfigType
,
center_loss
:
ConfigType
=
dict
(
type
=
'mmdet.CrossEntropyLoss'
,
use_sigmoid
=
True
),
bbox_loss
:
ConfigType
=
dict
(
type
=
'RotatedIoU3DLoss'
),
cls_loss
:
ConfigType
=
dict
(
type
=
'mmdet.FocalLoss'
),
train_cfg
:
OptConfigType
=
None
,
test_cfg
:
OptConfigType
=
None
,
init_cfg
:
OptConfigType
=
None
):
super
(
NerfDetHead
,
self
).
__init__
(
init_cfg
)
self
.
n_classes
=
n_classes
self
.
n_levels
=
n_levels
self
.
n_reg_outs
=
n_reg_outs
self
.
pts_assign_threshold
=
pts_assign_threshold
self
.
pts_center_threshold
=
pts_center_threshold
self
.
prior_generator
=
TASK_UTILS
.
build
(
prior_generator
)
self
.
center_loss
=
MODELS
.
build
(
center_loss
)
self
.
bbox_loss
=
MODELS
.
build
(
bbox_loss
)
self
.
cls_loss
=
MODELS
.
build
(
cls_loss
)
self
.
train_cfg
=
train_cfg
self
.
test_cfg
=
test_cfg
self
.
_init_layers
(
n_channels
,
n_reg_outs
,
n_classes
,
n_levels
)
def
_init_layers
(
self
,
n_channels
,
n_reg_outs
,
n_classes
,
n_levels
):
"""Initialize neural network layers of the head."""
self
.
conv_center
=
nn
.
Conv3d
(
n_channels
,
1
,
3
,
padding
=
1
,
bias
=
False
)
self
.
conv_reg
=
nn
.
Conv3d
(
n_channels
,
n_reg_outs
,
3
,
padding
=
1
,
bias
=
False
)
self
.
conv_cls
=
nn
.
Conv3d
(
n_channels
,
n_classes
,
3
,
padding
=
1
)
self
.
scales
=
nn
.
ModuleList
([
Scale
(
1.
)
for
_
in
range
(
n_levels
)])
def
init_weights
(
self
):
"""Initialize all layer weights."""
normal_init
(
self
.
conv_center
,
std
=
.
01
)
normal_init
(
self
.
conv_reg
,
std
=
.
01
)
normal_init
(
self
.
conv_cls
,
std
=
.
01
,
bias
=
bias_init_with_prob
(.
01
))
def
_forward_single
(
self
,
x
:
Tensor
,
scale
:
Scale
):
"""Forward pass per level.
Args:
x (Tensor): Per level 3d neck output tensor.
scale (mmcv.cnn.Scale): Per level multiplication weight.
Returns:
tuple[Tensor]: Centerness, bbox and classification predictions.
"""
return
(
self
.
conv_center
(
x
),
torch
.
exp
(
scale
(
self
.
conv_reg
(
x
))),
self
.
conv_cls
(
x
))
def
forward
(
self
,
x
):
return
multi_apply
(
self
.
_forward_single
,
x
,
self
.
scales
)
def
loss
(
self
,
x
:
Tuple
[
Tensor
],
batch_data_samples
:
SampleList
,
**
kwargs
)
->
dict
:
"""Perform forward propagation and loss calculation of the detection
head on the features of the upstream network.
Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
batch_data_samples (List[:obj:`NeRFDet3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
Returns:
dict: A dictionary of loss components.
"""
valid_pred
=
x
[
-
1
]
outs
=
self
(
x
[:
-
1
])
batch_gt_instances_3d
=
[]
batch_gt_instances_ignore
=
[]
batch_input_metas
=
[]
for
data_sample
in
batch_data_samples
:
batch_input_metas
.
append
(
data_sample
.
metainfo
)
batch_gt_instances_3d
.
append
(
data_sample
.
gt_instances_3d
)
batch_gt_instances_ignore
.
append
(
data_sample
.
get
(
'ignored_instances'
,
None
))
loss_inputs
=
outs
+
(
valid_pred
,
batch_gt_instances_3d
,
batch_input_metas
,
batch_gt_instances_ignore
)
losses
=
self
.
loss_by_feat
(
*
loss_inputs
)
return
losses
def
loss_by_feat
(
self
,
center_preds
:
List
[
List
[
Tensor
]],
bbox_preds
:
List
[
List
[
Tensor
]],
cls_preds
:
List
[
List
[
Tensor
]],
valid_pred
:
Tensor
,
batch_gt_instances_3d
:
InstanceList
,
batch_input_metas
:
List
[
dict
],
batch_gt_instances_ignore
:
OptInstanceList
=
None
,
**
kwargs
)
->
dict
:
"""Per scene loss function.
Args:
center_preds (list[list[Tensor]]): Centerness predictions for
all scenes. The first list contains predictions from different
levels. The second list contains predictions in a mini-batch.
bbox_preds (list[list[Tensor]]): Bbox predictions for all scenes.
The first list contains predictions from different
levels. The second list contains predictions in a mini-batch.
cls_preds (list[list[Tensor]]): Classification predictions for all
scenes. The first list contains predictions from different
levels. The second list contains predictions in a mini-batch.
valid_pred (Tensor): Valid mask prediction for all scenes.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instance_3d. It usually includes ``bboxes_3d``、`
`labels_3d``、``depths``、``centers_2d`` and attributes.
batch_input_metas (list[dict]): Meta information of each image,
e.g., image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict: Centerness, bbox, and classification loss values.
"""
valid_preds
=
self
.
_upsample_valid_preds
(
valid_pred
,
center_preds
)
center_losses
,
bbox_losses
,
cls_losses
=
[],
[],
[]
for
i
in
range
(
len
(
batch_input_metas
)):
center_loss
,
bbox_loss
,
cls_loss
=
self
.
_loss_by_feat_single
(
center_preds
=
[
x
[
i
]
for
x
in
center_preds
],
bbox_preds
=
[
x
[
i
]
for
x
in
bbox_preds
],
cls_preds
=
[
x
[
i
]
for
x
in
cls_preds
],
valid_preds
=
[
x
[
i
]
for
x
in
valid_preds
],
input_meta
=
batch_input_metas
[
i
],
gt_bboxes
=
batch_gt_instances_3d
[
i
].
bboxes_3d
,
gt_labels
=
batch_gt_instances_3d
[
i
].
labels_3d
)
center_losses
.
append
(
center_loss
)
bbox_losses
.
append
(
bbox_loss
)
cls_losses
.
append
(
cls_loss
)
return
dict
(
center_loss
=
torch
.
mean
(
torch
.
stack
(
center_losses
)),
bbox_loss
=
torch
.
mean
(
torch
.
stack
(
bbox_losses
)),
cls_loss
=
torch
.
mean
(
torch
.
stack
(
cls_losses
)))
def
_loss_by_feat_single
(
self
,
center_preds
,
bbox_preds
,
cls_preds
,
valid_preds
,
input_meta
,
gt_bboxes
,
gt_labels
):
featmap_sizes
=
[
featmap
.
size
()[
-
3
:]
for
featmap
in
center_preds
]
points
=
self
.
_get_points
(
featmap_sizes
=
featmap_sizes
,
origin
=
input_meta
[
'lidar2img'
][
'origin'
],
device
=
gt_bboxes
.
device
)
center_targets
,
bbox_targets
,
cls_targets
=
self
.
_get_targets
(
points
,
gt_bboxes
,
gt_labels
)
center_preds
=
torch
.
cat
(
[
x
.
permute
(
1
,
2
,
3
,
0
).
reshape
(
-
1
)
for
x
in
center_preds
])
bbox_preds
=
torch
.
cat
([
x
.
permute
(
1
,
2
,
3
,
0
).
reshape
(
-
1
,
x
.
shape
[
0
])
for
x
in
bbox_preds
])
cls_preds
=
torch
.
cat
(
[
x
.
permute
(
1
,
2
,
3
,
0
).
reshape
(
-
1
,
x
.
shape
[
0
])
for
x
in
cls_preds
])
valid_preds
=
torch
.
cat
(
[
x
.
permute
(
1
,
2
,
3
,
0
).
reshape
(
-
1
)
for
x
in
valid_preds
])
points
=
torch
.
cat
(
points
)
# cls loss
pos_inds
=
torch
.
nonzero
(
torch
.
logical_and
(
cls_targets
>=
0
,
valid_preds
)).
squeeze
(
1
)
n_pos
=
points
.
new_tensor
(
len
(
pos_inds
))
n_pos
=
max
(
reduce_mean
(
n_pos
),
1.
)
if
torch
.
any
(
valid_preds
):
cls_loss
=
self
.
cls_loss
(
cls_preds
[
valid_preds
],
cls_targets
[
valid_preds
],
avg_factor
=
n_pos
)
else
:
cls_loss
=
cls_preds
[
valid_preds
].
sum
()
# bbox and centerness losses
pos_center_preds
=
center_preds
[
pos_inds
]
pos_bbox_preds
=
bbox_preds
[
pos_inds
]
if
len
(
pos_inds
)
>
0
:
pos_center_targets
=
center_targets
[
pos_inds
]
pos_bbox_targets
=
bbox_targets
[
pos_inds
]
pos_points
=
points
[
pos_inds
]
center_loss
=
self
.
center_loss
(
pos_center_preds
,
pos_center_targets
,
avg_factor
=
n_pos
)
bbox_loss
=
self
.
bbox_loss
(
self
.
_bbox_pred_to_bbox
(
pos_points
,
pos_bbox_preds
),
pos_bbox_targets
,
weight
=
pos_center_targets
,
avg_factor
=
pos_center_targets
.
sum
())
else
:
center_loss
=
pos_center_preds
.
sum
()
bbox_loss
=
pos_bbox_preds
.
sum
()
return
center_loss
,
bbox_loss
,
cls_loss
def
predict
(
self
,
x
:
Tuple
[
Tensor
],
batch_data_samples
:
SampleList
,
rescale
:
bool
=
False
)
->
InstanceList
:
"""Perform forward propagation of the 3D detection head and predict
detection results on the features of the upstream network.
Args:
x (tuple[Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
batch_data_samples (List[:obj:`NeRFDet3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, `gt_pts_panoptic_seg` and
`gt_pts_sem_seg`.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[:obj:`InstanceData`]: Detection results of each sample
after the post process.
Each item usually contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instances, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes,
contains a tensor with shape (num_instances, C), where
C >= 6.
"""
batch_input_metas
=
[
data_samples
.
metainfo
for
data_samples
in
batch_data_samples
]
valid_pred
=
x
[
-
1
]
outs
=
self
(
x
[:
-
1
])
predictions
=
self
.
predict_by_feat
(
*
outs
,
valid_pred
=
valid_pred
,
batch_input_metas
=
batch_input_metas
,
rescale
=
rescale
)
return
predictions
def
predict_by_feat
(
self
,
center_preds
:
List
[
List
[
Tensor
]],
bbox_preds
:
List
[
List
[
Tensor
]],
cls_preds
:
List
[
List
[
Tensor
]],
valid_pred
:
Tensor
,
batch_input_metas
:
List
[
dict
],
**
kwargs
)
->
List
[
InstanceData
]:
"""Generate boxes for all scenes.
Args:
center_preds (list[list[Tensor]]): Centerness predictions for
all scenes.
bbox_preds (list[list[Tensor]]): Bbox predictions for all scenes.
cls_preds (list[list[Tensor]]): Classification predictions for all
scenes.
valid_pred (Tensor): Valid mask prediction for all scenes.
batch_input_metas (list[dict]): Meta infos for all scenes.
Returns:
list[tuple[Tensor]]: Predicted bboxes, scores, and labels for
all scenes.
"""
valid_preds
=
self
.
_upsample_valid_preds
(
valid_pred
,
center_preds
)
results
=
[]
for
i
in
range
(
len
(
batch_input_metas
)):
results
.
append
(
self
.
_predict_by_feat_single
(
center_preds
=
[
x
[
i
]
for
x
in
center_preds
],
bbox_preds
=
[
x
[
i
]
for
x
in
bbox_preds
],
cls_preds
=
[
x
[
i
]
for
x
in
cls_preds
],
valid_preds
=
[
x
[
i
]
for
x
in
valid_preds
],
input_meta
=
batch_input_metas
[
i
]))
return
results
def
_predict_by_feat_single
(
self
,
center_preds
:
List
[
Tensor
],
bbox_preds
:
List
[
Tensor
],
cls_preds
:
List
[
Tensor
],
valid_preds
:
List
[
Tensor
],
input_meta
:
dict
)
->
InstanceData
:
"""Generate boxes for single sample.
Args:
center_preds (list[Tensor]): Centerness predictions for all levels.
bbox_preds (list[Tensor]): Bbox predictions for all levels.
cls_preds (list[Tensor]): Classification predictions for all
levels.
valid_preds (tuple[Tensor]): Upsampled valid masks for all feature
levels.
input_meta (dict): Scene meta info.
Returns:
tuple[Tensor]: Predicted bounding boxes, scores and labels.
"""
featmap_sizes
=
[
featmap
.
size
()[
-
3
:]
for
featmap
in
center_preds
]
points
=
self
.
_get_points
(
featmap_sizes
=
featmap_sizes
,
origin
=
input_meta
[
'lidar2img'
][
'origin'
],
device
=
center_preds
[
0
].
device
)
mlvl_bboxes
,
mlvl_scores
=
[],
[]
for
center_pred
,
bbox_pred
,
cls_pred
,
valid_pred
,
point
in
zip
(
center_preds
,
bbox_preds
,
cls_preds
,
valid_preds
,
points
):
center_pred
=
center_pred
.
permute
(
1
,
2
,
3
,
0
).
reshape
(
-
1
,
1
)
bbox_pred
=
bbox_pred
.
permute
(
1
,
2
,
3
,
0
).
reshape
(
-
1
,
bbox_pred
.
shape
[
0
])
cls_pred
=
cls_pred
.
permute
(
1
,
2
,
3
,
0
).
reshape
(
-
1
,
cls_pred
.
shape
[
0
])
valid_pred
=
valid_pred
.
permute
(
1
,
2
,
3
,
0
).
reshape
(
-
1
,
1
)
scores
=
cls_pred
.
sigmoid
()
*
center_pred
.
sigmoid
()
*
valid_pred
max_scores
,
_
=
scores
.
max
(
dim
=
1
)
if
len
(
scores
)
>
self
.
test_cfg
.
nms_pre
>
0
:
_
,
ids
=
max_scores
.
topk
(
self
.
test_cfg
.
nms_pre
)
bbox_pred
=
bbox_pred
[
ids
]
scores
=
scores
[
ids
]
point
=
point
[
ids
]
bboxes
=
self
.
_bbox_pred_to_bbox
(
point
,
bbox_pred
)
mlvl_bboxes
.
append
(
bboxes
)
mlvl_scores
.
append
(
scores
)
bboxes
=
torch
.
cat
(
mlvl_bboxes
)
scores
=
torch
.
cat
(
mlvl_scores
)
bboxes
,
scores
,
labels
=
self
.
_nms
(
bboxes
,
scores
,
input_meta
)
bboxes
=
input_meta
[
'box_type_3d'
](
bboxes
,
box_dim
=
6
,
with_yaw
=
False
,
origin
=
(.
5
,
.
5
,
.
5
))
results
=
InstanceData
()
results
.
bboxes_3d
=
bboxes
results
.
scores_3d
=
scores
results
.
labels_3d
=
labels
return
results
@
staticmethod
def
_upsample_valid_preds
(
valid_pred
,
features
):
"""Upsample valid mask predictions.
Args:
valid_pred (Tensor): Valid mask prediction.
features (Tensor): Feature tensor.
Returns:
tuple[Tensor]: Upsampled valid masks for all feature levels.
"""
return
[
nn
.
Upsample
(
size
=
x
.
shape
[
-
3
:],
mode
=
'trilinear'
)(
valid_pred
).
round
().
bool
()
for
x
in
features
]
@
torch
.
no_grad
()
def
_get_points
(
self
,
featmap_sizes
,
origin
,
device
):
mlvl_points
=
[]
tmp_voxel_size
=
[.
16
,
.
16
,
.
2
]
for
i
,
featmap_size
in
enumerate
(
featmap_sizes
):
mlvl_points
.
append
(
get_points
(
n_voxels
=
torch
.
tensor
(
featmap_size
),
voxel_size
=
torch
.
tensor
(
tmp_voxel_size
)
*
(
2
**
i
),
origin
=
torch
.
tensor
(
origin
)).
reshape
(
3
,
-
1
).
transpose
(
0
,
1
).
to
(
device
))
return
mlvl_points
def
_bbox_pred_to_bbox
(
self
,
points
,
bbox_pred
):
return
torch
.
stack
([
points
[:,
0
]
-
bbox_pred
[:,
0
],
points
[:,
1
]
-
bbox_pred
[:,
2
],
points
[:,
2
]
-
bbox_pred
[:,
4
],
points
[:,
0
]
+
bbox_pred
[:,
1
],
points
[:,
1
]
+
bbox_pred
[:,
3
],
points
[:,
2
]
+
bbox_pred
[:,
5
]
],
-
1
)
def
_bbox_pred_to_loss
(
self
,
points
,
bbox_preds
):
return
self
.
_bbox_pred_to_bbox
(
points
,
bbox_preds
)
# The function is directly copied from FCAF3DHead.
@
staticmethod
def
_get_face_distances
(
points
,
boxes
):
"""Calculate distances from point to box faces.
Args:
points (Tensor): Final locations of shape (N_points, N_boxes, 3).
boxes (Tensor): 3D boxes of shape (N_points, N_boxes, 7)
Returns:
Tensor: Face distances of shape (N_points, N_boxes, 6),
(dx_min, dx_max, dy_min, dy_max, dz_min, dz_max).
"""
dx_min
=
points
[...,
0
]
-
boxes
[...,
0
]
+
boxes
[...,
3
]
/
2
dx_max
=
boxes
[...,
0
]
+
boxes
[...,
3
]
/
2
-
points
[...,
0
]
dy_min
=
points
[...,
1
]
-
boxes
[...,
1
]
+
boxes
[...,
4
]
/
2
dy_max
=
boxes
[...,
1
]
+
boxes
[...,
4
]
/
2
-
points
[...,
1
]
dz_min
=
points
[...,
2
]
-
boxes
[...,
2
]
+
boxes
[...,
5
]
/
2
dz_max
=
boxes
[...,
2
]
+
boxes
[...,
5
]
/
2
-
points
[...,
2
]
return
torch
.
stack
((
dx_min
,
dx_max
,
dy_min
,
dy_max
,
dz_min
,
dz_max
),
dim
=-
1
)
@
staticmethod
def
_get_centerness
(
face_distances
):
"""Compute point centerness w.r.t containing box.
Args:
face_distances (Tensor): Face distances of shape (B, N, 6),
(dx_min, dx_max, dy_min, dy_max, dz_min, dz_max).
Returns:
Tensor: Centerness of shape (B, N).
"""
x_dims
=
face_distances
[...,
[
0
,
1
]]
y_dims
=
face_distances
[...,
[
2
,
3
]]
z_dims
=
face_distances
[...,
[
4
,
5
]]
centerness_targets
=
x_dims
.
min
(
dim
=-
1
)[
0
]
/
x_dims
.
max
(
dim
=-
1
)[
0
]
*
\
y_dims
.
min
(
dim
=-
1
)[
0
]
/
y_dims
.
max
(
dim
=-
1
)[
0
]
*
\
z_dims
.
min
(
dim
=-
1
)[
0
]
/
z_dims
.
max
(
dim
=-
1
)[
0
]
return
torch
.
sqrt
(
centerness_targets
)
@
torch
.
no_grad
()
def
_get_targets
(
self
,
points
,
gt_bboxes
,
gt_labels
):
"""Compute targets for final locations for a single scene.
Args:
points (list[Tensor]): Final locations for all levels.
gt_bboxes (BaseInstance3DBoxes): Ground truth boxes.
gt_labels (Tensor): Ground truth labels.
Returns:
tuple[Tensor]: Centerness, bbox and classification
targets for all locations.
"""
float_max
=
1e8
expanded_scales
=
[
points
[
i
].
new_tensor
(
i
).
expand
(
len
(
points
[
i
])).
to
(
gt_labels
.
device
)
for
i
in
range
(
len
(
points
))
]
points
=
torch
.
cat
(
points
,
dim
=
0
).
to
(
gt_labels
.
device
)
scales
=
torch
.
cat
(
expanded_scales
,
dim
=
0
)
# below is based on FCOSHead._get_target_single
n_points
=
len
(
points
)
n_boxes
=
len
(
gt_bboxes
)
volumes
=
gt_bboxes
.
volume
.
to
(
points
.
device
)
volumes
=
volumes
.
expand
(
n_points
,
n_boxes
).
contiguous
()
gt_bboxes
=
torch
.
cat
(
(
gt_bboxes
.
gravity_center
,
gt_bboxes
.
tensor
[:,
3
:
6
]),
dim
=
1
)
gt_bboxes
=
gt_bboxes
.
to
(
points
.
device
).
expand
(
n_points
,
n_boxes
,
6
)
expanded_points
=
points
.
unsqueeze
(
1
).
expand
(
n_points
,
n_boxes
,
3
)
bbox_targets
=
self
.
_get_face_distances
(
expanded_points
,
gt_bboxes
)
# condition1: inside a gt bbox
inside_gt_bbox_mask
=
bbox_targets
[...,
:
6
].
min
(
-
1
)[
0
]
>
0
# skip angle
# condition2: positive points per scale >= limit
# calculate positive points per scale
n_pos_points_per_scale
=
[]
for
i
in
range
(
self
.
n_levels
):
n_pos_points_per_scale
.
append
(
torch
.
sum
(
inside_gt_bbox_mask
[
scales
==
i
],
dim
=
0
))
# find best scale
n_pos_points_per_scale
=
torch
.
stack
(
n_pos_points_per_scale
,
dim
=
0
)
lower_limit_mask
=
n_pos_points_per_scale
<
self
.
pts_assign_threshold
# fix nondeterministic argmax for torch<1.7
extra
=
torch
.
arange
(
self
.
n_levels
,
0
,
-
1
).
unsqueeze
(
1
).
expand
(
self
.
n_levels
,
n_boxes
).
to
(
lower_limit_mask
.
device
)
lower_index
=
torch
.
argmax
(
lower_limit_mask
.
int
()
*
extra
,
dim
=
0
)
-
1
lower_index
=
torch
.
where
(
lower_index
<
0
,
torch
.
zeros_like
(
lower_index
),
lower_index
)
all_upper_limit_mask
=
torch
.
all
(
torch
.
logical_not
(
lower_limit_mask
),
dim
=
0
)
best_scale
=
torch
.
where
(
all_upper_limit_mask
,
torch
.
ones_like
(
all_upper_limit_mask
)
*
self
.
n_levels
-
1
,
lower_index
)
# keep only points with best scale
best_scale
=
torch
.
unsqueeze
(
best_scale
,
0
).
expand
(
n_points
,
n_boxes
)
scales
=
torch
.
unsqueeze
(
scales
,
1
).
expand
(
n_points
,
n_boxes
)
inside_best_scale_mask
=
best_scale
==
scales
# condition3: limit topk locations per box by centerness
centerness
=
self
.
_get_centerness
(
bbox_targets
)
centerness
=
torch
.
where
(
inside_gt_bbox_mask
,
centerness
,
torch
.
ones_like
(
centerness
)
*
-
1
)
centerness
=
torch
.
where
(
inside_best_scale_mask
,
centerness
,
torch
.
ones_like
(
centerness
)
*
-
1
)
top_centerness
=
torch
.
topk
(
centerness
,
self
.
pts_center_threshold
+
1
,
dim
=
0
).
values
[
-
1
]
inside_top_centerness_mask
=
centerness
>
top_centerness
.
unsqueeze
(
0
)
# if there are still more than one objects for a location,
# we choose the one with minimal area
volumes
=
torch
.
where
(
inside_gt_bbox_mask
,
volumes
,
torch
.
ones_like
(
volumes
)
*
float_max
)
volumes
=
torch
.
where
(
inside_best_scale_mask
,
volumes
,
torch
.
ones_like
(
volumes
)
*
float_max
)
volumes
=
torch
.
where
(
inside_top_centerness_mask
,
volumes
,
torch
.
ones_like
(
volumes
)
*
float_max
)
min_area
,
min_area_inds
=
volumes
.
min
(
dim
=
1
)
labels
=
gt_labels
[
min_area_inds
]
labels
=
torch
.
where
(
min_area
==
float_max
,
torch
.
ones_like
(
labels
)
*
-
1
,
labels
)
bbox_targets
=
bbox_targets
[
range
(
n_points
),
min_area_inds
]
centerness_targets
=
self
.
_get_centerness
(
bbox_targets
)
return
centerness_targets
,
self
.
_bbox_pred_to_bbox
(
points
,
bbox_targets
),
labels
def
_nms
(
self
,
bboxes
,
scores
,
img_meta
):
scores
,
labels
=
scores
.
max
(
dim
=
1
)
ids
=
scores
>
self
.
test_cfg
.
score_thr
bboxes
=
bboxes
[
ids
]
scores
=
scores
[
ids
]
labels
=
labels
[
ids
]
ids
=
self
.
aligned_3d_nms
(
bboxes
,
scores
,
labels
,
self
.
test_cfg
.
iou_thr
)
bboxes
=
bboxes
[
ids
]
bboxes
=
torch
.
stack
(
((
bboxes
[:,
0
]
+
bboxes
[:,
3
])
/
2.
,
(
bboxes
[:,
1
]
+
bboxes
[:,
4
])
/
2.
,
(
bboxes
[:,
2
]
+
bboxes
[:,
5
])
/
2.
,
bboxes
[:,
3
]
-
bboxes
[:,
0
],
bboxes
[:,
4
]
-
bboxes
[:,
1
],
bboxes
[:,
5
]
-
bboxes
[:,
2
]),
dim
=
1
)
return
bboxes
,
scores
[
ids
],
labels
[
ids
]
@
staticmethod
def
aligned_3d_nms
(
boxes
,
scores
,
classes
,
thresh
):
"""3d nms for aligned boxes.
Args:
boxes (torch.Tensor): Aligned box with shape [n, 6].
scores (torch.Tensor): Scores of each box.
classes (torch.Tensor): Class of each box.
thresh (float): Iou threshold for nms.
Returns:
torch.Tensor: Indices of selected boxes.
"""
x1
=
boxes
[:,
0
]
y1
=
boxes
[:,
1
]
z1
=
boxes
[:,
2
]
x2
=
boxes
[:,
3
]
y2
=
boxes
[:,
4
]
z2
=
boxes
[:,
5
]
area
=
(
x2
-
x1
)
*
(
y2
-
y1
)
*
(
z2
-
z1
)
zero
=
boxes
.
new_zeros
(
1
,
)
score_sorted
=
torch
.
argsort
(
scores
)
pick
=
[]
while
(
score_sorted
.
shape
[
0
]
!=
0
):
last
=
score_sorted
.
shape
[
0
]
i
=
score_sorted
[
-
1
]
pick
.
append
(
i
)
xx1
=
torch
.
max
(
x1
[
i
],
x1
[
score_sorted
[:
last
-
1
]])
yy1
=
torch
.
max
(
y1
[
i
],
y1
[
score_sorted
[:
last
-
1
]])
zz1
=
torch
.
max
(
z1
[
i
],
z1
[
score_sorted
[:
last
-
1
]])
xx2
=
torch
.
min
(
x2
[
i
],
x2
[
score_sorted
[:
last
-
1
]])
yy2
=
torch
.
min
(
y2
[
i
],
y2
[
score_sorted
[:
last
-
1
]])
zz2
=
torch
.
min
(
z2
[
i
],
z2
[
score_sorted
[:
last
-
1
]])
classes1
=
classes
[
i
]
classes2
=
classes
[
score_sorted
[:
last
-
1
]]
inter_l
=
torch
.
max
(
zero
,
xx2
-
xx1
)
inter_w
=
torch
.
max
(
zero
,
yy2
-
yy1
)
inter_h
=
torch
.
max
(
zero
,
zz2
-
zz1
)
inter
=
inter_l
*
inter_w
*
inter_h
iou
=
inter
/
(
area
[
i
]
+
area
[
score_sorted
[:
last
-
1
]]
-
inter
)
iou
=
iou
*
(
classes1
==
classes2
).
float
()
score_sorted
=
score_sorted
[
torch
.
nonzero
(
iou
<=
thresh
,
as_tuple
=
False
).
flatten
()]
indices
=
boxes
.
new_tensor
(
pick
,
dtype
=
torch
.
long
)
return
indices
projects/NeRF-Det/nerfdet/scannet_multiview_dataset.py
0 → 100644
View file @
fe25f7a5
# Copyright (c) OpenMMLab. All rights reserved.
import
warnings
from
os
import
path
as
osp
from
typing
import
Callable
,
List
,
Optional
,
Union
import
numpy
as
np
from
mmdet3d.datasets
import
Det3DDataset
from
mmdet3d.registry
import
DATASETS
from
mmdet3d.structures
import
DepthInstance3DBoxes
@
DATASETS
.
register_module
()
class
MultiViewScanNetDataset
(
Det3DDataset
):
r
"""Multi-View ScanNet Dataset for NeRF-detection Task
This class serves as the API for experiments on the ScanNet Dataset.
Please refer to the `github repo <https://github.com/ScanNet/ScanNet>`_
for data downloading.
Args:
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
pipeline (List[dict]): Pipeline used for data processing.
Defaults to [].
modality (dict): Modality to specify the sensor data used as input.
Defaults to dict(use_camera=True, use_lidar=False).
box_type_3d (str): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`.
Defaults to 'Depth' in this dataset. Available options includes:
- 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Camera': Box in camera coordinates.
filter_empty_gt (bool): Whether to filter the data with empty GT.
If it's set to be True, the example with empty annotations after
data pipeline will be dropped and a random example will be chosen
in `__getitem__`. Defaults to True.
test_mode (bool): Whether the dataset is in test mode.
Defaults to False.
"""
METAINFO
=
{
'classes'
:
(
'cabinet'
,
'bed'
,
'chair'
,
'sofa'
,
'table'
,
'door'
,
'window'
,
'bookshelf'
,
'picture'
,
'counter'
,
'desk'
,
'curtain'
,
'refrigerator'
,
'showercurtrain'
,
'toilet'
,
'sink'
,
'bathtub'
,
'garbagebin'
)
}
def
__init__
(
self
,
data_root
:
str
,
ann_file
:
str
,
metainfo
:
Optional
[
dict
]
=
None
,
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
modality
:
dict
=
dict
(
use_camera
=
True
,
use_lidar
=
False
),
box_type_3d
:
str
=
'Depth'
,
filter_empty_gt
:
bool
=
True
,
remove_dontcare
:
bool
=
False
,
test_mode
:
bool
=
False
,
**
kwargs
)
->
None
:
self
.
remove_dontcare
=
remove_dontcare
super
().
__init__
(
data_root
=
data_root
,
ann_file
=
ann_file
,
metainfo
=
metainfo
,
pipeline
=
pipeline
,
modality
=
modality
,
box_type_3d
=
box_type_3d
,
filter_empty_gt
=
filter_empty_gt
,
test_mode
=
test_mode
,
**
kwargs
)
assert
'use_camera'
in
self
.
modality
and
\
'use_lidar'
in
self
.
modality
assert
self
.
modality
[
'use_camera'
]
or
self
.
modality
[
'use_lidar'
]
@
staticmethod
def
_get_axis_align_matrix
(
info
:
dict
)
->
np
.
ndarray
:
"""Get axis_align_matrix from info. If not exist, return identity mat.
Args:
info (dict): Info of a single sample data.
Returns:
np.ndarray: 4x4 transformation matrix.
"""
if
'axis_align_matrix'
in
info
:
return
np
.
array
(
info
[
'axis_align_matrix'
])
else
:
warnings
.
warn
(
'axis_align_matrix is not found in ScanNet data info, please '
'use new pre-process scripts to re-generate ScanNet data'
)
return
np
.
eye
(
4
).
astype
(
np
.
float32
)
def
parse_data_info
(
self
,
info
:
dict
)
->
dict
:
"""Process the raw data info.
Convert all relative path of needed modality data file to
the absolute path.
Args:
info (dict): Raw info dict.
Returns:
dict: Has `ann_info` in training stage. And
all path has been converted to absolute path.
"""
if
self
.
modality
[
'use_depth'
]:
info
[
'depth_info'
]
=
[]
if
self
.
modality
[
'use_neuralrecon_depth'
]:
info
[
'depth_info'
]
=
[]
if
self
.
modality
[
'use_lidar'
]:
# implement lidar processing in the future
raise
NotImplementedError
(
'Please modified '
'`MultiViewPipeline` to support lidar processing'
)
info
[
'axis_align_matrix'
]
=
self
.
_get_axis_align_matrix
(
info
)
info
[
'img_info'
]
=
[]
info
[
'lidar2img'
]
=
[]
info
[
'c2w'
]
=
[]
info
[
'camrotc2w'
]
=
[]
info
[
'lightpos'
]
=
[]
# load img and depth_img
for
i
in
range
(
len
(
info
[
'img_paths'
])):
img_filename
=
osp
.
join
(
self
.
data_root
,
info
[
'img_paths'
][
i
])
info
[
'img_info'
].
append
(
dict
(
filename
=
img_filename
))
if
'depth_info'
in
info
.
keys
():
if
self
.
modality
[
'use_neuralrecon_depth'
]:
info
[
'depth_info'
].
append
(
dict
(
filename
=
img_filename
[:
-
4
]
+
'.npy'
))
else
:
info
[
'depth_info'
].
append
(
dict
(
filename
=
img_filename
[:
-
4
]
+
'.png'
))
# implement lidar_info in input.keys() in the future.
extrinsic
=
np
.
linalg
.
inv
(
info
[
'axis_align_matrix'
]
@
info
[
'lidar2cam'
][
i
])
info
[
'lidar2img'
].
append
(
extrinsic
.
astype
(
np
.
float32
))
if
self
.
modality
[
'use_ray'
]:
c2w
=
(
info
[
'axis_align_matrix'
]
@
info
[
'lidar2cam'
][
i
]).
astype
(
np
.
float32
)
# noqa
info
[
'c2w'
].
append
(
c2w
)
info
[
'camrotc2w'
].
append
(
c2w
[
0
:
3
,
0
:
3
])
info
[
'lightpos'
].
append
(
c2w
[
0
:
3
,
3
])
origin
=
np
.
array
([.
0
,
.
0
,
.
5
])
info
[
'lidar2img'
]
=
dict
(
extrinsic
=
info
[
'lidar2img'
],
intrinsic
=
info
[
'cam2img'
].
astype
(
np
.
float32
),
origin
=
origin
.
astype
(
np
.
float32
))
if
self
.
modality
[
'use_ray'
]:
info
[
'ray_info'
]
=
[]
if
not
self
.
test_mode
:
info
[
'ann_info'
]
=
self
.
parse_ann_info
(
info
)
if
self
.
test_mode
and
self
.
load_eval_anns
:
info
[
'ann_info'
]
=
self
.
parse_ann_info
(
info
)
info
[
'eval_ann_info'
]
=
self
.
_remove_dontcare
(
info
[
'ann_info'
])
return
info
def
parse_ann_info
(
self
,
info
:
dict
)
->
dict
:
"""Process the `instances` in data info to `ann_info`.
Args:
info (dict): Info dict.
Returns:
dict: Processed `ann_info`.
"""
ann_info
=
super
().
parse_ann_info
(
info
)
if
self
.
remove_dontcare
:
ann_info
=
self
.
_remove_dontcare
(
ann_info
)
# empty gt
if
ann_info
is
None
:
ann_info
=
dict
()
ann_info
[
'gt_bboxes_3d'
]
=
np
.
zeros
((
0
,
6
),
dtype
=
np
.
float32
)
ann_info
[
'gt_labels_3d'
]
=
np
.
zeros
((
0
,
),
dtype
=
np
.
int64
)
ann_info
[
'gt_bboxes_3d'
]
=
DepthInstance3DBoxes
(
ann_info
[
'gt_bboxes_3d'
],
box_dim
=
ann_info
[
'gt_bboxes_3d'
].
shape
[
-
1
],
with_yaw
=
False
,
origin
=
(
0.5
,
0.5
,
0.5
)).
convert_to
(
self
.
box_mode_3d
)
# count the numbers
for
label
in
ann_info
[
'gt_labels_3d'
]:
if
label
!=
-
1
:
cat_name
=
self
.
metainfo
[
'classes'
][
label
]
self
.
num_ins_per_cat
[
cat_name
]
+=
1
return
ann_info
projects/NeRF-Det/prepare_infos.py
0 → 100644
View file @
fe25f7a5
# Copyright (c) OpenMMLab. All rights reserved.
"""Prepare the dataset for NeRF-Det.
Example:
python projects/NeRF-Det/prepare_infos.py
--root-path ./data/scannet
--out-dir ./data/scannet
"""
import
argparse
import
time
from
os
import
path
as
osp
from
pathlib
import
Path
import
mmengine
from
...tools.dataset_converters
import
indoor_converter
as
indoor
from
...tools.dataset_converters.update_infos_to_v2
import
(
clear_data_info_unused_keys
,
clear_instance_unused_keys
,
get_empty_instance
,
get_empty_standard_data_info
)
def
update_scannet_infos_nerfdet
(
pkl_path
,
out_dir
):
"""Update the origin pkl to the new format which will be used in nerf-det.
Args:
pkl_path (str): Path of the origin pkl.
out_dir (str): Output directory of the generated info file.
Returns:
The pkl will be overwritTen.
The new pkl is a dict containing two keys:
metainfo: Some base information of the pkl
data_list (list): A list containing all the information of the scenes.
"""
print
(
'The new refactored process is running.'
)
print
(
f
'
{
pkl_path
}
will be modified.'
)
if
out_dir
in
pkl_path
:
print
(
f
'Warning, you may overwriting '
f
'the original data
{
pkl_path
}
.'
)
time
.
sleep
(
5
)
METAINFO
=
{
'classes'
:
(
'cabinet'
,
'bed'
,
'chair'
,
'sofa'
,
'table'
,
'door'
,
'window'
,
'bookshelf'
,
'picture'
,
'counter'
,
'desk'
,
'curtain'
,
'refrigerator'
,
'showercurtrain'
,
'toilet'
,
'sink'
,
'bathtub'
,
'garbagebin'
)
}
print
(
f
'Reading from input file:
{
pkl_path
}
.'
)
data_list
=
mmengine
.
load
(
pkl_path
)
print
(
'Start updating:'
)
converted_list
=
[]
for
ori_info_dict
in
mmengine
.
track_iter_progress
(
data_list
):
temp_data_info
=
get_empty_standard_data_info
()
# intrinsics, extrinsics and imgs
temp_data_info
[
'cam2img'
]
=
ori_info_dict
[
'intrinsics'
]
temp_data_info
[
'lidar2cam'
]
=
ori_info_dict
[
'extrinsics'
]
temp_data_info
[
'img_paths'
]
=
ori_info_dict
[
'img_paths'
]
# annotation information
anns
=
ori_info_dict
.
get
(
'annos'
,
None
)
ignore_class_name
=
set
()
if
anns
is
not
None
:
temp_data_info
[
'axis_align_matrix'
]
=
anns
[
'axis_align_matrix'
].
tolist
()
if
anns
[
'gt_num'
]
==
0
:
instance_list
=
[]
else
:
num_instances
=
len
(
anns
[
'name'
])
instance_list
=
[]
for
instance_id
in
range
(
num_instances
):
empty_instance
=
get_empty_instance
()
empty_instance
[
'bbox_3d'
]
=
anns
[
'gt_boxes_upright_depth'
][
instance_id
].
tolist
()
if
anns
[
'name'
][
instance_id
]
in
METAINFO
[
'classes'
]:
empty_instance
[
'bbox_label_3d'
]
=
METAINFO
[
'classes'
].
index
(
anns
[
'name'
][
instance_id
])
else
:
ignore_class_name
.
add
(
anns
[
'name'
][
instance_id
])
empty_instance
[
'bbox_label_3d'
]
=
-
1
empty_instance
=
clear_instance_unused_keys
(
empty_instance
)
instance_list
.
append
(
empty_instance
)
temp_data_info
[
'instances'
]
=
instance_list
temp_data_info
,
_
=
clear_data_info_unused_keys
(
temp_data_info
)
converted_list
.
append
(
temp_data_info
)
pkl_name
=
Path
(
pkl_path
).
name
out_path
=
osp
.
join
(
out_dir
,
pkl_name
)
print
(
f
'Writing to output file:
{
out_path
}
.'
)
print
(
f
'ignore classes:
{
ignore_class_name
}
'
)
# dataset metainfo
metainfo
=
dict
()
metainfo
[
'categories'
]
=
{
k
:
i
for
i
,
k
in
enumerate
(
METAINFO
[
'classes'
])}
if
ignore_class_name
:
for
ignore_class
in
ignore_class_name
:
metainfo
[
'categories'
][
ignore_class
]
=
-
1
metainfo
[
'dataset'
]
=
'scannet'
metainfo
[
'info_version'
]
=
'1.1'
converted_data_info
=
dict
(
metainfo
=
metainfo
,
data_list
=
converted_list
)
mmengine
.
dump
(
converted_data_info
,
out_path
,
'pkl'
)
def
scannet_data_prep
(
root_path
,
info_prefix
,
out_dir
,
workers
):
"""Prepare the info file for scannet dataset.
Args:
root_path (str): Path of dataset root.
info_prefix (str): The prefix of info filenames.
out_dir (str): Output directory of the generated info file.
workers (int): Number of threads to be used.
version (str): Only used to generate the dataset of nerfdet now.
"""
indoor
.
create_indoor_info_file
(
root_path
,
info_prefix
,
out_dir
,
workers
=
workers
)
info_train_path
=
osp
.
join
(
out_dir
,
f
'
{
info_prefix
}
_infos_train.pkl'
)
info_val_path
=
osp
.
join
(
out_dir
,
f
'
{
info_prefix
}
_infos_val.pkl'
)
info_test_path
=
osp
.
join
(
out_dir
,
f
'
{
info_prefix
}
_infos_test.pkl'
)
update_scannet_infos_nerfdet
(
out_dir
=
out_dir
,
pkl_path
=
info_train_path
)
update_scannet_infos_nerfdet
(
out_dir
=
out_dir
,
pkl_path
=
info_val_path
)
update_scannet_infos_nerfdet
(
out_dir
=
out_dir
,
pkl_path
=
info_test_path
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Data converter arg parser'
)
parser
.
add_argument
(
'--root-path'
,
type
=
str
,
default
=
'./data/scannet'
,
help
=
'specify the root path of dataset'
)
parser
.
add_argument
(
'--out-dir'
,
type
=
str
,
default
=
'./data/scannet'
,
required
=
False
,
help
=
'name of info pkl'
)
parser
.
add_argument
(
'--extra-tag'
,
type
=
str
,
default
=
'scannet'
)
parser
.
add_argument
(
'--workers'
,
type
=
int
,
default
=
4
,
help
=
'number of threads to be used'
)
args
=
parser
.
parse_args
()
if
__name__
==
'__main__'
:
from
mmdet3d.utils
import
register_all_modules
register_all_modules
()
scannet_data_prep
(
root_path
=
args
.
root_path
,
info_prefix
=
args
.
extra_tag
,
out_dir
=
args
.
out_dir
,
workers
=
args
.
workers
)
projects/PETR/README.md
View file @
fe25f7a5
...
...
@@ -16,7 +16,7 @@ This is an implementation of *PETR*.
In MMDet3D's root directory, run the following command to train the model:
```bash
python tools/train.py projects/PETR/config
/petr
/petr_vovnet_gridmask_p4_800x320.py
python tools/train.py projects/PETR/config
s
/petr_vovnet_gridmask_p4_800x320.py
```
### Testing commands
...
...
@@ -24,7 +24,7 @@ python tools/train.py projects/PETR/config/petr/petr_vovnet_gridmask_p4_800x320.
In MMDet3D's root directory, run the following command to test the model:
```bash
python tools/test.py projects/PETR/config
/petr
/petr_vovnet_gridmask_p4_800x320.py ${CHECKPOINT_PATH}
python tools/test.py projects/PETR/config
s
/petr_vovnet_gridmask_p4_800x320.py ${CHECKPOINT_PATH}
``
`
## Results
...
...
projects/PETR/petr/petr_head.py
View file @
fe25f7a5
...
...
@@ -446,7 +446,7 @@ class PETRHead(AnchorFreeHead):
masks
=
x
.
new_ones
((
batch_size
,
num_cams
,
input_img_h
,
input_img_w
))
for
img_id
in
range
(
batch_size
):
for
cam_id
in
range
(
num_cams
):
img_h
,
img_w
,
_
=
img_metas
[
img_id
][
'img_shape'
][
cam_id
]
img_h
,
img_w
=
img_metas
[
img_id
][
'img_shape'
][
cam_id
]
masks
[
img_id
,
cam_id
,
:
img_h
,
:
img_w
]
=
0
x
=
self
.
input_proj
(
x
.
flatten
(
0
,
1
))
x
=
x
.
view
(
batch_size
,
num_cams
,
*
x
.
shape
[
-
3
:])
...
...
tests/data/waymo/kitti_format/waymo_infos_train.pkl
View file @
fe25f7a5
No preview for this file type
tests/data/waymo/kitti_format/waymo_infos_val.pkl
View file @
fe25f7a5
No preview for this file type
tests/test_datasets/test_waymo_dataset.py
0 → 100644
View file @
fe25f7a5
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
import
torch
from
mmcv.transforms.base
import
BaseTransform
from
mmengine.registry
import
TRANSFORMS
from
mmengine.structures
import
InstanceData
from
mmdet3d.datasets
import
WaymoDataset
from
mmdet3d.structures
import
Det3DDataSample
,
LiDARInstance3DBoxes
def
_generate_waymo_dataset_config
():
data_root
=
'tests/data/waymo/kitti_format'
ann_file
=
'waymo_infos_train.pkl'
classes
=
[
'Car'
,
'Pedestrian'
,
'Cyclist'
]
# wait for pipline refactor
if
'Identity'
not
in
TRANSFORMS
:
@
TRANSFORMS
.
register_module
()
class
Identity
(
BaseTransform
):
def
transform
(
self
,
info
):
if
'ann_info'
in
info
:
info
[
'gt_labels_3d'
]
=
info
[
'ann_info'
][
'gt_labels_3d'
]
data_sample
=
Det3DDataSample
()
gt_instances_3d
=
InstanceData
()
gt_instances_3d
.
labels_3d
=
info
[
'gt_labels_3d'
]
data_sample
.
gt_instances_3d
=
gt_instances_3d
info
[
'data_samples'
]
=
data_sample
return
info
pipeline
=
[
dict
(
type
=
'Identity'
),
]
modality
=
dict
(
use_lidar
=
True
,
use_camera
=
True
)
data_prefix
=
data_prefix
=
dict
(
pts
=
'training/velodyne'
,
CAM_FRONT
=
'training/image_0'
)
return
data_root
,
ann_file
,
classes
,
data_prefix
,
pipeline
,
modality
def
test_getitem
():
data_root
,
ann_file
,
classes
,
data_prefix
,
\
pipeline
,
modality
,
=
_generate_waymo_dataset_config
()
waymo_dataset
=
WaymoDataset
(
data_root
,
ann_file
,
data_prefix
=
data_prefix
,
pipeline
=
pipeline
,
metainfo
=
dict
(
classes
=
classes
),
modality
=
modality
)
waymo_dataset
.
prepare_data
(
0
)
input_dict
=
waymo_dataset
.
get_data_info
(
0
)
waymo_dataset
[
0
]
# assert the the path should contains data_prefix and data_root
assert
data_prefix
[
'pts'
]
in
input_dict
[
'lidar_points'
][
'lidar_path'
]
assert
data_root
in
input_dict
[
'lidar_points'
][
'lidar_path'
]
for
cam_id
,
img_info
in
input_dict
[
'images'
].
items
():
if
'img_path'
in
img_info
:
assert
data_prefix
[
'CAM_FRONT'
]
in
img_info
[
'img_path'
]
assert
data_root
in
img_info
[
'img_path'
]
ann_info
=
waymo_dataset
.
parse_ann_info
(
input_dict
)
# only one instance
assert
'gt_labels_3d'
in
ann_info
assert
ann_info
[
'gt_labels_3d'
].
dtype
==
np
.
int64
assert
'gt_bboxes_3d'
in
ann_info
assert
isinstance
(
ann_info
[
'gt_bboxes_3d'
],
LiDARInstance3DBoxes
)
assert
torch
.
allclose
(
ann_info
[
'gt_bboxes_3d'
].
tensor
.
sum
(),
torch
.
tensor
(
43.3103
))
assert
'centers_2d'
in
ann_info
assert
ann_info
[
'centers_2d'
].
dtype
==
np
.
float32
assert
'depths'
in
ann_info
assert
ann_info
[
'depths'
].
dtype
==
np
.
float32
tools/create_data.py
View file @
fe25f7a5
...
...
@@ -2,6 +2,8 @@
import
argparse
from
os
import
path
as
osp
from
mmengine
import
print_log
from
tools.dataset_converters
import
indoor_converter
as
indoor
from
tools.dataset_converters
import
kitti_converter
as
kitti
from
tools.dataset_converters
import
lyft_converter
as
lyft_converter
...
...
@@ -171,8 +173,19 @@ def waymo_data_prep(root_path,
version
,
out_dir
,
workers
,
max_sweeps
=
5
):
"""Prepare the info file for waymo dataset.
max_sweeps
=
10
,
only_gt_database
=
False
,
save_senor_data
=
False
,
skip_cam_instances_infos
=
False
):
"""Prepare waymo dataset. There are 3 steps as follows:
Step 1. Extract camera images and lidar point clouds from waymo raw
data in '*.tfreord' and save as kitti format.
Step 2. Generate waymo train/val/test infos and save as pickle file.
Step 3. Generate waymo ground truth database (point clouds within
each 3D bounding box) for data augmentation in training.
Steps 1 and 2 will be done in Waymo2KITTI, and step 3 will be done in
GTDatabaseCreater.
Args:
root_path (str): Path of dataset root.
...
...
@@ -180,44 +193,55 @@ def waymo_data_prep(root_path,
out_dir (str): Output directory of the generated info file.
workers (int): Number of threads to be used.
max_sweeps (int, optional): Number of input consecutive frames.
Default: 5. Here we store pose information of these frames
for later use.
Default to 10. Here we store ego2global information of these
frames for later use.
only_gt_database (bool, optional): Whether to only generate ground
truth database. Default to False.
save_senor_data (bool, optional): Whether to skip saving
image and lidar. Default to False.
skip_cam_instances_infos (bool, optional): Whether to skip
gathering cam_instances infos in Step 2. Default to False.
"""
from
tools.dataset_converters
import
waymo_converter
as
waymo
splits
=
[
'training'
,
'validation'
,
'testing'
,
'testing_3d_camera_only_detection'
]
for
i
,
split
in
enumerate
(
splits
):
load_dir
=
osp
.
join
(
root_path
,
'waymo_format'
,
split
)
if
split
==
'validation'
:
save_dir
=
osp
.
join
(
out_dir
,
'kitti_format'
,
'training'
)
else
:
save_dir
=
osp
.
join
(
out_dir
,
'kitti_format'
,
split
)
converter
=
waymo
.
Waymo2KITTI
(
load_dir
,
save_dir
,
prefix
=
str
(
i
),
workers
=
workers
,
test_mode
=
(
split
in
[
'testing'
,
'testing_3d_camera_only_detection'
]))
converter
.
convert
()
from
tools.dataset_converters.waymo_converter
import
\
create_ImageSets_img_ids
create_ImageSets_img_ids
(
osp
.
join
(
out_dir
,
'kitti_format'
),
splits
)
# Generate waymo infos
if
version
==
'v1.4'
:
splits
=
[
'training'
,
'validation'
,
'testing'
,
'testing_3d_camera_only_detection'
]
elif
version
==
'v1.4-mini'
:
splits
=
[
'training'
,
'validation'
]
else
:
raise
NotImplementedError
(
f
'Unsupported Waymo version
{
version
}
!'
)
out_dir
=
osp
.
join
(
out_dir
,
'kitti_format'
)
kitti
.
create_waymo_info_file
(
out_dir
,
info_prefix
,
max_sweeps
=
max_sweeps
,
workers
=
workers
)
info_train_path
=
osp
.
join
(
out_dir
,
f
'
{
info_prefix
}
_infos_train.pkl'
)
info_val_path
=
osp
.
join
(
out_dir
,
f
'
{
info_prefix
}
_infos_val.pkl'
)
info_trainval_path
=
osp
.
join
(
out_dir
,
f
'
{
info_prefix
}
_infos_trainval.pkl'
)
info_test_path
=
osp
.
join
(
out_dir
,
f
'
{
info_prefix
}
_infos_test.pkl'
)
update_pkl_infos
(
'waymo'
,
out_dir
=
out_dir
,
pkl_path
=
info_train_path
)
update_pkl_infos
(
'waymo'
,
out_dir
=
out_dir
,
pkl_path
=
info_val_path
)
update_pkl_infos
(
'waymo'
,
out_dir
=
out_dir
,
pkl_path
=
info_trainval_path
)
update_pkl_infos
(
'waymo'
,
out_dir
=
out_dir
,
pkl_path
=
info_test_path
)
if
not
only_gt_database
:
for
i
,
split
in
enumerate
(
splits
):
load_dir
=
osp
.
join
(
root_path
,
'waymo_format'
,
split
)
if
split
==
'validation'
:
save_dir
=
osp
.
join
(
out_dir
,
'training'
)
else
:
save_dir
=
osp
.
join
(
out_dir
,
split
)
converter
=
waymo
.
Waymo2KITTI
(
load_dir
,
save_dir
,
prefix
=
str
(
i
),
workers
=
workers
,
test_mode
=
(
split
in
[
'testing'
,
'testing_3d_camera_only_detection'
]),
info_prefix
=
info_prefix
,
max_sweeps
=
max_sweeps
,
split
=
split
,
save_senor_data
=
save_senor_data
,
save_cam_instances
=
not
skip_cam_instances_infos
)
converter
.
convert
()
if
split
==
'validation'
:
converter
.
merge_trainval_infos
()
from
tools.dataset_converters.waymo_converter
import
\
create_ImageSets_img_ids
create_ImageSets_img_ids
(
out_dir
,
splits
)
GTDatabaseCreater
(
'WaymoDataset'
,
out_dir
,
...
...
@@ -227,6 +251,8 @@ def waymo_data_prep(root_path,
with_mask
=
False
,
num_worker
=
workers
).
create
()
print_log
(
'Successfully preparing Waymo Open Dataset'
)
def
semantickitti_data_prep
(
info_prefix
,
out_dir
):
"""Prepare the info file for SemanticKITTI dataset.
...
...
@@ -274,12 +300,23 @@ parser.add_argument(
parser
.
add_argument
(
'--only-gt-database'
,
action
=
'store_true'
,
help
=
'Whether to only generate ground truth database.'
)
help
=
'''Whether to only generate ground truth database.
Only used when dataset is NuScenes or Waymo!'''
)
parser
.
add_argument
(
'--skip-cam_instances-infos'
,
action
=
'store_true'
,
help
=
'''Whether to skip gathering cam_instances infos.
Only used when dataset is Waymo!'''
)
parser
.
add_argument
(
'--skip-saving-sensor-data'
,
action
=
'store_true'
,
help
=
'''Whether to skip saving image and lidar.
Only used when dataset is Waymo!'''
)
args
=
parser
.
parse_args
()
if
__name__
==
'__main__'
:
from
mm
det3d.utils
import
register_all_modules
register_all_modules
(
)
from
mm
engine.registry
import
init_default_scope
init_default_scope
(
'mmdet3d'
)
if
args
.
dataset
==
'kitti'
:
if
args
.
only_gt_database
:
...
...
@@ -334,6 +371,17 @@ if __name__ == '__main__':
dataset_name
=
'NuScenesDataset'
,
out_dir
=
args
.
out_dir
,
max_sweeps
=
args
.
max_sweeps
)
elif
args
.
dataset
==
'waymo'
:
waymo_data_prep
(
root_path
=
args
.
root_path
,
info_prefix
=
args
.
extra_tag
,
version
=
args
.
version
,
out_dir
=
args
.
out_dir
,
workers
=
args
.
workers
,
max_sweeps
=
args
.
max_sweeps
,
only_gt_database
=
args
.
only_gt_database
,
save_senor_data
=
not
args
.
skip_saving_sensor_data
,
skip_cam_instances_infos
=
args
.
skip_cam_instances_infos
)
elif
args
.
dataset
==
'lyft'
:
train_version
=
f
'
{
args
.
version
}
-train'
lyft_data_prep
(
...
...
@@ -347,14 +395,6 @@ if __name__ == '__main__':
info_prefix
=
args
.
extra_tag
,
version
=
test_version
,
max_sweeps
=
args
.
max_sweeps
)
elif
args
.
dataset
==
'waymo'
:
waymo_data_prep
(
root_path
=
args
.
root_path
,
info_prefix
=
args
.
extra_tag
,
version
=
args
.
version
,
out_dir
=
args
.
out_dir
,
workers
=
args
.
workers
,
max_sweeps
=
args
.
max_sweeps
)
elif
args
.
dataset
==
'scannet'
:
scannet_data_prep
(
root_path
=
args
.
root_path
,
...
...
tools/create_data.sh
View file @
fe25f7a5
...
...
@@ -6,10 +6,11 @@ export PYTHONPATH=`pwd`:$PYTHONPATH
PARTITION
=
$1
JOB_NAME
=
$2
DATASET
=
$3
WORKERS
=
$4
GPUS
=
${
GPUS
:-
1
}
GPUS_PER_NODE
=
${
GPUS_PER_NODE
:-
1
}
SRUN_ARGS
=
${
SRUN_ARGS
:-
""
}
JOB_NAME
=
create_data
PY_ARGS
=
${
@
:5
}
srun
-p
${
PARTITION
}
\
--job-name
=
${
JOB_NAME
}
\
...
...
@@ -21,4 +22,6 @@ srun -p ${PARTITION} \
python
-u
tools/create_data.py
${
DATASET
}
\
--root-path
./data/
${
DATASET
}
\
--out-dir
./data/
${
DATASET
}
\
--extra-tag
${
DATASET
}
--workers
${
WORKERS
}
\
--extra-tag
${
DATASET
}
\
${
PY_ARGS
}
tools/dataset_converters/create_gt_database.py
View file @
fe25f7a5
...
...
@@ -7,7 +7,7 @@ import mmengine
import
numpy
as
np
from
mmcv.ops
import
roi_align
from
mmdet.evaluation
import
bbox_overlaps
from
mmengine
import
track_iter_progress
from
mmengine
import
print_log
,
track_iter_progress
from
pycocotools
import
mask
as
maskUtils
from
pycocotools.coco
import
COCO
...
...
@@ -504,7 +504,9 @@ class GTDatabaseCreater:
return
single_db_infos
def
create
(
self
):
print
(
f
'Create GT Database of
{
self
.
dataset_class_name
}
'
)
print_log
(
f
'Create GT Database of
{
self
.
dataset_class_name
}
'
,
logger
=
'current'
)
dataset_cfg
=
dict
(
type
=
self
.
dataset_class_name
,
data_root
=
self
.
data_path
,
...
...
@@ -610,12 +612,19 @@ class GTDatabaseCreater:
input_dict
[
'box_mode_3d'
]
=
self
.
dataset
.
box_mode_3d
return
input_dict
multi_db_infos
=
mmengine
.
track_parallel_progress
(
self
.
create_single
,
((
loop_dataset
(
i
)
for
i
in
range
(
len
(
self
.
dataset
))),
len
(
self
.
dataset
)),
self
.
num_worker
)
print
(
'Make global unique group id'
)
if
self
.
num_worker
==
0
:
multi_db_infos
=
mmengine
.
track_progress
(
self
.
create_single
,
((
loop_dataset
(
i
)
for
i
in
range
(
len
(
self
.
dataset
))),
len
(
self
.
dataset
)))
else
:
multi_db_infos
=
mmengine
.
track_parallel_progress
(
self
.
create_single
,
((
loop_dataset
(
i
)
for
i
in
range
(
len
(
self
.
dataset
))),
len
(
self
.
dataset
)),
self
.
num_worker
,
chunksize
=
1000
)
print_log
(
'Make global unique group id'
,
logger
=
'current'
)
group_counter_offset
=
0
all_db_infos
=
dict
()
for
single_db_infos
in
track_iter_progress
(
multi_db_infos
):
...
...
@@ -630,7 +639,8 @@ class GTDatabaseCreater:
group_counter_offset
+=
(
group_id
+
1
)
for
k
,
v
in
all_db_infos
.
items
():
print
(
f
'load
{
len
(
v
)
}
{
k
}
database infos'
)
print
_log
(
f
'load
{
len
(
v
)
}
{
k
}
database infos'
,
logger
=
'current'
)
print_log
(
f
'Saving GT database infos into
{
self
.
db_info_save_path
}
'
)
with
open
(
self
.
db_info_save_path
,
'wb'
)
as
f
:
pickle
.
dump
(
all_db_infos
,
f
)
tools/dataset_converters/waymo_converter.py
View file @
fe25f7a5
...
...
@@ -9,23 +9,33 @@ except ImportError:
raise
ImportError
(
'Please run "pip install waymo-open-dataset-tf-2-6-0" '
'>1.4.5 to install the official devkit first.'
)
import
copy
import
os
import
os.path
as
osp
from
glob
import
glob
from
io
import
BytesIO
from
os.path
import
exists
,
join
import
mmengine
import
numpy
as
np
import
tensorflow
as
tf
from
mmengine
import
print_log
from
nuscenes.utils.geometry_utils
import
view_points
from
PIL
import
Image
from
waymo_open_dataset.utils
import
range_image_utils
,
transform_utils
from
waymo_open_dataset.utils.frame_utils
import
\
parse_range_image_and_camera_projection
from
mmdet3d.datasets.convert_utils
import
post_process_coords
from
mmdet3d.structures
import
Box3DMode
,
LiDARInstance3DBoxes
,
points_cam2img
class
Waymo2KITTI
(
object
):
"""Waymo to KITTI converter.
"""Waymo to KITTI converter.
There are 2 steps as follows:
This class serves as the converter to change the waymo raw data to KITTI
format.
Step 1. Extract camera images and lidar point clouds from waymo raw data in
'*.tfreord' and save as kitti format.
Step 2. Generate waymo train/val/test infos and save as pickle file.
Args:
load_dir (str): Directory to load waymo raw data.
...
...
@@ -36,8 +46,16 @@ class Waymo2KITTI(object):
Defaults to 64.
test_mode (bool, optional): Whether in the test_mode.
Defaults to False.
save_cam_sync_labels (bool, optional): Whether to save cam sync labels.
Defaults to True.
save_senor_data (bool, optional): Whether to save image and lidar
data. Defaults to True.
save_cam_sync_instances (bool, optional): Whether to save cam sync
instances. Defaults to True.
save_cam_instances (bool, optional): Whether to save cam instances.
Defaults to False.
info_prefix (str, optional): Prefix of info filename.
Defaults to 'waymo'.
max_sweeps (int, optional): Max length of sweeps. Defaults to 10.
split (str, optional): Split of the data. Defaults to 'training'.
"""
def
__init__
(
self
,
...
...
@@ -46,18 +64,12 @@ class Waymo2KITTI(object):
prefix
,
workers
=
64
,
test_mode
=
False
,
save_cam_sync_labels
=
True
):
self
.
filter_empty_3dboxes
=
True
self
.
filter_no_label_zone_points
=
True
self
.
selected_waymo_classes
=
[
'VEHICLE'
,
'PEDESTRIAN'
,
'CYCLIST'
]
# Only data collected in specific locations will be converted
# If set None, this filter is disabled
# Available options: location_sf (main dataset)
self
.
selected_waymo_locations
=
None
self
.
save_track_id
=
False
save_senor_data
=
True
,
save_cam_sync_instances
=
True
,
save_cam_instances
=
True
,
info_prefix
=
'waymo'
,
max_sweeps
=
10
,
split
=
'training'
):
# turn on eager execution for older tensorflow versions
if
int
(
tf
.
__version__
.
split
(
'.'
)[
0
])
<
2
:
tf
.
enable_eager_execution
()
...
...
@@ -74,12 +86,21 @@ class Waymo2KITTI(object):
self
.
type_list
=
[
'UNKNOWN'
,
'VEHICLE'
,
'PEDESTRIAN'
,
'SIGN'
,
'CYCLIST'
]
self
.
waymo_to_kitti_class_map
=
{
'UNKNOWN'
:
'DontCare'
,
'PEDESTRIAN'
:
'Pedestrian'
,
'VEHICLE'
:
'Car'
,
'CYCLIST'
:
'Cyclist'
,
'SIGN'
:
'Sign'
# not in kitti
# MMDetection3D unified camera keys & class names
self
.
camera_types
=
[
'CAM_FRONT'
,
'CAM_FRONT_LEFT'
,
'CAM_FRONT_RIGHT'
,
'CAM_SIDE_LEFT'
,
'CAM_SIDE_RIGHT'
,
]
self
.
selected_waymo_classes
=
[
'VEHICLE'
,
'PEDESTRIAN'
,
'CYCLIST'
]
self
.
info_map
=
{
'training'
:
'_infos_train.pkl'
,
'validation'
:
'_infos_val.pkl'
,
'testing'
:
'_infos_test.pkl'
,
'testing_3d_camera_only_detection'
:
'_infos_test_cam_only.pkl'
}
self
.
load_dir
=
load_dir
...
...
@@ -87,61 +108,87 @@ class Waymo2KITTI(object):
self
.
prefix
=
prefix
self
.
workers
=
int
(
workers
)
self
.
test_mode
=
test_mode
self
.
save_cam_sync_labels
=
save_cam_sync_labels
self
.
save_senor_data
=
save_senor_data
self
.
save_cam_sync_instances
=
save_cam_sync_instances
self
.
save_cam_instances
=
save_cam_instances
self
.
info_prefix
=
info_prefix
self
.
max_sweeps
=
max_sweeps
self
.
split
=
split
# TODO: Discuss filter_empty_3dboxes and filter_no_label_zone_points
self
.
filter_empty_3dboxes
=
True
self
.
filter_no_label_zone_points
=
True
self
.
save_track_id
=
False
self
.
tfrecord_pathnames
=
sorted
(
glob
(
join
(
self
.
load_dir
,
'*.tfrecord'
)))
self
.
label_save_dir
=
f
'
{
self
.
save_dir
}
/label_'
self
.
label_all_save_dir
=
f
'
{
self
.
save_dir
}
/label_all'
self
.
image_save_dir
=
f
'
{
self
.
save_dir
}
/image_'
self
.
calib_save_dir
=
f
'
{
self
.
save_dir
}
/calib'
self
.
point_cloud_save_dir
=
f
'
{
self
.
save_dir
}
/velodyne'
self
.
pose_save_dir
=
f
'
{
self
.
save_dir
}
/pose'
self
.
timestamp_save_dir
=
f
'
{
self
.
save_dir
}
/timestamp'
if
self
.
save_cam_sync_labels
:
self
.
cam_sync_label_save_dir
=
f
'
{
self
.
save_dir
}
/cam_sync_label_'
self
.
cam_sync_label_all_save_dir
=
\
f
'
{
self
.
save_dir
}
/cam_sync_label_all'
self
.
create_folder
()
# Create folder for saving KITTI format camera images and
# lidar point clouds.
if
'testing_3d_camera_only_detection'
not
in
self
.
load_dir
:
mmengine
.
mkdir_or_exist
(
self
.
point_cloud_save_dir
)
for
i
in
range
(
5
):
mmengine
.
mkdir_or_exist
(
f
'
{
self
.
image_save_dir
}{
str
(
i
)
}
'
)
def
convert
(
self
):
"""Convert action."""
print
(
'Start converting ...'
)
mmengine
.
track_parallel_progress
(
self
.
convert_one
,
range
(
len
(
self
)),
self
.
workers
)
print
(
'
\n
Finished ...'
)
print_log
(
f
'Start converting
{
self
.
split
}
dataset'
,
logger
=
'current'
)
if
self
.
workers
==
0
:
data_infos
=
mmengine
.
track_progress
(
self
.
convert_one
,
range
(
len
(
self
)))
else
:
data_infos
=
mmengine
.
track_parallel_progress
(
self
.
convert_one
,
range
(
len
(
self
)),
self
.
workers
)
data_list
=
[]
for
data_info
in
data_infos
:
data_list
.
extend
(
data_info
)
metainfo
=
dict
()
metainfo
[
'dataset'
]
=
'waymo'
metainfo
[
'version'
]
=
'waymo_v1.4'
metainfo
[
'info_version'
]
=
'mmdet3d_v1.4'
waymo_infos
=
dict
(
data_list
=
data_list
,
metainfo
=
metainfo
)
filenames
=
osp
.
join
(
osp
.
dirname
(
self
.
save_dir
),
f
'
{
self
.
info_prefix
+
self
.
info_map
[
self
.
split
]
}
'
)
print_log
(
f
'Saving
{
self
.
split
}
dataset infos into
{
filenames
}
'
)
mmengine
.
dump
(
waymo_infos
,
filenames
)
def
convert_one
(
self
,
file_idx
):
"""Convert action for single file.
"""Convert one '*.tfrecord' file to kitti format. Each file stores all
the frames (about 200 frames) in current scene. We treat each frame as
a sample, save their images and point clouds in kitti format, and then
create info for all frames.
Args:
file_idx (int): Index of the file to be converted.
Returns:
List[dict]: Waymo infos for all frames in current file.
"""
pathname
=
self
.
tfrecord_pathnames
[
file_idx
]
dataset
=
tf
.
data
.
TFRecordDataset
(
pathname
,
compression_type
=
''
)
# NOTE: file_infos is not shared between processes, only stores frame
# infos within the current file.
file_infos
=
[]
for
frame_idx
,
data
in
enumerate
(
dataset
):
frame
=
dataset_pb2
.
Frame
()
frame
.
ParseFromString
(
bytearray
(
data
.
numpy
()))
if
(
self
.
selected_waymo_locations
is
not
None
and
frame
.
context
.
stats
.
location
not
in
self
.
selected_waymo_locations
):
continue
self
.
save_image
(
frame
,
file_idx
,
frame_idx
)
self
.
save_calib
(
frame
,
file_idx
,
frame_idx
)
self
.
save_
lidar
(
frame
,
file_idx
,
frame_idx
)
self
.
save_
pos
e
(
frame
,
file_idx
,
frame_idx
)
self
.
save_
timestamp
(
frame
,
file_idx
,
frame_idx
)
# Step 1. Extract camera images and lidar point clouds from waymo
# raw data in '*.tfreord' and save as kitti format.
if
self
.
save_
senor_data
:
self
.
save_
imag
e
(
frame
,
file_idx
,
frame_idx
)
self
.
save_
lidar
(
frame
,
file_idx
,
frame_idx
)
if
not
self
.
test_mode
:
# TODO save the depth image for waymo challenge solution.
self
.
save_label
(
frame
,
file_idx
,
frame_idx
)
if
self
.
save_cam_sync_labels
:
self
.
save_label
(
frame
,
file_idx
,
frame_idx
,
cam_sync
=
True
)
# Step 2. Generate waymo train/val/test infos and save as pkl file.
# TODO save the depth image for waymo challenge solution.
self
.
create_waymo_info_file
(
frame
,
file_idx
,
frame_idx
,
file_infos
)
return
file_infos
def
__len__
(
self
):
"""Length of the filename list."""
...
...
@@ -162,62 +209,6 @@ class Waymo2KITTI(object):
with
open
(
img_path
,
'wb'
)
as
fp
:
fp
.
write
(
img
.
image
)
def
save_calib
(
self
,
frame
,
file_idx
,
frame_idx
):
"""Parse and save the calibration data.
Args:
frame (:obj:`Frame`): Open dataset frame proto.
file_idx (int): Current file index.
frame_idx (int): Current frame index.
"""
# waymo front camera to kitti reference camera
T_front_cam_to_ref
=
np
.
array
([[
0.0
,
-
1.0
,
0.0
],
[
0.0
,
0.0
,
-
1.0
],
[
1.0
,
0.0
,
0.0
]])
camera_calibs
=
[]
R0_rect
=
[
f
'
{
i
:
e
}
'
for
i
in
np
.
eye
(
3
).
flatten
()]
Tr_velo_to_cams
=
[]
calib_context
=
''
for
camera
in
frame
.
context
.
camera_calibrations
:
# extrinsic parameters
T_cam_to_vehicle
=
np
.
array
(
camera
.
extrinsic
.
transform
).
reshape
(
4
,
4
)
T_vehicle_to_cam
=
np
.
linalg
.
inv
(
T_cam_to_vehicle
)
Tr_velo_to_cam
=
\
self
.
cart_to_homo
(
T_front_cam_to_ref
)
@
T_vehicle_to_cam
if
camera
.
name
==
1
:
# FRONT = 1, see dataset.proto for details
self
.
T_velo_to_front_cam
=
Tr_velo_to_cam
.
copy
()
Tr_velo_to_cam
=
Tr_velo_to_cam
[:
3
,
:].
reshape
((
12
,
))
Tr_velo_to_cams
.
append
([
f
'
{
i
:
e
}
'
for
i
in
Tr_velo_to_cam
])
# intrinsic parameters
camera_calib
=
np
.
zeros
((
3
,
4
))
camera_calib
[
0
,
0
]
=
camera
.
intrinsic
[
0
]
camera_calib
[
1
,
1
]
=
camera
.
intrinsic
[
1
]
camera_calib
[
0
,
2
]
=
camera
.
intrinsic
[
2
]
camera_calib
[
1
,
2
]
=
camera
.
intrinsic
[
3
]
camera_calib
[
2
,
2
]
=
1
camera_calib
=
list
(
camera_calib
.
reshape
(
12
))
camera_calib
=
[
f
'
{
i
:
e
}
'
for
i
in
camera_calib
]
camera_calibs
.
append
(
camera_calib
)
# all camera ids are saved as id-1 in the result because
# camera 0 is unknown in the proto
for
i
in
range
(
5
):
calib_context
+=
'P'
+
str
(
i
)
+
': '
+
\
' '
.
join
(
camera_calibs
[
i
])
+
'
\n
'
calib_context
+=
'R0_rect'
+
': '
+
' '
.
join
(
R0_rect
)
+
'
\n
'
for
i
in
range
(
5
):
calib_context
+=
'Tr_velo_to_cam_'
+
str
(
i
)
+
': '
+
\
' '
.
join
(
Tr_velo_to_cams
[
i
])
+
'
\n
'
with
open
(
f
'
{
self
.
calib_save_dir
}
/
{
self
.
prefix
}
'
+
f
'
{
str
(
file_idx
).
zfill
(
3
)
}{
str
(
frame_idx
).
zfill
(
3
)
}
.txt'
,
'w+'
)
as
fp_calib
:
fp_calib
.
write
(
calib_context
)
fp_calib
.
close
()
def
save_lidar
(
self
,
frame
,
file_idx
,
frame_idx
):
"""Parse and save the lidar data in psd format.
...
...
@@ -275,194 +266,6 @@ class Waymo2KITTI(object):
f
'
{
str
(
file_idx
).
zfill
(
3
)
}{
str
(
frame_idx
).
zfill
(
3
)
}
.bin'
point_cloud
.
astype
(
np
.
float32
).
tofile
(
pc_path
)
def
save_label
(
self
,
frame
,
file_idx
,
frame_idx
,
cam_sync
=
False
):
"""Parse and save the label data in txt format.
The relation between waymo and kitti coordinates is noteworthy:
1. x, y, z correspond to l, w, h (waymo) -> l, h, w (kitti)
2. x-y-z: front-left-up (waymo) -> right-down-front(kitti)
3. bbox origin at volumetric center (waymo) -> bottom center (kitti)
4. rotation: +x around y-axis (kitti) -> +x around z-axis (waymo)
Args:
frame (:obj:`Frame`): Open dataset frame proto.
file_idx (int): Current file index.
frame_idx (int): Current frame index.
cam_sync (bool, optional): Whether to save the cam sync labels.
Defaults to False.
"""
label_all_path
=
f
'
{
self
.
label_all_save_dir
}
/
{
self
.
prefix
}
'
+
\
f
'
{
str
(
file_idx
).
zfill
(
3
)
}{
str
(
frame_idx
).
zfill
(
3
)
}
.txt'
if
cam_sync
:
label_all_path
=
label_all_path
.
replace
(
'label_'
,
'cam_sync_label_'
)
fp_label_all
=
open
(
label_all_path
,
'w+'
)
id_to_bbox
=
dict
()
id_to_name
=
dict
()
for
labels
in
frame
.
projected_lidar_labels
:
name
=
labels
.
name
for
label
in
labels
.
labels
:
# TODO: need a workaround as bbox may not belong to front cam
bbox
=
[
label
.
box
.
center_x
-
label
.
box
.
length
/
2
,
label
.
box
.
center_y
-
label
.
box
.
width
/
2
,
label
.
box
.
center_x
+
label
.
box
.
length
/
2
,
label
.
box
.
center_y
+
label
.
box
.
width
/
2
]
id_to_bbox
[
label
.
id
]
=
bbox
id_to_name
[
label
.
id
]
=
name
-
1
for
obj
in
frame
.
laser_labels
:
bounding_box
=
None
name
=
None
id
=
obj
.
id
for
proj_cam
in
self
.
cam_list
:
if
id
+
proj_cam
in
id_to_bbox
:
bounding_box
=
id_to_bbox
.
get
(
id
+
proj_cam
)
name
=
str
(
id_to_name
.
get
(
id
+
proj_cam
))
break
# NOTE: the 2D labels do not have strict correspondence with
# the projected 2D lidar labels
# e.g.: the projected 2D labels can be in camera 2
# while the most_visible_camera can have id 4
if
cam_sync
:
if
obj
.
most_visible_camera_name
:
name
=
str
(
self
.
cam_list
.
index
(
f
'_
{
obj
.
most_visible_camera_name
}
'
))
box3d
=
obj
.
camera_synced_box
else
:
continue
else
:
box3d
=
obj
.
box
if
bounding_box
is
None
or
name
is
None
:
name
=
'0'
bounding_box
=
(
0
,
0
,
0
,
0
)
my_type
=
self
.
type_list
[
obj
.
type
]
if
my_type
not
in
self
.
selected_waymo_classes
:
continue
if
self
.
filter_empty_3dboxes
and
obj
.
num_lidar_points_in_box
<
1
:
continue
my_type
=
self
.
waymo_to_kitti_class_map
[
my_type
]
height
=
box3d
.
height
width
=
box3d
.
width
length
=
box3d
.
length
x
=
box3d
.
center_x
y
=
box3d
.
center_y
z
=
box3d
.
center_z
-
height
/
2
# project bounding box to the virtual reference frame
pt_ref
=
self
.
T_velo_to_front_cam
@
\
np
.
array
([
x
,
y
,
z
,
1
]).
reshape
((
4
,
1
))
x
,
y
,
z
,
_
=
pt_ref
.
flatten
().
tolist
()
rotation_y
=
-
box3d
.
heading
-
np
.
pi
/
2
track_id
=
obj
.
id
# not available
truncated
=
0
occluded
=
0
alpha
=
-
10
line
=
my_type
+
\
' {} {} {} {} {} {} {} {} {} {} {} {} {} {}
\n
'
.
format
(
round
(
truncated
,
2
),
occluded
,
round
(
alpha
,
2
),
round
(
bounding_box
[
0
],
2
),
round
(
bounding_box
[
1
],
2
),
round
(
bounding_box
[
2
],
2
),
round
(
bounding_box
[
3
],
2
),
round
(
height
,
2
),
round
(
width
,
2
),
round
(
length
,
2
),
round
(
x
,
2
),
round
(
y
,
2
),
round
(
z
,
2
),
round
(
rotation_y
,
2
))
if
self
.
save_track_id
:
line_all
=
line
[:
-
1
]
+
' '
+
name
+
' '
+
track_id
+
'
\n
'
else
:
line_all
=
line
[:
-
1
]
+
' '
+
name
+
'
\n
'
label_path
=
f
'
{
self
.
label_save_dir
}{
name
}
/
{
self
.
prefix
}
'
+
\
f
'
{
str
(
file_idx
).
zfill
(
3
)
}{
str
(
frame_idx
).
zfill
(
3
)
}
.txt'
if
cam_sync
:
label_path
=
label_path
.
replace
(
'label_'
,
'cam_sync_label_'
)
fp_label
=
open
(
label_path
,
'a'
)
fp_label
.
write
(
line
)
fp_label
.
close
()
fp_label_all
.
write
(
line_all
)
fp_label_all
.
close
()
def
save_pose
(
self
,
frame
,
file_idx
,
frame_idx
):
"""Parse and save the pose data.
Note that SDC's own pose is not included in the regular training
of KITTI dataset. KITTI raw dataset contains ego motion files
but are not often used. Pose is important for algorithms that
take advantage of the temporal information.
Args:
frame (:obj:`Frame`): Open dataset frame proto.
file_idx (int): Current file index.
frame_idx (int): Current frame index.
"""
pose
=
np
.
array
(
frame
.
pose
.
transform
).
reshape
(
4
,
4
)
np
.
savetxt
(
join
(
f
'
{
self
.
pose_save_dir
}
/
{
self
.
prefix
}
'
+
f
'
{
str
(
file_idx
).
zfill
(
3
)
}{
str
(
frame_idx
).
zfill
(
3
)
}
.txt'
),
pose
)
def
save_timestamp
(
self
,
frame
,
file_idx
,
frame_idx
):
"""Save the timestamp data in a separate file instead of the
pointcloud.
Note that SDC's own pose is not included in the regular training
of KITTI dataset. KITTI raw dataset contains ego motion files
but are not often used. Pose is important for algorithms that
take advantage of the temporal information.
Args:
frame (:obj:`Frame`): Open dataset frame proto.
file_idx (int): Current file index.
frame_idx (int): Current frame index.
"""
with
open
(
join
(
f
'
{
self
.
timestamp_save_dir
}
/
{
self
.
prefix
}
'
+
f
'
{
str
(
file_idx
).
zfill
(
3
)
}{
str
(
frame_idx
).
zfill
(
3
)
}
.txt'
),
'w'
)
as
f
:
f
.
write
(
str
(
frame
.
timestamp_micros
))
def
create_folder
(
self
):
"""Create folder for data preprocessing."""
if
not
self
.
test_mode
:
dir_list1
=
[
self
.
label_all_save_dir
,
self
.
calib_save_dir
,
self
.
pose_save_dir
,
self
.
timestamp_save_dir
,
]
dir_list2
=
[
self
.
label_save_dir
,
self
.
image_save_dir
]
if
self
.
save_cam_sync_labels
:
dir_list1
.
append
(
self
.
cam_sync_label_all_save_dir
)
dir_list2
.
append
(
self
.
cam_sync_label_save_dir
)
else
:
dir_list1
=
[
self
.
calib_save_dir
,
self
.
pose_save_dir
,
self
.
timestamp_save_dir
]
dir_list2
=
[
self
.
image_save_dir
]
if
'testing_3d_camera_only_detection'
not
in
self
.
load_dir
:
dir_list1
.
append
(
self
.
point_cloud_save_dir
)
for
d
in
dir_list1
:
mmengine
.
mkdir_or_exist
(
d
)
for
d
in
dir_list2
:
for
i
in
range
(
5
):
mmengine
.
mkdir_or_exist
(
f
'
{
d
}{
str
(
i
)
}
'
)
def
convert_range_image_to_point_cloud
(
self
,
frame
,
range_images
,
...
...
@@ -604,29 +407,317 @@ class Waymo2KITTI(object):
raise
ValueError
(
mat
.
shape
)
return
ret
def
create_waymo_info_file
(
self
,
frame
,
file_idx
,
frame_idx
,
file_infos
):
r
"""Generate waymo train/val/test infos.
For more details about infos, please refer to:
https://mmdetection3d.readthedocs.io/en/latest/advanced_guides/datasets/waymo.html
"""
# noqa: E501
frame_infos
=
dict
()
# Gather frame infos
sample_idx
=
\
f
'
{
self
.
prefix
}{
str
(
file_idx
).
zfill
(
3
)
}{
str
(
frame_idx
).
zfill
(
3
)
}
'
frame_infos
[
'sample_idx'
]
=
int
(
sample_idx
)
frame_infos
[
'timestamp'
]
=
frame
.
timestamp_micros
frame_infos
[
'ego2global'
]
=
np
.
array
(
frame
.
pose
.
transform
).
reshape
(
4
,
4
).
astype
(
np
.
float32
).
tolist
()
frame_infos
[
'context_name'
]
=
frame
.
context
.
name
# Gather camera infos
frame_infos
[
'images'
]
=
dict
()
# waymo front camera to kitti reference camera
T_front_cam_to_ref
=
np
.
array
([[
0.0
,
-
1.0
,
0.0
],
[
0.0
,
0.0
,
-
1.0
],
[
1.0
,
0.0
,
0.0
]])
camera_calibs
=
[]
Tr_velo_to_cams
=
[]
for
camera
in
frame
.
context
.
camera_calibrations
:
# extrinsic parameters
T_cam_to_vehicle
=
np
.
array
(
camera
.
extrinsic
.
transform
).
reshape
(
4
,
4
)
T_vehicle_to_cam
=
np
.
linalg
.
inv
(
T_cam_to_vehicle
)
Tr_velo_to_cam
=
\
self
.
cart_to_homo
(
T_front_cam_to_ref
)
@
T_vehicle_to_cam
Tr_velo_to_cams
.
append
(
Tr_velo_to_cam
)
# intrinsic parameters
camera_calib
=
np
.
zeros
((
3
,
4
))
camera_calib
[
0
,
0
]
=
camera
.
intrinsic
[
0
]
camera_calib
[
1
,
1
]
=
camera
.
intrinsic
[
1
]
camera_calib
[
0
,
2
]
=
camera
.
intrinsic
[
2
]
camera_calib
[
1
,
2
]
=
camera
.
intrinsic
[
3
]
camera_calib
[
2
,
2
]
=
1
camera_calibs
.
append
(
camera_calib
)
for
i
,
(
cam_key
,
camera_calib
,
Tr_velo_to_cam
)
in
enumerate
(
zip
(
self
.
camera_types
,
camera_calibs
,
Tr_velo_to_cams
)):
cam_infos
=
dict
()
cam_infos
[
'img_path'
]
=
str
(
sample_idx
)
+
'.jpg'
# NOTE: frames.images order is different
for
img
in
frame
.
images
:
if
img
.
name
==
i
+
1
:
width
,
height
=
Image
.
open
(
BytesIO
(
img
.
image
)).
size
cam_infos
[
'height'
]
=
height
cam_infos
[
'width'
]
=
width
cam_infos
[
'lidar2cam'
]
=
Tr_velo_to_cam
.
astype
(
np
.
float32
).
tolist
()
cam_infos
[
'cam2img'
]
=
camera_calib
.
astype
(
np
.
float32
).
tolist
()
cam_infos
[
'lidar2img'
]
=
(
camera_calib
@
Tr_velo_to_cam
).
astype
(
np
.
float32
).
tolist
()
frame_infos
[
'images'
][
cam_key
]
=
cam_infos
# Gather lidar infos
lidar_infos
=
dict
()
lidar_infos
[
'lidar_path'
]
=
str
(
sample_idx
)
+
'.bin'
lidar_infos
[
'num_pts_feats'
]
=
6
frame_infos
[
'lidar_points'
]
=
lidar_infos
# Gather lidar sweeps and camera sweeps infos
# TODO: Add lidar2img in image sweeps infos when we need it.
# TODO: Consider merging lidar sweeps infos and image sweeps infos.
lidar_sweeps_infos
,
image_sweeps_infos
=
[],
[]
for
prev_offset
in
range
(
-
1
,
-
self
.
max_sweeps
-
1
,
-
1
):
prev_lidar_infos
=
dict
()
prev_image_infos
=
dict
()
if
frame_idx
+
prev_offset
>=
0
:
prev_frame_infos
=
file_infos
[
prev_offset
]
prev_lidar_infos
[
'timestamp'
]
=
prev_frame_infos
[
'timestamp'
]
prev_lidar_infos
[
'ego2global'
]
=
prev_frame_infos
[
'ego2global'
]
prev_lidar_infos
[
'lidar_points'
]
=
dict
()
lidar_path
=
prev_frame_infos
[
'lidar_points'
][
'lidar_path'
]
prev_lidar_infos
[
'lidar_points'
][
'lidar_path'
]
=
lidar_path
lidar_sweeps_infos
.
append
(
prev_lidar_infos
)
prev_image_infos
[
'timestamp'
]
=
prev_frame_infos
[
'timestamp'
]
prev_image_infos
[
'ego2global'
]
=
prev_frame_infos
[
'ego2global'
]
prev_image_infos
[
'images'
]
=
dict
()
for
cam_key
in
self
.
camera_types
:
prev_image_infos
[
'images'
][
cam_key
]
=
dict
()
img_path
=
prev_frame_infos
[
'images'
][
cam_key
][
'img_path'
]
prev_image_infos
[
'images'
][
cam_key
][
'img_path'
]
=
img_path
image_sweeps_infos
.
append
(
prev_image_infos
)
if
lidar_sweeps_infos
:
frame_infos
[
'lidar_sweeps'
]
=
lidar_sweeps_infos
if
image_sweeps_infos
:
frame_infos
[
'image_sweeps'
]
=
image_sweeps_infos
if
not
self
.
test_mode
:
# Gather instances infos which is used for lidar-based 3D detection
frame_infos
[
'instances'
]
=
self
.
gather_instance_info
(
frame
)
# Gather cam_sync_instances infos which is used for image-based
# (multi-view) 3D detection.
if
self
.
save_cam_sync_instances
:
frame_infos
[
'cam_sync_instances'
]
=
self
.
gather_instance_info
(
frame
,
cam_sync
=
True
)
# Gather cam_instances infos which is used for image-based
# (monocular) 3D detection (optional).
# TODO: Should we use cam_sync_instances to generate cam_instances?
if
self
.
save_cam_instances
:
frame_infos
[
'cam_instances'
]
=
self
.
gather_cam_instance_info
(
copy
.
deepcopy
(
frame_infos
[
'instances'
]),
frame_infos
[
'images'
])
file_infos
.
append
(
frame_infos
)
def
gather_instance_info
(
self
,
frame
,
cam_sync
=
False
):
"""Generate instances and cam_sync_instances infos.
For more details about infos, please refer to:
https://mmdetection3d.readthedocs.io/en/latest/advanced_guides/datasets/waymo.html
"""
# noqa: E501
id_to_bbox
=
dict
()
id_to_name
=
dict
()
for
labels
in
frame
.
projected_lidar_labels
:
name
=
labels
.
name
for
label
in
labels
.
labels
:
# TODO: need a workaround as bbox may not belong to front cam
bbox
=
[
label
.
box
.
center_x
-
label
.
box
.
length
/
2
,
label
.
box
.
center_y
-
label
.
box
.
width
/
2
,
label
.
box
.
center_x
+
label
.
box
.
length
/
2
,
label
.
box
.
center_y
+
label
.
box
.
width
/
2
]
id_to_bbox
[
label
.
id
]
=
bbox
id_to_name
[
label
.
id
]
=
name
-
1
group_id
=
0
instance_infos
=
[]
for
obj
in
frame
.
laser_labels
:
instance_info
=
dict
()
bounding_box
=
None
name
=
None
id
=
obj
.
id
for
proj_cam
in
self
.
cam_list
:
if
id
+
proj_cam
in
id_to_bbox
:
bounding_box
=
id_to_bbox
.
get
(
id
+
proj_cam
)
name
=
id_to_name
.
get
(
id
+
proj_cam
)
break
# NOTE: the 2D labels do not have strict correspondence with
# the projected 2D lidar labels
# e.g.: the projected 2D labels can be in camera 2
# while the most_visible_camera can have id 4
if
cam_sync
:
if
obj
.
most_visible_camera_name
:
name
=
self
.
cam_list
.
index
(
f
'_
{
obj
.
most_visible_camera_name
}
'
)
box3d
=
obj
.
camera_synced_box
else
:
continue
else
:
box3d
=
obj
.
box
if
bounding_box
is
None
or
name
is
None
:
name
=
0
bounding_box
=
[
0.0
,
0.0
,
0.0
,
0.0
]
my_type
=
self
.
type_list
[
obj
.
type
]
if
my_type
not
in
self
.
selected_waymo_classes
:
continue
else
:
label
=
self
.
selected_waymo_classes
.
index
(
my_type
)
if
self
.
filter_empty_3dboxes
and
obj
.
num_lidar_points_in_box
<
1
:
continue
group_id
+=
1
instance_info
[
'group_id'
]
=
group_id
instance_info
[
'camera_id'
]
=
name
instance_info
[
'bbox'
]
=
bounding_box
instance_info
[
'bbox_label'
]
=
label
height
=
box3d
.
height
width
=
box3d
.
width
length
=
box3d
.
length
# NOTE: We save the bottom center of 3D bboxes.
x
=
box3d
.
center_x
y
=
box3d
.
center_y
z
=
box3d
.
center_z
-
height
/
2
rotation_y
=
box3d
.
heading
instance_info
[
'bbox_3d'
]
=
np
.
array
(
[
x
,
y
,
z
,
length
,
width
,
height
,
rotation_y
]).
astype
(
np
.
float32
).
tolist
()
instance_info
[
'bbox_label_3d'
]
=
label
instance_info
[
'num_lidar_pts'
]
=
obj
.
num_lidar_points_in_box
if
self
.
save_track_id
:
instance_info
[
'track_id'
]
=
obj
.
id
instance_infos
.
append
(
instance_info
)
return
instance_infos
def
gather_cam_instance_info
(
self
,
instances
:
dict
,
images
:
dict
):
"""Generate cam_instances infos.
For more details about infos, please refer to:
https://mmdetection3d.readthedocs.io/en/latest/advanced_guides/datasets/waymo.html
"""
# noqa: E501
cam_instances
=
dict
()
for
cam_type
in
self
.
camera_types
:
lidar2cam
=
np
.
array
(
images
[
cam_type
][
'lidar2cam'
])
cam2img
=
np
.
array
(
images
[
cam_type
][
'cam2img'
])
cam_instances
[
cam_type
]
=
[]
for
instance
in
instances
:
cam_instance
=
dict
()
gt_bboxes_3d
=
np
.
array
(
instance
[
'bbox_3d'
])
# Convert lidar coordinates to camera coordinates
gt_bboxes_3d
=
LiDARInstance3DBoxes
(
gt_bboxes_3d
[
None
,
:]).
convert_to
(
Box3DMode
.
CAM
,
lidar2cam
,
correct_yaw
=
True
)
corners_3d
=
gt_bboxes_3d
.
corners
.
numpy
()
corners_3d
=
corners_3d
[
0
].
T
# (1, 8, 3) -> (3, 8)
in_camera
=
np
.
argwhere
(
corners_3d
[
2
,
:]
>
0
).
flatten
()
corners_3d
=
corners_3d
[:,
in_camera
]
# Project 3d box to 2d.
corner_coords
=
view_points
(
corners_3d
,
cam2img
,
True
).
T
[:,
:
2
].
tolist
()
# Keep only corners that fall within the image.
# TODO: imsize should be determined by the current image size
# CAM_FRONT: (1920, 1280)
# CAM_FRONT_LEFT: (1920, 1280)
# CAM_SIDE_LEFT: (1920, 886)
final_coords
=
post_process_coords
(
corner_coords
,
imsize
=
(
images
[
'CAM_FRONT'
][
'width'
],
images
[
'CAM_FRONT'
][
'height'
]))
# Skip if the convex hull of the re-projected corners
# does not intersect the image canvas.
if
final_coords
is
None
:
continue
else
:
min_x
,
min_y
,
max_x
,
max_y
=
final_coords
cam_instance
[
'bbox'
]
=
[
min_x
,
min_y
,
max_x
,
max_y
]
cam_instance
[
'bbox_label'
]
=
instance
[
'bbox_label'
]
cam_instance
[
'bbox_3d'
]
=
gt_bboxes_3d
.
numpy
().
squeeze
(
).
astype
(
np
.
float32
).
tolist
()
cam_instance
[
'bbox_label_3d'
]
=
instance
[
'bbox_label_3d'
]
center_3d
=
gt_bboxes_3d
.
gravity_center
.
numpy
()
center_2d_with_depth
=
points_cam2img
(
center_3d
,
cam2img
,
with_depth
=
True
)
center_2d_with_depth
=
center_2d_with_depth
.
squeeze
().
tolist
()
# normalized center2D + depth
# if samples with depth < 0 will be removed
if
center_2d_with_depth
[
2
]
<=
0
:
continue
cam_instance
[
'center_2d'
]
=
center_2d_with_depth
[:
2
]
cam_instance
[
'depth'
]
=
center_2d_with_depth
[
2
]
# TODO: Discuss whether following info is necessary
cam_instance
[
'bbox_3d_isvalid'
]
=
True
cam_instance
[
'velocity'
]
=
-
1
cam_instances
[
cam_type
].
append
(
cam_instance
)
return
cam_instances
def
merge_trainval_infos
(
self
):
"""Merge training and validation infos into a single file."""
train_infos_path
=
osp
.
join
(
osp
.
dirname
(
self
.
save_dir
),
f
'
{
self
.
info_prefix
}
_infos_train.pkl'
)
val_infos_path
=
osp
.
join
(
osp
.
dirname
(
self
.
save_dir
),
f
'
{
self
.
info_prefix
}
_infos_val.pkl'
)
train_infos
=
mmengine
.
load
(
train_infos_path
)
val_infos
=
mmengine
.
load
(
val_infos_path
)
trainval_infos
=
dict
(
metainfo
=
train_infos
[
'metainfo'
],
data_list
=
train_infos
[
'data_list'
]
+
val_infos
[
'data_list'
])
mmengine
.
dump
(
trainval_infos
,
osp
.
join
(
osp
.
dirname
(
self
.
save_dir
),
f
'
{
self
.
info_prefix
}
_infos_trainval.pkl'
))
def
create_ImageSets_img_ids
(
root_dir
,
splits
):
"""Create txt files indicating what to collect in each split."""
save_dir
=
join
(
root_dir
,
'ImageSets/'
)
if
not
exists
(
save_dir
):
os
.
mkdir
(
save_dir
)
idx_all
=
[[]
for
i
in
splits
]
idx_all
=
[[]
for
_
in
splits
]
for
i
,
split
in
enumerate
(
splits
):
path
=
join
(
root_dir
,
split
s
[
i
],
'calib
'
)
path
=
join
(
root_dir
,
split
,
'image_0
'
)
if
not
exists
(
path
):
RawNames
=
[]
else
:
RawNames
=
os
.
listdir
(
path
)
for
name
in
RawNames
:
if
name
.
endswith
(
'.
txt
'
):
idx
=
name
.
replace
(
'.
txt
'
,
'
\n
'
)
if
name
.
endswith
(
'.
jpg
'
):
idx
=
name
.
replace
(
'.
jpg
'
,
'
\n
'
)
idx_all
[
int
(
idx
[
0
])].
append
(
idx
)
idx_all
[
i
].
sort
()
open
(
save_dir
+
'train.txt'
,
'w'
).
writelines
(
idx_all
[
0
])
open
(
save_dir
+
'val.txt'
,
'w'
).
writelines
(
idx_all
[
1
])
open
(
save_dir
+
'trainval.txt'
,
'w'
).
writelines
(
idx_all
[
0
]
+
idx_all
[
1
])
open
(
save_dir
+
'test.txt'
,
'w'
).
writelines
(
idx_all
[
2
])
# open(save_dir+'test_cam_only.txt','w').writelines(idx_all[3])
if
len
(
idx_all
)
>=
3
:
open
(
save_dir
+
'test.txt'
,
'w'
).
writelines
(
idx_all
[
2
])
if
len
(
idx_all
)
>=
4
:
open
(
save_dir
+
'test_cam_only.txt'
,
'w'
).
writelines
(
idx_all
[
3
])
print
(
'created txt files indicating what to collect in '
,
splits
)
tools/train.py
View file @
fe25f7a5
...
...
@@ -21,6 +21,12 @@ def parse_args():
action
=
'store_true'
,
default
=
False
,
help
=
'enable automatic-mixed-precision training'
)
parser
.
add_argument
(
'--sync_bn'
,
choices
=
[
'none'
,
'torch'
,
'mmcv'
],
default
=
'none'
,
help
=
'convert all BatchNorm layers in the model to SyncBatchNorm '
'(SyncBN) or mmcv.ops.sync_bn.SyncBatchNorm (MMSyncBN) layers.'
)
parser
.
add_argument
(
'--auto-scale-lr'
,
action
=
'store_true'
,
...
...
@@ -98,6 +104,10 @@ def main():
cfg
.
optim_wrapper
.
type
=
'AmpOptimWrapper'
cfg
.
optim_wrapper
.
loss_scale
=
'dynamic'
# convert BatchNorm layers
if
args
.
sync_bn
!=
'none'
:
cfg
.
sync_bn
=
args
.
sync_bn
# enable automatically scaling LR
if
args
.
auto_scale_lr
:
if
'auto_scale_lr'
in
cfg
and
\
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment