Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
lishj6
Sparse4d
Commits
a9dc86e9
Commit
a9dc86e9
authored
Sep 05, 2025
by
lishj6
🏸
Browse files
init_0905
parent
18eda5c1
Changes
79
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3521 additions
and
0 deletions
+3521
-0
projects/mmdet3d_plugin/datasets/pipelines/loading.py
projects/mmdet3d_plugin/datasets/pipelines/loading.py
+188
-0
projects/mmdet3d_plugin/datasets/pipelines/transform.py
projects/mmdet3d_plugin/datasets/pipelines/transform.py
+222
-0
projects/mmdet3d_plugin/datasets/samplers/__init__.py
projects/mmdet3d_plugin/datasets/samplers/__init__.py
+6
-0
projects/mmdet3d_plugin/datasets/samplers/distributed_sampler.py
...s/mmdet3d_plugin/datasets/samplers/distributed_sampler.py
+82
-0
projects/mmdet3d_plugin/datasets/samplers/group_in_batch_sampler.py
...mdet3d_plugin/datasets/samplers/group_in_batch_sampler.py
+178
-0
projects/mmdet3d_plugin/datasets/samplers/group_sampler.py
projects/mmdet3d_plugin/datasets/samplers/group_sampler.py
+119
-0
projects/mmdet3d_plugin/datasets/samplers/sampler.py
projects/mmdet3d_plugin/datasets/samplers/sampler.py
+7
-0
projects/mmdet3d_plugin/datasets/utils.py
projects/mmdet3d_plugin/datasets/utils.py
+225
-0
projects/mmdet3d_plugin/models/__init__.py
projects/mmdet3d_plugin/models/__init__.py
+30
-0
projects/mmdet3d_plugin/models/base_target.py
projects/mmdet3d_plugin/models/base_target.py
+49
-0
projects/mmdet3d_plugin/models/blocks.py
projects/mmdet3d_plugin/models/blocks.py
+394
-0
projects/mmdet3d_plugin/models/detection3d/__init__.py
projects/mmdet3d_plugin/models/detection3d/__init__.py
+8
-0
projects/mmdet3d_plugin/models/detection3d/decoder.py
projects/mmdet3d_plugin/models/detection3d/decoder.py
+106
-0
projects/mmdet3d_plugin/models/detection3d/detection3d_blocks.py
...s/mmdet3d_plugin/models/detection3d/detection3d_blocks.py
+304
-0
projects/mmdet3d_plugin/models/detection3d/losses.py
projects/mmdet3d_plugin/models/detection3d/losses.py
+92
-0
projects/mmdet3d_plugin/models/detection3d/target.py
projects/mmdet3d_plugin/models/detection3d/target.py
+436
-0
projects/mmdet3d_plugin/models/grid_mask.py
projects/mmdet3d_plugin/models/grid_mask.py
+139
-0
projects/mmdet3d_plugin/models/instance_bank.py
projects/mmdet3d_plugin/models/instance_bank.py
+254
-0
projects/mmdet3d_plugin/models/sparse4d.py
projects/mmdet3d_plugin/models/sparse4d.py
+130
-0
projects/mmdet3d_plugin/models/sparse4d_head.py
projects/mmdet3d_plugin/models/sparse4d_head.py
+552
-0
No files found.
projects/mmdet3d_plugin/datasets/pipelines/loading.py
0 → 100644
View file @
a9dc86e9
import
numpy
as
np
import
mmcv
from
mmdet.datasets.builder
import
PIPELINES
@
PIPELINES
.
register_module
()
class
LoadMultiViewImageFromFiles
(
object
):
"""Load multi channel images from a list of separate channel files.
Expects results['img_filename'] to be a list of filenames.
Args:
to_float32 (bool, optional): Whether to convert the img to float32.
Defaults to False.
color_type (str, optional): Color type of the file.
Defaults to 'unchanged'.
"""
def
__init__
(
self
,
to_float32
=
False
,
color_type
=
"unchanged"
):
self
.
to_float32
=
to_float32
self
.
color_type
=
color_type
def
__call__
(
self
,
results
):
"""Call function to load multi-view image from files.
Args:
results (dict): Result dict containing multi-view image filenames.
Returns:
dict: The result dict containing the multi-view image data.
Added keys and values are described below.
- filename (str): Multi-view image filenames.
- img (np.ndarray): Multi-view image arrays.
- img_shape (tuple[int]): Shape of multi-view image arrays.
- ori_shape (tuple[int]): Shape of original image arrays.
- pad_shape (tuple[int]): Shape of padded image arrays.
- scale_factor (float): Scale factor.
- img_norm_cfg (dict): Normalization configuration of images.
"""
filename
=
results
[
"img_filename"
]
# img is of shape (h, w, c, num_views)
img
=
np
.
stack
(
[
mmcv
.
imread
(
name
,
self
.
color_type
)
for
name
in
filename
],
axis
=-
1
)
if
self
.
to_float32
:
img
=
img
.
astype
(
np
.
float32
)
results
[
"filename"
]
=
filename
# unravel to list, see `DefaultFormatBundle` in formatting.py
# which will transpose each image separately and then stack into array
results
[
"img"
]
=
[
img
[...,
i
]
for
i
in
range
(
img
.
shape
[
-
1
])]
results
[
"img_shape"
]
=
img
.
shape
results
[
"ori_shape"
]
=
img
.
shape
# Set initial values for default meta_keys
results
[
"pad_shape"
]
=
img
.
shape
results
[
"scale_factor"
]
=
1.0
num_channels
=
1
if
len
(
img
.
shape
)
<
3
else
img
.
shape
[
2
]
results
[
"img_norm_cfg"
]
=
dict
(
mean
=
np
.
zeros
(
num_channels
,
dtype
=
np
.
float32
),
std
=
np
.
ones
(
num_channels
,
dtype
=
np
.
float32
),
to_rgb
=
False
,
)
return
results
def
__repr__
(
self
):
"""str: Return a string that describes the module."""
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
"(to_float32=
{
self
.
to_float32
}
, "
repr_str
+=
f
"color_type='
{
self
.
color_type
}
')"
return
repr_str
@
PIPELINES
.
register_module
()
class
LoadPointsFromFile
(
object
):
"""Load Points From File.
Load points from file.
Args:
coord_type (str): The type of coordinates of points cloud.
Available options includes:
- 'LIDAR': Points in LiDAR coordinates.
- 'DEPTH': Points in depth coordinates, usually for indoor dataset.
- 'CAMERA': Points in camera coordinates.
load_dim (int, optional): The dimension of the loaded points.
Defaults to 6.
use_dim (list[int], optional): Which dimensions of the points to use.
Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4
or use_dim=[0, 1, 2, 3] to use the intensity dimension.
shift_height (bool, optional): Whether to use shifted height.
Defaults to False.
use_color (bool, optional): Whether to use color features.
Defaults to False.
file_client_args (dict, optional): Config dict of file clients,
refer to
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
for more details. Defaults to dict(backend='disk').
"""
def
__init__
(
self
,
coord_type
,
load_dim
=
6
,
use_dim
=
[
0
,
1
,
2
],
shift_height
=
False
,
use_color
=
False
,
file_client_args
=
dict
(
backend
=
"disk"
),
):
self
.
shift_height
=
shift_height
self
.
use_color
=
use_color
if
isinstance
(
use_dim
,
int
):
use_dim
=
list
(
range
(
use_dim
))
assert
(
max
(
use_dim
)
<
load_dim
),
f
"Expect all used dimensions <
{
load_dim
}
, got
{
use_dim
}
"
assert
coord_type
in
[
"CAMERA"
,
"LIDAR"
,
"DEPTH"
]
self
.
coord_type
=
coord_type
self
.
load_dim
=
load_dim
self
.
use_dim
=
use_dim
self
.
file_client_args
=
file_client_args
.
copy
()
self
.
file_client
=
None
def
_load_points
(
self
,
pts_filename
):
"""Private function to load point clouds data.
Args:
pts_filename (str): Filename of point clouds data.
Returns:
np.ndarray: An array containing point clouds data.
"""
if
self
.
file_client
is
None
:
self
.
file_client
=
mmcv
.
FileClient
(
**
self
.
file_client_args
)
try
:
pts_bytes
=
self
.
file_client
.
get
(
pts_filename
)
points
=
np
.
frombuffer
(
pts_bytes
,
dtype
=
np
.
float32
)
except
ConnectionError
:
mmcv
.
check_file_exist
(
pts_filename
)
if
pts_filename
.
endswith
(
".npy"
):
points
=
np
.
load
(
pts_filename
)
else
:
points
=
np
.
fromfile
(
pts_filename
,
dtype
=
np
.
float32
)
return
points
def
__call__
(
self
,
results
):
"""Call function to load points data from file.
Args:
results (dict): Result dict containing point clouds data.
Returns:
dict: The result dict containing the point clouds data.
Added key and value are described below.
- points (:obj:`BasePoints`): Point clouds data.
"""
pts_filename
=
results
[
"pts_filename"
]
points
=
self
.
_load_points
(
pts_filename
)
points
=
points
.
reshape
(
-
1
,
self
.
load_dim
)
points
=
points
[:,
self
.
use_dim
]
attribute_dims
=
None
if
self
.
shift_height
:
floor_height
=
np
.
percentile
(
points
[:,
2
],
0.99
)
height
=
points
[:,
2
]
-
floor_height
points
=
np
.
concatenate
(
[
points
[:,
:
3
],
np
.
expand_dims
(
height
,
1
),
points
[:,
3
:]],
1
)
attribute_dims
=
dict
(
height
=
3
)
if
self
.
use_color
:
assert
len
(
self
.
use_dim
)
>=
6
if
attribute_dims
is
None
:
attribute_dims
=
dict
()
attribute_dims
.
update
(
dict
(
color
=
[
points
.
shape
[
1
]
-
3
,
points
.
shape
[
1
]
-
2
,
points
.
shape
[
1
]
-
1
,
]
)
)
results
[
"points"
]
=
points
return
results
projects/mmdet3d_plugin/datasets/pipelines/transform.py
0 → 100644
View file @
a9dc86e9
import
numpy
as
np
import
mmcv
from
mmcv.parallel
import
DataContainer
as
DC
from
mmdet.datasets.builder
import
PIPELINES
from
mmdet.datasets.pipelines
import
to_tensor
@
PIPELINES
.
register_module
()
class
MultiScaleDepthMapGenerator
(
object
):
def
__init__
(
self
,
downsample
=
1
,
max_depth
=
60
):
if
not
isinstance
(
downsample
,
(
list
,
tuple
)):
downsample
=
[
downsample
]
self
.
downsample
=
downsample
self
.
max_depth
=
max_depth
def
__call__
(
self
,
input_dict
):
points
=
input_dict
[
"points"
][...,
:
3
,
None
]
gt_depth
=
[]
for
i
,
lidar2img
in
enumerate
(
input_dict
[
"lidar2img"
]):
H
,
W
=
input_dict
[
"img_shape"
][
i
][:
2
]
pts_2d
=
(
np
.
squeeze
(
lidar2img
[:
3
,
:
3
]
@
points
,
axis
=-
1
)
+
lidar2img
[:
3
,
3
]
)
pts_2d
[:,
:
2
]
/=
pts_2d
[:,
2
:
3
]
U
=
np
.
round
(
pts_2d
[:,
0
]).
astype
(
np
.
int32
)
V
=
np
.
round
(
pts_2d
[:,
1
]).
astype
(
np
.
int32
)
depths
=
pts_2d
[:,
2
]
mask
=
np
.
logical_and
.
reduce
(
[
V
>=
0
,
V
<
H
,
U
>=
0
,
U
<
W
,
depths
>=
0.1
,
# depths <= self.max_depth,
]
)
V
,
U
,
depths
=
V
[
mask
],
U
[
mask
],
depths
[
mask
]
sort_idx
=
np
.
argsort
(
depths
)[::
-
1
]
V
,
U
,
depths
=
V
[
sort_idx
],
U
[
sort_idx
],
depths
[
sort_idx
]
depths
=
np
.
clip
(
depths
,
0.1
,
self
.
max_depth
)
for
j
,
downsample
in
enumerate
(
self
.
downsample
):
if
len
(
gt_depth
)
<
j
+
1
:
gt_depth
.
append
([])
h
,
w
=
(
int
(
H
/
downsample
),
int
(
W
/
downsample
))
u
=
np
.
floor
(
U
/
downsample
).
astype
(
np
.
int32
)
v
=
np
.
floor
(
V
/
downsample
).
astype
(
np
.
int32
)
depth_map
=
np
.
ones
([
h
,
w
],
dtype
=
np
.
float32
)
*
-
1
depth_map
[
v
,
u
]
=
depths
gt_depth
[
j
].
append
(
depth_map
)
input_dict
[
"gt_depth"
]
=
[
np
.
stack
(
x
)
for
x
in
gt_depth
]
return
input_dict
@
PIPELINES
.
register_module
()
class
NuScenesSparse4DAdaptor
(
object
):
def
__init
(
self
):
pass
def
__call__
(
self
,
input_dict
):
input_dict
[
"projection_mat"
]
=
np
.
float32
(
np
.
stack
(
input_dict
[
"lidar2img"
])
)
input_dict
[
"image_wh"
]
=
np
.
ascontiguousarray
(
np
.
array
(
input_dict
[
"img_shape"
],
dtype
=
np
.
float32
)[:,
:
2
][:,
::
-
1
]
)
input_dict
[
"T_global_inv"
]
=
np
.
linalg
.
inv
(
input_dict
[
"lidar2global"
])
input_dict
[
"T_global"
]
=
input_dict
[
"lidar2global"
]
if
"cam_intrinsic"
in
input_dict
:
input_dict
[
"cam_intrinsic"
]
=
np
.
float32
(
np
.
stack
(
input_dict
[
"cam_intrinsic"
])
)
input_dict
[
"focal"
]
=
input_dict
[
"cam_intrinsic"
][...,
0
,
0
]
# input_dict["focal"] = np.sqrt(
# np.abs(np.linalg.det(input_dict["cam_intrinsic"][:, :2, :2]))
# )
if
"instance_inds"
in
input_dict
:
input_dict
[
"instance_id"
]
=
input_dict
[
"instance_inds"
]
if
"gt_bboxes_3d"
in
input_dict
:
input_dict
[
"gt_bboxes_3d"
][:,
6
]
=
self
.
limit_period
(
input_dict
[
"gt_bboxes_3d"
][:,
6
],
offset
=
0.5
,
period
=
2
*
np
.
pi
)
input_dict
[
"gt_bboxes_3d"
]
=
DC
(
to_tensor
(
input_dict
[
"gt_bboxes_3d"
]).
float
()
)
if
"gt_labels_3d"
in
input_dict
:
input_dict
[
"gt_labels_3d"
]
=
DC
(
to_tensor
(
input_dict
[
"gt_labels_3d"
]).
long
()
)
imgs
=
[
img
.
transpose
(
2
,
0
,
1
)
for
img
in
input_dict
[
"img"
]]
imgs
=
np
.
ascontiguousarray
(
np
.
stack
(
imgs
,
axis
=
0
))
input_dict
[
"img"
]
=
DC
(
to_tensor
(
imgs
),
stack
=
True
)
return
input_dict
def
limit_period
(
self
,
val
:
np
.
ndarray
,
offset
:
float
=
0.5
,
period
:
float
=
np
.
pi
)
->
np
.
ndarray
:
limited_val
=
val
-
np
.
floor
(
val
/
period
+
offset
)
*
period
return
limited_val
@
PIPELINES
.
register_module
()
class
InstanceNameFilter
(
object
):
"""Filter GT objects by their names.
Args:
classes (list[str]): List of class names to be kept for training.
"""
def
__init__
(
self
,
classes
):
self
.
classes
=
classes
self
.
labels
=
list
(
range
(
len
(
self
.
classes
)))
def
__call__
(
self
,
input_dict
):
"""Call function to filter objects by their names.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d'
\
keys are updated in the result dict.
"""
gt_labels_3d
=
input_dict
[
"gt_labels_3d"
]
gt_bboxes_mask
=
np
.
array
(
[
n
in
self
.
labels
for
n
in
gt_labels_3d
],
dtype
=
np
.
bool_
)
input_dict
[
"gt_bboxes_3d"
]
=
input_dict
[
"gt_bboxes_3d"
][
gt_bboxes_mask
]
input_dict
[
"gt_labels_3d"
]
=
input_dict
[
"gt_labels_3d"
][
gt_bboxes_mask
]
if
"instance_inds"
in
input_dict
:
input_dict
[
"instance_inds"
]
=
input_dict
[
"instance_inds"
][
gt_bboxes_mask
]
return
input_dict
def
__repr__
(
self
):
"""str: Return a string that describes the module."""
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
"(classes=
{
self
.
classes
}
)"
return
repr_str
@
PIPELINES
.
register_module
()
class
CircleObjectRangeFilter
(
object
):
def
__init__
(
self
,
class_dist_thred
=
[
52.5
]
*
5
+
[
31.5
]
+
[
42
]
*
3
+
[
31.5
]
):
self
.
class_dist_thred
=
class_dist_thred
def
__call__
(
self
,
input_dict
):
gt_bboxes_3d
=
input_dict
[
"gt_bboxes_3d"
]
gt_labels_3d
=
input_dict
[
"gt_labels_3d"
]
dist
=
np
.
sqrt
(
np
.
sum
(
gt_bboxes_3d
[:,
:
2
]
**
2
,
axis
=-
1
)
)
mask
=
np
.
array
([
False
]
*
len
(
dist
))
for
label_idx
,
dist_thred
in
enumerate
(
self
.
class_dist_thred
):
mask
=
np
.
logical_or
(
mask
,
np
.
logical_and
(
gt_labels_3d
==
label_idx
,
dist
<=
dist_thred
),
)
gt_bboxes_3d
=
gt_bboxes_3d
[
mask
]
gt_labels_3d
=
gt_labels_3d
[
mask
]
input_dict
[
"gt_bboxes_3d"
]
=
gt_bboxes_3d
input_dict
[
"gt_labels_3d"
]
=
gt_labels_3d
if
"instance_inds"
in
input_dict
:
input_dict
[
"instance_inds"
]
=
input_dict
[
"instance_inds"
][
mask
]
return
input_dict
def
__repr__
(
self
):
"""str: Return a string that describes the module."""
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
"(class_dist_thred=
{
self
.
class_dist_thred
}
)"
return
repr_str
@
PIPELINES
.
register_module
()
class
NormalizeMultiviewImage
(
object
):
"""Normalize the image.
Added key is "img_norm_cfg".
Args:
mean (sequence): Mean values of 3 channels.
std (sequence): Std values of 3 channels.
to_rgb (bool): Whether to convert the image from BGR to RGB,
default is true.
"""
def
__init__
(
self
,
mean
,
std
,
to_rgb
=
True
):
self
.
mean
=
np
.
array
(
mean
,
dtype
=
np
.
float32
)
self
.
std
=
np
.
array
(
std
,
dtype
=
np
.
float32
)
self
.
to_rgb
=
to_rgb
def
__call__
(
self
,
results
):
"""Call function to normalize images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Normalized results, 'img_norm_cfg' key is added into
result dict.
"""
results
[
"img"
]
=
[
mmcv
.
imnormalize
(
img
,
self
.
mean
,
self
.
std
,
self
.
to_rgb
)
for
img
in
results
[
"img"
]
]
results
[
"img_norm_cfg"
]
=
dict
(
mean
=
self
.
mean
,
std
=
self
.
std
,
to_rgb
=
self
.
to_rgb
)
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
"(mean=
{
self
.
mean
}
, std=
{
self
.
std
}
, to_rgb=
{
self
.
to_rgb
}
)"
return
repr_str
projects/mmdet3d_plugin/datasets/samplers/__init__.py
0 → 100644
View file @
a9dc86e9
from
.group_sampler
import
DistributedGroupSampler
from
.distributed_sampler
import
DistributedSampler
from
.sampler
import
SAMPLER
,
build_sampler
from
.group_in_batch_sampler
import
(
GroupInBatchSampler
,
)
projects/mmdet3d_plugin/datasets/samplers/distributed_sampler.py
0 → 100644
View file @
a9dc86e9
import
math
import
torch
from
torch.utils.data
import
DistributedSampler
as
_DistributedSampler
from
.sampler
import
SAMPLER
import
pdb
import
sys
class
ForkedPdb
(
pdb
.
Pdb
):
def
interaction
(
self
,
*
args
,
**
kwargs
):
_stdin
=
sys
.
stdin
try
:
sys
.
stdin
=
open
(
"/dev/stdin"
)
pdb
.
Pdb
.
interaction
(
self
,
*
args
,
**
kwargs
)
finally
:
sys
.
stdin
=
_stdin
def
set_trace
():
ForkedPdb
().
set_trace
(
sys
.
_getframe
().
f_back
)
@
SAMPLER
.
register_module
()
class
DistributedSampler
(
_DistributedSampler
):
def
__init__
(
self
,
dataset
=
None
,
num_replicas
=
None
,
rank
=
None
,
shuffle
=
True
,
seed
=
0
):
super
().
__init__
(
dataset
,
num_replicas
=
num_replicas
,
rank
=
rank
,
shuffle
=
shuffle
)
# for the compatibility from PyTorch 1.3+
self
.
seed
=
seed
if
seed
is
not
None
else
0
def
__iter__
(
self
):
# deterministically shuffle based on epoch
assert
not
self
.
shuffle
if
"data_infos"
in
dir
(
self
.
dataset
):
timestamps
=
[
x
[
"timestamp"
]
/
1e6
for
x
in
self
.
dataset
.
data_infos
]
vehicle_idx
=
[
x
[
"lidar_path"
].
split
(
"/"
)[
-
1
][:
4
]
if
"lidar_path"
in
x
else
None
for
x
in
self
.
dataset
.
data_infos
]
else
:
timestamps
=
[
x
[
"timestamp"
]
/
1e6
for
x
in
self
.
dataset
.
datasets
[
0
].
data_infos
]
*
len
(
self
.
dataset
.
datasets
)
vehicle_idx
=
[
x
[
"lidar_path"
].
split
(
"/"
)[
-
1
][:
4
]
if
"lidar_path"
in
x
else
None
for
x
in
self
.
dataset
.
datasets
[
0
].
data_infos
]
*
len
(
self
.
dataset
.
datasets
)
sequence_splits
=
[]
for
i
in
range
(
len
(
timestamps
)):
if
i
==
0
or
(
abs
(
timestamps
[
i
]
-
timestamps
[
i
-
1
])
>
4
or
vehicle_idx
[
i
]
!=
vehicle_idx
[
i
-
1
]
):
sequence_splits
.
append
([
i
])
else
:
sequence_splits
[
-
1
].
append
(
i
)
indices
=
[]
perfix_sum
=
0
split_length
=
len
(
self
.
dataset
)
//
self
.
num_replicas
for
i
in
range
(
len
(
sequence_splits
)):
if
perfix_sum
>=
(
self
.
rank
+
1
)
*
split_length
:
break
elif
perfix_sum
>=
self
.
rank
*
split_length
:
indices
.
extend
(
sequence_splits
[
i
])
perfix_sum
+=
len
(
sequence_splits
[
i
])
self
.
num_samples
=
len
(
indices
)
return
iter
(
indices
)
projects/mmdet3d_plugin/datasets/samplers/group_in_batch_sampler.py
0 → 100644
View file @
a9dc86e9
# https://github.com/Divadi/SOLOFusion/blob/main/mmdet3d/datasets/samplers/infinite_group_each_sample_in_batch_sampler.py
import
itertools
import
copy
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
from
mmcv.runner
import
get_dist_info
from
torch.utils.data.sampler
import
Sampler
# https://github.com/open-mmlab/mmdetection/blob/3b72b12fe9b14de906d1363982b9fba05e7d47c1/mmdet/core/utils/dist_utils.py#L157
def
sync_random_seed
(
seed
=
None
,
device
=
"cuda"
):
"""Make sure different ranks share the same seed.
All workers must call this function, otherwise it will deadlock.
This method is generally used in `DistributedSampler`,
because the seed should be identical across all processes
in the distributed group.
In distributed sampling, different ranks should sample non-overlapped
data in the dataset. Therefore, this function is used to make sure that
each rank shuffles the data indices in the same order based
on the same seed. Then different ranks could use different indices
to select non-overlapped data from the same data list.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if
seed
is
None
:
seed
=
np
.
random
.
randint
(
2
**
31
)
assert
isinstance
(
seed
,
int
)
rank
,
world_size
=
get_dist_info
()
if
world_size
==
1
:
return
seed
if
rank
==
0
:
random_num
=
torch
.
tensor
(
seed
,
dtype
=
torch
.
int32
,
device
=
device
)
else
:
random_num
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int32
,
device
=
device
)
dist
.
broadcast
(
random_num
,
src
=
0
)
return
random_num
.
item
()
class
GroupInBatchSampler
(
Sampler
):
"""
Pardon this horrendous name. Basically, we want every sample to be from its own group.
If batch size is 4 and # of GPUs is 8, each sample of these 32 should be operating on
its own group.
Shuffling is only done for group order, not done within groups.
"""
def
__init__
(
self
,
dataset
,
batch_size
=
1
,
world_size
=
None
,
rank
=
None
,
seed
=
0
,
skip_prob
=
0.5
,
sequence_flip_prob
=
0.1
,
):
_rank
,
_world_size
=
get_dist_info
()
if
world_size
is
None
:
world_size
=
_world_size
if
rank
is
None
:
rank
=
_rank
self
.
dataset
=
dataset
self
.
batch_size
=
batch_size
self
.
world_size
=
world_size
self
.
rank
=
rank
self
.
seed
=
sync_random_seed
(
seed
)
self
.
size
=
len
(
self
.
dataset
)
assert
hasattr
(
self
.
dataset
,
"flag"
)
self
.
flag
=
self
.
dataset
.
flag
self
.
group_sizes
=
np
.
bincount
(
self
.
flag
)
self
.
groups_num
=
len
(
self
.
group_sizes
)
self
.
global_batch_size
=
batch_size
*
world_size
assert
self
.
groups_num
>=
self
.
global_batch_size
# Now, for efficiency, make a dict group_idx: List[dataset sample_idxs]
self
.
group_idx_to_sample_idxs
=
{
group_idx
:
np
.
where
(
self
.
flag
==
group_idx
)[
0
].
tolist
()
for
group_idx
in
range
(
self
.
groups_num
)
}
# Get a generator per sample idx. Considering samples over all
# GPUs, each sample position has its own generator
self
.
group_indices_per_global_sample_idx
=
[
self
.
_group_indices_per_global_sample_idx
(
self
.
rank
*
self
.
batch_size
+
local_sample_idx
)
for
local_sample_idx
in
range
(
self
.
batch_size
)
]
# Keep track of a buffer of dataset sample idxs for each local sample idx
self
.
buffer_per_local_sample
=
[[]
for
_
in
range
(
self
.
batch_size
)]
self
.
aug_per_local_sample
=
[
None
for
_
in
range
(
self
.
batch_size
)]
self
.
skip_prob
=
skip_prob
self
.
sequence_flip_prob
=
sequence_flip_prob
def
_infinite_group_indices
(
self
):
g
=
torch
.
Generator
()
g
.
manual_seed
(
self
.
seed
)
while
True
:
yield
from
torch
.
randperm
(
self
.
groups_num
,
generator
=
g
).
tolist
()
def
_group_indices_per_global_sample_idx
(
self
,
global_sample_idx
):
yield
from
itertools
.
islice
(
self
.
_infinite_group_indices
(),
global_sample_idx
,
None
,
self
.
global_batch_size
,
)
def
__iter__
(
self
):
while
True
:
curr_batch
=
[]
for
local_sample_idx
in
range
(
self
.
batch_size
):
skip
=
(
np
.
random
.
uniform
()
<
self
.
skip_prob
and
len
(
self
.
buffer_per_local_sample
[
local_sample_idx
])
>
1
)
if
len
(
self
.
buffer_per_local_sample
[
local_sample_idx
])
==
0
:
# Finished current group, refill with next group
# skip = False
new_group_idx
=
next
(
self
.
group_indices_per_global_sample_idx
[
local_sample_idx
]
)
self
.
buffer_per_local_sample
[
local_sample_idx
]
=
copy
.
deepcopy
(
self
.
group_idx_to_sample_idxs
[
new_group_idx
]
)
if
np
.
random
.
uniform
()
<
self
.
sequence_flip_prob
:
self
.
buffer_per_local_sample
[
local_sample_idx
]
=
self
.
buffer_per_local_sample
[
local_sample_idx
][
::
-
1
]
if
self
.
dataset
.
keep_consistent_seq_aug
:
self
.
aug_per_local_sample
[
local_sample_idx
]
=
self
.
dataset
.
get_augmentation
()
if
not
self
.
dataset
.
keep_consistent_seq_aug
:
self
.
aug_per_local_sample
[
local_sample_idx
]
=
self
.
dataset
.
get_augmentation
()
if
skip
:
self
.
buffer_per_local_sample
[
local_sample_idx
].
pop
(
0
)
curr_batch
.
append
(
dict
(
idx
=
self
.
buffer_per_local_sample
[
local_sample_idx
].
pop
(
0
),
aug_config
=
self
.
aug_per_local_sample
[
local_sample_idx
],
)
)
yield
curr_batch
def
__len__
(
self
):
"""Length of base dataset."""
return
self
.
size
def
set_epoch
(
self
,
epoch
):
self
.
epoch
=
epoch
projects/mmdet3d_plugin/datasets/samplers/group_sampler.py
0 → 100644
View file @
a9dc86e9
# Copyright (c) OpenMMLab. All rights reserved.
import
math
import
numpy
as
np
import
torch
from
mmcv.runner
import
get_dist_info
from
torch.utils.data
import
Sampler
from
.sampler
import
SAMPLER
import
random
from
IPython
import
embed
@
SAMPLER
.
register_module
()
class
DistributedGroupSampler
(
Sampler
):
"""Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Arguments:
dataset: Dataset used for sampling.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
seed (int, optional): random seed used to shuffle the sampler if
``shuffle=True``. This number should be identical across all
processes in the distributed group. Default: 0.
"""
def
__init__
(
self
,
dataset
,
samples_per_gpu
=
1
,
num_replicas
=
None
,
rank
=
None
,
seed
=
0
):
_rank
,
_num_replicas
=
get_dist_info
()
if
num_replicas
is
None
:
num_replicas
=
_num_replicas
if
rank
is
None
:
rank
=
_rank
self
.
dataset
=
dataset
self
.
samples_per_gpu
=
samples_per_gpu
self
.
num_replicas
=
num_replicas
self
.
rank
=
rank
self
.
epoch
=
0
self
.
seed
=
seed
if
seed
is
not
None
else
0
assert
hasattr
(
self
.
dataset
,
"flag"
)
self
.
flag
=
self
.
dataset
.
flag
self
.
group_sizes
=
np
.
bincount
(
self
.
flag
)
self
.
num_samples
=
0
for
i
,
j
in
enumerate
(
self
.
group_sizes
):
self
.
num_samples
+=
(
int
(
math
.
ceil
(
self
.
group_sizes
[
i
]
*
1.0
/
self
.
samples_per_gpu
/
self
.
num_replicas
)
)
*
self
.
samples_per_gpu
)
self
.
total_size
=
self
.
num_samples
*
self
.
num_replicas
def
__iter__
(
self
):
# deterministically shuffle based on epoch
g
=
torch
.
Generator
()
g
.
manual_seed
(
self
.
epoch
+
self
.
seed
)
indices
=
[]
for
i
,
size
in
enumerate
(
self
.
group_sizes
):
if
size
>
0
:
indice
=
np
.
where
(
self
.
flag
==
i
)[
0
]
assert
len
(
indice
)
==
size
# add .numpy() to avoid bug when selecting indice in parrots.
# TODO: check whether torch.randperm() can be replaced by
# numpy.random.permutation().
indice
=
indice
[
list
(
torch
.
randperm
(
int
(
size
),
generator
=
g
).
numpy
())
].
tolist
()
extra
=
int
(
math
.
ceil
(
size
*
1.0
/
self
.
samples_per_gpu
/
self
.
num_replicas
)
)
*
self
.
samples_per_gpu
*
self
.
num_replicas
-
len
(
indice
)
# pad indice
tmp
=
indice
.
copy
()
for
_
in
range
(
extra
//
size
):
indice
.
extend
(
tmp
)
indice
.
extend
(
tmp
[:
extra
%
size
])
indices
.
extend
(
indice
)
assert
len
(
indices
)
==
self
.
total_size
indices
=
[
indices
[
j
]
for
i
in
list
(
torch
.
randperm
(
len
(
indices
)
//
self
.
samples_per_gpu
,
generator
=
g
)
)
for
j
in
range
(
i
*
self
.
samples_per_gpu
,
(
i
+
1
)
*
self
.
samples_per_gpu
)
]
# subsample
offset
=
self
.
num_samples
*
self
.
rank
indices
=
indices
[
offset
:
offset
+
self
.
num_samples
]
assert
len
(
indices
)
==
self
.
num_samples
return
iter
(
indices
)
def
__len__
(
self
):
return
self
.
num_samples
def
set_epoch
(
self
,
epoch
):
self
.
epoch
=
epoch
projects/mmdet3d_plugin/datasets/samplers/sampler.py
0 → 100644
View file @
a9dc86e9
from
mmcv.utils.registry
import
Registry
,
build_from_cfg
SAMPLER
=
Registry
(
"sampler"
)
def
build_sampler
(
cfg
,
default_args
):
return
build_from_cfg
(
cfg
,
SAMPLER
,
default_args
)
projects/mmdet3d_plugin/datasets/utils.py
0 → 100644
View file @
a9dc86e9
import
copy
import
cv2
import
numpy
as
np
import
torch
from
projects.mmdet3d_plugin.core.box3d
import
*
def
box3d_to_corners
(
box3d
):
if
isinstance
(
box3d
,
torch
.
Tensor
):
box3d
=
box3d
.
detach
().
cpu
().
numpy
()
corners_norm
=
np
.
stack
(
np
.
unravel_index
(
np
.
arange
(
8
),
[
2
]
*
3
),
axis
=
1
)
corners_norm
=
corners_norm
[[
0
,
1
,
3
,
2
,
4
,
5
,
7
,
6
]]
# use relative origin [0.5, 0.5, 0]
corners_norm
=
corners_norm
-
np
.
array
([
0.5
,
0.5
,
0.5
])
corners
=
box3d
[:,
None
,
[
W
,
L
,
H
]]
*
corners_norm
.
reshape
([
1
,
8
,
3
])
# rotate around z axis
rot_cos
=
np
.
cos
(
box3d
[:,
YAW
])
rot_sin
=
np
.
sin
(
box3d
[:,
YAW
])
rot_mat
=
np
.
tile
(
np
.
eye
(
3
)[
None
],
(
box3d
.
shape
[
0
],
1
,
1
))
rot_mat
[:,
0
,
0
]
=
rot_cos
rot_mat
[:,
0
,
1
]
=
-
rot_sin
rot_mat
[:,
1
,
0
]
=
rot_sin
rot_mat
[:,
1
,
1
]
=
rot_cos
corners
=
(
rot_mat
[:,
None
]
@
corners
[...,
None
]).
squeeze
(
axis
=-
1
)
corners
+=
box3d
[:,
None
,
:
3
]
return
corners
def
plot_rect3d_on_img
(
img
,
num_rects
,
rect_corners
,
color
=
(
0
,
255
,
0
),
thickness
=
1
):
"""Plot the boundary lines of 3D rectangular on 2D images.
Args:
img (numpy.array): The numpy array of image.
num_rects (int): Number of 3D rectangulars.
rect_corners (numpy.array): Coordinates of the corners of 3D
rectangulars. Should be in the shape of [num_rect, 8, 2].
color (tuple[int], optional): The color to draw bboxes.
Default: (0, 255, 0).
thickness (int, optional): The thickness of bboxes. Default: 1.
"""
line_indices
=
(
(
0
,
1
),
(
0
,
3
),
(
0
,
4
),
(
1
,
2
),
(
1
,
5
),
(
3
,
2
),
(
3
,
7
),
(
4
,
5
),
(
4
,
7
),
(
2
,
6
),
(
5
,
6
),
(
6
,
7
),
)
h
,
w
=
img
.
shape
[:
2
]
for
i
in
range
(
num_rects
):
corners
=
np
.
clip
(
rect_corners
[
i
],
-
1e4
,
1e5
).
astype
(
np
.
int32
)
for
start
,
end
in
line_indices
:
if
(
(
corners
[
start
,
1
]
>=
h
or
corners
[
start
,
1
]
<
0
)
or
(
corners
[
start
,
0
]
>=
w
or
corners
[
start
,
0
]
<
0
)
)
and
(
(
corners
[
end
,
1
]
>=
h
or
corners
[
end
,
1
]
<
0
)
or
(
corners
[
end
,
0
]
>=
w
or
corners
[
end
,
0
]
<
0
)
):
continue
if
isinstance
(
color
[
0
],
int
):
cv2
.
line
(
img
,
(
corners
[
start
,
0
],
corners
[
start
,
1
]),
(
corners
[
end
,
0
],
corners
[
end
,
1
]),
color
,
thickness
,
cv2
.
LINE_AA
,
)
else
:
cv2
.
line
(
img
,
(
corners
[
start
,
0
],
corners
[
start
,
1
]),
(
corners
[
end
,
0
],
corners
[
end
,
1
]),
color
[
i
],
thickness
,
cv2
.
LINE_AA
,
)
return
img
.
astype
(
np
.
uint8
)
def
draw_lidar_bbox3d_on_img
(
bboxes3d
,
raw_img
,
lidar2img_rt
,
img_metas
=
None
,
color
=
(
0
,
255
,
0
),
thickness
=
1
):
"""Project the 3D bbox on 2D plane and draw on input image.
Args:
bboxes3d (:obj:`LiDARInstance3DBoxes`):
3d bbox in lidar coordinate system to visualize.
raw_img (numpy.array): The numpy array of image.
lidar2img_rt (numpy.array, shape=[4, 4]): The projection matrix
according to the camera intrinsic parameters.
img_metas (dict): Useless here.
color (tuple[int], optional): The color to draw bboxes.
Default: (0, 255, 0).
thickness (int, optional): The thickness of bboxes. Default: 1.
"""
img
=
raw_img
.
copy
()
# corners_3d = bboxes3d.corners
corners_3d
=
box3d_to_corners
(
bboxes3d
)
num_bbox
=
corners_3d
.
shape
[
0
]
pts_4d
=
np
.
concatenate
(
[
corners_3d
.
reshape
(
-
1
,
3
),
np
.
ones
((
num_bbox
*
8
,
1
))],
axis
=-
1
)
lidar2img_rt
=
copy
.
deepcopy
(
lidar2img_rt
).
reshape
(
4
,
4
)
if
isinstance
(
lidar2img_rt
,
torch
.
Tensor
):
lidar2img_rt
=
lidar2img_rt
.
cpu
().
numpy
()
pts_2d
=
pts_4d
@
lidar2img_rt
.
T
pts_2d
[:,
2
]
=
np
.
clip
(
pts_2d
[:,
2
],
a_min
=
1e-5
,
a_max
=
1e5
)
pts_2d
[:,
0
]
/=
pts_2d
[:,
2
]
pts_2d
[:,
1
]
/=
pts_2d
[:,
2
]
imgfov_pts_2d
=
pts_2d
[...,
:
2
].
reshape
(
num_bbox
,
8
,
2
)
return
plot_rect3d_on_img
(
img
,
num_bbox
,
imgfov_pts_2d
,
color
,
thickness
)
def
draw_points_on_img
(
points
,
img
,
lidar2img_rt
,
color
=
(
0
,
255
,
0
),
circle
=
4
):
img
=
img
.
copy
()
N
=
points
.
shape
[
0
]
points
=
points
.
cpu
().
numpy
()
lidar2img_rt
=
copy
.
deepcopy
(
lidar2img_rt
).
reshape
(
4
,
4
)
if
isinstance
(
lidar2img_rt
,
torch
.
Tensor
):
lidar2img_rt
=
lidar2img_rt
.
cpu
().
numpy
()
pts_2d
=
(
np
.
sum
(
points
[:,
:,
None
]
*
lidar2img_rt
[:
3
,
:
3
],
axis
=-
1
)
+
lidar2img_rt
[:
3
,
3
]
)
pts_2d
[...,
2
]
=
np
.
clip
(
pts_2d
[...,
2
],
a_min
=
1e-5
,
a_max
=
1e5
)
pts_2d
=
pts_2d
[...,
:
2
]
/
pts_2d
[...,
2
:
3
]
pts_2d
=
np
.
clip
(
pts_2d
,
-
1e4
,
1e4
).
astype
(
np
.
int32
)
for
i
in
range
(
N
):
for
point
in
pts_2d
[
i
]:
if
isinstance
(
color
[
0
],
int
):
color_tmp
=
color
else
:
color_tmp
=
color
[
i
]
cv2
.
circle
(
img
,
point
.
tolist
(),
circle
,
color_tmp
,
thickness
=-
1
)
return
img
.
astype
(
np
.
uint8
)
def
draw_lidar_bbox3d_on_bev
(
bboxes_3d
,
bev_size
,
bev_range
=
115
,
color
=
(
255
,
0
,
0
),
thickness
=
3
):
if
isinstance
(
bev_size
,
(
list
,
tuple
)):
bev_h
,
bev_w
=
bev_size
else
:
bev_h
,
bev_w
=
bev_size
,
bev_size
bev
=
np
.
zeros
([
bev_h
,
bev_w
,
3
])
marking_color
=
(
127
,
127
,
127
)
bev_resolution
=
bev_range
/
bev_h
for
cir
in
range
(
int
(
bev_range
/
2
/
10
)):
cv2
.
circle
(
bev
,
(
int
(
bev_h
/
2
),
int
(
bev_w
/
2
)),
int
((
cir
+
1
)
*
10
/
bev_resolution
),
marking_color
,
thickness
=
thickness
,
)
cv2
.
line
(
bev
,
(
0
,
int
(
bev_h
/
2
)),
(
bev_w
,
int
(
bev_h
/
2
)),
marking_color
,
)
cv2
.
line
(
bev
,
(
int
(
bev_w
/
2
),
0
),
(
int
(
bev_w
/
2
),
bev_h
),
marking_color
,
)
if
len
(
bboxes_3d
)
!=
0
:
bev_corners
=
box3d_to_corners
(
bboxes_3d
)[:,
[
0
,
3
,
4
,
7
]][
...,
[
0
,
1
]
]
xs
=
bev_corners
[...,
0
]
/
bev_resolution
+
bev_w
/
2
ys
=
-
bev_corners
[...,
1
]
/
bev_resolution
+
bev_h
/
2
for
obj_idx
,
(
x
,
y
)
in
enumerate
(
zip
(
xs
,
ys
)):
for
p1
,
p2
in
((
0
,
1
),
(
0
,
2
),
(
1
,
3
),
(
2
,
3
)):
if
isinstance
(
color
[
0
],
(
list
,
tuple
)):
tmp
=
color
[
obj_idx
]
else
:
tmp
=
color
cv2
.
line
(
bev
,
(
int
(
x
[
p1
]),
int
(
y
[
p1
])),
(
int
(
x
[
p2
]),
int
(
y
[
p2
])),
tmp
,
thickness
=
thickness
,
)
return
bev
.
astype
(
np
.
uint8
)
def
draw_lidar_bbox3d
(
bboxes_3d
,
imgs
,
lidar2imgs
,
color
=
(
255
,
0
,
0
)):
vis_imgs
=
[]
for
i
,
(
img
,
lidar2img
)
in
enumerate
(
zip
(
imgs
,
lidar2imgs
)):
vis_imgs
.
append
(
draw_lidar_bbox3d_on_img
(
bboxes_3d
,
img
,
lidar2img
,
color
=
color
)
)
num_imgs
=
len
(
vis_imgs
)
if
num_imgs
<
4
or
num_imgs
%
2
!=
0
:
vis_imgs
=
np
.
concatenate
(
vis_imgs
,
axis
=
1
)
else
:
vis_imgs
=
np
.
concatenate
([
np
.
concatenate
(
vis_imgs
[:
num_imgs
//
2
],
axis
=
1
),
np
.
concatenate
(
vis_imgs
[
num_imgs
//
2
:],
axis
=
1
)
],
axis
=
0
)
bev
=
draw_lidar_bbox3d_on_bev
(
bboxes_3d
,
vis_imgs
.
shape
[
0
],
color
=
color
)
vis_imgs
=
np
.
concatenate
([
bev
,
vis_imgs
],
axis
=
1
)
return
vis_imgs
projects/mmdet3d_plugin/models/__init__.py
0 → 100644
View file @
a9dc86e9
from
.sparse4d
import
Sparse4D
from
.sparse4d_head
import
Sparse4DHead
from
.blocks
import
(
DeformableFeatureAggregation
,
DenseDepthNet
,
AsymmetricFFN
,
)
from
.instance_bank
import
InstanceBank
from
.detection3d
import
(
SparseBox3DDecoder
,
SparseBox3DTarget
,
SparseBox3DRefinementModule
,
SparseBox3DKeyPointsGenerator
,
SparseBox3DEncoder
,
)
__all__
=
[
"Sparse4D"
,
"Sparse4DHead"
,
"DeformableFeatureAggregation"
,
"DenseDepthNet"
,
"AsymmetricFFN"
,
"InstanceBank"
,
"SparseBox3DDecoder"
,
"SparseBox3DTarget"
,
"SparseBox3DRefinementModule"
,
"SparseBox3DKeyPointsGenerator"
,
"SparseBox3DEncoder"
,
]
projects/mmdet3d_plugin/models/base_target.py
0 → 100644
View file @
a9dc86e9
from
abc
import
ABC
,
abstractmethod
__all__
=
[
"BaseTargetWithDenoising"
]
class
BaseTargetWithDenoising
(
ABC
):
def
__init__
(
self
,
num_dn_groups
=
0
,
num_temp_dn_groups
=
0
):
super
(
BaseTargetWithDenoising
,
self
).
__init__
()
self
.
num_dn_groups
=
num_dn_groups
self
.
num_temp_dn_groups
=
num_temp_dn_groups
self
.
dn_metas
=
None
@
abstractmethod
def
sample
(
self
,
cls_pred
,
box_pred
,
cls_target
,
box_target
):
"""
Perform Hungarian matching between predictions and ground truth,
returning the matched ground truth corresponding to the predictions
along with the corresponding regression weights.
"""
def
get_dn_anchors
(
self
,
cls_target
,
box_target
,
*
args
,
**
kwargs
):
"""
Generate noisy instances for the current frame, with a total of
'self.num_dn_groups' groups.
"""
return
None
def
update_dn
(
self
,
instance_feature
,
anchor
,
*
args
,
**
kwargs
):
"""
Insert the previously saved 'self.dn_metas' into the noisy instances
of the current frame.
"""
def
cache_dn
(
self
,
dn_instance_feature
,
dn_anchor
,
dn_cls_target
,
valid_mask
,
dn_id_target
,
):
"""
Randomly save information for 'self.num_temp_dn_groups' groups of
temporal noisy instances to 'self.dn_metas'.
"""
if
self
.
num_temp_dn_groups
<
0
:
return
self
.
dn_metas
=
dict
(
dn_anchor
=
dn_anchor
[:,
:
self
.
num_temp_dn_groups
])
projects/mmdet3d_plugin/models/blocks.py
0 → 100644
View file @
a9dc86e9
# Copyright (c) Horizon Robotics. All rights reserved.
from
typing
import
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
torch.cuda.amp.autocast_mode
import
autocast
from
mmcv.cnn
import
Linear
,
build_activation_layer
,
build_norm_layer
from
mmcv.runner.base_module
import
Sequential
,
BaseModule
from
mmcv.cnn.bricks.transformer
import
FFN
from
mmcv.utils
import
build_from_cfg
from
mmcv.cnn.bricks.drop
import
build_dropout
from
mmcv.cnn
import
xavier_init
,
constant_init
from
mmcv.cnn.bricks.registry
import
(
ATTENTION
,
PLUGIN_LAYERS
,
FEEDFORWARD_NETWORK
,
)
try
:
from
..ops
import
deformable_aggregation_function
as
DAF
except
:
DAF
=
None
__all__
=
[
"DeformableFeatureAggregation"
,
"DenseDepthNet"
,
"AsymmetricFFN"
,
]
def
linear_relu_ln
(
embed_dims
,
in_loops
,
out_loops
,
input_dims
=
None
):
if
input_dims
is
None
:
input_dims
=
embed_dims
layers
=
[]
for
_
in
range
(
out_loops
):
for
_
in
range
(
in_loops
):
layers
.
append
(
Linear
(
input_dims
,
embed_dims
))
layers
.
append
(
nn
.
ReLU
(
inplace
=
True
))
input_dims
=
embed_dims
layers
.
append
(
nn
.
LayerNorm
(
embed_dims
))
return
layers
@
ATTENTION
.
register_module
()
class
DeformableFeatureAggregation
(
BaseModule
):
def
__init__
(
self
,
embed_dims
:
int
=
256
,
num_groups
:
int
=
8
,
num_levels
:
int
=
4
,
num_cams
:
int
=
6
,
proj_drop
:
float
=
0.0
,
attn_drop
:
float
=
0.0
,
kps_generator
:
dict
=
None
,
temporal_fusion_module
=
None
,
use_temporal_anchor_embed
=
True
,
use_deformable_func
=
False
,
use_camera_embed
=
False
,
residual_mode
=
"add"
,
):
super
(
DeformableFeatureAggregation
,
self
).
__init__
()
if
embed_dims
%
num_groups
!=
0
:
raise
ValueError
(
f
"embed_dims must be divisible by num_groups, "
f
"but got
{
embed_dims
}
and
{
num_groups
}
"
)
self
.
group_dims
=
int
(
embed_dims
/
num_groups
)
self
.
embed_dims
=
embed_dims
self
.
num_levels
=
num_levels
self
.
num_groups
=
num_groups
self
.
num_cams
=
num_cams
self
.
use_temporal_anchor_embed
=
use_temporal_anchor_embed
if
use_deformable_func
:
assert
DAF
is
not
None
,
"deformable_aggregation needs to be set up."
self
.
use_deformable_func
=
use_deformable_func
self
.
attn_drop
=
attn_drop
self
.
residual_mode
=
residual_mode
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
kps_generator
[
"embed_dims"
]
=
embed_dims
self
.
kps_generator
=
build_from_cfg
(
kps_generator
,
PLUGIN_LAYERS
)
self
.
num_pts
=
self
.
kps_generator
.
num_pts
if
temporal_fusion_module
is
not
None
:
if
"embed_dims"
not
in
temporal_fusion_module
:
temporal_fusion_module
[
"embed_dims"
]
=
embed_dims
self
.
temp_module
=
build_from_cfg
(
temporal_fusion_module
,
PLUGIN_LAYERS
)
else
:
self
.
temp_module
=
None
self
.
output_proj
=
Linear
(
embed_dims
,
embed_dims
)
if
use_camera_embed
:
self
.
camera_encoder
=
Sequential
(
*
linear_relu_ln
(
embed_dims
,
1
,
2
,
12
)
)
self
.
weights_fc
=
Linear
(
embed_dims
,
num_groups
*
num_levels
*
self
.
num_pts
)
else
:
self
.
camera_encoder
=
None
self
.
weights_fc
=
Linear
(
embed_dims
,
num_groups
*
num_cams
*
num_levels
*
self
.
num_pts
)
def
init_weight
(
self
):
constant_init
(
self
.
weights_fc
,
val
=
0.0
,
bias
=
0.0
)
xavier_init
(
self
.
output_proj
,
distribution
=
"uniform"
,
bias
=
0.0
)
def
forward
(
self
,
instance_feature
:
torch
.
Tensor
,
anchor
:
torch
.
Tensor
,
anchor_embed
:
torch
.
Tensor
,
feature_maps
:
List
[
torch
.
Tensor
],
metas
:
dict
,
**
kwargs
:
dict
,
):
bs
,
num_anchor
=
instance_feature
.
shape
[:
2
]
key_points
=
self
.
kps_generator
(
anchor
,
instance_feature
)
weights
=
self
.
_get_weights
(
instance_feature
,
anchor_embed
,
metas
)
if
self
.
use_deformable_func
:
points_2d
=
(
self
.
project_points
(
key_points
,
metas
[
"projection_mat"
],
metas
.
get
(
"image_wh"
),
)
.
permute
(
0
,
2
,
3
,
1
,
4
)
.
reshape
(
bs
,
num_anchor
,
self
.
num_pts
,
self
.
num_cams
,
2
)
)
weights
=
(
weights
.
permute
(
0
,
1
,
4
,
2
,
3
,
5
)
.
contiguous
()
.
reshape
(
bs
,
num_anchor
,
self
.
num_pts
,
self
.
num_cams
,
self
.
num_levels
,
self
.
num_groups
,
)
)
features
=
DAF
(
*
feature_maps
,
points_2d
,
weights
).
reshape
(
bs
,
num_anchor
,
self
.
embed_dims
)
else
:
features
=
self
.
feature_sampling
(
feature_maps
,
key_points
,
metas
[
"projection_mat"
],
metas
.
get
(
"image_wh"
),
)
features
=
self
.
multi_view_level_fusion
(
features
,
weights
)
features
=
features
.
sum
(
dim
=
2
)
# fuse multi-point features
output
=
self
.
proj_drop
(
self
.
output_proj
(
features
))
if
self
.
residual_mode
==
"add"
:
output
=
output
+
instance_feature
elif
self
.
residual_mode
==
"cat"
:
output
=
torch
.
cat
([
output
,
instance_feature
],
dim
=-
1
)
return
output
def
_get_weights
(
self
,
instance_feature
,
anchor_embed
,
metas
=
None
):
bs
,
num_anchor
=
instance_feature
.
shape
[:
2
]
feature
=
instance_feature
+
anchor_embed
if
self
.
camera_encoder
is
not
None
:
camera_embed
=
self
.
camera_encoder
(
metas
[
"projection_mat"
][:,
:,
:
3
].
reshape
(
bs
,
self
.
num_cams
,
-
1
)
)
feature
=
feature
[:,
:,
None
]
+
camera_embed
[:,
None
]
weights
=
(
self
.
weights_fc
(
feature
)
.
reshape
(
bs
,
num_anchor
,
-
1
,
self
.
num_groups
)
.
softmax
(
dim
=-
2
)
.
reshape
(
bs
,
num_anchor
,
self
.
num_cams
,
self
.
num_levels
,
self
.
num_pts
,
self
.
num_groups
,
)
)
if
self
.
training
and
self
.
attn_drop
>
0
:
mask
=
torch
.
rand
(
bs
,
num_anchor
,
self
.
num_cams
,
1
,
self
.
num_pts
,
1
)
mask
=
mask
.
to
(
device
=
weights
.
device
,
dtype
=
weights
.
dtype
)
weights
=
((
mask
>
self
.
attn_drop
)
*
weights
)
/
(
1
-
self
.
attn_drop
)
return
weights
@
staticmethod
def
project_points
(
key_points
,
projection_mat
,
image_wh
=
None
):
bs
,
num_anchor
,
num_pts
=
key_points
.
shape
[:
3
]
pts_extend
=
torch
.
cat
(
[
key_points
,
torch
.
ones_like
(
key_points
[...,
:
1
])],
dim
=-
1
)
points_2d
=
torch
.
matmul
(
projection_mat
[:,
:,
None
,
None
],
pts_extend
[:,
None
,
...,
None
]
).
squeeze
(
-
1
)
points_2d
=
points_2d
[...,
:
2
]
/
torch
.
clamp
(
points_2d
[...,
2
:
3
],
min
=
1e-5
)
if
image_wh
is
not
None
:
points_2d
=
points_2d
/
image_wh
[:,
:,
None
,
None
]
return
points_2d
@
staticmethod
def
feature_sampling
(
feature_maps
:
List
[
torch
.
Tensor
],
key_points
:
torch
.
Tensor
,
projection_mat
:
torch
.
Tensor
,
image_wh
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
num_levels
=
len
(
feature_maps
)
num_cams
=
feature_maps
[
0
].
shape
[
1
]
bs
,
num_anchor
,
num_pts
=
key_points
.
shape
[:
3
]
points_2d
=
DeformableFeatureAggregation
.
project_points
(
key_points
,
projection_mat
,
image_wh
)
points_2d
=
points_2d
*
2
-
1
points_2d
=
points_2d
.
flatten
(
end_dim
=
1
)
features
=
[]
for
fm
in
feature_maps
:
features
.
append
(
torch
.
nn
.
functional
.
grid_sample
(
fm
.
flatten
(
end_dim
=
1
),
points_2d
)
)
features
=
torch
.
stack
(
features
,
dim
=
1
)
features
=
features
.
reshape
(
bs
,
num_cams
,
num_levels
,
-
1
,
num_anchor
,
num_pts
).
permute
(
0
,
4
,
1
,
2
,
5
,
3
)
# bs, num_anchor, num_cams, num_levels, num_pts, embed_dims
return
features
def
multi_view_level_fusion
(
self
,
features
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
):
bs
,
num_anchor
=
weights
.
shape
[:
2
]
features
=
weights
[...,
None
]
*
features
.
reshape
(
features
.
shape
[:
-
1
]
+
(
self
.
num_groups
,
self
.
group_dims
)
)
features
=
features
.
sum
(
dim
=
2
).
sum
(
dim
=
2
)
features
=
features
.
reshape
(
bs
,
num_anchor
,
self
.
num_pts
,
self
.
embed_dims
)
return
features
@
PLUGIN_LAYERS
.
register_module
()
class
DenseDepthNet
(
BaseModule
):
def
__init__
(
self
,
embed_dims
=
256
,
num_depth_layers
=
1
,
equal_focal
=
100
,
max_depth
=
60
,
loss_weight
=
1.0
,
):
super
().
__init__
()
self
.
embed_dims
=
embed_dims
self
.
equal_focal
=
equal_focal
self
.
num_depth_layers
=
num_depth_layers
self
.
max_depth
=
max_depth
self
.
loss_weight
=
loss_weight
self
.
depth_layers
=
nn
.
ModuleList
()
for
i
in
range
(
num_depth_layers
):
self
.
depth_layers
.
append
(
nn
.
Conv2d
(
embed_dims
,
1
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
)
def
forward
(
self
,
feature_maps
,
focal
=
None
,
gt_depths
=
None
):
if
focal
is
None
:
focal
=
self
.
equal_focal
else
:
focal
=
focal
.
reshape
(
-
1
)
depths
=
[]
for
i
,
feat
in
enumerate
(
feature_maps
[:
self
.
num_depth_layers
]):
depth
=
self
.
depth_layers
[
i
](
feat
.
flatten
(
end_dim
=
1
).
float
()).
exp
()
depth
=
depth
.
transpose
(
0
,
-
1
)
*
focal
/
self
.
equal_focal
depth
=
depth
.
transpose
(
0
,
-
1
)
depths
.
append
(
depth
)
if
gt_depths
is
not
None
and
self
.
training
:
loss
=
self
.
loss
(
depths
,
gt_depths
)
return
loss
return
depths
def
loss
(
self
,
depth_preds
,
gt_depths
):
loss
=
0.0
for
pred
,
gt
in
zip
(
depth_preds
,
gt_depths
):
pred
=
pred
.
permute
(
0
,
2
,
3
,
1
).
contiguous
().
reshape
(
-
1
)
gt
=
gt
.
reshape
(
-
1
)
fg_mask
=
torch
.
logical_and
(
gt
>
0.0
,
torch
.
logical_not
(
torch
.
isnan
(
pred
))
)
gt
=
gt
[
fg_mask
]
pred
=
pred
[
fg_mask
]
pred
=
torch
.
clip
(
pred
,
0.0
,
self
.
max_depth
)
with
autocast
(
enabled
=
False
):
error
=
torch
.
abs
(
pred
-
gt
).
sum
()
_loss
=
(
error
/
max
(
1.0
,
len
(
gt
)
*
len
(
depth_preds
))
*
self
.
loss_weight
)
loss
=
loss
+
_loss
return
loss
@
FEEDFORWARD_NETWORK
.
register_module
()
class
AsymmetricFFN
(
BaseModule
):
def
__init__
(
self
,
in_channels
=
None
,
pre_norm
=
None
,
embed_dims
=
256
,
feedforward_channels
=
1024
,
num_fcs
=
2
,
act_cfg
=
dict
(
type
=
"ReLU"
,
inplace
=
True
),
ffn_drop
=
0.0
,
dropout_layer
=
None
,
add_identity
=
True
,
init_cfg
=
None
,
**
kwargs
,
):
super
(
AsymmetricFFN
,
self
).
__init__
(
init_cfg
)
assert
num_fcs
>=
2
,
(
"num_fcs should be no less "
f
"than 2. got
{
num_fcs
}
."
)
self
.
in_channels
=
in_channels
self
.
pre_norm
=
pre_norm
self
.
embed_dims
=
embed_dims
self
.
feedforward_channels
=
feedforward_channels
self
.
num_fcs
=
num_fcs
self
.
act_cfg
=
act_cfg
self
.
activate
=
build_activation_layer
(
act_cfg
)
layers
=
[]
if
in_channels
is
None
:
in_channels
=
embed_dims
if
pre_norm
is
not
None
:
self
.
pre_norm
=
build_norm_layer
(
pre_norm
,
in_channels
)[
1
]
for
_
in
range
(
num_fcs
-
1
):
layers
.
append
(
Sequential
(
Linear
(
in_channels
,
feedforward_channels
),
self
.
activate
,
nn
.
Dropout
(
ffn_drop
),
)
)
in_channels
=
feedforward_channels
layers
.
append
(
Linear
(
feedforward_channels
,
embed_dims
))
layers
.
append
(
nn
.
Dropout
(
ffn_drop
))
self
.
layers
=
Sequential
(
*
layers
)
self
.
dropout_layer
=
(
build_dropout
(
dropout_layer
)
if
dropout_layer
else
torch
.
nn
.
Identity
()
)
self
.
add_identity
=
add_identity
if
self
.
add_identity
:
self
.
identity_fc
=
(
torch
.
nn
.
Identity
()
if
in_channels
==
embed_dims
else
Linear
(
self
.
in_channels
,
embed_dims
)
)
def
forward
(
self
,
x
,
identity
=
None
):
if
self
.
pre_norm
is
not
None
:
x
=
self
.
pre_norm
(
x
)
out
=
self
.
layers
(
x
)
if
not
self
.
add_identity
:
return
self
.
dropout_layer
(
out
)
if
identity
is
None
:
identity
=
x
identity
=
self
.
identity_fc
(
identity
)
return
identity
+
self
.
dropout_layer
(
out
)
projects/mmdet3d_plugin/models/detection3d/__init__.py
0 → 100644
View file @
a9dc86e9
from
.decoder
import
SparseBox3DDecoder
from
.target
import
SparseBox3DTarget
from
.detection3d_blocks
import
(
SparseBox3DRefinementModule
,
SparseBox3DKeyPointsGenerator
,
SparseBox3DEncoder
,
)
from
.losses
import
SparseBox3DLoss
projects/mmdet3d_plugin/models/detection3d/decoder.py
0 → 100644
View file @
a9dc86e9
# Copyright (c) Horizon Robotics. All rights reserved.
from
typing
import
Optional
import
torch
from
mmdet.core.bbox.builder
import
BBOX_CODERS
from
projects.mmdet3d_plugin.core.box3d
import
*
@
BBOX_CODERS
.
register_module
()
class
SparseBox3DDecoder
(
object
):
def
__init__
(
self
,
num_output
:
int
=
300
,
score_threshold
:
Optional
[
float
]
=
None
,
sorted
:
bool
=
True
,
):
super
(
SparseBox3DDecoder
,
self
).
__init__
()
self
.
num_output
=
num_output
self
.
score_threshold
=
score_threshold
self
.
sorted
=
sorted
def
decode_box
(
self
,
box
):
yaw
=
torch
.
atan2
(
box
[:,
SIN_YAW
],
box
[:,
COS_YAW
])
box
=
torch
.
cat
(
[
box
[:,
[
X
,
Y
,
Z
]],
box
[:,
[
W
,
L
,
H
]].
exp
(),
yaw
[:,
None
],
box
[:,
VX
:],
],
dim
=-
1
,
)
return
box
def
decode
(
self
,
cls_scores
,
box_preds
,
instance_id
=
None
,
qulity
=
None
,
output_idx
=-
1
,
):
squeeze_cls
=
instance_id
is
not
None
cls_scores
=
cls_scores
[
output_idx
].
sigmoid
()
if
squeeze_cls
:
cls_scores
,
cls_ids
=
cls_scores
.
max
(
dim
=-
1
)
cls_scores
=
cls_scores
.
unsqueeze
(
dim
=-
1
)
box_preds
=
box_preds
[
output_idx
]
bs
,
num_pred
,
num_cls
=
cls_scores
.
shape
cls_scores
,
indices
=
cls_scores
.
flatten
(
start_dim
=
1
).
topk
(
self
.
num_output
,
dim
=
1
,
sorted
=
self
.
sorted
)
if
not
squeeze_cls
:
cls_ids
=
indices
%
num_cls
if
self
.
score_threshold
is
not
None
:
mask
=
cls_scores
>=
self
.
score_threshold
if
qulity
is
not
None
:
centerness
=
qulity
[
output_idx
][...,
CNS
]
centerness
=
torch
.
gather
(
centerness
,
1
,
indices
//
num_cls
)
cls_scores_origin
=
cls_scores
.
clone
()
cls_scores
*=
centerness
.
sigmoid
()
cls_scores
,
idx
=
torch
.
sort
(
cls_scores
,
dim
=
1
,
descending
=
True
)
if
not
squeeze_cls
:
cls_ids
=
torch
.
gather
(
cls_ids
,
1
,
idx
)
if
self
.
score_threshold
is
not
None
:
mask
=
torch
.
gather
(
mask
,
1
,
idx
)
indices
=
torch
.
gather
(
indices
,
1
,
idx
)
output
=
[]
for
i
in
range
(
bs
):
category_ids
=
cls_ids
[
i
]
if
squeeze_cls
:
category_ids
=
category_ids
[
indices
[
i
]]
scores
=
cls_scores
[
i
]
box
=
box_preds
[
i
,
indices
[
i
]
//
num_cls
]
if
self
.
score_threshold
is
not
None
:
category_ids
=
category_ids
[
mask
[
i
]]
scores
=
scores
[
mask
[
i
]]
box
=
box
[
mask
[
i
]]
if
qulity
is
not
None
:
scores_origin
=
cls_scores_origin
[
i
]
if
self
.
score_threshold
is
not
None
:
scores_origin
=
scores_origin
[
mask
[
i
]]
box
=
self
.
decode_box
(
box
)
output
.
append
(
{
"boxes_3d"
:
box
.
cpu
(),
"scores_3d"
:
scores
.
cpu
(),
"labels_3d"
:
category_ids
.
cpu
(),
}
)
if
qulity
is
not
None
:
output
[
-
1
][
"cls_scores"
]
=
scores_origin
.
cpu
()
if
instance_id
is
not
None
:
ids
=
instance_id
[
i
,
indices
[
i
]]
if
self
.
score_threshold
is
not
None
:
ids
=
ids
[
mask
[
i
]]
output
[
-
1
][
"instance_ids"
]
=
ids
return
output
projects/mmdet3d_plugin/models/detection3d/detection3d_blocks.py
0 → 100644
View file @
a9dc86e9
import
torch
import
torch.nn
as
nn
import
numpy
as
np
from
mmcv.cnn
import
Linear
,
Scale
,
bias_init_with_prob
from
mmcv.runner.base_module
import
Sequential
,
BaseModule
from
mmcv.cnn
import
xavier_init
from
mmcv.cnn.bricks.registry
import
(
PLUGIN_LAYERS
,
POSITIONAL_ENCODING
,
)
from
projects.mmdet3d_plugin.core.box3d
import
*
from
..blocks
import
linear_relu_ln
__all__
=
[
"SparseBox3DRefinementModule"
,
"SparseBox3DKeyPointsGenerator"
,
"SparseBox3DEncoder"
,
]
@
POSITIONAL_ENCODING
.
register_module
()
class
SparseBox3DEncoder
(
BaseModule
):
def
__init__
(
self
,
embed_dims
,
vel_dims
=
3
,
mode
=
"add"
,
output_fc
=
True
,
in_loops
=
1
,
out_loops
=
2
,
):
super
().
__init__
()
assert
mode
in
[
"add"
,
"cat"
]
self
.
embed_dims
=
embed_dims
self
.
vel_dims
=
vel_dims
self
.
mode
=
mode
def
embedding_layer
(
input_dims
,
output_dims
):
return
nn
.
Sequential
(
*
linear_relu_ln
(
output_dims
,
in_loops
,
out_loops
,
input_dims
)
)
if
not
isinstance
(
embed_dims
,
(
list
,
tuple
)):
embed_dims
=
[
embed_dims
]
*
5
self
.
pos_fc
=
embedding_layer
(
3
,
embed_dims
[
0
])
self
.
size_fc
=
embedding_layer
(
3
,
embed_dims
[
1
])
self
.
yaw_fc
=
embedding_layer
(
2
,
embed_dims
[
2
])
if
vel_dims
>
0
:
self
.
vel_fc
=
embedding_layer
(
self
.
vel_dims
,
embed_dims
[
3
])
if
output_fc
:
self
.
output_fc
=
embedding_layer
(
embed_dims
[
-
1
],
embed_dims
[
-
1
])
else
:
self
.
output_fc
=
None
def
forward
(
self
,
box_3d
:
torch
.
Tensor
):
pos_feat
=
self
.
pos_fc
(
box_3d
[...,
[
X
,
Y
,
Z
]])
size_feat
=
self
.
size_fc
(
box_3d
[...,
[
W
,
L
,
H
]])
yaw_feat
=
self
.
yaw_fc
(
box_3d
[...,
[
SIN_YAW
,
COS_YAW
]])
if
self
.
mode
==
"add"
:
output
=
pos_feat
+
size_feat
+
yaw_feat
elif
self
.
mode
==
"cat"
:
output
=
torch
.
cat
([
pos_feat
,
size_feat
,
yaw_feat
],
dim
=-
1
)
if
self
.
vel_dims
>
0
:
vel_feat
=
self
.
vel_fc
(
box_3d
[...,
VX
:
VX
+
self
.
vel_dims
])
if
self
.
mode
==
"add"
:
output
=
output
+
vel_feat
elif
self
.
mode
==
"cat"
:
output
=
torch
.
cat
([
output
,
vel_feat
],
dim
=-
1
)
if
self
.
output_fc
is
not
None
:
output
=
self
.
output_fc
(
output
)
return
output
@
PLUGIN_LAYERS
.
register_module
()
class
SparseBox3DRefinementModule
(
BaseModule
):
def
__init__
(
self
,
embed_dims
=
256
,
output_dim
=
11
,
num_cls
=
10
,
normalize_yaw
=
False
,
refine_yaw
=
False
,
with_cls_branch
=
True
,
with_quality_estimation
=
False
,
):
super
(
SparseBox3DRefinementModule
,
self
).
__init__
()
self
.
embed_dims
=
embed_dims
self
.
output_dim
=
output_dim
self
.
num_cls
=
num_cls
self
.
normalize_yaw
=
normalize_yaw
self
.
refine_yaw
=
refine_yaw
self
.
refine_state
=
[
X
,
Y
,
Z
,
W
,
L
,
H
]
if
self
.
refine_yaw
:
self
.
refine_state
+=
[
SIN_YAW
,
COS_YAW
]
self
.
layers
=
nn
.
Sequential
(
*
linear_relu_ln
(
embed_dims
,
2
,
2
),
Linear
(
self
.
embed_dims
,
self
.
output_dim
),
Scale
([
1.0
]
*
self
.
output_dim
),
)
self
.
with_cls_branch
=
with_cls_branch
if
with_cls_branch
:
self
.
cls_layers
=
nn
.
Sequential
(
*
linear_relu_ln
(
embed_dims
,
1
,
2
),
Linear
(
self
.
embed_dims
,
self
.
num_cls
),
)
self
.
with_quality_estimation
=
with_quality_estimation
if
with_quality_estimation
:
self
.
quality_layers
=
nn
.
Sequential
(
*
linear_relu_ln
(
embed_dims
,
1
,
2
),
Linear
(
self
.
embed_dims
,
2
),
)
def
init_weight
(
self
):
if
self
.
with_cls_branch
:
bias_init
=
bias_init_with_prob
(
0.01
)
nn
.
init
.
constant_
(
self
.
cls_layers
[
-
1
].
bias
,
bias_init
)
def
forward
(
self
,
instance_feature
:
torch
.
Tensor
,
anchor
:
torch
.
Tensor
,
anchor_embed
:
torch
.
Tensor
,
time_interval
:
torch
.
Tensor
=
1.0
,
return_cls
=
True
,
):
feature
=
instance_feature
+
anchor_embed
output
=
self
.
layers
(
feature
)
output
[...,
self
.
refine_state
]
=
(
output
[...,
self
.
refine_state
]
+
anchor
[...,
self
.
refine_state
]
)
if
self
.
normalize_yaw
:
output
[...,
[
SIN_YAW
,
COS_YAW
]]
=
torch
.
nn
.
functional
.
normalize
(
output
[...,
[
SIN_YAW
,
COS_YAW
]],
dim
=-
1
)
if
self
.
output_dim
>
8
:
if
not
isinstance
(
time_interval
,
torch
.
Tensor
):
time_interval
=
instance_feature
.
new_tensor
(
time_interval
)
translation
=
torch
.
transpose
(
output
[...,
VX
:],
0
,
-
1
)
velocity
=
torch
.
transpose
(
translation
/
time_interval
,
0
,
-
1
)
output
[...,
VX
:]
=
velocity
+
anchor
[...,
VX
:]
if
return_cls
:
assert
self
.
with_cls_branch
,
"Without classification layers !!!"
cls
=
self
.
cls_layers
(
instance_feature
)
else
:
cls
=
None
if
return_cls
and
self
.
with_quality_estimation
:
quality
=
self
.
quality_layers
(
feature
)
else
:
quality
=
None
return
output
,
cls
,
quality
@
PLUGIN_LAYERS
.
register_module
()
class
SparseBox3DKeyPointsGenerator
(
BaseModule
):
def
__init__
(
self
,
embed_dims
=
256
,
num_learnable_pts
=
0
,
fix_scale
=
None
,
):
super
(
SparseBox3DKeyPointsGenerator
,
self
).
__init__
()
self
.
embed_dims
=
embed_dims
self
.
num_learnable_pts
=
num_learnable_pts
if
fix_scale
is
None
:
fix_scale
=
((
0.0
,
0.0
,
0.0
),)
self
.
fix_scale
=
nn
.
Parameter
(
torch
.
tensor
(
fix_scale
),
requires_grad
=
False
)
self
.
num_pts
=
len
(
self
.
fix_scale
)
+
num_learnable_pts
if
num_learnable_pts
>
0
:
self
.
learnable_fc
=
Linear
(
self
.
embed_dims
,
num_learnable_pts
*
3
)
def
init_weight
(
self
):
if
self
.
num_learnable_pts
>
0
:
xavier_init
(
self
.
learnable_fc
,
distribution
=
"uniform"
,
bias
=
0.0
)
def
forward
(
self
,
anchor
,
instance_feature
=
None
,
T_cur2temp_list
=
None
,
cur_timestamp
=
None
,
temp_timestamps
=
None
,
):
bs
,
num_anchor
=
anchor
.
shape
[:
2
]
size
=
anchor
[...,
None
,
[
W
,
L
,
H
]].
exp
()
key_points
=
self
.
fix_scale
*
size
if
self
.
num_learnable_pts
>
0
and
instance_feature
is
not
None
:
learnable_scale
=
(
self
.
learnable_fc
(
instance_feature
)
.
reshape
(
bs
,
num_anchor
,
self
.
num_learnable_pts
,
3
)
.
sigmoid
()
-
0.5
)
key_points
=
torch
.
cat
(
[
key_points
,
learnable_scale
*
size
],
dim
=-
2
)
rotation_mat
=
anchor
.
new_zeros
([
bs
,
num_anchor
,
3
,
3
])
rotation_mat
[:,
:,
0
,
0
]
=
anchor
[:,
:,
COS_YAW
]
rotation_mat
[:,
:,
0
,
1
]
=
-
anchor
[:,
:,
SIN_YAW
]
rotation_mat
[:,
:,
1
,
0
]
=
anchor
[:,
:,
SIN_YAW
]
rotation_mat
[:,
:,
1
,
1
]
=
anchor
[:,
:,
COS_YAW
]
rotation_mat
[:,
:,
2
,
2
]
=
1
key_points
=
torch
.
matmul
(
rotation_mat
[:,
:,
None
],
key_points
[...,
None
]
).
squeeze
(
-
1
)
key_points
=
key_points
+
anchor
[...,
None
,
[
X
,
Y
,
Z
]]
if
(
cur_timestamp
is
None
or
temp_timestamps
is
None
or
T_cur2temp_list
is
None
or
len
(
temp_timestamps
)
==
0
):
return
key_points
temp_key_points_list
=
[]
velocity
=
anchor
[...,
VX
:]
for
i
,
t_time
in
enumerate
(
temp_timestamps
):
time_interval
=
cur_timestamp
-
t_time
translation
=
(
velocity
*
time_interval
.
to
(
dtype
=
velocity
.
dtype
)[:,
None
,
None
]
)
temp_key_points
=
key_points
-
translation
[:,
:,
None
]
T_cur2temp
=
T_cur2temp_list
[
i
].
to
(
dtype
=
key_points
.
dtype
)
temp_key_points
=
(
T_cur2temp
[:,
None
,
None
,
:
3
]
@
torch
.
cat
(
[
temp_key_points
,
torch
.
ones_like
(
temp_key_points
[...,
:
1
]),
],
dim
=-
1
,
).
unsqueeze
(
-
1
)
)
temp_key_points
=
temp_key_points
.
squeeze
(
-
1
)
temp_key_points_list
.
append
(
temp_key_points
)
return
key_points
,
temp_key_points_list
@
staticmethod
def
anchor_projection
(
anchor
,
T_src2dst_list
,
src_timestamp
=
None
,
dst_timestamps
=
None
,
time_intervals
=
None
,
):
dst_anchors
=
[]
for
i
in
range
(
len
(
T_src2dst_list
)):
vel
=
anchor
[...,
VX
:]
vel_dim
=
vel
.
shape
[
-
1
]
T_src2dst
=
torch
.
unsqueeze
(
T_src2dst_list
[
i
].
to
(
dtype
=
anchor
.
dtype
),
dim
=
1
)
center
=
anchor
[...,
[
X
,
Y
,
Z
]]
if
time_intervals
is
not
None
:
time_interval
=
time_intervals
[
i
]
elif
src_timestamp
is
not
None
and
dst_timestamps
is
not
None
:
time_interval
=
(
src_timestamp
-
dst_timestamps
[
i
]).
to
(
dtype
=
vel
.
dtype
)
else
:
time_interval
=
None
if
time_interval
is
not
None
:
translation
=
vel
.
transpose
(
0
,
-
1
)
*
time_interval
translation
=
translation
.
transpose
(
0
,
-
1
)
center
=
center
-
translation
center
=
(
torch
.
matmul
(
T_src2dst
[...,
:
3
,
:
3
],
center
[...,
None
]
).
squeeze
(
dim
=-
1
)
+
T_src2dst
[...,
:
3
,
3
]
)
size
=
anchor
[...,
[
W
,
L
,
H
]]
yaw
=
torch
.
matmul
(
T_src2dst
[...,
:
2
,
:
2
],
anchor
[...,
[
COS_YAW
,
SIN_YAW
],
None
],
).
squeeze
(
-
1
)
vel
=
torch
.
matmul
(
T_src2dst
[...,
:
vel_dim
,
:
vel_dim
],
vel
[...,
None
]
).
squeeze
(
-
1
)
dst_anchor
=
torch
.
cat
([
center
,
size
,
yaw
,
vel
],
dim
=-
1
)
# TODO: Fix bug
# index = [X, Y, Z, W, L, H, COS_YAW, SIN_YAW] + [VX, VY, VZ][:vel_dim]
# index = torch.tensor(index, device=dst_anchor.device)
# index = torch.argsort(index)
# dst_anchor = dst_anchor.index_select(dim=-1, index=index)
dst_anchors
.
append
(
dst_anchor
)
return
dst_anchors
@
staticmethod
def
distance
(
anchor
):
return
torch
.
norm
(
anchor
[...,
:
2
],
p
=
2
,
dim
=-
1
)
projects/mmdet3d_plugin/models/detection3d/losses.py
0 → 100644
View file @
a9dc86e9
import
torch
import
torch.nn
as
nn
from
mmcv.utils
import
build_from_cfg
from
mmdet.models.builder
import
LOSSES
from
projects.mmdet3d_plugin.core.box3d
import
*
@
LOSSES
.
register_module
()
class
SparseBox3DLoss
(
nn
.
Module
):
def
__init__
(
self
,
loss_box
,
loss_centerness
=
None
,
loss_yawness
=
None
,
cls_allow_reverse
=
None
,
):
super
().
__init__
()
def
build
(
cfg
,
registry
):
if
cfg
is
None
:
return
None
return
build_from_cfg
(
cfg
,
registry
)
self
.
loss_box
=
build
(
loss_box
,
LOSSES
)
self
.
loss_cns
=
build
(
loss_centerness
,
LOSSES
)
self
.
loss_yns
=
build
(
loss_yawness
,
LOSSES
)
self
.
cls_allow_reverse
=
cls_allow_reverse
def
forward
(
self
,
box
,
box_target
,
weight
=
None
,
avg_factor
=
None
,
suffix
=
""
,
quality
=
None
,
cls_target
=
None
,
**
kwargs
,
):
# Some categories do not distinguish between positive and negative
# directions. For example, barrier in nuScenes dataset.
if
self
.
cls_allow_reverse
is
not
None
and
cls_target
is
not
None
:
if_reverse
=
(
torch
.
nn
.
functional
.
cosine_similarity
(
box_target
[...,
[
SIN_YAW
,
COS_YAW
]],
box
[...,
[
SIN_YAW
,
COS_YAW
]],
dim
=-
1
,
)
<
0
)
if_reverse
=
(
torch
.
isin
(
cls_target
,
cls_target
.
new_tensor
(
self
.
cls_allow_reverse
)
)
&
if_reverse
)
box_target
[...,
[
SIN_YAW
,
COS_YAW
]]
=
torch
.
where
(
if_reverse
[...,
None
],
-
box_target
[...,
[
SIN_YAW
,
COS_YAW
]],
box_target
[...,
[
SIN_YAW
,
COS_YAW
]],
)
output
=
{}
box_loss
=
self
.
loss_box
(
box
,
box_target
,
weight
=
weight
,
avg_factor
=
avg_factor
)
output
[
f
"loss_box
{
suffix
}
"
]
=
box_loss
if
quality
is
not
None
:
cns
=
quality
[...,
CNS
]
yns
=
quality
[...,
YNS
].
sigmoid
()
cns_target
=
torch
.
norm
(
box_target
[...,
[
X
,
Y
,
Z
]]
-
box
[...,
[
X
,
Y
,
Z
]],
p
=
2
,
dim
=-
1
)
cns_target
=
torch
.
exp
(
-
cns_target
)
cns_loss
=
self
.
loss_cns
(
cns
,
cns_target
,
avg_factor
=
avg_factor
)
output
[
f
"loss_cns
{
suffix
}
"
]
=
cns_loss
yns_target
=
(
torch
.
nn
.
functional
.
cosine_similarity
(
box_target
[...,
[
SIN_YAW
,
COS_YAW
]],
box
[...,
[
SIN_YAW
,
COS_YAW
]],
dim
=-
1
,
)
>
0
)
yns_target
=
yns_target
.
float
()
yns_loss
=
self
.
loss_yns
(
yns
,
yns_target
,
avg_factor
=
avg_factor
)
output
[
f
"loss_yns
{
suffix
}
"
]
=
yns_loss
return
output
projects/mmdet3d_plugin/models/detection3d/target.py
0 → 100644
View file @
a9dc86e9
import
torch
import
numpy
as
np
import
torch.nn.functional
as
F
from
scipy.optimize
import
linear_sum_assignment
from
mmdet.core.bbox.builder
import
BBOX_SAMPLERS
from
projects.mmdet3d_plugin.core.box3d
import
*
from
..base_target
import
BaseTargetWithDenoising
__all__
=
[
"SparseBox3DTarget"
]
@
BBOX_SAMPLERS
.
register_module
()
class
SparseBox3DTarget
(
BaseTargetWithDenoising
):
def
__init__
(
self
,
cls_weight
=
2.0
,
alpha
=
0.25
,
gamma
=
2
,
eps
=
1e-12
,
box_weight
=
0.25
,
reg_weights
=
None
,
cls_wise_reg_weights
=
None
,
num_dn_groups
=
0
,
dn_noise_scale
=
0.5
,
max_dn_gt
=
32
,
add_neg_dn
=
True
,
num_temp_dn_groups
=
0
,
):
super
(
SparseBox3DTarget
,
self
).
__init__
(
num_dn_groups
,
num_temp_dn_groups
)
self
.
cls_weight
=
cls_weight
self
.
box_weight
=
box_weight
self
.
alpha
=
alpha
self
.
gamma
=
gamma
self
.
eps
=
eps
self
.
reg_weights
=
reg_weights
if
self
.
reg_weights
is
None
:
self
.
reg_weights
=
[
1.0
]
*
8
+
[
0.0
]
*
2
self
.
cls_wise_reg_weights
=
cls_wise_reg_weights
self
.
dn_noise_scale
=
dn_noise_scale
self
.
max_dn_gt
=
max_dn_gt
self
.
add_neg_dn
=
add_neg_dn
def
encode_reg_target
(
self
,
box_target
,
device
=
None
):
outputs
=
[]
for
box
in
box_target
:
output
=
torch
.
cat
(
[
box
[...,
[
X
,
Y
,
Z
]],
box
[...,
[
W
,
L
,
H
]].
log
(),
torch
.
sin
(
box
[...,
YAW
]).
unsqueeze
(
-
1
),
torch
.
cos
(
box
[...,
YAW
]).
unsqueeze
(
-
1
),
box
[...,
YAW
+
1
:],
],
dim
=-
1
,
)
if
device
is
not
None
:
output
=
output
.
to
(
device
=
device
)
outputs
.
append
(
output
)
return
outputs
def
sample
(
self
,
cls_pred
,
box_pred
,
cls_target
,
box_target
,
):
bs
,
num_pred
,
num_cls
=
cls_pred
.
shape
cls_cost
=
self
.
_cls_cost
(
cls_pred
,
cls_target
)
box_target
=
self
.
encode_reg_target
(
box_target
,
box_pred
.
device
)
instance_reg_weights
=
[]
for
i
in
range
(
len
(
box_target
)):
weights
=
torch
.
logical_not
(
box_target
[
i
].
isnan
()).
to
(
dtype
=
box_target
[
i
].
dtype
)
if
self
.
cls_wise_reg_weights
is
not
None
:
for
cls
,
weight
in
self
.
cls_wise_reg_weights
.
items
():
weights
=
torch
.
where
(
(
cls_target
[
i
]
==
cls
)[:,
None
],
weights
.
new_tensor
(
weight
),
weights
,
)
instance_reg_weights
.
append
(
weights
)
box_cost
=
self
.
_box_cost
(
box_pred
,
box_target
,
instance_reg_weights
)
indices
=
[]
for
i
in
range
(
bs
):
if
cls_cost
[
i
]
is
not
None
and
box_cost
[
i
]
is
not
None
:
cost
=
(
cls_cost
[
i
]
+
box_cost
[
i
]).
detach
().
cpu
().
numpy
()
cost
=
np
.
where
(
np
.
isneginf
(
cost
)
|
np
.
isnan
(
cost
),
1e8
,
cost
)
assign
=
linear_sum_assignment
(
cost
)
indices
.
append
(
[
cls_pred
.
new_tensor
(
x
,
dtype
=
torch
.
int64
)
for
x
in
assign
]
)
else
:
indices
.
append
([
None
,
None
])
output_cls_target
=
(
cls_target
[
0
].
new_ones
([
bs
,
num_pred
],
dtype
=
torch
.
long
)
*
num_cls
)
output_box_target
=
box_pred
.
new_zeros
(
box_pred
.
shape
)
output_reg_weights
=
box_pred
.
new_zeros
(
box_pred
.
shape
)
for
i
,
(
pred_idx
,
target_idx
)
in
enumerate
(
indices
):
if
len
(
cls_target
[
i
])
==
0
:
continue
output_cls_target
[
i
,
pred_idx
]
=
cls_target
[
i
][
target_idx
]
output_box_target
[
i
,
pred_idx
]
=
box_target
[
i
][
target_idx
]
output_reg_weights
[
i
,
pred_idx
]
=
instance_reg_weights
[
i
][
target_idx
]
return
output_cls_target
,
output_box_target
,
output_reg_weights
def
_cls_cost
(
self
,
cls_pred
,
cls_target
):
bs
=
cls_pred
.
shape
[
0
]
cls_pred
=
cls_pred
.
sigmoid
()
cost
=
[]
for
i
in
range
(
bs
):
if
len
(
cls_target
[
i
])
>
0
:
neg_cost
=
(
-
(
1
-
cls_pred
[
i
]
+
self
.
eps
).
log
()
*
(
1
-
self
.
alpha
)
*
cls_pred
[
i
].
pow
(
self
.
gamma
)
)
pos_cost
=
(
-
(
cls_pred
[
i
]
+
self
.
eps
).
log
()
*
self
.
alpha
*
(
1
-
cls_pred
[
i
]).
pow
(
self
.
gamma
)
)
cost
.
append
(
(
pos_cost
[:,
cls_target
[
i
]]
-
neg_cost
[:,
cls_target
[
i
]])
*
self
.
cls_weight
)
else
:
cost
.
append
(
None
)
return
cost
def
_box_cost
(
self
,
box_pred
,
box_target
,
instance_reg_weights
):
bs
=
box_pred
.
shape
[
0
]
cost
=
[]
for
i
in
range
(
bs
):
if
len
(
box_target
[
i
])
>
0
:
cost
.
append
(
torch
.
sum
(
torch
.
abs
(
box_pred
[
i
,
:,
None
]
-
box_target
[
i
][
None
])
*
instance_reg_weights
[
i
][
None
]
*
box_pred
.
new_tensor
(
self
.
reg_weights
),
dim
=-
1
,
)
*
self
.
box_weight
)
else
:
cost
.
append
(
None
)
return
cost
def
get_dn_anchors
(
self
,
cls_target
,
box_target
,
gt_instance_id
=
None
):
if
self
.
num_dn_groups
<=
0
:
return
None
if
self
.
num_temp_dn_groups
<=
0
:
gt_instance_id
=
None
if
self
.
max_dn_gt
>
0
:
cls_target
=
[
x
[:
self
.
max_dn_gt
]
for
x
in
cls_target
]
box_target
=
[
x
[:
self
.
max_dn_gt
]
for
x
in
box_target
]
if
gt_instance_id
is
not
None
:
gt_instance_id
=
[
x
[:
self
.
max_dn_gt
]
for
x
in
gt_instance_id
]
max_dn_gt
=
max
([
len
(
x
)
for
x
in
cls_target
])
if
max_dn_gt
==
0
:
return
None
cls_target
=
torch
.
stack
(
[
F
.
pad
(
x
,
(
0
,
max_dn_gt
-
x
.
shape
[
0
]),
value
=-
1
)
for
x
in
cls_target
]
)
box_target
=
self
.
encode_reg_target
(
box_target
,
cls_target
.
device
)
box_target
=
torch
.
stack
(
[
F
.
pad
(
x
,
(
0
,
0
,
0
,
max_dn_gt
-
x
.
shape
[
0
]))
for
x
in
box_target
]
)
box_target
=
torch
.
where
(
cls_target
[...,
None
]
==
-
1
,
box_target
.
new_tensor
(
0
),
box_target
)
if
gt_instance_id
is
not
None
:
gt_instance_id
=
torch
.
stack
(
[
F
.
pad
(
x
,
(
0
,
max_dn_gt
-
x
.
shape
[
0
]),
value
=-
1
)
for
x
in
gt_instance_id
]
)
bs
,
num_gt
,
state_dims
=
box_target
.
shape
if
self
.
num_dn_groups
>
1
:
cls_target
=
cls_target
.
tile
(
self
.
num_dn_groups
,
1
)
box_target
=
box_target
.
tile
(
self
.
num_dn_groups
,
1
,
1
)
if
gt_instance_id
is
not
None
:
gt_instance_id
=
gt_instance_id
.
tile
(
self
.
num_dn_groups
,
1
)
noise
=
torch
.
rand_like
(
box_target
)
*
2
-
1
noise
*=
box_target
.
new_tensor
(
self
.
dn_noise_scale
)
dn_anchor
=
box_target
+
noise
if
self
.
add_neg_dn
:
noise_neg
=
torch
.
rand_like
(
box_target
)
+
1
flag
=
torch
.
where
(
torch
.
rand_like
(
box_target
)
>
0.5
,
noise_neg
.
new_tensor
(
1
),
noise_neg
.
new_tensor
(
-
1
),
)
noise_neg
*=
flag
noise_neg
*=
box_target
.
new_tensor
(
self
.
dn_noise_scale
)
dn_anchor
=
torch
.
cat
([
dn_anchor
,
box_target
+
noise_neg
],
dim
=
1
)
num_gt
*=
2
box_cost
=
self
.
_box_cost
(
dn_anchor
,
box_target
,
torch
.
ones_like
(
box_target
)
)
dn_box_target
=
torch
.
zeros_like
(
dn_anchor
)
dn_cls_target
=
-
torch
.
ones_like
(
cls_target
)
*
3
if
gt_instance_id
is
not
None
:
dn_id_target
=
-
torch
.
ones_like
(
gt_instance_id
)
if
self
.
add_neg_dn
:
dn_cls_target
=
torch
.
cat
([
dn_cls_target
,
dn_cls_target
],
dim
=
1
)
if
gt_instance_id
is
not
None
:
dn_id_target
=
torch
.
cat
([
dn_id_target
,
dn_id_target
],
dim
=
1
)
for
i
in
range
(
dn_anchor
.
shape
[
0
]):
cost
=
box_cost
[
i
].
cpu
().
numpy
()
anchor_idx
,
gt_idx
=
linear_sum_assignment
(
cost
)
anchor_idx
=
dn_anchor
.
new_tensor
(
anchor_idx
,
dtype
=
torch
.
int64
)
gt_idx
=
dn_anchor
.
new_tensor
(
gt_idx
,
dtype
=
torch
.
int64
)
dn_box_target
[
i
,
anchor_idx
]
=
box_target
[
i
,
gt_idx
]
dn_cls_target
[
i
,
anchor_idx
]
=
cls_target
[
i
,
gt_idx
]
if
gt_instance_id
is
not
None
:
dn_id_target
[
i
,
anchor_idx
]
=
gt_instance_id
[
i
,
gt_idx
]
dn_anchor
=
(
dn_anchor
.
reshape
(
self
.
num_dn_groups
,
bs
,
num_gt
,
state_dims
)
.
permute
(
1
,
0
,
2
,
3
)
.
flatten
(
1
,
2
)
)
dn_box_target
=
(
dn_box_target
.
reshape
(
self
.
num_dn_groups
,
bs
,
num_gt
,
state_dims
)
.
permute
(
1
,
0
,
2
,
3
)
.
flatten
(
1
,
2
)
)
dn_cls_target
=
(
dn_cls_target
.
reshape
(
self
.
num_dn_groups
,
bs
,
num_gt
)
.
permute
(
1
,
0
,
2
)
.
flatten
(
1
)
)
if
gt_instance_id
is
not
None
:
dn_id_target
=
(
dn_id_target
.
reshape
(
self
.
num_dn_groups
,
bs
,
num_gt
)
.
permute
(
1
,
0
,
2
)
.
flatten
(
1
)
)
else
:
dn_id_target
=
None
valid_mask
=
dn_cls_target
>=
0
if
self
.
add_neg_dn
:
cls_target
=
(
torch
.
cat
([
cls_target
,
cls_target
],
dim
=
1
)
.
reshape
(
self
.
num_dn_groups
,
bs
,
num_gt
)
.
permute
(
1
,
0
,
2
)
.
flatten
(
1
)
)
valid_mask
=
torch
.
logical_or
(
valid_mask
,
((
cls_target
>=
0
)
&
(
dn_cls_target
==
-
3
))
)
# valid denotes the items is not from pad.
attn_mask
=
dn_box_target
.
new_ones
(
num_gt
*
self
.
num_dn_groups
,
num_gt
*
self
.
num_dn_groups
)
for
i
in
range
(
self
.
num_dn_groups
):
start
=
num_gt
*
i
end
=
start
+
num_gt
attn_mask
[
start
:
end
,
start
:
end
]
=
0
attn_mask
=
attn_mask
==
1
dn_cls_target
=
dn_cls_target
.
long
()
return
(
dn_anchor
,
dn_box_target
,
dn_cls_target
,
attn_mask
,
valid_mask
,
dn_id_target
,
)
def
update_dn
(
self
,
instance_feature
,
anchor
,
dn_reg_target
,
dn_cls_target
,
valid_mask
,
dn_id_target
,
num_noraml_anchor
,
temporal_valid_mask
,
):
bs
,
num_anchor
=
instance_feature
.
shape
[:
2
]
if
temporal_valid_mask
is
None
:
self
.
dn_metas
=
None
if
self
.
dn_metas
is
None
or
num_noraml_anchor
>=
num_anchor
:
return
(
instance_feature
,
anchor
,
dn_reg_target
,
dn_cls_target
,
valid_mask
,
dn_id_target
,
)
# split instance_feature and anchor into non-dn and dn
num_dn
=
num_anchor
-
num_noraml_anchor
dn_instance_feature
=
instance_feature
[:,
-
num_dn
:]
dn_anchor
=
anchor
[:,
-
num_dn
:]
instance_feature
=
instance_feature
[:,
:
num_noraml_anchor
]
anchor
=
anchor
[:,
:
num_noraml_anchor
]
# reshape all dn metas from (bs,num_all_dn,xxx)
# to (bs, dn_group, num_dn_per_group, xxx)
num_dn_groups
=
self
.
num_dn_groups
num_dn
=
num_dn
//
num_dn_groups
dn_feat
=
dn_instance_feature
.
reshape
(
bs
,
num_dn_groups
,
num_dn
,
-
1
)
dn_anchor
=
dn_anchor
.
reshape
(
bs
,
num_dn_groups
,
num_dn
,
-
1
)
dn_reg_target
=
dn_reg_target
.
reshape
(
bs
,
num_dn_groups
,
num_dn
,
-
1
)
dn_cls_target
=
dn_cls_target
.
reshape
(
bs
,
num_dn_groups
,
num_dn
)
valid_mask
=
valid_mask
.
reshape
(
bs
,
num_dn_groups
,
num_dn
)
if
dn_id_target
is
not
None
:
dn_id
=
dn_id_target
.
reshape
(
bs
,
num_dn_groups
,
num_dn
)
# update temp_dn_metas by instance_id
temp_dn_feat
=
self
.
dn_metas
[
"dn_instance_feature"
]
_
,
num_temp_dn_groups
,
num_temp_dn
=
temp_dn_feat
.
shape
[:
3
]
temp_dn_id
=
self
.
dn_metas
[
"dn_id_target"
]
# bs, num_temp_dn_groups, num_temp_dn, num_dn
match
=
temp_dn_id
[...,
None
]
==
dn_id
[:,
:
num_temp_dn_groups
,
None
]
temp_reg_target
=
(
match
[...,
None
]
*
dn_reg_target
[:,
:
num_temp_dn_groups
,
None
]
).
sum
(
dim
=
3
)
temp_cls_target
=
torch
.
where
(
torch
.
all
(
torch
.
logical_not
(
match
),
dim
=-
1
),
self
.
dn_metas
[
"dn_cls_target"
].
new_tensor
(
-
1
),
self
.
dn_metas
[
"dn_cls_target"
],
)
temp_valid_mask
=
self
.
dn_metas
[
"valid_mask"
]
temp_dn_anchor
=
self
.
dn_metas
[
"dn_anchor"
]
# handle the misalignment the length of temp_dn to dn caused by the
# change of num_gt, then concat the temp_dn and dn
temp_dn_metas
=
[
temp_dn_feat
,
temp_dn_anchor
,
temp_reg_target
,
temp_cls_target
,
temp_valid_mask
,
temp_dn_id
,
]
dn_metas
=
[
dn_feat
,
dn_anchor
,
dn_reg_target
,
dn_cls_target
,
valid_mask
,
dn_id
,
]
output
=
[]
for
i
,
(
temp_meta
,
meta
)
in
enumerate
(
zip
(
temp_dn_metas
,
dn_metas
)):
if
num_temp_dn
<
num_dn
:
pad
=
(
0
,
num_dn
-
num_temp_dn
)
if
temp_meta
.
dim
()
==
4
:
pad
=
(
0
,
0
)
+
pad
else
:
assert
temp_meta
.
dim
()
==
3
temp_meta
=
F
.
pad
(
temp_meta
,
pad
,
value
=
0
)
else
:
temp_meta
=
temp_meta
[:,
:,
:
num_dn
]
mask
=
temporal_valid_mask
[:,
None
,
None
]
if
meta
.
dim
()
==
4
:
mask
=
mask
.
unsqueeze
(
dim
=-
1
)
temp_meta
=
torch
.
where
(
mask
,
temp_meta
,
meta
[:,
:
num_temp_dn_groups
]
)
meta
=
torch
.
cat
([
temp_meta
,
meta
[:,
num_temp_dn_groups
:]],
dim
=
1
)
meta
=
meta
.
flatten
(
1
,
2
)
output
.
append
(
meta
)
output
[
0
]
=
torch
.
cat
([
instance_feature
,
output
[
0
]],
dim
=
1
)
output
[
1
]
=
torch
.
cat
([
anchor
,
output
[
1
]],
dim
=
1
)
return
output
def
cache_dn
(
self
,
dn_instance_feature
,
dn_anchor
,
dn_cls_target
,
valid_mask
,
dn_id_target
,
):
if
self
.
num_temp_dn_groups
<
0
:
return
num_dn_groups
=
self
.
num_dn_groups
bs
,
num_dn
=
dn_instance_feature
.
shape
[:
2
]
num_temp_dn
=
num_dn
//
num_dn_groups
temp_group_mask
=
(
torch
.
randperm
(
num_dn_groups
)
<
self
.
num_temp_dn_groups
)
temp_group_mask
=
temp_group_mask
.
to
(
device
=
dn_anchor
.
device
)
dn_instance_feature
=
dn_instance_feature
.
detach
().
reshape
(
bs
,
num_dn_groups
,
num_temp_dn
,
-
1
)[:,
temp_group_mask
]
dn_anchor
=
dn_anchor
.
detach
().
reshape
(
bs
,
num_dn_groups
,
num_temp_dn
,
-
1
)[:,
temp_group_mask
]
dn_cls_target
=
dn_cls_target
.
reshape
(
bs
,
num_dn_groups
,
num_temp_dn
)[
:,
temp_group_mask
]
valid_mask
=
valid_mask
.
reshape
(
bs
,
num_dn_groups
,
num_temp_dn
)[
:,
temp_group_mask
]
if
dn_id_target
is
not
None
:
dn_id_target
=
dn_id_target
.
reshape
(
bs
,
num_dn_groups
,
num_temp_dn
)[:,
temp_group_mask
]
self
.
dn_metas
=
dict
(
dn_instance_feature
=
dn_instance_feature
,
dn_anchor
=
dn_anchor
,
dn_cls_target
=
dn_cls_target
,
valid_mask
=
valid_mask
,
dn_id_target
=
dn_id_target
,
)
projects/mmdet3d_plugin/models/grid_mask.py
0 → 100644
View file @
a9dc86e9
This diff is collapsed.
Click to expand it.
projects/mmdet3d_plugin/models/instance_bank.py
0 → 100644
View file @
a9dc86e9
This diff is collapsed.
Click to expand it.
projects/mmdet3d_plugin/models/sparse4d.py
0 → 100644
View file @
a9dc86e9
# Copyright (c) Horizon Robotics. All rights reserved.
from
inspect
import
signature
import
torch
from
mmcv.runner
import
force_fp32
,
auto_fp16
from
mmcv.utils
import
build_from_cfg
from
mmcv.cnn.bricks.registry
import
PLUGIN_LAYERS
from
mmdet.models
import
(
DETECTORS
,
BaseDetector
,
build_backbone
,
build_head
,
build_neck
,
)
from
.grid_mask
import
GridMask
try
:
from
..ops
import
feature_maps_format
DAF_VALID
=
True
except
:
DAF_VALID
=
False
__all__
=
[
"Sparse4D"
]
@
DETECTORS
.
register_module
()
class
Sparse4D
(
BaseDetector
):
def
__init__
(
self
,
img_backbone
,
head
,
img_neck
=
None
,
init_cfg
=
None
,
train_cfg
=
None
,
test_cfg
=
None
,
pretrained
=
None
,
use_grid_mask
=
True
,
use_deformable_func
=
False
,
depth_branch
=
None
,
):
super
(
Sparse4D
,
self
).
__init__
(
init_cfg
=
init_cfg
)
if
pretrained
is
not
None
:
backbone
.
pretrained
=
pretrained
self
.
img_backbone
=
build_backbone
(
img_backbone
)
self
.
img_backbone
=
self
.
img_backbone
.
to
(
device
=
'cuda'
,
memory_format
=
torch
.
channels_last
)
if
img_neck
is
not
None
:
self
.
img_neck
=
build_neck
(
img_neck
)
self
.
head
=
build_head
(
head
)
self
.
use_grid_mask
=
use_grid_mask
if
use_deformable_func
:
assert
DAF_VALID
,
"deformable_aggregation needs to be set up."
self
.
use_deformable_func
=
use_deformable_func
if
depth_branch
is
not
None
:
self
.
depth_branch
=
build_from_cfg
(
depth_branch
,
PLUGIN_LAYERS
)
else
:
self
.
depth_branch
=
None
if
use_grid_mask
:
self
.
grid_mask
=
GridMask
(
True
,
True
,
rotate
=
1
,
offset
=
False
,
ratio
=
0.5
,
mode
=
1
,
prob
=
0.7
)
@
auto_fp16
(
apply_to
=
(
"img"
,),
out_fp32
=
True
)
def
extract_feat
(
self
,
img
,
return_depth
=
False
,
metas
=
None
):
bs
=
img
.
shape
[
0
]
if
img
.
dim
()
==
5
:
# multi-view
num_cams
=
img
.
shape
[
1
]
img
=
img
.
flatten
(
end_dim
=
1
)
else
:
num_cams
=
1
img
=
img
.
to
(
memory_format
=
torch
.
channels_last
)
if
self
.
use_grid_mask
:
img
=
self
.
grid_mask
(
img
)
if
"metas"
in
signature
(
self
.
img_backbone
.
forward
).
parameters
:
feature_maps
=
self
.
img_backbone
(
img
,
num_cams
,
metas
=
metas
)
else
:
feature_maps
=
self
.
img_backbone
(
img
)
if
self
.
img_neck
is
not
None
:
feature_maps
=
list
(
self
.
img_neck
(
feature_maps
))
for
i
,
feat
in
enumerate
(
feature_maps
):
feature_maps
[
i
]
=
torch
.
reshape
(
feat
,
(
bs
,
num_cams
)
+
feat
.
shape
[
1
:]
)
if
return_depth
and
self
.
depth_branch
is
not
None
:
depths
=
self
.
depth_branch
(
feature_maps
,
metas
.
get
(
"focal"
))
else
:
depths
=
None
if
self
.
use_deformable_func
:
feature_maps
=
feature_maps_format
(
feature_maps
)
if
return_depth
:
return
feature_maps
,
depths
return
feature_maps
@
force_fp32
(
apply_to
=
(
"img"
,))
def
forward
(
self
,
img
,
**
data
):
if
self
.
training
:
return
self
.
forward_train
(
img
,
**
data
)
else
:
return
self
.
forward_test
(
img
,
**
data
)
def
forward_train
(
self
,
img
,
**
data
):
feature_maps
,
depths
=
self
.
extract_feat
(
img
,
True
,
data
)
model_outs
=
self
.
head
(
feature_maps
,
data
)
output
=
self
.
head
.
loss
(
model_outs
,
data
)
if
depths
is
not
None
and
"gt_depth"
in
data
:
output
[
"loss_dense_depth"
]
=
self
.
depth_branch
.
loss
(
depths
,
data
[
"gt_depth"
]
)
return
output
def
forward_test
(
self
,
img
,
**
data
):
if
isinstance
(
img
,
list
):
return
self
.
aug_test
(
img
,
**
data
)
else
:
return
self
.
simple_test
(
img
,
**
data
)
def
simple_test
(
self
,
img
,
**
data
):
feature_maps
=
self
.
extract_feat
(
img
)
model_outs
=
self
.
head
(
feature_maps
,
data
)
results
=
self
.
head
.
post_process
(
model_outs
)
output
=
[{
"img_bbox"
:
result
}
for
result
in
results
]
return
output
def
aug_test
(
self
,
img
,
**
data
):
# fake test time augmentation
for
key
in
data
.
keys
():
if
isinstance
(
data
[
key
],
list
):
data
[
key
]
=
data
[
key
][
0
]
return
self
.
simple_test
(
img
[
0
],
**
data
)
projects/mmdet3d_plugin/models/sparse4d_head.py
0 → 100644
View file @
a9dc86e9
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment