Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
TS-MODELS-OPT
training
Autonomous-Driving-models
Commits
d2b71343
Commit
d2b71343
authored
Apr 08, 2026
by
雍大凯
Browse files
add code
parent
69e57885
Changes
259
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4284 additions
and
0 deletions
+4284
-0
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/detectors/bevdepth.py
...hocc/projects/mmdet3d_plugin/models/detectors/bevdepth.py
+258
-0
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/detectors/bevdepth4d.py
...cc/projects/mmdet3d_plugin/models/detectors/bevdepth4d.py
+58
-0
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/detectors/bevdet.py
...ashocc/projects/mmdet3d_plugin/models/detectors/bevdet.py
+250
-0
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/detectors/bevdet4d.py
...hocc/projects/mmdet3d_plugin/models/detectors/bevdet4d.py
+387
-0
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/detectors/bevdet_occ.py
...cc/projects/mmdet3d_plugin/models/detectors/bevdet_occ.py
+1480
-0
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/detectors/bevstereo4d.py
...c/projects/mmdet3d_plugin/models/detectors/bevstereo4d.py
+287
-0
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/losses/__init__.py
...lashocc/projects/mmdet3d_plugin/models/losses/__init__.py
+4
-0
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/losses/__pycache__/__init__.cpython-310.pyc
...plugin/models/losses/__pycache__/__init__.cpython-310.pyc
+0
-0
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/losses/__pycache__/cross_entropy_loss.cpython-310.pyc
...els/losses/__pycache__/cross_entropy_loss.cpython-310.pyc
+0
-0
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/losses/__pycache__/focal_loss.cpython-310.pyc
...ugin/models/losses/__pycache__/focal_loss.cpython-310.pyc
+0
-0
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/losses/__pycache__/lovasz_softmax.cpython-310.pyc
.../models/losses/__pycache__/lovasz_softmax.cpython-310.pyc
+0
-0
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/losses/__pycache__/semkitti_loss.cpython-310.pyc
...n/models/losses/__pycache__/semkitti_loss.cpython-310.pyc
+0
-0
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/losses/cross_entropy_loss.py
...ojects/mmdet3d_plugin/models/losses/cross_entropy_loss.py
+302
-0
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/losses/focal_loss.py
...shocc/projects/mmdet3d_plugin/models/losses/focal_loss.py
+265
-0
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/losses/lovasz_softmax.py
...c/projects/mmdet3d_plugin/models/losses/lovasz_softmax.py
+329
-0
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/losses/semkitti_loss.py
...cc/projects/mmdet3d_plugin/models/losses/semkitti_loss.py
+185
-0
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/model_utils/__init__.py
...cc/projects/mmdet3d_plugin/models/model_utils/__init__.py
+3
-0
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/model_utils/__pycache__/__init__.cpython-310.pyc
...n/models/model_utils/__pycache__/__init__.cpython-310.pyc
+0
-0
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/model_utils/__pycache__/depthnet.cpython-310.pyc
...n/models/model_utils/__pycache__/depthnet.cpython-310.pyc
+0
-0
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/model_utils/depthnet.py
...cc/projects/mmdet3d_plugin/models/model_utils/depthnet.py
+476
-0
No files found.
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/detectors/bevdepth.py
0 → 100644
View file @
d2b71343
# Copyright (c) Phigent Robotics. All rights reserved.
import
torch
import
torch.nn.functional
as
F
from
mmcv.runner
import
force_fp32
from
mmdet3d.models
import
DETECTORS
from
.bevdet
import
BEVDet
from
mmdet3d.models
import
builder
@
DETECTORS
.
register_module
()
class
BEVDepth
(
BEVDet
):
def
__init__
(
self
,
img_backbone
,
img_neck
,
img_view_transformer
,
img_bev_encoder_backbone
,
img_bev_encoder_neck
,
pts_bbox_head
=
None
,
**
kwargs
):
super
(
BEVDepth
,
self
).
__init__
(
img_backbone
=
img_backbone
,
img_neck
=
img_neck
,
img_view_transformer
=
img_view_transformer
,
img_bev_encoder_backbone
=
img_bev_encoder_backbone
,
img_bev_encoder_neck
=
img_bev_encoder_neck
,
pts_bbox_head
=
pts_bbox_head
)
def
image_encoder
(
self
,
img
,
stereo
=
False
):
"""
Args:
img: (B, N, 3, H, W)
stereo: bool
Returns:
x: (B, N, C, fH, fW)
stereo_feat: (B*N, C_stereo, fH_stereo, fW_stereo) / None
"""
imgs
=
img
B
,
N
,
C
,
imH
,
imW
=
imgs
.
shape
imgs
=
imgs
.
view
(
B
*
N
,
C
,
imH
,
imW
)
x
=
self
.
img_backbone
(
imgs
)
stereo_feat
=
None
if
stereo
:
stereo_feat
=
x
[
0
]
x
=
x
[
1
:]
if
self
.
with_img_neck
:
x
=
self
.
img_neck
(
x
)
if
type
(
x
)
in
[
list
,
tuple
]:
x
=
x
[
0
]
_
,
output_dim
,
ouput_H
,
output_W
=
x
.
shape
x
=
x
.
view
(
B
,
N
,
output_dim
,
ouput_H
,
output_W
)
return
x
,
stereo_feat
@
force_fp32
()
def
bev_encoder
(
self
,
x
):
"""
Args:
x: (B, C, Dy, Dx)
Returns:
x: (B, C', 2*Dy, 2*Dx)
"""
x
=
self
.
img_bev_encoder_backbone
(
x
)
x
=
self
.
img_bev_encoder_neck
(
x
)
if
type
(
x
)
in
[
list
,
tuple
]:
x
=
x
[
0
]
return
x
def
prepare_inputs
(
self
,
inputs
):
# split the inputs into each frame
assert
len
(
inputs
)
==
7
B
,
N
,
C
,
H
,
W
=
inputs
[
0
].
shape
imgs
,
sensor2egos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
,
bda
=
\
inputs
sensor2egos
=
sensor2egos
.
view
(
B
,
N
,
4
,
4
)
ego2globals
=
ego2globals
.
view
(
B
,
N
,
4
,
4
)
# calculate the transformation from adj sensor to key ego
keyego2global
=
ego2globals
[:,
0
,
...].
unsqueeze
(
1
)
# (B, 1, 4, 4)
global2keyego
=
torch
.
inverse
(
keyego2global
.
double
())
# (B, 1, 4, 4)
sensor2keyegos
=
\
global2keyego
@
ego2globals
.
double
()
@
sensor2egos
.
double
()
# (B, N_views, 4, 4)
sensor2keyegos
=
sensor2keyegos
.
float
()
return
[
imgs
,
sensor2keyegos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
,
bda
]
def
extract_img_feat
(
self
,
img_inputs
,
img_metas
,
**
kwargs
):
""" Extract features of images.
img_inputs:
imgs: (B, N_views, 3, H, W)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
Returns:
x: [(B, C', H', W'), ]
depth: (B*N, D, fH, fW)
"""
imgs
,
sensor2keyegos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
,
bda
=
self
.
prepare_inputs
(
img_inputs
)
x
,
_
=
self
.
image_encoder
(
imgs
)
# x: (B, N, C, fH, fW)
mlp_input
=
self
.
img_view_transformer
.
get_mlp_input
(
sensor2keyegos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
,
bda
)
# (B, N_views, 27)
x
,
depth
=
self
.
img_view_transformer
([
x
,
sensor2keyegos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
,
bda
,
mlp_input
])
# x: (B, C, Dy, Dx)
# depth: (B*N, D, fH, fW)
x
=
self
.
bev_encoder
(
x
)
return
[
x
],
depth
def
extract_feat
(
self
,
points
,
img_inputs
,
img_metas
,
**
kwargs
):
"""Extract features from images and points."""
"""
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_inputs:
imgs: (B, N_views, 3, H, W)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
"""
img_feats
,
depth
=
self
.
extract_img_feat
(
img_inputs
,
img_metas
,
**
kwargs
)
pts_feats
=
None
return
img_feats
,
pts_feats
,
depth
def
forward_train
(
self
,
points
=
None
,
img_inputs
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
img_metas
=
None
,
gt_bboxes
=
None
,
gt_labels
=
None
,
gt_bboxes_ignore
=
None
,
**
kwargs
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_inputs:
imgs: (B, N_views, 3, H, W) # N_views = 6 * (N_history + 1)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
img_feats
,
pts_feats
,
depth
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
gt_depth
=
kwargs
[
'gt_depth'
]
# (B, N_views, img_H, img_W)
loss_depth
=
self
.
img_view_transformer
.
get_depth_loss
(
gt_depth
,
depth
)
losses
=
dict
(
loss_depth
=
loss_depth
)
losses_pts
=
self
.
forward_pts_train
(
img_feats
,
gt_bboxes_3d
,
gt_labels_3d
,
img_metas
,
gt_bboxes_ignore
)
losses
.
update
(
losses_pts
)
return
losses
def
forward_test
(
self
,
points
=
None
,
img_inputs
=
None
,
img_metas
=
None
,
**
kwargs
):
"""
Args:
points (list[torch.Tensor]): the outer list indicates test-time
augmentations and inner torch.Tensor should have a shape NxC,
which contains all points in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch
img (list[torch.Tensor], optional): the outer
list indicates test-time augmentations and inner
torch.Tensor should have a shape NxCxHxW, which contains
all images in the batch. Defaults to None.
"""
for
var
,
name
in
[(
img_inputs
,
'img_inputs'
),
(
img_metas
,
'img_metas'
)]:
if
not
isinstance
(
var
,
list
):
raise
TypeError
(
'{} must be a list, but got {}'
.
format
(
name
,
type
(
var
)))
num_augs
=
len
(
img_inputs
)
if
num_augs
!=
len
(
img_metas
):
raise
ValueError
(
'num of augmentations ({}) != num of image meta ({})'
.
format
(
len
(
img_inputs
),
len
(
img_metas
)))
if
not
isinstance
(
img_inputs
[
0
][
0
],
list
):
img_inputs
=
[
img_inputs
]
if
img_inputs
is
None
else
img_inputs
points
=
[
points
]
if
points
is
None
else
points
return
self
.
simple_test
(
points
[
0
],
img_metas
[
0
],
img_inputs
[
0
],
**
kwargs
)
else
:
return
self
.
aug_test
(
None
,
img_metas
[
0
],
img_inputs
[
0
],
**
kwargs
)
def
aug_test
(
self
,
points
,
img_metas
,
img
=
None
,
rescale
=
False
):
"""Test function without augmentaiton."""
assert
False
def
simple_test
(
self
,
points
,
img_metas
,
img_inputs
=
None
,
rescale
=
False
,
**
kwargs
):
"""Test function without augmentaiton.
Returns:
bbox_list: List[dict0, dict1, ...] len = bs
dict: {
'pts_bbox': dict: {
'boxes_3d': (N, 9)
'scores_3d': (N, )
'labels_3d': (N, )
}
}
"""
img_feats
,
_
,
_
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
bbox_list
=
[
dict
()
for
_
in
range
(
len
(
img_metas
))]
bbox_pts
=
self
.
simple_test_pts
(
img_feats
,
img_metas
,
rescale
=
rescale
)
# bbox_pts: List[dict0, dict1, ...], len = batch_size
# dict: {
# 'boxes_3d': (N, 9)
# 'scores_3d': (N, )
# 'labels_3d': (N, )
# }
for
result_dict
,
pts_bbox
in
zip
(
bbox_list
,
bbox_pts
):
result_dict
[
'pts_bbox'
]
=
pts_bbox
return
bbox_list
def
forward_dummy
(
self
,
points
=
None
,
img_metas
=
None
,
img_inputs
=
None
,
**
kwargs
):
img_feats
,
_
,
_
=
self
.
extract_feat
(
points
,
img
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
assert
self
.
with_pts_bbox
outs
=
self
.
pts_bbox_head
(
img_feats
)
return
outs
\ No newline at end of file
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/detectors/bevdepth4d.py
0 → 100644
View file @
d2b71343
# Copyright (c) Phigent Robotics. All rights reserved.
import
torch
import
torch.nn.functional
as
F
from
mmcv.runner
import
force_fp32
from
mmdet3d.models
import
DETECTORS
from
mmdet3d.models
import
builder
from
.bevdet4d
import
BEVDet4D
@
DETECTORS
.
register_module
()
class
BEVDepth4D
(
BEVDet4D
):
def
forward_train
(
self
,
points
=
None
,
img_metas
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
gt_labels
=
None
,
gt_bboxes
=
None
,
img_inputs
=
None
,
proposals
=
None
,
gt_bboxes_ignore
=
None
,
**
kwargs
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
img_feats
,
pts_feats
,
depth
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
gt_depth
=
kwargs
[
'gt_depth'
]
# (B, N_views, img_H, img_W)
loss_depth
=
self
.
img_view_transformer
.
get_depth_loss
(
gt_depth
,
depth
)
losses
=
dict
(
loss_depth
=
loss_depth
)
losses_pts
=
self
.
forward_pts_train
(
img_feats
,
gt_bboxes_3d
,
gt_labels_3d
,
img_metas
,
gt_bboxes_ignore
)
losses
.
update
(
losses_pts
)
return
losses
\ No newline at end of file
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/detectors/bevdet.py
0 → 100644
View file @
d2b71343
# Copyright (c) Phigent Robotics. All rights reserved.
import
torch
import
torch.nn.functional
as
F
from
mmcv.runner
import
force_fp32
from
mmdet3d.models
import
DETECTORS
from
mmdet3d.models
import
CenterPoint
from
mmdet3d.models
import
builder
@
DETECTORS
.
register_module
()
class
BEVDet
(
CenterPoint
):
def
__init__
(
self
,
img_backbone
,
img_neck
,
img_view_transformer
,
img_bev_encoder_backbone
,
img_bev_encoder_neck
,
pts_bbox_head
=
None
,
**
kwargs
):
super
(
BEVDet
,
self
).
__init__
(
img_backbone
=
img_backbone
,
img_neck
=
img_neck
,
pts_bbox_head
=
pts_bbox_head
,
**
kwargs
)
self
.
img_view_transformer
=
builder
.
build_neck
(
img_view_transformer
)
self
.
img_bev_encoder_backbone
=
builder
.
build_backbone
(
img_bev_encoder_backbone
)
self
.
img_bev_encoder_neck
=
builder
.
build_neck
(
img_bev_encoder_neck
)
@
torch
.
compile
def
image_encoder
(
self
,
img
,
stereo
=
False
):
"""
Args:
img: (B, N, 3, H, W)
stereo: bool
Returns:
x: (B, N, C, fH, fW)
stereo_feat: (B*N, C_stereo, fH_stereo, fW_stereo) / None
"""
imgs
=
img
B
,
N
,
C
,
imH
,
imW
=
imgs
.
shape
imgs
=
imgs
.
view
(
B
*
N
,
C
,
imH
,
imW
)
x
=
self
.
img_backbone
(
imgs
)
stereo_feat
=
None
if
stereo
:
stereo_feat
=
x
[
0
]
x
=
x
[
1
:]
if
self
.
with_img_neck
:
x
=
self
.
img_neck
(
x
)
if
type
(
x
)
in
[
list
,
tuple
]:
x
=
x
[
0
]
_
,
output_dim
,
ouput_H
,
output_W
=
x
.
shape
x
=
x
.
view
(
B
,
N
,
output_dim
,
ouput_H
,
output_W
)
return
x
,
stereo_feat
@
torch
.
compile
@
force_fp32
()
def
bev_encoder
(
self
,
x
):
"""
Args:
x: (B, C, Dy, Dx)
Returns:
x: (B, C', 2*Dy, 2*Dx)
"""
x
=
self
.
img_bev_encoder_backbone
(
x
)
x
=
self
.
img_bev_encoder_neck
(
x
)
if
type
(
x
)
in
[
list
,
tuple
]:
x
=
x
[
0
]
return
x
@
torch
.
compile
def
prepare_inputs
(
self
,
inputs
):
# split the inputs into each frame
assert
len
(
inputs
)
==
7
B
,
N
,
C
,
H
,
W
=
inputs
[
0
].
shape
imgs
,
sensor2egos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
,
bda
=
\
inputs
sensor2egos
=
sensor2egos
.
view
(
B
,
N
,
4
,
4
)
ego2globals
=
ego2globals
.
view
(
B
,
N
,
4
,
4
)
# calculate the transformation from adj sensor to key ego
keyego2global
=
ego2globals
[:,
0
,
...].
unsqueeze
(
1
)
# (B, 1, 4, 4)
global2keyego
=
torch
.
inverse
(
keyego2global
.
double
())
# (B, 1, 4, 4)
sensor2keyegos
=
\
global2keyego
@
ego2globals
.
double
()
@
sensor2egos
.
double
()
# (B, N_views, 4, 4)
sensor2keyegos
=
sensor2keyegos
.
float
()
return
[
imgs
,
sensor2keyegos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
,
bda
]
def
extract_img_feat
(
self
,
img_inputs
,
img_metas
,
**
kwargs
):
""" Extract features of images.
img_inputs:
imgs: (B, N_views, 3, H, W)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
Returns:
x: [(B, C', H', W'), ]
depth: (B*N, D, fH, fW)
"""
img_inputs
=
self
.
prepare_inputs
(
img_inputs
)
x
,
_
=
self
.
image_encoder
(
img_inputs
[
0
])
# x: (B, N, C, fH, fW)
x
,
depth
=
self
.
img_view_transformer
([
x
]
+
img_inputs
[
1
:
7
])
# x: (B, C, Dy, Dx)
# depth: (B*N, D, fH, fW)
x
=
self
.
bev_encoder
(
x
)
return
[
x
],
depth
def
extract_feat
(
self
,
points
,
img_inputs
,
img_metas
,
**
kwargs
):
"""Extract features from images and points."""
"""
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_inputs:
imgs: (B, N_views, 3, H, W)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
"""
img_feats
,
depth
=
self
.
extract_img_feat
(
img_inputs
,
img_metas
,
**
kwargs
)
pts_feats
=
None
return
img_feats
,
pts_feats
,
depth
def
forward_train
(
self
,
points
=
None
,
img_inputs
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
img_metas
=
None
,
gt_bboxes
=
None
,
gt_labels
=
None
,
gt_bboxes_ignore
=
None
,
**
kwargs
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_inputs:
imgs: (B, N_views, 3, H, W) # N_views = 6 * (N_history + 1)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
img_feats
,
pts_feats
,
_
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
losses
=
dict
()
losses_pts
=
self
.
forward_pts_train
(
img_feats
,
gt_bboxes_3d
,
gt_labels_3d
,
img_metas
,
gt_bboxes_ignore
)
losses
.
update
(
losses_pts
)
return
losses
def
forward_test
(
self
,
points
=
None
,
img_inputs
=
None
,
img_metas
=
None
,
**
kwargs
):
"""
Args:
points (list[torch.Tensor]): the outer list indicates test-time
augmentations and inner torch.Tensor should have a shape NxC,
which contains all points in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch
img (list[torch.Tensor], optional): the outer
list indicates test-time augmentations and inner
torch.Tensor should have a shape NxCxHxW, which contains
all images in the batch. Defaults to None.
"""
for
var
,
name
in
[(
img_inputs
,
'img_inputs'
),
(
img_metas
,
'img_metas'
)]:
if
not
isinstance
(
var
,
list
):
raise
TypeError
(
'{} must be a list, but got {}'
.
format
(
name
,
type
(
var
)))
num_augs
=
len
(
img_inputs
)
if
num_augs
!=
len
(
img_metas
):
raise
ValueError
(
'num of augmentations ({}) != num of image meta ({})'
.
format
(
len
(
img_inputs
),
len
(
img_metas
)))
if
not
isinstance
(
img_inputs
[
0
][
0
],
list
):
img_inputs
=
[
img_inputs
]
if
img_inputs
is
None
else
img_inputs
points
=
[
points
]
if
points
is
None
else
points
return
self
.
simple_test
(
points
[
0
],
img_metas
[
0
],
img_inputs
[
0
],
**
kwargs
)
else
:
return
self
.
aug_test
(
None
,
img_metas
[
0
],
img_inputs
[
0
],
**
kwargs
)
def
aug_test
(
self
,
points
,
img_metas
,
img
=
None
,
rescale
=
False
):
"""Test function without augmentaiton."""
assert
False
def
simple_test
(
self
,
points
,
img_metas
,
img_inputs
=
None
,
rescale
=
False
,
**
kwargs
):
"""Test function without augmentaiton.
Returns:
bbox_list: List[dict0, dict1, ...] len = bs
dict: {
'pts_bbox': dict: {
'boxes_3d': (N, 9)
'scores_3d': (N, )
'labels_3d': (N, )
}
}
"""
img_feats
,
_
,
_
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
bbox_list
=
[
dict
()
for
_
in
range
(
len
(
img_metas
))]
bbox_pts
=
self
.
simple_test_pts
(
img_feats
,
img_metas
,
rescale
=
rescale
)
# bbox_pts: List[dict0, dict1, ...], len = batch_size
# dict: {
# 'boxes_3d': (N, 9)
# 'scores_3d': (N, )
# 'labels_3d': (N, )
# }
for
result_dict
,
pts_bbox
in
zip
(
bbox_list
,
bbox_pts
):
result_dict
[
'pts_bbox'
]
=
pts_bbox
return
bbox_list
def
forward_dummy
(
self
,
points
=
None
,
img_metas
=
None
,
img_inputs
=
None
,
**
kwargs
):
img_feats
,
_
,
_
=
self
.
extract_feat
(
points
,
img
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
assert
self
.
with_pts_bbox
outs
=
self
.
pts_bbox_head
(
img_feats
)
return
outs
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/detectors/bevdet4d.py
0 → 100644
View file @
d2b71343
# Copyright (c) Phigent Robotics. All rights reserved.
import
torch
import
torch.nn.functional
as
F
from
mmcv.runner
import
force_fp32
from
mmdet3d.models
import
DETECTORS
from
mmdet3d.models
import
builder
from
.bevdet
import
BEVDet
@
DETECTORS
.
register_module
()
class
BEVDet4D
(
BEVDet
):
r
"""BEVDet4D paradigm for multi-camera 3D object detection.
Please refer to the `paper <https://arxiv.org/abs/2203.17054>`_
Args:
pre_process (dict | None): Configuration dict of BEV pre-process net.
align_after_view_transfromation (bool): Whether to align the BEV
Feature after view transformation. By default, the BEV feature of
the previous frame is aligned during the view transformation.
num_adj (int): Number of adjacent frames.
with_prev (bool): Whether to set the BEV feature of previous frame as
all zero. By default, False.
"""
def
__init__
(
self
,
pre_process
=
None
,
align_after_view_transfromation
=
False
,
num_adj
=
1
,
with_prev
=
True
,
**
kwargs
):
super
(
BEVDet4D
,
self
).
__init__
(
**
kwargs
)
self
.
pre_process
=
pre_process
is
not
None
if
self
.
pre_process
:
self
.
pre_process_net
=
builder
.
build_backbone
(
pre_process
)
self
.
align_after_view_transfromation
=
align_after_view_transfromation
self
.
num_frame
=
num_adj
+
1
self
.
with_prev
=
with_prev
self
.
grid
=
None
def
gen_grid
(
self
,
input
,
sensor2keyegos
,
bda
,
bda_adj
=
None
):
"""
Args:
input: (B, C, Dy, Dx) bev_feat
sensor2keyegos: List[
curr_sensor-->key_ego: (B, N_views, 4, 4)
prev_sensor-->key_ego: (B, N_views, 4, 4)
]
bda: (B, 3, 3)
bda_adj: None
Returns:
grid: (B, Dy, Dx, 2)
"""
B
,
C
,
H
,
W
=
input
.
shape
v
=
sensor2keyegos
[
0
].
shape
[
0
]
# N_views
if
self
.
grid
is
None
:
# generate grid
xs
=
torch
.
linspace
(
0
,
W
-
1
,
W
,
dtype
=
input
.
dtype
,
device
=
input
.
device
).
view
(
1
,
W
).
expand
(
H
,
W
)
# (Dy, Dx)
ys
=
torch
.
linspace
(
0
,
H
-
1
,
H
,
dtype
=
input
.
dtype
,
device
=
input
.
device
).
view
(
H
,
1
).
expand
(
H
,
W
)
# (Dy, Dx)
grid
=
torch
.
stack
((
xs
,
ys
,
torch
.
ones_like
(
xs
)),
-
1
)
# (Dy, Dx, 3) 3: (x, y, 1)
self
.
grid
=
grid
else
:
grid
=
self
.
grid
# (Dy, Dx, 3) --> (1, Dy, Dx, 3) --> (B, Dy, Dx, 3) --> (B, Dy, Dx, 3, 1)) 3: (grid_x, grid_y, 1)
grid
=
grid
.
view
(
1
,
H
,
W
,
3
).
expand
(
B
,
H
,
W
,
3
).
view
(
B
,
H
,
W
,
3
,
1
)
curr_sensor2keyego
=
sensor2keyegos
[
0
][:,
0
:
1
,
:,
:]
# (B, 1, 4, 4)
prev_sensor2keyego
=
sensor2keyegos
[
1
][:,
0
:
1
,
:,
:]
# (B, 1, 4, 4)
# add bev data augmentation
bda_
=
torch
.
zeros
((
B
,
1
,
4
,
4
),
dtype
=
grid
.
dtype
).
to
(
grid
)
# (B, 1, 4, 4)
bda_
[:,
:,
:
3
,
:
3
]
=
bda
.
unsqueeze
(
1
)
bda_
[:,
:,
3
,
3
]
=
1
curr_sensor2keyego
=
bda_
.
matmul
(
curr_sensor2keyego
)
# (B, 1, 4, 4)
if
bda_adj
is
not
None
:
bda_
=
torch
.
zeros
((
B
,
1
,
4
,
4
),
dtype
=
grid
.
dtype
).
to
(
grid
)
bda_
[:,
:,
:
3
,
:
3
]
=
bda_adj
.
unsqueeze
(
1
)
bda_
[:,
:,
3
,
3
]
=
1
prev_sensor2keyego
=
bda_
.
matmul
(
prev_sensor2keyego
)
# (B, 1, 4, 4)
# transformation from current ego frame to adjacent ego frame
# key_ego --> prev_cam_front --> prev_ego
keyego2adjego
=
curr_sensor2keyego
.
matmul
(
torch
.
inverse
(
prev_sensor2keyego
))
keyego2adjego
=
keyego2adjego
.
unsqueeze
(
dim
=
1
)
# (B, 1, 1, 4, 4)
# (B, 1, 1, 3, 3)
keyego2adjego
=
keyego2adjego
[...,
[
True
,
True
,
False
,
True
],
:][...,
[
True
,
True
,
False
,
True
]]
# x = grid_x * vx + x_min; y = grid_y * vy + y_min;
# feat2bev:
# [[vx, 0, x_min],
# [0, vy, y_min],
# [0, 0, 1 ]]
feat2bev
=
torch
.
zeros
((
3
,
3
),
dtype
=
grid
.
dtype
).
to
(
grid
)
feat2bev
[
0
,
0
]
=
self
.
img_view_transformer
.
grid_interval
[
0
]
feat2bev
[
1
,
1
]
=
self
.
img_view_transformer
.
grid_interval
[
1
]
feat2bev
[
0
,
2
]
=
self
.
img_view_transformer
.
grid_lower_bound
[
0
]
feat2bev
[
1
,
2
]
=
self
.
img_view_transformer
.
grid_lower_bound
[
1
]
feat2bev
[
2
,
2
]
=
1
feat2bev
=
feat2bev
.
view
(
1
,
3
,
3
)
# (1, 3, 3)
# curr_feat_grid --> key ego --> prev_cam --> prev_ego --> prev_feat_grid
tf
=
torch
.
inverse
(
feat2bev
).
matmul
(
keyego2adjego
).
matmul
(
feat2bev
)
# (B, 1, 1, 3, 3)
grid
=
tf
.
matmul
(
grid
)
# (B, Dy, Dx, 3, 1) 3: (grid_x, grid_y, 1)
normalize_factor
=
torch
.
tensor
([
W
-
1.0
,
H
-
1.0
],
dtype
=
input
.
dtype
,
device
=
input
.
device
)
# (2, )
# (B, Dy, Dx, 2)
grid
=
grid
[:,
:,
:,
:
2
,
0
]
/
normalize_factor
.
view
(
1
,
1
,
1
,
2
)
*
2.0
-
1.0
return
grid
@
force_fp32
()
def
shift_feature
(
self
,
input
,
sensor2keyegos
,
bda
,
bda_adj
=
None
):
"""
Args:
input: (B, C, Dy, Dx) bev_feat
sensor2keyegos: List[
curr_sensor-->key_ego: (B, N_views, 4, 4)
prev_sensor-->key_ego: (B, N_views, 4, 4)
]
bda: (B, 3, 3)
bda_adj: None
Returns:
output: aligned bev feat (B, C, Dy, Dx).
"""
grid
=
self
.
gen_grid
(
input
,
sensor2keyegos
,
bda
,
bda_adj
=
bda_adj
)
# grid: (B, Dy, Dx, 2), 介于(-1, 1)
output
=
F
.
grid_sample
(
input
,
grid
.
to
(
input
.
dtype
),
align_corners
=
True
)
# (B, C, Dy, Dx)
return
output
def
prepare_bev_feat
(
self
,
img
,
sensor2egos
,
ego2globals
,
intrin
,
post_rot
,
post_tran
,
bda
,
mlp_input
):
"""
Args:
imgs: (B, N_views, 3, H, W)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
mlp_input:
Returns:
bev_feat: (B, C, Dy, Dx)
depth: (B*N, D, fH, fW)
"""
x
,
_
=
self
.
image_encoder
(
img
)
# x: (B, N, C, fH, fW)
# bev_feat: (B, C * Dz(=1), Dy, Dx)
# depth: (B * N, D, fH, fW)
bev_feat
,
depth
=
self
.
img_view_transformer
(
[
x
,
sensor2egos
,
ego2globals
,
intrin
,
post_rot
,
post_tran
,
bda
,
mlp_input
])
if
self
.
pre_process
:
bev_feat
=
self
.
pre_process_net
(
bev_feat
)[
0
]
# (B, C, Dy, Dx)
return
bev_feat
,
depth
def
extract_img_feat_sequential
(
self
,
inputs
,
feat_prev
):
"""
Args:
inputs:
curr_img: (1, N_views, 3, H, W)
sensor2keyegos_curr: (N_prev, N_views, 4, 4)
ego2globals_curr: (N_prev, N_views, 4, 4)
intrins: (1, N_views, 3, 3)
sensor2keyegos_prev: (N_prev, N_views, 4, 4)
ego2globals_prev: (N_prev, N_views, 4, 4)
post_rots: (1, N_views, 3, 3)
post_trans: (1, N_views, 3, )
bda_curr: (N_prev, 3, 3)
feat_prev: (N_prev, C, Dy, Dx)
Returns:
"""
imgs
,
sensor2keyegos_curr
,
ego2globals_curr
,
intrins
=
inputs
[:
4
]
sensor2keyegos_prev
,
_
,
post_rots
,
post_trans
,
bda
=
inputs
[
4
:]
bev_feat_list
=
[]
mlp_input
=
self
.
img_view_transformer
.
get_mlp_input
(
sensor2keyegos_curr
[
0
:
1
,
...],
ego2globals_curr
[
0
:
1
,
...],
intrins
,
post_rots
,
post_trans
,
bda
[
0
:
1
,
...])
inputs_curr
=
(
imgs
,
sensor2keyegos_curr
[
0
:
1
,
...],
ego2globals_curr
[
0
:
1
,
...],
intrins
,
post_rots
,
post_trans
,
bda
[
0
:
1
,
...],
mlp_input
)
# (1, C, Dx, Dy), (1*N, D, fH, fW)
bev_feat
,
depth
=
self
.
prepare_bev_feat
(
*
inputs_curr
)
bev_feat_list
.
append
(
bev_feat
)
# align the feat_prev
_
,
C
,
H
,
W
=
feat_prev
.
shape
# feat_prev: (N_prev, C, Dy, Dx)
feat_prev
=
\
self
.
shift_feature
(
feat_prev
,
# (N_prev, C, Dy, Dx)
[
sensor2keyegos_curr
,
# (N_prev, N_views, 4, 4)
sensor2keyegos_prev
],
# (N_prev, N_views, 4, 4)
bda
# (N_prev, 3, 3)
)
bev_feat_list
.
append
(
feat_prev
.
view
(
1
,
(
self
.
num_frame
-
1
)
*
C
,
H
,
W
))
# (1, N_prev*C, Dy, Dx)
bev_feat
=
torch
.
cat
(
bev_feat_list
,
dim
=
1
)
# (1, N_frames*C, Dy, Dx)
x
=
self
.
bev_encoder
(
bev_feat
)
return
[
x
],
depth
def
prepare_inputs
(
self
,
img_inputs
,
stereo
=
False
):
"""
Args:
img_inputs:
imgs: (B, N, 3, H, W) # N = 6 * (N_history + 1)
sensor2egos: (B, N, 4, 4)
ego2globals: (B, N, 4, 4)
intrins: (B, N, 3, 3)
post_rots: (B, N, 3, 3)
post_trans: (B, N, 3)
bda_rot: (B, 3, 3)
stereo: bool
Returns:
imgs: List[(B, N_views, C, H, W), (B, N_views, C, H, W), ...] len = N_frames
sensor2keyegos: List[(B, N_views, 4, 4), (B, N_views, 4, 4), ...]
ego2globals: List[(B, N_views, 4, 4), (B, N_views, 4, 4), ...]
intrins: List[(B, N_views, 3, 3), (B, N_views, 3, 3), ...]
post_rots: List[(B, N_views, 3, 3), (B, N_views, 3, 3), ...]
post_trans: List[(B, N_views, 3), (B, N_views, 3), ...]
bda: (B, 3, 3)
"""
B
,
N
,
C
,
H
,
W
=
img_inputs
[
0
].
shape
N
=
N
//
self
.
num_frame
# N_views = 6
imgs
=
img_inputs
[
0
].
view
(
B
,
N
,
self
.
num_frame
,
C
,
H
,
W
)
# (B, N_views, N_frames, C, H, W)
imgs
=
torch
.
split
(
imgs
,
1
,
2
)
imgs
=
[
t
.
squeeze
(
2
)
for
t
in
imgs
]
# List[(B, N_views, C, H, W), (B, N_views, C, H, W), ...]
sensor2egos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
,
bda
=
\
img_inputs
[
1
:
7
]
sensor2egos
=
sensor2egos
.
view
(
B
,
self
.
num_frame
,
N
,
4
,
4
)
ego2globals
=
ego2globals
.
view
(
B
,
self
.
num_frame
,
N
,
4
,
4
)
# calculate the transformation from sensor to key ego
# key_ego --> global (B, 1, 1, 4, 4)
keyego2global
=
ego2globals
[:,
0
,
0
,
...].
unsqueeze
(
1
).
unsqueeze
(
1
)
# global --> key_ego (B, 1, 1, 4, 4)
global2keyego
=
torch
.
inverse
(
keyego2global
.
double
())
# sensor --> ego --> global --> key_ego
sensor2keyegos
=
\
global2keyego
@
ego2globals
.
double
()
@
sensor2egos
.
double
()
# (B, N_frames, N_views, 4, 4)
sensor2keyegos
=
sensor2keyegos
.
float
()
# -------------------- for stereo --------------------------
curr2adjsensor
=
None
if
stereo
:
# (B, N_frames, N_views, 4, 4), (B, N_frames, N_views, 4, 4)
sensor2egos_cv
,
ego2globals_cv
=
sensor2egos
,
ego2globals
sensor2egos_curr
=
\
sensor2egos_cv
[:,
:
self
.
temporal_frame
,
...].
double
()
# (B, N_temporal=2, N_views, 4, 4)
ego2globals_curr
=
\
ego2globals_cv
[:,
:
self
.
temporal_frame
,
...].
double
()
# (B, N_temporal=2, N_views, 4, 4)
sensor2egos_adj
=
\
sensor2egos_cv
[:,
1
:
self
.
temporal_frame
+
1
,
...].
double
()
# (B, N_temporal=2, N_views, 4, 4)
ego2globals_adj
=
\
ego2globals_cv
[:,
1
:
self
.
temporal_frame
+
1
,
...].
double
()
# (B, N_temporal=2, N_views, 4, 4)
# curr_sensor --> curr_ego --> global --> prev_ego --> prev_sensor
curr2adjsensor
=
\
torch
.
inverse
(
ego2globals_adj
@
sensor2egos_adj
)
\
@
ego2globals_curr
@
sensor2egos_curr
# (B, N_temporal=2, N_views, 4, 4)
curr2adjsensor
=
curr2adjsensor
.
float
()
# (B, N_temporal=2, N_views, 4, 4)
curr2adjsensor
=
torch
.
split
(
curr2adjsensor
,
1
,
1
)
curr2adjsensor
=
[
p
.
squeeze
(
1
)
for
p
in
curr2adjsensor
]
curr2adjsensor
.
extend
([
None
for
_
in
range
(
self
.
extra_ref_frames
)])
# curr2adjsensor: List[(B, N_views, 4, 4), (B, N_views, 4, 4), None]
assert
len
(
curr2adjsensor
)
==
self
.
num_frame
# -------------------- for stereo --------------------------
extra
=
[
sensor2keyegos
,
# (B, N_frames, N_views, 4, 4)
ego2globals
,
# (B, N_frames, N_views, 4, 4)
intrins
.
view
(
B
,
self
.
num_frame
,
N
,
3
,
3
),
# (B, N_frames, N_views, 3, 3)
post_rots
.
view
(
B
,
self
.
num_frame
,
N
,
3
,
3
),
# (B, N_frames, N_views, 3, 3)
post_trans
.
view
(
B
,
self
.
num_frame
,
N
,
3
)
# (B, N_frames, N_views, 3)
]
extra
=
[
torch
.
split
(
t
,
1
,
1
)
for
t
in
extra
]
extra
=
[[
p
.
squeeze
(
1
)
for
p
in
t
]
for
t
in
extra
]
sensor2keyegos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
=
extra
return
imgs
,
sensor2keyegos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
,
\
bda
,
curr2adjsensor
def
extract_img_feat
(
self
,
img_inputs
,
img_metas
,
pred_prev
=
False
,
sequential
=
False
,
**
kwargs
):
"""
Args:
img_inputs:
imgs: (B, N, 3, H, W) # N = 6 * (N_history + 1)
sensor2egos: (B, N, 4, 4)
ego2globals: (B, N, 4, 4)
intrins: (B, N, 3, 3)
post_rots: (B, N, 3, 3)
post_trans: (B, N, 3)
bda_rot: (B, 3, 3)
img_metas:
**kwargs:
Returns:
x: [(B, C', H', W'), ]
depth: (B*N_views, D, fH, fW)
"""
if
sequential
:
return
self
.
extract_img_feat_sequential
(
img_inputs
,
kwargs
[
'feat_prev'
])
imgs
,
sensor2keyegos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
,
\
bda
,
_
=
self
.
prepare_inputs
(
img_inputs
)
"""Extract features of images."""
bev_feat_list
=
[]
depth_list
=
[]
key_frame
=
True
# back propagation for key frame only
for
img
,
sensor2keyego
,
ego2global
,
intrin
,
post_rot
,
post_tran
in
zip
(
imgs
,
sensor2keyegos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
):
if
key_frame
or
self
.
with_prev
:
if
self
.
align_after_view_transfromation
:
sensor2keyego
,
ego2global
=
sensor2keyegos
[
0
],
ego2globals
[
0
]
mlp_input
=
self
.
img_view_transformer
.
get_mlp_input
(
sensor2keyegos
[
0
],
ego2globals
[
0
],
intrin
,
post_rot
,
post_tran
,
bda
)
# (B, N_views, 27)
inputs_curr
=
(
img
,
sensor2keyego
,
ego2global
,
intrin
,
post_rot
,
post_tran
,
bda
,
mlp_input
)
if
key_frame
:
# bev_feat: (B, C, Dy, Dx)
# depth: (B*N_views, D, fH, fW)
bev_feat
,
depth
=
self
.
prepare_bev_feat
(
*
inputs_curr
)
else
:
with
torch
.
no_grad
():
bev_feat
,
depth
=
self
.
prepare_bev_feat
(
*
inputs_curr
)
else
:
# https://github.com/HuangJunJie2017/BEVDet/issues/275
bev_feat
=
torch
.
zeros_like
(
bev_feat_list
[
0
])
depth
=
None
bev_feat_list
.
append
(
bev_feat
)
depth_list
.
append
(
depth
)
key_frame
=
False
# bev_feat_list: List[(B, C, Dy, Dx), (B, C, Dy, Dx), ...]
# depth_list: List[(B*N_views, D, fH, fW), (B*N_views, D, fH, fW), ...]
if
pred_prev
:
assert
self
.
align_after_view_transfromation
assert
sensor2keyegos
[
0
].
shape
[
0
]
==
1
# batch_size = 1
feat_prev
=
torch
.
cat
(
bev_feat_list
[
1
:],
dim
=
0
)
# (1, N_views, 4, 4) --> (N_prev, N_views, 4, 4)
ego2globals_curr
=
\
ego2globals
[
0
].
repeat
(
self
.
num_frame
-
1
,
1
,
1
,
1
)
# (1, N_views, 4, 4) --> (N_prev, N_views, 4, 4)
sensor2keyegos_curr
=
\
sensor2keyegos
[
0
].
repeat
(
self
.
num_frame
-
1
,
1
,
1
,
1
)
ego2globals_prev
=
torch
.
cat
(
ego2globals
[
1
:],
dim
=
0
)
# (N_prev, N_views, 4, 4)
sensor2keyegos_prev
=
torch
.
cat
(
sensor2keyegos
[
1
:],
dim
=
0
)
# (N_prev, N_views, 4, 4)
bda_curr
=
bda
.
repeat
(
self
.
num_frame
-
1
,
1
,
1
)
# (N_prev, 3, 3)
return
feat_prev
,
[
imgs
[
0
],
# (1, N_views, 3, H, W)
sensor2keyegos_curr
,
# (N_prev, N_views, 4, 4)
ego2globals_curr
,
# (N_prev, N_views, 4, 4)
intrins
[
0
],
# (1, N_views, 3, 3)
sensor2keyegos_prev
,
# (N_prev, N_views, 4, 4)
ego2globals_prev
,
# (N_prev, N_views, 4, 4)
post_rots
[
0
],
# (1, N_views, 3, 3)
post_trans
[
0
],
# (1, N_views, 3, )
bda_curr
]
# (N_prev, 3, 3)
if
self
.
align_after_view_transfromation
:
for
adj_id
in
range
(
1
,
self
.
num_frame
):
bev_feat_list
[
adj_id
]
=
self
.
shift_feature
(
bev_feat_list
[
adj_id
],
# (B, C, Dy, Dx)
[
sensor2keyegos
[
0
],
# (B, N_views, 4, 4)
sensor2keyegos
[
adj_id
]
# (B, N_views, 4, 4)
],
bda
# (B, 3, 3)
)
# (B, C, Dy, Dx)
bev_feat
=
torch
.
cat
(
bev_feat_list
,
dim
=
1
)
# (B, N_frames*C, Dy, Dx)
x
=
self
.
bev_encoder
(
bev_feat
)
return
[
x
],
depth_list
[
0
]
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/detectors/bevdet_occ.py
0 → 100644
View file @
d2b71343
# Copyright (c) Phigent Robotics. All rights reserved.
from
...ops
import
TRTBEVPoolv2
from
.bevdet
import
BEVDet
from
.bevdepth
import
BEVDepth
from
.bevdepth4d
import
BEVDepth4D
from
.bevstereo4d
import
BEVStereo4D
from
mmdet3d.models
import
DETECTORS
from
mmdet3d.models.builder
import
build_head
import
torch.nn.functional
as
F
from
mmdet3d.core
import
bbox3d2result
import
numpy
as
np
from
multiprocessing.dummy
import
Pool
as
ThreadPool
from
...ops
import
nearest_assign
# pool = ThreadPool(processes=4) # 创建线程池
# for pano
grid_config_occ
=
{
'x'
:
[
-
40
,
40
,
0.4
],
'y'
:
[
-
40
,
40
,
0.4
],
'z'
:
[
-
1
,
5.4
,
6.4
],
'depth'
:
[
1.0
,
45.0
,
1.0
],
}
# det
det_class_name
=
[
'car'
,
'truck'
,
'trailer'
,
'bus'
,
'construction_vehicle'
,
'bicycle'
,
'motorcycle'
,
'pedestrian'
,
'traffic_cone'
,
'barrier'
]
# occ
occ_class_names
=
[
'others'
,
'barrier'
,
'bicycle'
,
'bus'
,
'car'
,
'construction_vehicle'
,
'motorcycle'
,
'pedestrian'
,
'traffic_cone'
,
'trailer'
,
'truck'
,
'driveable_surface'
,
'other_flat'
,
'sidewalk'
,
'terrain'
,
'manmade'
,
'vegetation'
,
'free'
]
det_ind
=
[
2
,
3
,
4
,
5
,
6
,
7
,
9
,
10
]
occ_ind
=
[
5
,
3
,
0
,
4
,
6
,
7
,
2
,
1
]
detind2occind
=
{
0
:
4
,
1
:
10
,
2
:
9
,
3
:
3
,
4
:
5
,
5
:
2
,
6
:
6
,
7
:
7
,
8
:
8
,
9
:
1
,
}
occind2detind
=
{
4
:
0
,
10
:
1
,
9
:
2
,
3
:
3
,
5
:
4
,
2
:
5
,
6
:
6
,
7
:
7
,
8
:
8
,
1
:
9
,
}
occind2detind_cuda
=
[
-
1
,
-
1
,
5
,
3
,
0
,
4
,
6
,
7
,
-
1
,
2
,
1
]
inst_occ
=
np
.
ones
([
200
,
200
,
16
])
*
0
import
torch
X1
,
Y1
,
Z1
=
200
,
200
,
16
coords_x
=
torch
.
arange
(
X1
).
float
()
coords_y
=
torch
.
arange
(
Y1
).
float
()
coords_z
=
torch
.
arange
(
Z1
).
float
()
coords
=
torch
.
stack
(
torch
.
meshgrid
([
coords_x
,
coords_y
,
coords_z
])).
permute
(
1
,
2
,
3
,
0
)
# W, H, D, 3
# coords = coords.cpu().numpy()
st
=
[
grid_config_occ
[
'x'
][
0
],
grid_config_occ
[
'y'
][
0
],
grid_config_occ
[
'z'
][
0
]]
sx
=
[
grid_config_occ
[
'x'
][
2
],
grid_config_occ
[
'y'
][
2
],
0.4
]
@
DETECTORS
.
register_module
()
class
BEVDetOCC
(
BEVDet
):
def
__init__
(
self
,
occ_head
=
None
,
upsample
=
False
,
**
kwargs
):
super
(
BEVDetOCC
,
self
).
__init__
(
**
kwargs
)
self
.
occ_head
=
build_head
(
occ_head
)
self
.
pts_bbox_head
=
None
self
.
upsample
=
upsample
def
forward_train
(
self
,
points
=
None
,
img_metas
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
gt_labels
=
None
,
gt_bboxes
=
None
,
img_inputs
=
None
,
proposals
=
None
,
gt_bboxes_ignore
=
None
,
**
kwargs
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
pts_feats
,
depth
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
losses
=
dict
()
voxel_semantics
=
kwargs
[
'voxel_semantics'
]
# (B, Dx, Dy, Dz)
mask_camera
=
kwargs
[
'mask_camera'
]
# (B, Dx, Dy, Dz)
occ_bev_feature
=
img_feats
[
0
]
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
loss_occ
=
self
.
forward_occ_train
(
occ_bev_feature
,
voxel_semantics
,
mask_camera
)
losses
.
update
(
loss_occ
)
return
losses
@
torch
.
compile
(
mode
=
"reduce-overhead"
)
def
forward_occ_train
(
self
,
img_feats
,
voxel_semantics
,
mask_camera
):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
outs
=
self
.
occ_head
(
img_feats
)
# assert voxel_semantics.min() >= 0 and voxel_semantics.max() <= 17
loss_occ
=
self
.
occ_head
.
loss
(
outs
,
# (B, Dx, Dy, Dz, n_cls)
voxel_semantics
,
# (B, Dx, Dy, Dz)
mask_camera
,
# (B, Dx, Dy, Dz)
)
return
loss_occ
def
simple_test
(
self
,
points
,
img_metas
,
img
=
None
,
rescale
=
False
,
**
kwargs
):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
_
,
_
=
self
.
extract_feat
(
points
,
img_inputs
=
img
,
img_metas
=
img_metas
,
**
kwargs
)
occ_bev_feature
=
img_feats
[
0
]
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
occ_list
=
self
.
simple_test_occ
(
occ_bev_feature
,
img_metas
)
# List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return
occ_list
def
simple_test_occ
(
self
,
img_feats
,
img_metas
=
None
):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
img_metas:
Returns:
occ_preds: List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
outs
=
self
.
occ_head
(
img_feats
)
if
not
hasattr
(
self
.
occ_head
,
"get_occ_gpu"
):
occ_preds
=
self
.
occ_head
.
get_occ
(
outs
,
img_metas
)
# List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
else
:
occ_preds
=
self
.
occ_head
.
get_occ_gpu
(
outs
,
img_metas
)
# List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return
occ_preds
def
forward_dummy
(
self
,
points
=
None
,
img_metas
=
None
,
img_inputs
=
None
,
**
kwargs
):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
pts_feats
,
depth
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
occ_bev_feature
=
img_feats
[
0
]
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
outs
=
self
.
occ_head
(
occ_bev_feature
)
return
outs
@
DETECTORS
.
register_module
()
class
BEVDepthOCC
(
BEVDepth
):
def
__init__
(
self
,
occ_head
=
None
,
upsample
=
False
,
**
kwargs
):
super
(
BEVDepthOCC
,
self
).
__init__
(
**
kwargs
)
self
.
occ_head
=
build_head
(
occ_head
)
self
.
pts_bbox_head
=
None
self
.
upsample
=
upsample
def
forward_train
(
self
,
points
=
None
,
img_metas
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
gt_labels
=
None
,
gt_bboxes
=
None
,
img_inputs
=
None
,
proposals
=
None
,
gt_bboxes_ignore
=
None
,
**
kwargs
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
pts_feats
,
depth
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
losses
=
dict
()
gt_depth
=
kwargs
[
'gt_depth'
]
# (B, N_views, img_H, img_W)
loss_depth
=
self
.
img_view_transformer
.
get_depth_loss
(
gt_depth
,
depth
)
losses
[
'loss_depth'
]
=
loss_depth
voxel_semantics
=
kwargs
[
'voxel_semantics'
]
# (B, Dx, Dy, Dz)
mask_camera
=
kwargs
[
'mask_camera'
]
# (B, Dx, Dy, Dz)
occ_bev_feature
=
img_feats
[
0
]
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
loss_occ
=
self
.
forward_occ_train
(
occ_bev_feature
,
voxel_semantics
,
mask_camera
)
losses
.
update
(
loss_occ
)
return
losses
def
forward_occ_train
(
self
,
img_feats
,
voxel_semantics
,
mask_camera
):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
outs
=
self
.
occ_head
(
img_feats
)
# assert voxel_semantics.min() >= 0 and voxel_semantics.max() <= 17
loss_occ
=
self
.
occ_head
.
loss
(
outs
,
# (B, Dx, Dy, Dz, n_cls)
voxel_semantics
,
# (B, Dx, Dy, Dz)
mask_camera
,
# (B, Dx, Dy, Dz)
)
return
loss_occ
def
simple_test
(
self
,
points
,
img_metas
,
img
=
None
,
rescale
=
False
,
**
kwargs
):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
_
,
_
=
self
.
extract_feat
(
points
,
img_inputs
=
img
,
img_metas
=
img_metas
,
**
kwargs
)
occ_bev_feature
=
img_feats
[
0
]
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
occ_list
=
self
.
simple_test_occ
(
occ_bev_feature
,
img_metas
)
# List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return
occ_list
def
simple_test_occ
(
self
,
img_feats
,
img_metas
=
None
):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
img_metas:
Returns:
occ_preds: List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
outs
=
self
.
occ_head
(
img_feats
)
# occ_preds = self.occ_head.get_occ(outs, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
occ_preds
=
self
.
occ_head
.
get_occ_gpu
(
outs
,
img_metas
)
# List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return
occ_preds
def
forward_dummy
(
self
,
points
=
None
,
img_metas
=
None
,
img_inputs
=
None
,
**
kwargs
):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
pts_feats
,
depth
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
occ_bev_feature
=
img_feats
[
0
]
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
outs
=
self
.
occ_head
(
occ_bev_feature
)
return
outs
@
DETECTORS
.
register_module
()
class
BEVDepthPano
(
BEVDepthOCC
):
def
__init__
(
self
,
aux_centerness_head
=
None
,
**
kwargs
):
super
(
BEVDepthPano
,
self
).
__init__
(
**
kwargs
)
self
.
aux_centerness_head
=
None
if
aux_centerness_head
:
train_cfg
=
kwargs
[
'train_cfg'
]
test_cfg
=
kwargs
[
'test_cfg'
]
pts_train_cfg
=
train_cfg
.
pts
if
train_cfg
else
None
aux_centerness_head
.
update
(
train_cfg
=
pts_train_cfg
)
pts_test_cfg
=
test_cfg
.
pts
if
test_cfg
else
None
aux_centerness_head
.
update
(
test_cfg
=
pts_test_cfg
)
self
.
aux_centerness_head
=
build_head
(
aux_centerness_head
)
if
'inst_class_ids'
in
kwargs
:
self
.
inst_class_ids
=
kwargs
[
'inst_class_ids'
]
else
:
self
.
inst_class_ids
=
[
2
,
3
,
4
,
5
,
6
,
7
,
9
,
10
]
X1
,
Y1
,
Z1
=
200
,
200
,
16
coords_x
=
torch
.
arange
(
X1
).
float
()
coords_y
=
torch
.
arange
(
Y1
).
float
()
coords_z
=
torch
.
arange
(
Z1
).
float
()
self
.
coords
=
torch
.
stack
(
torch
.
meshgrid
([
coords_x
,
coords_y
,
coords_z
])).
permute
(
1
,
2
,
3
,
0
)
# W, H, D, 3
self
.
st
=
torch
.
tensor
([
grid_config_occ
[
'x'
][
0
],
grid_config_occ
[
'y'
][
0
],
grid_config_occ
[
'z'
][
0
]])
self
.
sx
=
torch
.
tensor
([
grid_config_occ
[
'x'
][
2
],
grid_config_occ
[
'y'
][
2
],
0.4
])
self
.
is_to_d
=
False
def
forward_train
(
self
,
points
=
None
,
img_metas
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
gt_labels
=
None
,
gt_bboxes
=
None
,
img_inputs
=
None
,
proposals
=
None
,
gt_bboxes_ignore
=
None
,
**
kwargs
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
pts_feats
,
depth
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
losses
=
dict
()
gt_depth
=
kwargs
[
'gt_depth'
]
# (B, N_views, img_H, img_W)
loss_depth
=
self
.
img_view_transformer
.
get_depth_loss
(
gt_depth
,
depth
)
losses
[
'loss_depth'
]
=
loss_depth
voxel_semantics
=
kwargs
[
'voxel_semantics'
]
# (B, Dx, Dy, Dz)
mask_camera
=
kwargs
[
'mask_camera'
]
# (B, Dx, Dy, Dz)
occ_bev_feature
=
img_feats
[
0
]
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
loss_occ
=
self
.
forward_occ_train
(
occ_bev_feature
,
voxel_semantics
,
mask_camera
)
losses
.
update
(
loss_occ
)
losses_aux_centerness
=
self
.
forward_aux_centerness_train
([
occ_bev_feature
],
gt_bboxes_3d
,
gt_labels_3d
,
img_metas
,
gt_bboxes_ignore
)
losses
.
update
(
losses_aux_centerness
)
return
losses
def
forward_aux_centerness_train
(
self
,
pts_feats
,
gt_bboxes_3d
,
gt_labels_3d
,
img_metas
,
gt_bboxes_ignore
=
None
):
outs
=
self
.
aux_centerness_head
(
pts_feats
)
loss_inputs
=
[
gt_bboxes_3d
,
gt_labels_3d
,
outs
]
losses
=
self
.
aux_centerness_head
.
loss
(
*
loss_inputs
)
return
losses
def
simple_test_aux_centerness
(
self
,
x
,
img_metas
,
rescale
=
False
,
**
kwargs
):
"""Test function of point cloud branch."""
# outs = self.aux_centerness_head(x)
tx
=
self
.
aux_centerness_head
.
shared_conv
(
x
[
0
])
# (B, C'=share_conv_channel, H, W)
outs_inst_center_reg
=
self
.
aux_centerness_head
.
task_heads
[
0
].
reg
(
tx
)
outs_inst_center_height
=
self
.
aux_centerness_head
.
task_heads
[
0
].
height
(
tx
)
outs_inst_center_heatmap
=
self
.
aux_centerness_head
.
task_heads
[
0
].
heatmap
(
tx
)
outs
=
([{
"reg"
:
outs_inst_center_reg
,
"height"
:
outs_inst_center_height
,
"heatmap"
:
outs_inst_center_heatmap
,
}],)
# # bbox_list = self.aux_centerness_head.get_bboxes(
# # outs, img_metas, rescale=rescale)
# # bbox_results = [
# # bbox3d2result(bboxes, scores, labels)
# # for bboxes, scores, labels in bbox_list
# # ]
ins_cen_list
=
self
.
aux_centerness_head
.
get_centers
(
outs
,
img_metas
,
rescale
=
rescale
)
# return bbox_results, ins_cen_list
return
None
,
ins_cen_list
def
simple_test
(
self
,
points
,
img_metas
,
img
=
None
,
rescale
=
False
,
**
kwargs
):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
result_list
=
[
dict
()
for
_
in
range
(
len
(
img_metas
))]
img_feats
,
_
,
_
=
self
.
extract_feat
(
points
,
img_inputs
=
img
,
img_metas
=
img_metas
,
**
kwargs
)
occ_bev_feature
=
img_feats
[
0
]
w_pano
=
kwargs
[
'w_pano'
]
if
'w_pano'
in
kwargs
else
True
if
w_pano
==
True
:
bbox_pts
,
ins_cen_list
=
self
.
simple_test_aux_centerness
([
occ_bev_feature
],
img_metas
,
rescale
=
rescale
,
**
kwargs
)
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
occ_list
=
self
.
simple_test_occ
(
occ_bev_feature
,
img_metas
)
# List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
for
result_dict
,
occ_pred
in
zip
(
result_list
,
occ_list
):
result_dict
[
'pred_occ'
]
=
occ_pred
w_panoproc
=
kwargs
[
'w_panoproc'
]
if
'w_panoproc'
in
kwargs
else
True
# 37.53 ms
if
w_panoproc
==
True
:
# # for pano
inst_xyz
=
ins_cen_list
[
0
][
0
]
if
self
.
is_to_d
==
False
:
self
.
st
=
self
.
st
.
to
(
inst_xyz
)
self
.
sx
=
self
.
sx
.
to
(
inst_xyz
)
self
.
coords
=
self
.
coords
.
to
(
inst_xyz
)
self
.
is_to_d
=
True
inst_xyz
=
((
inst_xyz
-
self
.
st
)
/
self
.
sx
).
int
()
inst_cls
=
ins_cen_list
[
2
][
0
].
int
()
inst_num
=
18
# 37.62 ms
# inst_occ = torch.tensor(occ_pred).to(inst_cls)
# inst_occ = occ_pred.clone().detach()
inst_occ
=
occ_pred
.
clone
().
detach
()
# 37.61 ms
if
len
(
inst_cls
)
>
0
:
cls_sort
,
indices
=
inst_cls
.
sort
()
l2s
=
{}
if
len
(
inst_cls
)
==
1
:
l2s
[
cls_sort
[
0
].
item
()]
=
0
l2s
[
cls_sort
[
0
].
item
()]
=
0
# # tind_list = cls_sort[1:] - cls_sort[:-1]!=0
# # for tind in range(len(tind_list)):
# # if tind_list[tind] == True:
# # l2s[cls_sort[1+tind].item()] = tind + 1
tind_list
=
(
cls_sort
[
1
:]
-
cls_sort
[:
-
1
])
!=
0
if
tind_list
.
__len__
()
>
0
:
for
tind
in
torch
.
range
(
0
,
len
(
tind_list
)
-
1
)[
tind_list
]:
l2s
[
cls_sort
[
1
+
int
(
tind
.
item
())].
item
()]
=
int
(
tind
.
item
())
+
1
is_cuda
=
True
# is_cuda = False
if
is_cuda
==
True
:
inst_id_list
=
indices
+
inst_num
l2s_key
=
indices
.
new_tensor
([
detind2occind
[
k
]
for
k
in
l2s
.
keys
()]).
to
(
torch
.
int
)
inst_occ
=
nearest_assign
(
occ_pred
.
to
(
torch
.
int
),
l2s_key
.
to
(
torch
.
int
),
indices
.
new_tensor
(
occind2detind_cuda
).
to
(
torch
.
int
),
inst_cls
.
to
(
torch
.
int
),
inst_xyz
.
to
(
torch
.
int
),
inst_id_list
.
to
(
torch
.
int
)
)
else
:
for
cls_label_num_in_occ
in
self
.
inst_class_ids
:
mask
=
occ_pred
==
cls_label_num_in_occ
if
mask
.
sum
()
==
0
:
continue
else
:
cls_label_num_in_inst
=
occind2detind
[
cls_label_num_in_occ
]
select_mask
=
inst_cls
==
cls_label_num_in_inst
if
sum
(
select_mask
)
>
0
:
indices
=
self
.
coords
[
mask
]
inst_index_same_cls
=
inst_xyz
[
select_mask
]
select_ind
=
((
indices
[:,
None
,:]
-
inst_index_same_cls
[
None
,:,:])
**
2
).
sum
(
-
1
).
argmin
(
axis
=
1
).
int
()
inst_occ
[
mask
]
=
select_ind
+
inst_num
+
l2s
[
cls_label_num_in_inst
]
result_list
[
0
][
'pano_inst'
]
=
inst_occ
#.cpu().numpy()
return
result_list
@
DETECTORS
.
register_module
()
class
BEVDepth4DOCC
(
BEVDepth4D
):
def
__init__
(
self
,
occ_head
=
None
,
upsample
=
False
,
**
kwargs
):
super
(
BEVDepth4DOCC
,
self
).
__init__
(
**
kwargs
)
self
.
occ_head
=
build_head
(
occ_head
)
self
.
pts_bbox_head
=
None
self
.
upsample
=
upsample
def
forward_train
(
self
,
points
=
None
,
img_metas
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
gt_labels
=
None
,
gt_bboxes
=
None
,
img_inputs
=
None
,
proposals
=
None
,
gt_bboxes_ignore
=
None
,
**
kwargs
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
pts_feats
,
depth
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
gt_depth
=
kwargs
[
'gt_depth'
]
# (B, N_views, img_H, img_W)
losses
=
dict
()
loss_depth
=
self
.
img_view_transformer
.
get_depth_loss
(
gt_depth
,
depth
)
losses
[
'loss_depth'
]
=
loss_depth
voxel_semantics
=
kwargs
[
'voxel_semantics'
]
# (B, Dx, Dy, Dz)
mask_camera
=
kwargs
[
'mask_camera'
]
# (B, Dx, Dy, Dz)
loss_occ
=
self
.
forward_occ_train
(
img_feats
[
0
],
voxel_semantics
,
mask_camera
)
losses
.
update
(
loss_occ
)
return
losses
def
forward_occ_train
(
self
,
img_feats
,
voxel_semantics
,
mask_camera
):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
outs
=
self
.
occ_head
(
img_feats
)
assert
voxel_semantics
.
min
()
>=
0
and
voxel_semantics
.
max
()
<=
17
loss_occ
=
self
.
occ_head
.
loss
(
outs
,
# (B, Dx, Dy, Dz, n_cls)
voxel_semantics
,
# (B, Dx, Dy, Dz)
mask_camera
,
# (B, Dx, Dy, Dz)
)
return
loss_occ
def
simple_test
(
self
,
points
,
img_metas
,
img
=
None
,
rescale
=
False
,
**
kwargs
):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
_
,
_
=
self
.
extract_feat
(
points
,
img_inputs
=
img
,
img_metas
=
img_metas
,
**
kwargs
)
occ_list
=
self
.
simple_test_occ
(
img_feats
[
0
],
img_metas
)
# List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return
occ_list
def
simple_test_occ
(
self
,
img_feats
,
img_metas
=
None
):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
img_metas:
Returns:
occ_preds: List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
outs
=
self
.
occ_head
(
img_feats
)
# occ_preds = self.occ_head.get_occ(outs, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
occ_preds
=
self
.
occ_head
.
get_occ_gpu
(
outs
,
img_metas
)
# List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return
occ_preds
def
forward_dummy
(
self
,
points
=
None
,
img_metas
=
None
,
img_inputs
=
None
,
**
kwargs
):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
pts_feats
,
depth
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
occ_bev_feature
=
img_feats
[
0
]
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
outs
=
self
.
occ_head
(
occ_bev_feature
)
return
outs
@
DETECTORS
.
register_module
()
class
BEVDepth4DPano
(
BEVDepth4DOCC
):
def
__init__
(
self
,
aux_centerness_head
=
None
,
**
kwargs
):
super
(
BEVDepth4DPano
,
self
).
__init__
(
**
kwargs
)
self
.
aux_centerness_head
=
None
if
aux_centerness_head
:
train_cfg
=
kwargs
[
'train_cfg'
]
test_cfg
=
kwargs
[
'test_cfg'
]
pts_train_cfg
=
train_cfg
.
pts
if
train_cfg
else
None
aux_centerness_head
.
update
(
train_cfg
=
pts_train_cfg
)
pts_test_cfg
=
test_cfg
.
pts
if
test_cfg
else
None
aux_centerness_head
.
update
(
test_cfg
=
pts_test_cfg
)
self
.
aux_centerness_head
=
build_head
(
aux_centerness_head
)
if
'inst_class_ids'
in
kwargs
:
self
.
inst_class_ids
=
kwargs
[
'inst_class_ids'
]
else
:
self
.
inst_class_ids
=
[
2
,
3
,
4
,
5
,
6
,
7
,
9
,
10
]
X1
,
Y1
,
Z1
=
200
,
200
,
16
coords_x
=
torch
.
arange
(
X1
).
float
()
coords_y
=
torch
.
arange
(
Y1
).
float
()
coords_z
=
torch
.
arange
(
Z1
).
float
()
self
.
coords
=
torch
.
stack
(
torch
.
meshgrid
([
coords_x
,
coords_y
,
coords_z
])).
permute
(
1
,
2
,
3
,
0
)
# W, H, D, 3
self
.
st
=
torch
.
tensor
([
grid_config_occ
[
'x'
][
0
],
grid_config_occ
[
'y'
][
0
],
grid_config_occ
[
'z'
][
0
]])
self
.
sx
=
torch
.
tensor
([
grid_config_occ
[
'x'
][
2
],
grid_config_occ
[
'y'
][
2
],
0.4
])
self
.
is_to_d
=
False
def
forward_train
(
self
,
points
=
None
,
img_metas
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
gt_labels
=
None
,
gt_bboxes
=
None
,
img_inputs
=
None
,
proposals
=
None
,
gt_bboxes_ignore
=
None
,
**
kwargs
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
pts_feats
,
depth
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
gt_depth
=
kwargs
[
'gt_depth'
]
# (B, N_views, img_H, img_W)
losses
=
dict
()
loss_depth
=
self
.
img_view_transformer
.
get_depth_loss
(
gt_depth
,
depth
)
losses
[
'loss_depth'
]
=
loss_depth
voxel_semantics
=
kwargs
[
'voxel_semantics'
]
# (B, Dx, Dy, Dz)
mask_camera
=
kwargs
[
'mask_camera'
]
# (B, Dx, Dy, Dz)
loss_occ
=
self
.
forward_occ_train
(
img_feats
[
0
],
voxel_semantics
,
mask_camera
)
losses
.
update
(
loss_occ
)
losses_aux_centerness
=
self
.
forward_aux_centerness_train
([
img_feats
[
0
]],
gt_bboxes_3d
,
gt_labels_3d
,
img_metas
,
gt_bboxes_ignore
)
losses
.
update
(
losses_aux_centerness
)
return
losses
def
forward_aux_centerness_train
(
self
,
pts_feats
,
gt_bboxes_3d
,
gt_labels_3d
,
img_metas
,
gt_bboxes_ignore
=
None
):
outs
=
self
.
aux_centerness_head
(
pts_feats
)
loss_inputs
=
[
gt_bboxes_3d
,
gt_labels_3d
,
outs
]
losses
=
self
.
aux_centerness_head
.
loss
(
*
loss_inputs
)
return
losses
def
simple_test_aux_centerness
(
self
,
x
,
img_metas
,
rescale
=
False
,
**
kwargs
):
"""Test function of point cloud branch."""
outs
=
self
.
aux_centerness_head
(
x
)
bbox_list
=
self
.
aux_centerness_head
.
get_bboxes
(
outs
,
img_metas
,
rescale
=
rescale
)
bbox_results
=
[
bbox3d2result
(
bboxes
,
scores
,
labels
)
for
bboxes
,
scores
,
labels
in
bbox_list
]
ins_cen_list
=
self
.
aux_centerness_head
.
get_centers
(
outs
,
img_metas
,
rescale
=
rescale
)
return
bbox_results
,
ins_cen_list
def
simple_test
(
self
,
points
,
img_metas
,
img
=
None
,
rescale
=
False
,
**
kwargs
):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
result_list
=
[
dict
()
for
_
in
range
(
len
(
img_metas
))]
img_feats
,
_
,
_
=
self
.
extract_feat
(
points
,
img_inputs
=
img
,
img_metas
=
img_metas
,
**
kwargs
)
occ_bev_feature
=
img_feats
[
0
]
w_pano
=
kwargs
[
'w_pano'
]
if
'w_pano'
in
kwargs
else
True
if
w_pano
==
True
:
bbox_pts
,
ins_cen_list
=
self
.
simple_test_aux_centerness
([
occ_bev_feature
],
img_metas
,
rescale
=
rescale
,
**
kwargs
)
occ_list
=
self
.
simple_test_occ
(
occ_bev_feature
,
img_metas
)
# List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
for
result_dict
,
occ_pred
in
zip
(
result_list
,
occ_list
):
result_dict
[
'pred_occ'
]
=
occ_pred
w_panoproc
=
kwargs
[
'w_panoproc'
]
if
'w_panoproc'
in
kwargs
else
True
if
w_panoproc
==
True
:
# # for pano
inst_xyz
=
ins_cen_list
[
0
][
0
]
if
self
.
is_to_d
==
False
:
self
.
st
=
self
.
st
.
to
(
inst_xyz
)
self
.
sx
=
self
.
sx
.
to
(
inst_xyz
)
self
.
coords
=
self
.
coords
.
to
(
inst_xyz
)
self
.
is_to_d
=
True
inst_xyz
=
((
inst_xyz
-
self
.
st
)
/
self
.
sx
).
int
()
inst_cls
=
ins_cen_list
[
2
][
0
].
int
()
inst_num
=
18
# 37.62 ms
# inst_occ = torch.tensor(occ_pred).to(inst_cls)
# inst_occ = occ_pred.clone().detach()
inst_occ
=
occ_pred
.
clone
().
detach
()
# 37.61 ms
if
len
(
inst_cls
)
>
0
:
cls_sort
,
indices
=
inst_cls
.
sort
()
l2s
=
{}
if
len
(
inst_cls
)
==
1
:
l2s
[
cls_sort
[
0
].
item
()]
=
0
l2s
[
cls_sort
[
0
].
item
()]
=
0
# # tind_list = cls_sort[1:] - cls_sort[:-1]!=0
# # for tind in range(len(tind_list)):
# # if tind_list[tind] == True:
# # l2s[cls_sort[1+tind].item()] = tind + 1
tind_list
=
(
cls_sort
[
1
:]
-
cls_sort
[:
-
1
])
!=
0
if
tind_list
.
__len__
()
>
0
:
for
tind
in
torch
.
range
(
0
,
len
(
tind_list
)
-
1
)[
tind_list
]:
l2s
[
cls_sort
[
1
+
int
(
tind
.
item
())].
item
()]
=
int
(
tind
.
item
())
+
1
is_cuda
=
True
# is_cuda = False
if
is_cuda
==
True
:
inst_id_list
=
indices
+
inst_num
l2s_key
=
indices
.
new_tensor
([
detind2occind
[
k
]
for
k
in
l2s
.
keys
()]).
to
(
torch
.
int
)
inst_occ
=
nearest_assign
(
occ_pred
.
to
(
torch
.
int
),
l2s_key
.
to
(
torch
.
int
),
indices
.
new_tensor
(
occind2detind_cuda
).
to
(
torch
.
int
),
inst_cls
.
to
(
torch
.
int
),
inst_xyz
.
to
(
torch
.
int
),
inst_id_list
.
to
(
torch
.
int
)
)
else
:
for
cls_label_num_in_occ
in
self
.
inst_class_ids
:
mask
=
occ_pred
==
cls_label_num_in_occ
if
mask
.
sum
()
==
0
:
continue
else
:
cls_label_num_in_inst
=
occind2detind
[
cls_label_num_in_occ
]
select_mask
=
inst_cls
==
cls_label_num_in_inst
if
sum
(
select_mask
)
>
0
:
indices
=
self
.
coords
[
mask
]
inst_index_same_cls
=
inst_xyz
[
select_mask
]
select_ind
=
((
indices
[:,
None
,:]
-
inst_index_same_cls
[
None
,:,:])
**
2
).
sum
(
-
1
).
argmin
(
axis
=
1
).
int
()
inst_occ
[
mask
]
=
select_ind
+
inst_num
+
l2s
[
cls_label_num_in_inst
]
result_list
[
0
][
'pano_inst'
]
=
inst_occ
#.cpu().numpy()
return
result_list
@
DETECTORS
.
register_module
()
class
BEVStereo4DOCC
(
BEVStereo4D
):
def
__init__
(
self
,
occ_head
=
None
,
upsample
=
False
,
**
kwargs
):
super
(
BEVStereo4DOCC
,
self
).
__init__
(
**
kwargs
)
self
.
occ_head
=
build_head
(
occ_head
)
self
.
pts_bbox_head
=
None
self
.
upsample
=
upsample
def
forward_train
(
self
,
points
=
None
,
img_metas
=
None
,
gt_bboxes_3d
=
None
,
gt_labels_3d
=
None
,
gt_labels
=
None
,
gt_bboxes
=
None
,
img_inputs
=
None
,
proposals
=
None
,
gt_bboxes_ignore
=
None
,
**
kwargs
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
pts_feats
,
depth
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
gt_depth
=
kwargs
[
'gt_depth'
]
# (B, N_views, img_H, img_W)
losses
=
dict
()
loss_depth
=
self
.
img_view_transformer
.
get_depth_loss
(
gt_depth
,
depth
)
losses
[
'loss_depth'
]
=
loss_depth
voxel_semantics
=
kwargs
[
'voxel_semantics'
]
# (B, Dx, Dy, Dz)
mask_camera
=
kwargs
[
'mask_camera'
]
# (B, Dx, Dy, Dz)
loss_occ
=
self
.
forward_occ_train
(
img_feats
[
0
],
voxel_semantics
,
mask_camera
)
losses
.
update
(
loss_occ
)
return
losses
def
forward_occ_train
(
self
,
img_feats
,
voxel_semantics
,
mask_camera
):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
outs
=
self
.
occ_head
(
img_feats
)
assert
voxel_semantics
.
min
()
>=
0
and
voxel_semantics
.
max
()
<=
17
loss_occ
=
self
.
occ_head
.
loss
(
outs
,
# (B, Dx, Dy, Dz, n_cls)
voxel_semantics
,
# (B, Dx, Dy, Dz)
mask_camera
,
# (B, Dx, Dy, Dz)
)
return
loss_occ
def
simple_test
(
self
,
points
,
img_metas
,
img
=
None
,
rescale
=
False
,
**
kwargs
):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
_
,
_
=
self
.
extract_feat
(
points
,
img_inputs
=
img
,
img_metas
=
img_metas
,
**
kwargs
)
occ_list
=
self
.
simple_test_occ
(
img_feats
[
0
],
img_metas
)
# List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return
occ_list
def
simple_test_occ
(
self
,
img_feats
,
img_metas
=
None
):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
img_metas:
Returns:
occ_preds: List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
outs
=
self
.
occ_head
(
img_feats
)
# occ_preds = self.occ_head.get_occ(outs, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
occ_preds
=
self
.
occ_head
.
get_occ_gpu
(
outs
,
img_metas
)
# List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return
occ_preds
def
forward_dummy
(
self
,
points
=
None
,
img_metas
=
None
,
img_inputs
=
None
,
**
kwargs
):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats
,
pts_feats
,
depth
=
self
.
extract_feat
(
points
,
img_inputs
=
img_inputs
,
img_metas
=
img_metas
,
**
kwargs
)
occ_bev_feature
=
img_feats
[
0
]
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
outs
=
self
.
occ_head
(
occ_bev_feature
)
return
outs
@
DETECTORS
.
register_module
()
class
BEVDetOCCTRT
(
BEVDetOCC
):
def
__init__
(
self
,
wocc
=
True
,
wdet3d
=
True
,
uni_train
=
True
,
**
kwargs
):
super
(
BEVDetOCCTRT
,
self
).
__init__
(
**
kwargs
)
self
.
wocc
=
wocc
self
.
wdet3d
=
wdet3d
self
.
uni_train
=
uni_train
def
result_serialize
(
self
,
outs_det3d
=
None
,
outs_occ
=
None
):
outs_
=
[]
if
outs_det3d
is
not
None
:
for
out
in
outs_det3d
:
for
key
in
[
'reg'
,
'height'
,
'dim'
,
'rot'
,
'vel'
,
'heatmap'
]:
outs_
.
append
(
out
[
0
][
key
])
if
outs_occ
is
not
None
:
outs_
.
append
(
outs_occ
)
return
outs_
def
result_deserialize
(
self
,
outs
):
outs_
=
[]
keys
=
[
'reg'
,
'height'
,
'dim'
,
'rot'
,
'vel'
,
'heatmap'
]
for
head_id
in
range
(
len
(
outs
)
//
6
):
outs_head
=
[
dict
()]
for
kid
,
key
in
enumerate
(
keys
):
outs_head
[
0
][
key
]
=
outs
[
head_id
*
6
+
kid
]
outs_
.
append
(
outs_head
)
return
outs_
def
forward_part1
(
self
,
img
,
):
x
=
self
.
img_backbone
(
img
)
x
=
self
.
img_neck
(
x
)
x
=
self
.
img_view_transformer
.
depth_net
(
x
[
0
])
depth
=
x
[:,
:
self
.
img_view_transformer
.
D
].
softmax
(
dim
=
1
)
tran_feat
=
x
[:,
self
.
img_view_transformer
.
D
:(
self
.
img_view_transformer
.
D
+
self
.
img_view_transformer
.
out_channels
)]
tran_feat
=
tran_feat
.
permute
(
0
,
2
,
3
,
1
)
# depth = depth.reshape(-1)
# tran_feat = tran_feat.flatten(0,2)
return
tran_feat
.
flatten
(
0
,
2
),
depth
.
reshape
(
-
1
)
def
forward_part2
(
self
,
tran_feat
,
depth
,
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
):
tran_feat
=
tran_feat
.
reshape
(
6
,
16
,
44
,
64
)
depth
=
depth
.
reshape
(
6
,
16
,
44
,
44
)
x
=
TRTBEVPoolv2
.
apply
(
depth
.
contiguous
(),
tran_feat
.
contiguous
(),
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
int
(
self
.
img_view_transformer
.
grid_size
[
0
].
item
()),
int
(
self
.
img_view_transformer
.
grid_size
[
1
].
item
()),
int
(
self
.
img_view_transformer
.
grid_size
[
2
].
item
())
)
# -> [1, 64, 200, 200]
return
x
.
reshape
(
-
1
)
def
forward_part3
(
self
,
x
):
x
=
x
.
reshape
(
1
,
200
,
200
,
64
)
x
=
x
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
# return [x, 2*x, 3*x, 4*x, 5*x, 6*x, 7*x]
bev_feature
=
self
.
img_bev_encoder_backbone
(
x
)
occ_bev_feature
=
self
.
img_bev_encoder_neck
(
bev_feature
)
outs_occ
=
None
if
self
.
wocc
==
True
:
if
self
.
uni_train
==
True
:
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
outs_occ
=
self
.
occ_head
(
occ_bev_feature
)
outs_det3d
=
None
if
self
.
wdet3d
==
True
:
outs_det3d
=
self
.
pts_bbox_head
([
occ_bev_feature
])
outs
=
self
.
result_serialize
(
outs_det3d
,
outs_occ
)
return
outs
def
forward_ori
(
self
,
img
,
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
):
x
=
self
.
img_backbone
(
img
)
x
=
self
.
img_neck
(
x
)
x
=
self
.
img_view_transformer
.
depth_net
(
x
[
0
])
depth
=
x
[:,
:
self
.
img_view_transformer
.
D
].
softmax
(
dim
=
1
)
tran_feat
=
x
[:,
self
.
img_view_transformer
.
D
:(
self
.
img_view_transformer
.
D
+
self
.
img_view_transformer
.
out_channels
)]
tran_feat
=
tran_feat
.
permute
(
0
,
2
,
3
,
1
)
x
=
TRTBEVPoolv2
.
apply
(
depth
.
contiguous
(),
tran_feat
.
contiguous
(),
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
int
(
self
.
img_view_transformer
.
grid_size
[
0
].
item
()),
int
(
self
.
img_view_transformer
.
grid_size
[
1
].
item
()),
int
(
self
.
img_view_transformer
.
grid_size
[
2
].
item
())
)
x
=
x
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
# return [x, 2*x, 3*x, 4*x, 5*x, 6*x, 7*x]
bev_feature
=
self
.
img_bev_encoder_backbone
(
x
)
occ_bev_feature
=
self
.
img_bev_encoder_neck
(
bev_feature
)
outs_occ
=
None
if
self
.
wocc
==
True
:
if
self
.
uni_train
==
True
:
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
outs_occ
=
self
.
occ_head
(
occ_bev_feature
)
outs_det3d
=
None
if
self
.
wdet3d
==
True
:
outs_det3d
=
self
.
pts_bbox_head
([
occ_bev_feature
])
outs
=
self
.
result_serialize
(
outs_det3d
,
outs_occ
)
return
outs
def
forward_with_argmax
(
self
,
img
,
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
):
outs
=
self
.
forward_ori
(
img
,
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
)
pred_occ_label
=
outs
[
0
].
argmax
(
-
1
)
return
pred_occ_label
def
get_bev_pool_input
(
self
,
input
):
input
=
self
.
prepare_inputs
(
input
)
coor
=
self
.
img_view_transformer
.
get_lidar_coor
(
*
input
[
1
:
7
])
return
self
.
img_view_transformer
.
voxel_pooling_prepare_v2
(
coor
)
@
DETECTORS
.
register_module
()
class
BEVDepthOCCTRT
(
BEVDetOCC
):
def
__init__
(
self
,
wocc
=
True
,
wdet3d
=
True
,
uni_train
=
True
,
**
kwargs
):
super
(
BEVDepthOCCTRT
,
self
).
__init__
(
**
kwargs
)
self
.
wocc
=
wocc
self
.
wdet3d
=
wdet3d
self
.
uni_train
=
uni_train
def
result_serialize
(
self
,
outs_det3d
=
None
,
outs_occ
=
None
):
outs_
=
[]
if
outs_det3d
is
not
None
:
for
out
in
outs_det3d
:
for
key
in
[
'reg'
,
'height'
,
'dim'
,
'rot'
,
'vel'
,
'heatmap'
]:
outs_
.
append
(
out
[
0
][
key
])
if
outs_occ
is
not
None
:
outs_
.
append
(
outs_occ
)
return
outs_
def
result_deserialize
(
self
,
outs
):
outs_
=
[]
keys
=
[
'reg'
,
'height'
,
'dim'
,
'rot'
,
'vel'
,
'heatmap'
]
for
head_id
in
range
(
len
(
outs
)
//
6
):
outs_head
=
[
dict
()]
for
kid
,
key
in
enumerate
(
keys
):
outs_head
[
0
][
key
]
=
outs
[
head_id
*
6
+
kid
]
outs_
.
append
(
outs_head
)
return
outs_
def
forward_ori
(
self
,
img
,
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
mlp_input
,
):
x
=
self
.
img_backbone
(
img
)
x
=
self
.
img_neck
(
x
)
x
=
self
.
img_view_transformer
.
depth_net
(
x
[
0
],
mlp_input
)
depth
=
x
[:,
:
self
.
img_view_transformer
.
D
].
softmax
(
dim
=
1
)
tran_feat
=
x
[:,
self
.
img_view_transformer
.
D
:(
self
.
img_view_transformer
.
D
+
self
.
img_view_transformer
.
out_channels
)]
tran_feat
=
tran_feat
.
permute
(
0
,
2
,
3
,
1
)
x
=
TRTBEVPoolv2
.
apply
(
depth
.
contiguous
(),
tran_feat
.
contiguous
(),
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
int
(
self
.
img_view_transformer
.
grid_size
[
0
].
item
()),
int
(
self
.
img_view_transformer
.
grid_size
[
1
].
item
()),
int
(
self
.
img_view_transformer
.
grid_size
[
2
].
item
())
)
x
=
x
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
# return [x, 2*x, 3*x, 4*x, 5*x, 6*x, 7*x]
bev_feature
=
self
.
img_bev_encoder_backbone
(
x
)
occ_bev_feature
=
self
.
img_bev_encoder_neck
(
bev_feature
)
outs_occ
=
None
if
self
.
wocc
==
True
:
if
self
.
uni_train
==
True
:
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
outs_occ
=
self
.
occ_head
(
occ_bev_feature
)
outs_det3d
=
None
if
self
.
wdet3d
==
True
:
outs_det3d
=
self
.
pts_bbox_head
([
occ_bev_feature
])
outs
=
self
.
result_serialize
(
outs_det3d
,
outs_occ
)
return
outs
def
forward_with_argmax
(
self
,
img
,
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
mlp_input
,
):
outs
=
self
.
forward_ori
(
img
,
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
mlp_input
,
)
pred_occ_label
=
outs
[
0
].
argmax
(
-
1
)
return
pred_occ_label
def
get_bev_pool_input
(
self
,
input
):
input
=
self
.
prepare_inputs
(
input
)
coor
=
self
.
img_view_transformer
.
get_lidar_coor
(
*
input
[
1
:
7
])
mlp_input
=
self
.
img_view_transformer
.
get_mlp_input
(
*
input
[
1
:
7
])
# sensor2keyegos, ego2globals, intrins, post_rots, post_trans, bda) # (B, N_views, 27)
return
self
.
img_view_transformer
.
voxel_pooling_prepare_v2
(
coor
),
mlp_input
@
DETECTORS
.
register_module
()
class
BEVDepthPanoTRT
(
BEVDepthPano
):
def
__init__
(
self
,
wocc
=
True
,
wdet3d
=
True
,
uni_train
=
True
,
**
kwargs
):
super
(
BEVDepthPanoTRT
,
self
).
__init__
(
**
kwargs
)
self
.
wocc
=
wocc
self
.
wdet3d
=
wdet3d
self
.
uni_train
=
uni_train
def
result_serialize
(
self
,
outs_det3d
=
None
,
outs_occ
=
None
):
outs_
=
[]
if
outs_det3d
is
not
None
:
for
out
in
outs_det3d
:
for
key
in
[
'reg'
,
'height'
,
'dim'
,
'rot'
,
'vel'
,
'heatmap'
]:
outs_
.
append
(
out
[
0
][
key
])
if
outs_occ
is
not
None
:
outs_
.
append
(
outs_occ
)
return
outs_
def
result_deserialize
(
self
,
outs
):
outs_
=
[]
keys
=
[
'reg'
,
'height'
,
'dim'
,
'rot'
,
'vel'
,
'heatmap'
]
for
head_id
in
range
(
len
(
outs
)
//
6
):
outs_head
=
[
dict
()]
for
kid
,
key
in
enumerate
(
keys
):
outs_head
[
0
][
key
]
=
outs
[
head_id
*
6
+
kid
]
outs_
.
append
(
outs_head
)
return
outs_
def
forward_part1
(
self
,
img
,
mlp_input
,
):
x
=
self
.
img_backbone
(
img
)
x
=
self
.
img_neck
(
x
)
x
=
self
.
img_view_transformer
.
depth_net
(
x
[
0
],
mlp_input
)
depth
=
x
[:,
:
self
.
img_view_transformer
.
D
].
softmax
(
dim
=
1
)
tran_feat
=
x
[:,
self
.
img_view_transformer
.
D
:(
self
.
img_view_transformer
.
D
+
self
.
img_view_transformer
.
out_channels
)]
tran_feat
=
tran_feat
.
permute
(
0
,
2
,
3
,
1
)
# depth = depth.reshape(-1)
# tran_feat = tran_feat.flatten(0,2)
return
tran_feat
.
flatten
(
0
,
2
),
depth
.
reshape
(
-
1
)
def
forward_part2
(
self
,
tran_feat
,
depth
,
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
):
tran_feat
=
tran_feat
.
reshape
(
6
,
16
,
44
,
64
)
depth
=
depth
.
reshape
(
6
,
16
,
44
,
44
)
x
=
TRTBEVPoolv2
.
apply
(
depth
.
contiguous
(),
tran_feat
.
contiguous
(),
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
int
(
self
.
img_view_transformer
.
grid_size
[
0
].
item
()),
int
(
self
.
img_view_transformer
.
grid_size
[
1
].
item
()),
int
(
self
.
img_view_transformer
.
grid_size
[
2
].
item
())
)
# -> [1, 64, 200, 200]
return
x
.
reshape
(
-
1
)
def
forward_part3
(
self
,
x
):
x
=
x
.
reshape
(
1
,
200
,
200
,
64
)
x
=
x
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
# return [x, 2*x, 3*x, 4*x, 5*x, 6*x, 7*x]
bev_feature
=
self
.
img_bev_encoder_backbone
(
x
)
occ_bev_feature
=
self
.
img_bev_encoder_neck
(
bev_feature
)
outs_occ
=
None
if
self
.
wocc
==
True
:
if
self
.
uni_train
==
True
:
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
outs_occ
=
self
.
occ_head
(
occ_bev_feature
)
outs_det3d
=
None
if
self
.
wdet3d
==
True
:
outs_det3d
=
self
.
pts_bbox_head
([
occ_bev_feature
])
outs
=
self
.
result_serialize
(
outs_det3d
,
outs_occ
)
# outs_inst_center = self.aux_centerness_head([occ_bev_feature])
x
=
self
.
aux_centerness_head
.
shared_conv
(
occ_bev_feature
)
# (B, C'=share_conv_channel, H, W)
# 运行不同task_head,
outs_inst_center_reg
=
self
.
aux_centerness_head
.
task_heads
[
0
].
reg
(
x
)
outs
.
append
(
outs_inst_center_reg
)
outs_inst_center_height
=
self
.
aux_centerness_head
.
task_heads
[
0
].
height
(
x
)
outs
.
append
(
outs_inst_center_height
)
outs_inst_center_heatmap
=
self
.
aux_centerness_head
.
task_heads
[
0
].
heatmap
(
x
)
outs
.
append
(
outs_inst_center_heatmap
)
def
forward_ori
(
self
,
img
,
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
mlp_input
,
):
x
=
self
.
img_backbone
(
img
)
x
=
self
.
img_neck
(
x
)
x
=
self
.
img_view_transformer
.
depth_net
(
x
[
0
],
mlp_input
)
depth
=
x
[:,
:
self
.
img_view_transformer
.
D
].
softmax
(
dim
=
1
)
tran_feat
=
x
[:,
self
.
img_view_transformer
.
D
:(
self
.
img_view_transformer
.
D
+
self
.
img_view_transformer
.
out_channels
)]
tran_feat
=
tran_feat
.
permute
(
0
,
2
,
3
,
1
)
x
=
TRTBEVPoolv2
.
apply
(
depth
.
contiguous
(),
tran_feat
.
contiguous
(),
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
int
(
self
.
img_view_transformer
.
grid_size
[
0
].
item
()),
int
(
self
.
img_view_transformer
.
grid_size
[
1
].
item
()),
int
(
self
.
img_view_transformer
.
grid_size
[
2
].
item
())
)
x
=
x
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
# return [x, 2*x, 3*x, 4*x, 5*x, 6*x, 7*x]
bev_feature
=
self
.
img_bev_encoder_backbone
(
x
)
occ_bev_feature
=
self
.
img_bev_encoder_neck
(
bev_feature
)
outs_occ
=
None
if
self
.
wocc
==
True
:
if
self
.
uni_train
==
True
:
if
self
.
upsample
:
occ_bev_feature
=
F
.
interpolate
(
occ_bev_feature
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
True
)
outs_occ
=
self
.
occ_head
(
occ_bev_feature
)
outs_det3d
=
None
if
self
.
wdet3d
==
True
:
outs_det3d
=
self
.
pts_bbox_head
([
occ_bev_feature
])
outs
=
self
.
result_serialize
(
outs_det3d
,
outs_occ
)
# outs_inst_center = self.aux_centerness_head([occ_bev_feature])
x
=
self
.
aux_centerness_head
.
shared_conv
(
occ_bev_feature
)
# (B, C'=share_conv_channel, H, W)
# 运行不同task_head,
outs_inst_center_reg
=
self
.
aux_centerness_head
.
task_heads
[
0
].
reg
(
x
)
outs
.
append
(
outs_inst_center_reg
)
outs_inst_center_height
=
self
.
aux_centerness_head
.
task_heads
[
0
].
height
(
x
)
outs
.
append
(
outs_inst_center_height
)
outs_inst_center_heatmap
=
self
.
aux_centerness_head
.
task_heads
[
0
].
heatmap
(
x
)
outs
.
append
(
outs_inst_center_heatmap
)
return
outs
def
forward_with_argmax
(
self
,
img
,
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
mlp_input
,
):
outs
=
self
.
forward_ori
(
img
,
ranks_depth
,
ranks_feat
,
ranks_bev
,
interval_starts
,
interval_lengths
,
mlp_input
,
)
pred_occ_label
=
outs
[
0
].
argmax
(
-
1
)
return
pred_occ_label
,
*
outs
[
1
:]
def
get_bev_pool_input
(
self
,
input
):
input
=
self
.
prepare_inputs
(
input
)
coor
=
self
.
img_view_transformer
.
get_lidar_coor
(
*
input
[
1
:
7
])
mlp_input
=
self
.
img_view_transformer
.
get_mlp_input
(
*
input
[
1
:
7
])
# sensor2keyegos, ego2globals, intrins, post_rots, post_trans, bda) # (B, N_views, 27)
return
self
.
img_view_transformer
.
voxel_pooling_prepare_v2
(
coor
),
mlp_input
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/detectors/bevstereo4d.py
0 → 100644
View file @
d2b71343
# Copyright (c) Phigent Robotics. All rights reserved.
import
torch
import
torch.nn.functional
as
F
from
mmcv.runner
import
force_fp32
from
mmdet3d.models
import
DETECTORS
from
mmdet3d.models
import
builder
from
.bevdepth4d
import
BEVDepth4D
from
mmdet.models.backbones.resnet
import
ResNet
@
DETECTORS
.
register_module
()
class
BEVStereo4D
(
BEVDepth4D
):
def
__init__
(
self
,
**
kwargs
):
super
(
BEVStereo4D
,
self
).
__init__
(
**
kwargs
)
self
.
extra_ref_frames
=
1
self
.
temporal_frame
=
self
.
num_frame
self
.
num_frame
+=
self
.
extra_ref_frames
def
extract_stereo_ref_feat
(
self
,
x
):
"""
Args:
x: (B, N_views, 3, H, W)
Returns:
x: (B*N_views, C_stereo, fH_stereo, fW_stereo)
"""
B
,
N
,
C
,
imH
,
imW
=
x
.
shape
x
=
x
.
view
(
B
*
N
,
C
,
imH
,
imW
)
# (B*N_views, 3, H, W)
if
isinstance
(
self
.
img_backbone
,
ResNet
):
if
self
.
img_backbone
.
deep_stem
:
x
=
self
.
img_backbone
.
stem
(
x
)
else
:
x
=
self
.
img_backbone
.
conv1
(
x
)
x
=
self
.
img_backbone
.
norm1
(
x
)
x
=
self
.
img_backbone
.
relu
(
x
)
x
=
self
.
img_backbone
.
maxpool
(
x
)
for
i
,
layer_name
in
enumerate
(
self
.
img_backbone
.
res_layers
):
res_layer
=
getattr
(
self
.
img_backbone
,
layer_name
)
x
=
res_layer
(
x
)
return
x
else
:
x
=
self
.
img_backbone
.
patch_embed
(
x
)
hw_shape
=
(
self
.
img_backbone
.
patch_embed
.
DH
,
self
.
img_backbone
.
patch_embed
.
DW
)
if
self
.
img_backbone
.
use_abs_pos_embed
:
x
=
x
+
self
.
img_backbone
.
absolute_pos_embed
x
=
self
.
img_backbone
.
drop_after_pos
(
x
)
for
i
,
stage
in
enumerate
(
self
.
img_backbone
.
stages
):
x
,
hw_shape
,
out
,
out_hw_shape
=
stage
(
x
,
hw_shape
)
out
=
out
.
view
(
-
1
,
*
out_hw_shape
,
self
.
img_backbone
.
num_features
[
i
])
out
=
out
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
return
out
def
prepare_bev_feat
(
self
,
img
,
sensor2keyego
,
ego2global
,
intrin
,
post_rot
,
post_tran
,
bda
,
mlp_input
,
feat_prev_iv
,
k2s_sensor
,
extra_ref_frame
):
"""
Args:
img: (B, N_views, 3, H, W)
sensor2keyego: (B, N_views, 4, 4)
ego2global: (B, N_views, 4, 4)
intrin: (B, N_views, 3, 3)
post_rot: (B, N_views, 3, 3)
post_tran: (B, N_views, 3)
bda: (B, 3, 3)
mlp_input: (B, N_views, 27)
feat_prev_iv: (B*N_views, C_stereo, fH_stereo, fW_stereo) or None
k2s_sensor: (B, N_views, 4, 4) or None
extra_ref_frame:
Returns:
bev_feat: (B, C, Dy, Dx)
depth: (B*N, D, fH, fW)
stereo_feat: (B*N_views, C_stereo, fH_stereo, fW_stereo)
"""
if
extra_ref_frame
:
stereo_feat
=
self
.
extract_stereo_ref_feat
(
img
)
# (B*N_views, C_stereo, fH_stereo, fW_stereo)
return
None
,
None
,
stereo_feat
# x: (B, N_views, C, fH, fW)
# stereo_feat: (B*N, C_stereo, fH_stereo, fW_stereo)
x
,
stereo_feat
=
self
.
image_encoder
(
img
,
stereo
=
True
)
# 建立cost volume 所需的信息.
metas
=
dict
(
k2s_sensor
=
k2s_sensor
,
# (B, N_views, 4, 4)
intrins
=
intrin
,
# (B, N_views, 3, 3)
post_rots
=
post_rot
,
# (B, N_views, 3, 3)
post_trans
=
post_tran
,
# (B, N_views, 3)
frustum
=
self
.
img_view_transformer
.
cv_frustum
.
to
(
x
),
# (D, fH_stereo, fW_stereo, 3) 3:(u, v, d)
cv_downsample
=
4
,
downsample
=
self
.
img_view_transformer
.
downsample
,
grid_config
=
self
.
img_view_transformer
.
grid_config
,
cv_feat_list
=
[
feat_prev_iv
,
stereo_feat
]
)
# bev_feat: (B, C * Dz(=1), Dy, Dx)
# depth: (B * N, D, fH, fW)
bev_feat
,
depth
=
self
.
img_view_transformer
(
[
x
,
sensor2keyego
,
ego2global
,
intrin
,
post_rot
,
post_tran
,
bda
,
mlp_input
],
metas
)
if
self
.
pre_process
:
bev_feat
=
self
.
pre_process_net
(
bev_feat
)[
0
]
# (B, C, Dy, Dx)
return
bev_feat
,
depth
,
stereo_feat
def
extract_img_feat_sequential
(
self
,
inputs
,
feat_prev
):
"""
Args:
inputs:
curr_img: (1, N_views, 3, H, W)
sensor2keyegos_curr: (N_prev, N_views, 4, 4)
ego2globals_curr: (N_prev, N_views, 4, 4)
intrins: (1, N_views, 3, 3)
sensor2keyegos_prev: (N_prev, N_views, 4, 4)
ego2globals_prev: (N_prev, N_views, 4, 4)
post_rots: (1, N_views, 3, 3)
post_trans: (1, N_views, 3, )
bda_curr: (N_prev, 3, 3)
feat_prev_iv:
curr2adjsensor: (1, N_views, 4, 4)
feat_prev: (N_prev, C, Dy, Dx)
Returns:
"""
imgs
,
sensor2keyegos_curr
,
ego2globals_curr
,
intrins
=
inputs
[:
4
]
sensor2keyegos_prev
,
_
,
post_rots
,
post_trans
,
bda
=
inputs
[
4
:
9
]
feat_prev_iv
,
curr2adjsensor
=
inputs
[
9
:]
bev_feat_list
=
[]
mlp_input
=
self
.
img_view_transformer
.
get_mlp_input
(
sensor2keyegos_curr
[
0
:
1
,
...],
ego2globals_curr
[
0
:
1
,
...],
intrins
,
post_rots
,
post_trans
,
bda
[
0
:
1
,
...])
inputs_curr
=
(
imgs
,
sensor2keyegos_curr
[
0
:
1
,
...],
ego2globals_curr
[
0
:
1
,
...],
intrins
,
post_rots
,
post_trans
,
bda
[
0
:
1
,
...],
mlp_input
,
feat_prev_iv
,
curr2adjsensor
,
False
)
# (1, C, Dx, Dy), (1*N, D, fH, fW)
bev_feat
,
depth
,
_
=
self
.
prepare_bev_feat
(
*
inputs_curr
)
bev_feat_list
.
append
(
bev_feat
)
# align the feat_prev
_
,
C
,
H
,
W
=
feat_prev
.
shape
# feat_prev: (N_prev, C, Dy, Dx)
feat_prev
=
\
self
.
shift_feature
(
feat_prev
,
# (N_prev, C, Dy, Dx)
[
sensor2keyegos_curr
,
# (N_prev, N_views, 4, 4)
sensor2keyegos_prev
],
# (N_prev, N_views, 4, 4)
bda
# (N_prev, 3, 3)
)
bev_feat_list
.
append
(
feat_prev
.
view
(
1
,
(
self
.
num_frame
-
2
)
*
C
,
H
,
W
))
# (1, N_prev*C, Dy, Dx)
bev_feat
=
torch
.
cat
(
bev_feat_list
,
dim
=
1
)
# (1, N_frames*C, Dy, Dx)
x
=
self
.
bev_encoder
(
bev_feat
)
return
[
x
],
depth
def
extract_img_feat
(
self
,
img_inputs
,
img_metas
,
pred_prev
=
False
,
sequential
=
False
,
**
kwargs
):
"""
Args:
img_inputs:
imgs: (B, N, 3, H, W) # N = 6 * (N_history + 1)
sensor2egos: (B, N, 4, 4)
ego2globals: (B, N, 4, 4)
intrins: (B, N, 3, 3)
post_rots: (B, N, 3, 3)
post_trans: (B, N, 3)
bda_rot: (B, 3, 3)
img_metas:
**kwargs:
Returns:
x: [(B, C', H', W'), ]
depth: (B*N_views, D, fH, fW)
"""
if
sequential
:
return
self
.
extract_img_feat_sequential
(
img_inputs
,
kwargs
[
'feat_prev'
])
imgs
,
sensor2keyegos
,
ego2globals
,
intrins
,
post_rots
,
post_trans
,
\
bda
,
curr2adjsensor
=
self
.
prepare_inputs
(
img_inputs
,
stereo
=
True
)
"""Extract features of images."""
bev_feat_list
=
[]
depth_key_frame
=
None
feat_prev_iv
=
None
for
fid
in
range
(
self
.
num_frame
-
1
,
-
1
,
-
1
):
img
,
sensor2keyego
,
ego2global
,
intrin
,
post_rot
,
post_tran
=
\
imgs
[
fid
],
sensor2keyegos
[
fid
],
ego2globals
[
fid
],
intrins
[
fid
],
\
post_rots
[
fid
],
post_trans
[
fid
]
key_frame
=
fid
==
0
extra_ref_frame
=
fid
==
self
.
num_frame
-
self
.
extra_ref_frames
if
key_frame
or
self
.
with_prev
:
if
self
.
align_after_view_transfromation
:
sensor2keyego
,
ego2global
=
sensor2keyegos
[
0
],
ego2globals
[
0
]
mlp_input
=
self
.
img_view_transformer
.
get_mlp_input
(
sensor2keyegos
[
0
],
ego2globals
[
0
],
intrin
,
post_rot
,
post_tran
,
bda
)
# (B, N_views, 27)
inputs_curr
=
(
img
,
sensor2keyego
,
ego2global
,
intrin
,
post_rot
,
post_tran
,
bda
,
mlp_input
,
feat_prev_iv
,
curr2adjsensor
[
fid
],
extra_ref_frame
)
if
key_frame
:
bev_feat
,
depth
,
feat_curr_iv
=
\
self
.
prepare_bev_feat
(
*
inputs_curr
)
depth_key_frame
=
depth
else
:
with
torch
.
no_grad
():
bev_feat
,
depth
,
feat_curr_iv
=
\
self
.
prepare_bev_feat
(
*
inputs_curr
)
if
not
extra_ref_frame
:
bev_feat_list
.
append
(
bev_feat
)
if
not
key_frame
:
feat_prev_iv
=
feat_curr_iv
if
pred_prev
:
assert
self
.
align_after_view_transfromation
assert
sensor2keyegos
[
0
].
shape
[
0
]
==
1
# batch_size = 1
feat_prev
=
torch
.
cat
(
bev_feat_list
[
1
:],
dim
=
0
)
# (1, N_views, 4, 4) --> (N_prev, N_views, 4, 4)
ego2globals_curr
=
\
ego2globals
[
0
].
repeat
(
self
.
num_frame
-
2
,
1
,
1
,
1
)
# (1, N_views, 4, 4) --> (N_prev, N_views, 4, 4)
sensor2keyegos_curr
=
\
sensor2keyegos
[
0
].
repeat
(
self
.
num_frame
-
2
,
1
,
1
,
1
)
ego2globals_prev
=
torch
.
cat
(
ego2globals
[
1
:
-
1
],
dim
=
0
)
# (N_prev, N_views, 4, 4)
sensor2keyegos_prev
=
torch
.
cat
(
sensor2keyegos
[
1
:
-
1
],
dim
=
0
)
# (N_prev, N_views, 4, 4)
bda_curr
=
bda
.
repeat
(
self
.
num_frame
-
2
,
1
,
1
)
# (N_prev, 3, 3)
return
feat_prev
,
[
imgs
[
0
],
# (1, N_views, 3, H, W)
sensor2keyegos_curr
,
# (N_prev, N_views, 4, 4)
ego2globals_curr
,
# (N_prev, N_views, 4, 4)
intrins
[
0
],
# (1, N_views, 3, 3)
sensor2keyegos_prev
,
# (N_prev, N_views, 4, 4)
ego2globals_prev
,
# (N_prev, N_views, 4, 4)
post_rots
[
0
],
# (1, N_views, 3, 3)
post_trans
[
0
],
# (1, N_views, 3, )
bda_curr
,
# (N_prev, 3, 3)
feat_prev_iv
,
curr2adjsensor
[
0
]]
if
not
self
.
with_prev
:
bev_feat_key
=
bev_feat_list
[
0
]
if
len
(
bev_feat_key
.
shape
)
==
4
:
b
,
c
,
h
,
w
=
bev_feat_key
.
shape
bev_feat_list
=
\
[
torch
.
zeros
([
b
,
c
*
(
self
.
num_frame
-
self
.
extra_ref_frames
-
1
),
h
,
w
]).
to
(
bev_feat_key
),
bev_feat_key
]
else
:
b
,
c
,
z
,
h
,
w
=
bev_feat_key
.
shape
bev_feat_list
=
\
[
torch
.
zeros
([
b
,
c
*
(
self
.
num_frame
-
self
.
extra_ref_frames
-
1
),
z
,
h
,
w
]).
to
(
bev_feat_key
),
bev_feat_key
]
if
self
.
align_after_view_transfromation
:
for
adj_id
in
range
(
self
.
num_frame
-
2
):
bev_feat_list
[
adj_id
]
=
self
.
shift_feature
(
bev_feat_list
[
adj_id
],
# (B, C, Dy, Dx)
[
sensor2keyegos
[
0
],
# (B, N_views, 4, 4)
sensor2keyegos
[
self
.
num_frame
-
2
-
adj_id
]],
# (B, N_views, 4, 4)
bda
# (B, 3, 3)
)
# (B, C, Dy, Dx)
bev_feat
=
torch
.
cat
(
bev_feat_list
,
dim
=
1
)
x
=
self
.
bev_encoder
(
bev_feat
)
return
[
x
],
depth_key_frame
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/losses/__init__.py
0 → 100644
View file @
d2b71343
from
.cross_entropy_loss
import
CrossEntropyLoss
from
.focal_loss
import
CustomFocalLoss
__all__
=
[
'CrossEntropyLoss'
,
'CustomFocalLoss'
]
\ No newline at end of file
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/losses/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
d2b71343
File added
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/losses/__pycache__/cross_entropy_loss.cpython-310.pyc
0 → 100644
View file @
d2b71343
File added
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/losses/__pycache__/focal_loss.cpython-310.pyc
0 → 100644
View file @
d2b71343
File added
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/losses/__pycache__/lovasz_softmax.cpython-310.pyc
0 → 100644
View file @
d2b71343
File added
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/losses/__pycache__/semkitti_loss.cpython-310.pyc
0 → 100644
View file @
d2b71343
File added
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/losses/cross_entropy_loss.py
0 → 100644
View file @
d2b71343
# Copyright (c) OpenMMLab. All rights reserved.
import
warnings
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmdet.models.builder
import
LOSSES
from
mmdet.models.losses.utils
import
weight_reduce_loss
def
cross_entropy
(
pred
,
label
,
weight
=
None
,
reduction
=
'mean'
,
avg_factor
=
None
,
class_weight
=
None
,
ignore_index
=-
100
,
avg_non_ignore
=
False
):
"""Calculate the CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored.
If None, it will be set to default value. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
Returns:
torch.Tensor: The calculated loss
"""
# The default value of ignore_index is the same as F.cross_entropy
ignore_index
=
-
100
if
ignore_index
is
None
else
ignore_index
# element-wise losses
loss
=
F
.
cross_entropy
(
pred
,
label
,
weight
=
class_weight
,
reduction
=
'none'
,
ignore_index
=
ignore_index
)
# average loss over non-ignored elements
# pytorch's official cross_entropy average loss over non-ignored elements
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
if
(
avg_factor
is
None
)
and
avg_non_ignore
and
reduction
==
'mean'
:
avg_factor
=
label
.
numel
()
-
(
label
==
ignore_index
).
sum
().
item
()
# apply weights and do the reduction
if
weight
is
not
None
:
weight
=
weight
.
float
()
loss
=
weight_reduce_loss
(
loss
,
weight
=
weight
,
reduction
=
reduction
,
avg_factor
=
avg_factor
)
return
loss
def
_expand_onehot_labels
(
labels
,
label_weights
,
label_channels
,
ignore_index
):
"""Expand onehot labels to match the size of prediction."""
bin_labels
=
labels
.
new_full
((
labels
.
size
(
0
),
label_channels
),
0
)
valid_mask
=
(
labels
>=
0
)
&
(
labels
!=
ignore_index
)
inds
=
torch
.
nonzero
(
valid_mask
&
(
labels
<
label_channels
),
as_tuple
=
False
)
if
inds
.
numel
()
>
0
:
bin_labels
[
inds
,
labels
[
inds
]]
=
1
valid_mask
=
valid_mask
.
view
(
-
1
,
1
).
expand
(
labels
.
size
(
0
),
label_channels
).
float
()
if
label_weights
is
None
:
bin_label_weights
=
valid_mask
else
:
bin_label_weights
=
label_weights
.
view
(
-
1
,
1
).
repeat
(
1
,
label_channels
)
bin_label_weights
*=
valid_mask
return
bin_labels
,
bin_label_weights
,
valid_mask
def
binary_cross_entropy
(
pred
,
label
,
weight
=
None
,
reduction
=
'mean'
,
avg_factor
=
None
,
class_weight
=
None
,
ignore_index
=-
100
,
avg_non_ignore
=
False
):
"""Calculate the binary CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, 1) or (N, ).
When the shape of pred is (N, 1), label will be expanded to
one-hot format, and when the shape of pred is (N, ), label
will not be expanded to one-hot format.
label (torch.Tensor): The learning label of the prediction,
with shape (N, ).
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored.
If None, it will be set to default value. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
Returns:
torch.Tensor: The calculated loss.
"""
# The default value of ignore_index is the same as F.cross_entropy
ignore_index
=
-
100
if
ignore_index
is
None
else
ignore_index
if
pred
.
dim
()
!=
label
.
dim
():
label
,
weight
,
valid_mask
=
_expand_onehot_labels
(
label
,
weight
,
pred
.
size
(
-
1
),
ignore_index
)
else
:
# should mask out the ignored elements
valid_mask
=
((
label
>=
0
)
&
(
label
!=
ignore_index
)).
float
()
if
weight
is
not
None
:
# The inplace writing method will have a mismatched broadcast
# shape error if the weight and valid_mask dimensions
# are inconsistent such as (B,N,1) and (B,N,C).
weight
=
weight
*
valid_mask
else
:
weight
=
valid_mask
# average loss over non-ignored elements
if
(
avg_factor
is
None
)
and
avg_non_ignore
and
reduction
==
'mean'
:
avg_factor
=
valid_mask
.
sum
().
item
()
# weighted element-wise losses
weight
=
weight
.
float
()
loss
=
F
.
binary_cross_entropy_with_logits
(
pred
,
label
.
float
(),
pos_weight
=
class_weight
,
reduction
=
'none'
)
# do the reduction for the weighted loss
loss
=
weight_reduce_loss
(
loss
,
weight
,
reduction
=
reduction
,
avg_factor
=
avg_factor
)
return
loss
def
mask_cross_entropy
(
pred
,
target
,
label
,
reduction
=
'mean'
,
avg_factor
=
None
,
class_weight
=
None
,
ignore_index
=
None
,
**
kwargs
):
"""Calculate the CrossEntropy loss for masks.
Args:
pred (torch.Tensor): The prediction with shape (N, C, *), C is the
number of classes. The trailing * indicates arbitrary shape.
target (torch.Tensor): The learning label of the prediction.
label (torch.Tensor): ``label`` indicates the class label of the mask
corresponding object. This will be used to select the mask in the
of the class which the object belongs to when the mask prediction
if not class-agnostic.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (None): Placeholder, to be consistent with other loss.
Default: None.
Returns:
torch.Tensor: The calculated loss
Example:
>>> N, C = 3, 11
>>> H, W = 2, 2
>>> pred = torch.randn(N, C, H, W) * 1000
>>> target = torch.rand(N, H, W)
>>> label = torch.randint(0, C, size=(N,))
>>> reduction = 'mean'
>>> avg_factor = None
>>> class_weights = None
>>> loss = mask_cross_entropy(pred, target, label, reduction,
>>> avg_factor, class_weights)
>>> assert loss.shape == (1,)
"""
assert
ignore_index
is
None
,
'BCE loss does not support ignore_index'
# TODO: handle these two reserved arguments
assert
reduction
==
'mean'
and
avg_factor
is
None
num_rois
=
pred
.
size
()[
0
]
inds
=
torch
.
arange
(
0
,
num_rois
,
dtype
=
torch
.
long
,
device
=
pred
.
device
)
pred_slice
=
pred
[
inds
,
label
].
squeeze
(
1
)
return
F
.
binary_cross_entropy_with_logits
(
pred_slice
,
target
,
weight
=
class_weight
,
reduction
=
'mean'
)[
None
]
@
LOSSES
.
register_module
(
force
=
True
)
class
CrossEntropyLoss
(
nn
.
Module
):
def
__init__
(
self
,
use_sigmoid
=
False
,
use_mask
=
False
,
reduction
=
'mean'
,
class_weight
=
None
,
ignore_index
=
None
,
loss_weight
=
1.0
,
avg_non_ignore
=
False
):
"""CrossEntropyLoss.
Args:
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
of softmax. Defaults to False.
use_mask (bool, optional): Whether to use mask cross entropy loss.
Defaults to False.
reduction (str, optional): . Defaults to 'mean'.
Options are "none", "mean" and "sum".
class_weight (list[float], optional): Weight of each class.
Defaults to None.
ignore_index (int | None): The label index to be ignored.
Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
"""
super
(
CrossEntropyLoss
,
self
).
__init__
()
assert
(
use_sigmoid
is
False
)
or
(
use_mask
is
False
)
self
.
use_sigmoid
=
use_sigmoid
self
.
use_mask
=
use_mask
self
.
reduction
=
reduction
self
.
loss_weight
=
loss_weight
self
.
class_weight
=
class_weight
self
.
ignore_index
=
ignore_index
self
.
avg_non_ignore
=
avg_non_ignore
if
((
ignore_index
is
not
None
)
and
not
self
.
avg_non_ignore
and
self
.
reduction
==
'mean'
):
warnings
.
warn
(
'Default ``avg_non_ignore`` is False, if you would like to '
'ignore the certain label and average loss over non-ignore '
'labels, which is the same with PyTorch official '
'cross_entropy, set ``avg_non_ignore=True``.'
)
if
self
.
use_sigmoid
:
self
.
cls_criterion
=
binary_cross_entropy
elif
self
.
use_mask
:
self
.
cls_criterion
=
mask_cross_entropy
else
:
self
.
cls_criterion
=
cross_entropy
def
extra_repr
(
self
):
"""Extra repr."""
s
=
f
'avg_non_ignore=
{
self
.
avg_non_ignore
}
'
return
s
def
forward
(
self
,
cls_score
,
label
,
weight
=
None
,
avg_factor
=
None
,
reduction_override
=
None
,
ignore_index
=
None
,
**
kwargs
):
"""Forward function.
Args:
cls_score (torch.Tensor): The prediction.
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The method used to reduce the
loss. Options are "none", "mean" and "sum".
ignore_index (int | None): The label index to be ignored.
If not None, it will override the default value. Default: None.
Returns:
torch.Tensor: The calculated loss.
"""
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
reduction
=
(
reduction_override
if
reduction_override
else
self
.
reduction
)
if
ignore_index
is
None
:
ignore_index
=
self
.
ignore_index
if
self
.
class_weight
is
not
None
:
class_weight
=
cls_score
.
new_tensor
(
self
.
class_weight
,
device
=
cls_score
.
device
)
else
:
class_weight
=
None
loss_cls
=
self
.
loss_weight
*
self
.
cls_criterion
(
cls_score
,
label
,
weight
,
class_weight
=
class_weight
,
reduction
=
reduction
,
avg_factor
=
avg_factor
,
ignore_index
=
ignore_index
,
avg_non_ignore
=
self
.
avg_non_ignore
,
**
kwargs
)
return
loss_cls
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/losses/focal_loss.py
0 → 100644
View file @
d2b71343
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.ops
import
sigmoid_focal_loss
as
_sigmoid_focal_loss
from
mmdet.models.builder
import
LOSSES
from
mmdet.models.losses.utils
import
weight_reduce_loss
import
numpy
as
np
# This method is only for debugging
def
py_sigmoid_focal_loss
(
pred
,
target
,
weight
=
None
,
gamma
=
2.0
,
alpha
=
0.25
,
reduction
=
'mean'
,
avg_factor
=
None
):
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the
number of classes
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
pred_sigmoid
=
pred
.
sigmoid
()
target
=
target
.
type_as
(
pred
)
pt
=
(
1
-
pred_sigmoid
)
*
target
+
pred_sigmoid
*
(
1
-
target
)
focal_weight
=
(
alpha
*
target
+
(
1
-
alpha
)
*
(
1
-
target
))
*
pt
.
pow
(
gamma
)
loss
=
F
.
binary_cross_entropy_with_logits
(
pred
,
target
,
reduction
=
'none'
)
*
focal_weight
if
weight
is
not
None
:
if
weight
.
shape
!=
loss
.
shape
:
if
weight
.
size
(
0
)
==
loss
.
size
(
0
):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight
=
weight
.
view
(
-
1
,
1
)
else
:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert
weight
.
numel
()
==
loss
.
numel
()
weight
=
weight
.
view
(
loss
.
size
(
0
),
-
1
)
assert
weight
.
ndim
==
loss
.
ndim
loss
=
loss
*
weight
loss
=
loss
.
sum
(
-
1
).
mean
()
# loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return
loss
def
py_focal_loss_with_prob
(
pred
,
target
,
weight
=
None
,
gamma
=
2.0
,
alpha
=
0.25
,
reduction
=
'mean'
,
avg_factor
=
None
):
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
Different from `py_sigmoid_focal_loss`, this function accepts probability
as input.
Args:
pred (torch.Tensor): The prediction probability with shape (N, C),
C is the number of classes.
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
num_classes
=
pred
.
size
(
1
)
target
=
F
.
one_hot
(
target
,
num_classes
=
num_classes
+
1
)
target
=
target
[:,
:
num_classes
]
target
=
target
.
type_as
(
pred
)
pt
=
(
1
-
pred
)
*
target
+
pred
*
(
1
-
target
)
focal_weight
=
(
alpha
*
target
+
(
1
-
alpha
)
*
(
1
-
target
))
*
pt
.
pow
(
gamma
)
loss
=
F
.
binary_cross_entropy
(
pred
,
target
,
reduction
=
'none'
)
*
focal_weight
if
weight
is
not
None
:
if
weight
.
shape
!=
loss
.
shape
:
if
weight
.
size
(
0
)
==
loss
.
size
(
0
):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight
=
weight
.
view
(
-
1
,
1
)
else
:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert
weight
.
numel
()
==
loss
.
numel
()
weight
=
weight
.
view
(
loss
.
size
(
0
),
-
1
)
assert
weight
.
ndim
==
loss
.
ndim
loss
=
weight_reduce_loss
(
loss
,
weight
,
reduction
,
avg_factor
)
return
loss
def
sigmoid_focal_loss
(
pred
,
target
,
weight
=
None
,
gamma
=
2.0
,
alpha
=
0.25
,
reduction
=
'mean'
,
avg_factor
=
None
):
r
"""A wrapper of cuda version `Focal Loss
<https://arxiv.org/abs/1708.02002>`_.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
# Function.apply does not accept keyword arguments, so the decorator
# "weighted_loss" is not applicable
loss
=
_sigmoid_focal_loss
(
pred
.
contiguous
(),
target
.
contiguous
(),
gamma
,
alpha
,
None
,
'none'
)
if
weight
is
not
None
:
if
weight
.
shape
!=
loss
.
shape
:
if
weight
.
size
(
0
)
==
loss
.
size
(
0
):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight
=
weight
.
view
(
-
1
,
1
)
else
:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert
weight
.
numel
()
==
loss
.
numel
()
weight
=
weight
.
view
(
loss
.
size
(
0
),
-
1
)
assert
weight
.
ndim
==
loss
.
ndim
loss
=
loss
*
weight
loss
=
loss
.
sum
(
-
1
).
mean
()
# loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return
loss
@
LOSSES
.
register_module
()
class
CustomFocalLoss
(
nn
.
Module
):
def
__init__
(
self
,
use_sigmoid
=
True
,
gamma
=
2.0
,
alpha
=
0.25
,
reduction
=
'mean'
,
loss_weight
=
100.0
,
activated
=
False
):
"""`Focal Loss <https://arxiv.org/abs/1708.02002>`_
Args:
use_sigmoid (bool, optional): Whether to the prediction is
used for sigmoid or softmax. Defaults to True.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and
"sum".
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
activated (bool, optional): Whether the input is activated.
If True, it means the input has been activated and can be
treated as probabilities. Else, it should be treated as logits.
Defaults to False.
"""
super
(
CustomFocalLoss
,
self
).
__init__
()
assert
use_sigmoid
is
True
,
'Only sigmoid focal loss supported now.'
self
.
use_sigmoid
=
use_sigmoid
self
.
gamma
=
gamma
self
.
alpha
=
alpha
self
.
reduction
=
reduction
self
.
loss_weight
=
loss_weight
self
.
activated
=
activated
H
,
W
=
200
,
200
xy
,
yx
=
torch
.
meshgrid
([
torch
.
arange
(
H
)
-
H
/
2
,
torch
.
arange
(
W
)
-
W
/
2
])
c
=
torch
.
stack
([
xy
,
yx
],
2
)
c
=
torch
.
norm
(
c
,
2
,
-
1
)
c_max
=
c
.
max
()
self
.
c
=
(
c
/
c_max
+
1
).
cuda
()
def
forward
(
self
,
pred
,
target
,
weight
=
None
,
avg_factor
=
None
,
ignore_index
=
255
,
reduction_override
=
None
):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
B
,
H
,
W
,
D
=
target
.
shape
c
=
self
.
c
[
None
,
:,
:,
None
].
repeat
(
B
,
1
,
1
,
D
).
reshape
(
-
1
)
visible_mask
=
(
target
!=
ignore_index
).
reshape
(
-
1
).
nonzero
().
squeeze
(
-
1
)
weight_mask
=
weight
[
None
,
:]
*
c
[
visible_mask
,
None
]
# visible_mask[:, None]
num_classes
=
pred
.
size
(
1
)
pred
=
pred
.
permute
(
0
,
2
,
3
,
4
,
1
).
reshape
(
-
1
,
num_classes
)[
visible_mask
]
target
=
target
.
reshape
(
-
1
)[
visible_mask
]
assert
reduction_override
in
(
None
,
'none'
,
'mean'
,
'sum'
)
reduction
=
(
reduction_override
if
reduction_override
else
self
.
reduction
)
if
self
.
use_sigmoid
:
if
self
.
activated
:
calculate_loss_func
=
py_focal_loss_with_prob
else
:
if
torch
.
cuda
.
is_available
()
and
pred
.
is_cuda
:
calculate_loss_func
=
sigmoid_focal_loss
else
:
num_classes
=
pred
.
size
(
1
)
target
=
F
.
one_hot
(
target
,
num_classes
=
num_classes
+
1
)
target
=
target
[:,
:
num_classes
]
calculate_loss_func
=
py_sigmoid_focal_loss
loss_cls
=
self
.
loss_weight
*
calculate_loss_func
(
pred
,
target
.
to
(
torch
.
long
),
weight_mask
,
gamma
=
self
.
gamma
,
alpha
=
self
.
alpha
,
reduction
=
reduction
,
avg_factor
=
avg_factor
)
else
:
raise
NotImplementedError
return
loss_cls
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/losses/lovasz_softmax.py
0 → 100644
View file @
d2b71343
# -*- coding:utf-8 -*-
# author: Xinge
"""
Lovasz-Softmax and Jaccard hinge loss in PyTorch
Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
"""
from
__future__
import
print_function
,
division
import
torch
from
torch.autograd
import
Variable
import
torch.nn.functional
as
F
import
numpy
as
np
try
:
from
itertools
import
ifilterfalse
except
ImportError
:
# py3k
from
itertools
import
filterfalse
as
ifilterfalse
from
torch.cuda.amp
import
autocast
def
lovasz_grad
(
gt_sorted
):
"""
Computes gradient of the Lovasz extension w.r.t sorted errors
See Alg. 1 in paper
"""
p
=
len
(
gt_sorted
)
gts
=
gt_sorted
.
sum
()
intersection
=
gts
-
gt_sorted
.
float
().
cumsum
(
0
)
union
=
gts
+
(
1
-
gt_sorted
).
float
().
cumsum
(
0
)
jaccard
=
1.
-
intersection
/
union
if
p
>
1
:
# cover 1-pixel case
jaccard
[
1
:
p
]
=
jaccard
[
1
:
p
]
-
jaccard
[
0
:
-
1
]
return
jaccard
def
iou_binary
(
preds
,
labels
,
EMPTY
=
1.
,
ignore
=
None
,
per_image
=
True
):
"""
IoU for foreground class
binary: 1 foreground, 0 background
"""
if
not
per_image
:
preds
,
labels
=
(
preds
,),
(
labels
,)
ious
=
[]
for
pred
,
label
in
zip
(
preds
,
labels
):
intersection
=
((
label
==
1
)
&
(
pred
==
1
)).
sum
()
union
=
((
label
==
1
)
|
((
pred
==
1
)
&
(
label
!=
ignore
))).
sum
()
if
not
union
:
iou
=
EMPTY
else
:
iou
=
float
(
intersection
)
/
float
(
union
)
ious
.
append
(
iou
)
iou
=
mean
(
ious
)
# mean accross images if per_image
return
100
*
iou
def
iou
(
preds
,
labels
,
C
,
EMPTY
=
1.
,
ignore
=
None
,
per_image
=
False
):
"""
Array of IoU for each (non ignored) class
"""
if
not
per_image
:
preds
,
labels
=
(
preds
,),
(
labels
,)
ious
=
[]
for
pred
,
label
in
zip
(
preds
,
labels
):
iou
=
[]
for
i
in
range
(
C
):
if
i
!=
ignore
:
# The ignored label is sometimes among predicted classes (ENet - CityScapes)
intersection
=
((
label
==
i
)
&
(
pred
==
i
)).
sum
()
union
=
((
label
==
i
)
|
((
pred
==
i
)
&
(
label
!=
ignore
))).
sum
()
if
not
union
:
iou
.
append
(
EMPTY
)
else
:
iou
.
append
(
float
(
intersection
)
/
float
(
union
))
ious
.
append
(
iou
)
ious
=
[
mean
(
iou
)
for
iou
in
zip
(
*
ious
)]
# mean accross images if per_image
return
100
*
np
.
array
(
ious
)
# --------------------------- BINARY LOSSES ---------------------------
def
lovasz_hinge
(
logits
,
labels
,
per_image
=
True
,
ignore
=
None
):
"""
Binary Lovasz hinge loss
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
per_image: compute the loss per image instead of per batch
ignore: void class id
"""
if
per_image
:
loss
=
mean
(
lovasz_hinge_flat
(
*
flatten_binary_scores
(
log
.
unsqueeze
(
0
),
lab
.
unsqueeze
(
0
),
ignore
))
for
log
,
lab
in
zip
(
logits
,
labels
))
else
:
loss
=
lovasz_hinge_flat
(
*
flatten_binary_scores
(
logits
,
labels
,
ignore
))
return
loss
def
lovasz_hinge_flat
(
logits
,
labels
):
"""
Binary Lovasz hinge loss
logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
labels: [P] Tensor, binary ground truth labels (0 or 1)
ignore: label to ignore
"""
if
len
(
labels
)
==
0
:
# only void pixels, the gradients should be 0
return
logits
.
sum
()
*
0.
signs
=
2.
*
labels
.
float
()
-
1.
errors
=
(
1.
-
logits
*
Variable
(
signs
))
errors_sorted
,
perm
=
torch
.
sort
(
errors
,
dim
=
0
,
descending
=
True
)
perm
=
perm
.
data
gt_sorted
=
labels
[
perm
]
grad
=
lovasz_grad
(
gt_sorted
)
loss
=
torch
.
dot
(
F
.
relu
(
errors_sorted
),
Variable
(
grad
))
return
loss
def
flatten_binary_scores
(
scores
,
labels
,
ignore
=
None
):
"""
Flattens predictions in the batch (binary case)
Remove labels equal to 'ignore'
"""
scores
=
scores
.
view
(
-
1
)
labels
=
labels
.
view
(
-
1
)
if
ignore
is
None
:
return
scores
,
labels
valid
=
(
labels
!=
ignore
)
vscores
=
scores
[
valid
]
vlabels
=
labels
[
valid
]
return
vscores
,
vlabels
class
StableBCELoss
(
torch
.
nn
.
modules
.
Module
):
def
__init__
(
self
):
super
(
StableBCELoss
,
self
).
__init__
()
def
forward
(
self
,
input
,
target
):
neg_abs
=
-
input
.
abs
()
loss
=
input
.
clamp
(
min
=
0
)
-
input
*
target
+
(
1
+
neg_abs
.
exp
()).
log
()
return
loss
.
mean
()
def
binary_xloss
(
logits
,
labels
,
ignore
=
None
):
"""
Binary Cross entropy loss
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
ignore: void class id
"""
logits
,
labels
=
flatten_binary_scores
(
logits
,
labels
,
ignore
)
loss
=
StableBCELoss
()(
logits
,
Variable
(
labels
.
float
()))
return
loss
# --------------------------- MULTICLASS LOSSES ---------------------------
def
lovasz_softmax
(
probas
,
labels
,
classes
=
'present'
,
per_image
=
False
,
ignore
=
None
):
"""
Multi-class Lovasz-Softmax loss
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
per_image: compute the loss per image instead of per batch
ignore: void class labels
"""
if
per_image
:
loss
=
mean
(
lovasz_softmax_flat
(
*
flatten_probas
(
prob
.
unsqueeze
(
0
),
lab
.
unsqueeze
(
0
),
ignore
),
classes
=
classes
)
for
prob
,
lab
in
zip
(
probas
,
labels
))
else
:
with
autocast
(
False
):
loss
=
lovasz_softmax_flat
(
*
flatten_probas
(
probas
,
labels
,
ignore
),
classes
=
classes
)
return
loss
def
lovasz_softmax_flat
(
probas
,
labels
,
classes
=
'present'
):
"""
Multi-class Lovasz-Softmax loss
probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
labels: [P] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
"""
if
probas
.
numel
()
==
0
:
# only void pixels, the gradients should be 0
return
probas
*
0.
C
=
probas
.
size
(
1
)
losses
=
[]
class_to_sum
=
list
(
range
(
C
))
if
classes
in
[
'all'
,
'present'
]
else
classes
for
c
in
class_to_sum
:
fg
=
(
labels
==
c
).
float
()
# foreground for class c
if
(
classes
is
'present'
and
fg
.
sum
()
==
0
):
continue
if
C
==
1
:
if
len
(
classes
)
>
1
:
raise
ValueError
(
'Sigmoid output possible only with 1 class'
)
class_pred
=
probas
[:,
0
]
else
:
class_pred
=
probas
[:,
c
]
errors
=
(
Variable
(
fg
)
-
class_pred
).
abs
()
errors_sorted
,
perm
=
torch
.
sort
(
errors
,
0
,
descending
=
True
)
perm
=
perm
.
data
fg_sorted
=
fg
[
perm
]
losses
.
append
(
torch
.
dot
(
errors_sorted
,
Variable
(
lovasz_grad
(
fg_sorted
))))
return
mean
(
losses
)
def
flatten_probas
(
probas
,
labels
,
ignore
=
None
):
"""
Flattens predictions in the batch
"""
if
probas
.
dim
()
==
2
:
if
ignore
is
not
None
:
valid
=
(
labels
!=
ignore
)
probas
=
probas
[
valid
]
labels
=
labels
[
valid
]
return
probas
,
labels
elif
probas
.
dim
()
==
3
:
# assumes output of a sigmoid layer
B
,
H
,
W
=
probas
.
size
()
probas
=
probas
.
view
(
B
,
1
,
H
,
W
)
elif
probas
.
dim
()
==
5
:
#3D segmentation
B
,
C
,
L
,
H
,
W
=
probas
.
size
()
probas
=
probas
.
contiguous
().
view
(
B
,
C
,
L
,
H
*
W
)
B
,
C
,
H
,
W
=
probas
.
size
()
probas
=
probas
.
permute
(
0
,
2
,
3
,
1
).
contiguous
().
view
(
-
1
,
C
)
# B * H * W, C = P, C
labels
=
labels
.
view
(
-
1
)
if
ignore
is
None
:
return
probas
,
labels
valid
=
(
labels
!=
ignore
)
vprobas
=
probas
[
valid
.
nonzero
().
squeeze
()]
vlabels
=
labels
[
valid
]
return
vprobas
,
vlabels
def
xloss
(
logits
,
labels
,
ignore
=
None
):
"""
Cross entropy loss
"""
return
F
.
cross_entropy
(
logits
,
Variable
(
labels
),
ignore_index
=
255
)
def
jaccard_loss
(
probas
,
labels
,
ignore
=
None
,
smooth
=
100
,
bk_class
=
None
):
"""
Something wrong with this loss
Multi-class Lovasz-Softmax loss
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
per_image: compute the loss per image instead of per batch
ignore: void class labels
"""
vprobas
,
vlabels
=
flatten_probas
(
probas
,
labels
,
ignore
)
true_1_hot
=
torch
.
eye
(
vprobas
.
shape
[
1
])[
vlabels
]
if
bk_class
:
one_hot_assignment
=
torch
.
ones_like
(
vlabels
)
one_hot_assignment
[
vlabels
==
bk_class
]
=
0
one_hot_assignment
=
one_hot_assignment
.
float
().
unsqueeze
(
1
)
true_1_hot
=
true_1_hot
*
one_hot_assignment
true_1_hot
=
true_1_hot
.
to
(
vprobas
.
device
)
intersection
=
torch
.
sum
(
vprobas
*
true_1_hot
)
cardinality
=
torch
.
sum
(
vprobas
+
true_1_hot
)
loss
=
(
intersection
+
smooth
/
(
cardinality
-
intersection
+
smooth
)).
mean
()
return
(
1
-
loss
)
*
smooth
def
hinge_jaccard_loss
(
probas
,
labels
,
ignore
=
None
,
classes
=
'present'
,
hinge
=
0.1
,
smooth
=
100
):
"""
Multi-class Hinge Jaccard loss
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
ignore: void class labels
"""
vprobas
,
vlabels
=
flatten_probas
(
probas
,
labels
,
ignore
)
C
=
vprobas
.
size
(
1
)
losses
=
[]
class_to_sum
=
list
(
range
(
C
))
if
classes
in
[
'all'
,
'present'
]
else
classes
for
c
in
class_to_sum
:
if
c
in
vlabels
:
c_sample_ind
=
vlabels
==
c
cprobas
=
vprobas
[
c_sample_ind
,:]
non_c_ind
=
np
.
array
([
a
for
a
in
class_to_sum
if
a
!=
c
])
class_pred
=
cprobas
[:,
c
]
max_non_class_pred
=
torch
.
max
(
cprobas
[:,
non_c_ind
],
dim
=
1
)[
0
]
TP
=
torch
.
sum
(
torch
.
clamp
(
class_pred
-
max_non_class_pred
,
max
=
hinge
)
+
1.
)
+
smooth
FN
=
torch
.
sum
(
torch
.
clamp
(
max_non_class_pred
-
class_pred
,
min
=
-
hinge
)
+
hinge
)
if
(
~
c_sample_ind
).
sum
()
==
0
:
FP
=
0
else
:
nonc_probas
=
vprobas
[
~
c_sample_ind
,:]
class_pred
=
nonc_probas
[:,
c
]
max_non_class_pred
=
torch
.
max
(
nonc_probas
[:,
non_c_ind
],
dim
=
1
)[
0
]
FP
=
torch
.
sum
(
torch
.
clamp
(
class_pred
-
max_non_class_pred
,
max
=
hinge
)
+
1.
)
losses
.
append
(
1
-
TP
/
(
TP
+
FP
+
FN
))
if
len
(
losses
)
==
0
:
return
0
return
mean
(
losses
)
# --------------------------- HELPER FUNCTIONS ---------------------------
def
isnan
(
x
):
return
x
!=
x
def
mean
(
l
,
ignore_nan
=
False
,
empty
=
0
):
"""
nanmean compatible with generators.
"""
l
=
iter
(
l
)
if
ignore_nan
:
l
=
ifilterfalse
(
isnan
,
l
)
try
:
n
=
1
acc
=
next
(
l
)
except
StopIteration
:
if
empty
==
'raise'
:
raise
ValueError
(
'Empty mean'
)
return
empty
for
n
,
v
in
enumerate
(
l
,
2
):
acc
+=
v
if
n
==
1
:
return
acc
return
acc
/
n
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/losses/semkitti_loss.py
0 → 100644
View file @
d2b71343
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
numpy
as
np
# from mmcv.runner import BaseModule, force_fp32
from
torch.cuda.amp
import
autocast
semantic_kitti_class_frequencies
=
np
.
array
(
[
5.41773033e09
,
1.57835390e07
,
1.25136000e05
,
1.18809000e05
,
6.46799000e05
,
8.21951000e05
,
2.62978000e05
,
2.83696000e05
,
2.04750000e05
,
6.16887030e07
,
4.50296100e06
,
4.48836500e07
,
2.26992300e06
,
5.68402180e07
,
1.57196520e07
,
1.58442623e08
,
2.06162300e06
,
3.69705220e07
,
1.15198800e06
,
3.34146000e05
,
]
)
kitti_class_names
=
[
"empty"
,
"car"
,
"bicycle"
,
"motorcycle"
,
"truck"
,
"other-vehicle"
,
"person"
,
"bicyclist"
,
"motorcyclist"
,
"road"
,
"parking"
,
"sidewalk"
,
"other-ground"
,
"building"
,
"fence"
,
"vegetation"
,
"trunk"
,
"terrain"
,
"pole"
,
"traffic-sign"
,
]
def
inverse_sigmoid
(
x
,
sign
=
'A'
):
x
=
x
.
to
(
torch
.
float32
)
while
x
>=
1
-
1e-5
:
x
=
x
-
1e-5
while
x
<
1e-5
:
x
=
x
+
1e-5
return
-
torch
.
log
((
1
/
x
)
-
1
)
def
KL_sep
(
p
,
target
):
"""
KL divergence on nonzeros classes
"""
nonzeros
=
target
!=
0
nonzero_p
=
p
[
nonzeros
]
kl_term
=
F
.
kl_div
(
torch
.
log
(
nonzero_p
),
target
[
nonzeros
],
reduction
=
"sum"
)
return
kl_term
def
geo_scal_loss
(
pred
,
ssc_target
,
ignore_index
=
255
,
non_empty_idx
=
0
):
# Get softmax probabilities
pred
=
F
.
softmax
(
pred
,
dim
=
1
)
# Compute empty and nonempty probabilities
empty_probs
=
pred
[:,
non_empty_idx
]
nonempty_probs
=
1
-
empty_probs
# Remove unknown voxels
mask
=
ssc_target
!=
ignore_index
nonempty_target
=
ssc_target
!=
non_empty_idx
nonempty_target
=
nonempty_target
[
mask
].
float
()
nonempty_probs
=
nonempty_probs
[
mask
]
empty_probs
=
empty_probs
[
mask
]
eps
=
1e-5
intersection
=
(
nonempty_target
*
nonempty_probs
).
sum
()
precision
=
intersection
/
(
nonempty_probs
.
sum
()
+
eps
)
recall
=
intersection
/
(
nonempty_target
.
sum
()
+
eps
)
spec
=
((
1
-
nonempty_target
)
*
(
empty_probs
)).
sum
()
/
((
1
-
nonempty_target
).
sum
()
+
eps
)
with
autocast
(
False
):
return
(
F
.
binary_cross_entropy_with_logits
(
inverse_sigmoid
(
precision
,
'A'
),
torch
.
ones_like
(
precision
))
+
F
.
binary_cross_entropy_with_logits
(
inverse_sigmoid
(
recall
,
'B'
),
torch
.
ones_like
(
recall
))
+
F
.
binary_cross_entropy_with_logits
(
inverse_sigmoid
(
spec
,
'C'
),
torch
.
ones_like
(
spec
))
)
def
sem_scal_loss
(
pred_
,
ssc_target
,
ignore_index
=
255
):
# Get softmax probabilities
with
autocast
(
False
):
pred
=
F
.
softmax
(
pred_
,
dim
=
1
)
# (B, n_class, Dx, Dy, Dz)
loss
=
0
count
=
0
mask
=
ssc_target
!=
ignore_index
n_classes
=
pred
.
shape
[
1
]
begin
=
0
for
i
in
range
(
begin
,
n_classes
-
1
):
# Get probability of class i
p
=
pred
[:,
i
]
# (B, Dx, Dy, Dz)
# Remove unknown voxels
target_ori
=
ssc_target
# (B, Dx, Dy, Dz)
p
=
p
[
mask
]
target
=
ssc_target
[
mask
]
completion_target
=
torch
.
ones_like
(
target
)
completion_target
[
target
!=
i
]
=
0
completion_target_ori
=
torch
.
ones_like
(
target_ori
).
float
()
completion_target_ori
[
target_ori
!=
i
]
=
0
if
torch
.
sum
(
completion_target
)
>
0
:
count
+=
1.0
nominator
=
torch
.
sum
(
p
*
completion_target
)
loss_class
=
0
if
torch
.
sum
(
p
)
>
0
:
precision
=
nominator
/
(
torch
.
sum
(
p
)
+
1e-5
)
loss_precision
=
F
.
binary_cross_entropy_with_logits
(
inverse_sigmoid
(
precision
,
'D'
),
torch
.
ones_like
(
precision
)
)
loss_class
+=
loss_precision
if
torch
.
sum
(
completion_target
)
>
0
:
recall
=
nominator
/
(
torch
.
sum
(
completion_target
)
+
1e-5
)
# loss_recall = F.binary_cross_entropy(recall, torch.ones_like(recall))
loss_recall
=
F
.
binary_cross_entropy_with_logits
(
inverse_sigmoid
(
recall
,
'E'
),
torch
.
ones_like
(
recall
))
loss_class
+=
loss_recall
if
torch
.
sum
(
1
-
completion_target
)
>
0
:
specificity
=
torch
.
sum
((
1
-
p
)
*
(
1
-
completion_target
))
/
(
torch
.
sum
(
1
-
completion_target
)
+
1e-5
)
loss_specificity
=
F
.
binary_cross_entropy_with_logits
(
inverse_sigmoid
(
specificity
,
'F'
),
torch
.
ones_like
(
specificity
)
)
loss_class
+=
loss_specificity
loss
+=
loss_class
# print(i, loss_class, loss_recall, loss_specificity)
l
=
loss
/
count
if
torch
.
isnan
(
l
):
from
IPython
import
embed
embed
()
exit
()
return
l
def
CE_ssc_loss
(
pred
,
target
,
class_weights
=
None
,
ignore_index
=
255
):
"""
:param: prediction: the predicted tensor, must be [BS, C, ...]
"""
criterion
=
nn
.
CrossEntropyLoss
(
weight
=
class_weights
,
ignore_index
=
ignore_index
,
reduction
=
"mean"
)
# from IPython import embed
# embed()
# exit()
with
autocast
(
False
):
loss
=
criterion
(
pred
,
target
.
long
())
return
loss
def
vel_loss
(
pred
,
gt
):
with
autocast
(
False
):
return
F
.
l1_loss
(
pred
,
gt
)
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/model_utils/__init__.py
0 → 100644
View file @
d2b71343
from
.depthnet
import
DepthNet
__all__
=
[
'DepthNet'
]
\ No newline at end of file
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/model_utils/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
d2b71343
File added
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/model_utils/__pycache__/depthnet.cpython-310.pyc
0 → 100644
View file @
d2b71343
File added
docker-hub/FlashOCC/Flashocc/projects/mmdet3d_plugin/models/model_utils/depthnet.py
0 → 100644
View file @
d2b71343
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmdet.models.backbones.resnet
import
BasicBlock
from
mmcv.cnn
import
build_conv_layer
from
torch.cuda.amp.autocast_mode
import
autocast
from
torch.utils.checkpoint
import
checkpoint
class
_ASPPModule
(
nn
.
Module
):
def
__init__
(
self
,
inplanes
,
planes
,
kernel_size
,
padding
,
dilation
,
BatchNorm
):
super
(
_ASPPModule
,
self
).
__init__
()
self
.
atrous_conv
=
nn
.
Conv2d
(
inplanes
,
planes
,
kernel_size
=
kernel_size
,
stride
=
1
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
False
)
self
.
bn
=
BatchNorm
(
planes
)
self
.
relu
=
nn
.
ReLU
()
self
.
_init_weight
()
def
forward
(
self
,
x
):
x
=
self
.
atrous_conv
(
x
)
x
=
self
.
bn
(
x
)
return
self
.
relu
(
x
)
def
_init_weight
(
self
):
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
torch
.
nn
.
init
.
kaiming_normal_
(
m
.
weight
)
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
m
.
weight
.
data
.
fill_
(
1
)
m
.
bias
.
data
.
zero_
()
class
ASPP
(
nn
.
Module
):
def
__init__
(
self
,
inplanes
,
mid_channels
=
256
,
BatchNorm
=
nn
.
BatchNorm2d
):
super
(
ASPP
,
self
).
__init__
()
dilations
=
[
1
,
6
,
12
,
18
]
self
.
aspp1
=
_ASPPModule
(
inplanes
,
mid_channels
,
1
,
padding
=
0
,
dilation
=
dilations
[
0
],
BatchNorm
=
BatchNorm
)
self
.
aspp2
=
_ASPPModule
(
inplanes
,
mid_channels
,
3
,
padding
=
dilations
[
1
],
dilation
=
dilations
[
1
],
BatchNorm
=
BatchNorm
)
self
.
aspp3
=
_ASPPModule
(
inplanes
,
mid_channels
,
3
,
padding
=
dilations
[
2
],
dilation
=
dilations
[
2
],
BatchNorm
=
BatchNorm
)
self
.
aspp4
=
_ASPPModule
(
inplanes
,
mid_channels
,
3
,
padding
=
dilations
[
3
],
dilation
=
dilations
[
3
],
BatchNorm
=
BatchNorm
)
self
.
global_avg_pool
=
nn
.
Sequential
(
nn
.
AdaptiveAvgPool2d
((
1
,
1
)),
nn
.
Conv2d
(
inplanes
,
mid_channels
,
1
,
stride
=
1
,
bias
=
False
),
BatchNorm
(
mid_channels
),
nn
.
ReLU
(),
)
self
.
conv1
=
nn
.
Conv2d
(
int
(
mid_channels
*
5
),
inplanes
,
1
,
bias
=
False
)
self
.
bn1
=
BatchNorm
(
inplanes
)
self
.
relu
=
nn
.
ReLU
()
self
.
dropout
=
nn
.
Dropout
(
0.5
)
self
.
_init_weight
()
def
forward
(
self
,
x
):
"""
Args:
x: (B*N, C, fH, fW)
Returns:
x: (B*N, C, fH, fW)
"""
x1
=
self
.
aspp1
(
x
)
x2
=
self
.
aspp2
(
x
)
x3
=
self
.
aspp3
(
x
)
x4
=
self
.
aspp4
(
x
)
x5
=
self
.
global_avg_pool
(
x
)
x5
=
F
.
interpolate
(
x5
,
size
=
x4
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
True
)
x
=
torch
.
cat
((
x1
,
x2
,
x3
,
x4
,
x5
),
dim
=
1
)
# (B*N, 5*C', fH, fW)
x
=
self
.
conv1
(
x
)
# (B*N, C, fH, fW)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu
(
x
)
return
self
.
dropout
(
x
)
def
_init_weight
(
self
):
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
torch
.
nn
.
init
.
kaiming_normal_
(
m
.
weight
)
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
m
.
weight
.
data
.
fill_
(
1
)
m
.
bias
.
data
.
zero_
()
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
ReLU
,
drop
=
0.0
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
drop1
=
nn
.
Dropout
(
drop
)
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop2
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
"""
Args:
x: (B*N_views, 27)
Returns:
x: (B*N_views, C)
"""
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop1
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop2
(
x
)
return
x
class
SELayer
(
nn
.
Module
):
def
__init__
(
self
,
channels
,
act_layer
=
nn
.
ReLU
,
gate_layer
=
nn
.
Sigmoid
):
super
().
__init__
()
self
.
conv_reduce
=
nn
.
Conv2d
(
channels
,
channels
,
1
,
bias
=
True
)
self
.
act1
=
act_layer
()
self
.
conv_expand
=
nn
.
Conv2d
(
channels
,
channels
,
1
,
bias
=
True
)
self
.
gate
=
gate_layer
()
def
forward
(
self
,
x
,
x_se
):
"""
Args:
x: (B*N_views, C_mid, fH, fW)
x_se: (B*N_views, C_mid, 1, 1)
Returns:
x: (B*N_views, C_mid, fH, fW)
"""
x_se
=
self
.
conv_reduce
(
x_se
)
# (B*N_views, C_mid, 1, 1)
x_se
=
self
.
act1
(
x_se
)
# (B*N_views, C_mid, 1, 1)
x_se
=
self
.
conv_expand
(
x_se
)
# (B*N_views, C_mid, 1, 1)
return
x
*
self
.
gate
(
x_se
)
# (B*N_views, C_mid, fH, fW)
class
DepthNet
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
mid_channels
,
context_channels
,
depth_channels
,
use_dcn
=
True
,
use_aspp
=
True
,
with_cp
=
False
,
stereo
=
False
,
bias
=
0.0
,
aspp_mid_channels
=-
1
):
super
(
DepthNet
,
self
).
__init__
()
self
.
reduce_conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
mid_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
nn
.
BatchNorm2d
(
mid_channels
),
nn
.
ReLU
(
inplace
=
True
),
)
# 生成context feature
self
.
context_conv
=
nn
.
Conv2d
(
mid_channels
,
context_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
bn
=
nn
.
BatchNorm1d
(
27
)
self
.
depth_mlp
=
Mlp
(
in_features
=
27
,
hidden_features
=
mid_channels
,
out_features
=
mid_channels
)
self
.
depth_se
=
SELayer
(
channels
=
mid_channels
)
# NOTE: add camera-aware
self
.
context_mlp
=
Mlp
(
in_features
=
27
,
hidden_features
=
mid_channels
,
out_features
=
mid_channels
)
self
.
context_se
=
SELayer
(
channels
=
mid_channels
)
# NOTE: add camera-aware
depth_conv_input_channels
=
mid_channels
downsample
=
None
if
stereo
:
depth_conv_input_channels
+=
depth_channels
downsample
=
nn
.
Conv2d
(
depth_conv_input_channels
,
mid_channels
,
1
,
1
,
0
)
cost_volumn_net
=
[]
for
stage
in
range
(
int
(
2
)):
cost_volumn_net
.
extend
([
nn
.
Conv2d
(
depth_channels
,
depth_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
),
nn
.
BatchNorm2d
(
depth_channels
)])
self
.
cost_volumn_net
=
nn
.
Sequential
(
*
cost_volumn_net
)
self
.
bias
=
bias
# 3个残差blocks
depth_conv_list
=
[
BasicBlock
(
depth_conv_input_channels
,
mid_channels
,
downsample
=
downsample
),
BasicBlock
(
mid_channels
,
mid_channels
),
BasicBlock
(
mid_channels
,
mid_channels
)]
if
use_aspp
:
if
aspp_mid_channels
<
0
:
aspp_mid_channels
=
mid_channels
depth_conv_list
.
append
(
ASPP
(
mid_channels
,
aspp_mid_channels
))
if
use_dcn
:
depth_conv_list
.
append
(
build_conv_layer
(
cfg
=
dict
(
type
=
'DCN'
,
in_channels
=
mid_channels
,
out_channels
=
mid_channels
,
kernel_size
=
3
,
padding
=
1
,
groups
=
4
,
im2col_step
=
128
,
)))
depth_conv_list
.
append
(
nn
.
Conv2d
(
mid_channels
,
depth_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
))
self
.
depth_conv
=
nn
.
Sequential
(
*
depth_conv_list
)
self
.
with_cp
=
with_cp
self
.
depth_channels
=
depth_channels
# ----------------------------------------- 用于建立cost volume ----------------------------------
def
gen_grid
(
self
,
metas
,
B
,
N
,
D
,
H
,
W
,
hi
,
wi
):
"""
Args:
metas: dict{
k2s_sensor: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
frustum: (D, fH_stereo, fW_stereo, 3) 3:(u, v, d)
cv_downsample: 4,
downsample: self.img_view_transformer.downsample=16,
grid_config: self.img_view_transformer.grid_config,
cv_feat_list: [feat_prev_iv, stereo_feat]
}
B: batchsize
N: N_views
D: D
H: fH_stereo
W: fW_stereo
hi: H_img
wi: W_img
Returns:
grid: (B*N_views, D*fH_stereo, fW_stereo, 2)
"""
frustum
=
metas
[
'frustum'
]
# (D, fH_stereo, fW_stereo, 3) 3:(u, v, d)
# 逆图像增广:
points
=
frustum
-
metas
[
'post_trans'
].
view
(
B
,
N
,
1
,
1
,
1
,
3
)
points
=
torch
.
inverse
(
metas
[
'post_rots'
]).
view
(
B
,
N
,
1
,
1
,
1
,
3
,
3
)
\
.
matmul
(
points
.
unsqueeze
(
-
1
))
# (B, N_views, D, fH_stereo, fW_stereo, 3, 1)
# (u, v, d) --> (du, dv, d)
# (B, N_views, D, fH_stereo, fW_stereo, 3, 1)
points
=
torch
.
cat
(
(
points
[...,
:
2
,
:]
*
points
[...,
2
:
3
,
:],
points
[...,
2
:
3
,
:]),
5
)
# cur_pixel --> curr_camera --> prev_camera
rots
=
metas
[
'k2s_sensor'
][:,
:,
:
3
,
:
3
].
contiguous
()
trans
=
metas
[
'k2s_sensor'
][:,
:,
:
3
,
3
].
contiguous
()
combine
=
rots
.
matmul
(
torch
.
inverse
(
metas
[
'intrins'
]))
points
=
combine
.
view
(
B
,
N
,
1
,
1
,
1
,
3
,
3
).
matmul
(
points
)
points
+=
trans
.
view
(
B
,
N
,
1
,
1
,
1
,
3
,
1
)
# (B, N_views, D, fH_stereo, fW_stereo, 3, 1)
neg_mask
=
points
[...,
2
,
0
]
<
1e-3
# prev_camera --> prev_pixel
points
=
metas
[
'intrins'
].
view
(
B
,
N
,
1
,
1
,
1
,
3
,
3
).
matmul
(
points
)
# (du, dv, d) --> (u, v) (B, N_views, D, fH_stereo, fW_stereo, 2, 1)
points
=
points
[...,
:
2
,
:]
/
points
[...,
2
:
3
,
:]
# 图像增广
points
=
metas
[
'post_rots'
][...,
:
2
,
:
2
].
view
(
B
,
N
,
1
,
1
,
1
,
2
,
2
).
matmul
(
points
).
squeeze
(
-
1
)
points
+=
metas
[
'post_trans'
][...,
:
2
].
view
(
B
,
N
,
1
,
1
,
1
,
2
)
# (B, N_views, D, fH_stereo, fW_stereo, 2)
px
=
points
[...,
0
]
/
(
wi
-
1.0
)
*
2.0
-
1.0
py
=
points
[...,
1
]
/
(
hi
-
1.0
)
*
2.0
-
1.0
px
[
neg_mask
]
=
-
2
py
[
neg_mask
]
=
-
2
grid
=
torch
.
stack
([
px
,
py
],
dim
=-
1
)
# (B, N_views, D, fH_stereo, fW_stereo, 2)
grid
=
grid
.
view
(
B
*
N
,
D
*
H
,
W
,
2
)
# (B*N_views, D*fH_stereo, fW_stereo, 2)
return
grid
def
calculate_cost_volumn
(
self
,
metas
):
"""
Args:
metas: dict{
k2s_sensor: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
frustum: (D, fH_stereo, fW_stereo, 3) 3:(u, v, d)
cv_downsample: 4,
downsample: self.img_view_transformer.downsample=16,
grid_config: self.img_view_transformer.grid_config,
cv_feat_list: [feat_prev_iv, stereo_feat]
}
Returns:
cost_volumn: (B*N_views, D, fH_stereo, fW_stereo)
"""
prev
,
curr
=
metas
[
'cv_feat_list'
]
# (B*N_views, C_stereo, fH_stereo, fW_stereo)
group_size
=
4
_
,
c
,
hf
,
wf
=
curr
.
shape
#
hi
,
wi
=
hf
*
4
,
wf
*
4
# H_img, W_img
B
,
N
,
_
=
metas
[
'post_trans'
].
shape
D
,
H
,
W
,
_
=
metas
[
'frustum'
].
shape
grid
=
self
.
gen_grid
(
metas
,
B
,
N
,
D
,
H
,
W
,
hi
,
wi
).
to
(
curr
.
dtype
)
# (B*N_views, D*fH_stereo, fW_stereo, 2)
prev
=
prev
.
view
(
B
*
N
,
-
1
,
H
,
W
)
# (B*N_views, C_stereo, fH_stereo, fW_stereo)
curr
=
curr
.
view
(
B
*
N
,
-
1
,
H
,
W
)
# (B*N_views, C_stereo, fH_stereo, fW_stereo)
cost_volumn
=
0
# process in group wise to save memory
for
fid
in
range
(
curr
.
shape
[
1
]
//
group_size
):
# (B*N_views, group_size, fH_stereo, fW_stereo)
prev_curr
=
prev
[:,
fid
*
group_size
:(
fid
+
1
)
*
group_size
,
...]
wrap_prev
=
F
.
grid_sample
(
prev_curr
,
grid
,
align_corners
=
True
,
padding_mode
=
'zeros'
)
# (B*N_views, group_size, D*fH_stereo, fW_stereo)
# (B*N_views, group_size, fH_stereo, fW_stereo)
curr_tmp
=
curr
[:,
fid
*
group_size
:(
fid
+
1
)
*
group_size
,
...]
# (B*N_views, group_size, 1, fH_stereo, fW_stereo) - (B*N_views, group_size, D, fH_stereo, fW_stereo)
# --> (B*N_views, group_size, D, fH_stereo, fW_stereo)
# https://github.com/HuangJunJie2017/BEVDet/issues/278
cost_volumn_tmp
=
curr_tmp
.
unsqueeze
(
2
)
-
\
wrap_prev
.
view
(
B
*
N
,
-
1
,
D
,
H
,
W
)
cost_volumn_tmp
=
cost_volumn_tmp
.
abs
().
sum
(
dim
=
1
)
# (B*N_views, D, fH_stereo, fW_stereo)
cost_volumn
+=
cost_volumn_tmp
# (B*N_views, D, fH_stereo, fW_stereo)
if
not
self
.
bias
==
0
:
invalid
=
wrap_prev
[:,
0
,
...].
view
(
B
*
N
,
D
,
H
,
W
)
==
0
cost_volumn
[
invalid
]
=
cost_volumn
[
invalid
]
+
self
.
bias
# matching cost --> prob
cost_volumn
=
-
cost_volumn
cost_volumn
=
cost_volumn
.
softmax
(
dim
=
1
)
return
cost_volumn
# ----------------------------------------- 用于建立cost volume --------------------------------------
def
forward
(
self
,
x
,
mlp_input
,
stereo_metas
=
None
):
"""
Args:
x: (B*N_views, C, fH, fW)
mlp_input: (B, N_views, 27)
stereo_metas: None or dict{
k2s_sensor: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
frustum: (D, fH_stereo, fW_stereo, 3) 3:(u, v, d)
cv_downsample: 4,
downsample: self.img_view_transformer.downsample=16,
grid_config: self.img_view_transformer.grid_config,
cv_feat_list: [feat_prev_iv, stereo_feat]
}
Returns:
x: (B*N_views, D+C_context, fH, fW)
"""
mlp_input
=
self
.
bn
(
mlp_input
.
reshape
(
-
1
,
mlp_input
.
shape
[
-
1
]))
# (B*N_views, 27)
x
=
self
.
reduce_conv
(
x
)
# (B*N_views, C_mid, fH, fW)
# (B*N_views, 27) --> (B*N_views, C_mid) --> (B*N_views, C_mid, 1, 1)
context_se
=
self
.
context_mlp
(
mlp_input
)[...,
None
,
None
]
context
=
self
.
context_se
(
x
,
context_se
)
# (B*N_views, C_mid, fH, fW)
context
=
self
.
context_conv
(
context
)
# (B*N_views, C_context, fH, fW)
# (B*N_views, 27) --> (B*N_views, C_mid) --> (B*N_views, C_mid, 1, 1)
depth_se
=
self
.
depth_mlp
(
mlp_input
)[...,
None
,
None
]
depth
=
self
.
depth_se
(
x
,
depth_se
)
# (B*N_views, C_mid, fH, fW)
if
not
stereo_metas
is
None
:
if
stereo_metas
[
'cv_feat_list'
][
0
]
is
None
:
BN
,
_
,
H
,
W
=
x
.
shape
scale_factor
=
float
(
stereo_metas
[
'downsample'
])
/
\
stereo_metas
[
'cv_downsample'
]
cost_volumn
=
\
torch
.
zeros
((
BN
,
self
.
depth_channels
,
int
(
H
*
scale_factor
),
int
(
W
*
scale_factor
))).
to
(
x
)
else
:
with
torch
.
no_grad
():
# https://github.com/HuangJunJie2017/BEVDet/issues/278
cost_volumn
=
self
.
calculate_cost_volumn
(
stereo_metas
)
# (B*N_views, D, fH_stereo, fW_stereo)
cost_volumn
=
self
.
cost_volumn_net
(
cost_volumn
)
# (B*N_views, D, fH, fW)
depth
=
torch
.
cat
([
depth
,
cost_volumn
],
dim
=
1
)
# (B*N_views, C_mid+D, fH, fW)
if
self
.
with_cp
:
depth
=
checkpoint
(
self
.
depth_conv
,
depth
)
else
:
# 3*res blocks +ASPP/DCN + Conv(c_mid-->D)
depth
=
self
.
depth_conv
(
depth
)
# x: (B*N_views, C_mid, fH, fW) --> (B*N_views, D, fH, fW)
return
torch
.
cat
([
depth
,
context
],
dim
=
1
)
class
DepthAggregation
(
nn
.
Module
):
"""pixel cloud feature extraction."""
def
__init__
(
self
,
in_channels
,
mid_channels
,
out_channels
):
super
(
DepthAggregation
,
self
).
__init__
()
self
.
reduce_conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
mid_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
mid_channels
),
nn
.
ReLU
(
inplace
=
True
),
)
self
.
conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
mid_channels
,
mid_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
mid_channels
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv2d
(
mid_channels
,
mid_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
mid_channels
),
nn
.
ReLU
(
inplace
=
True
),
)
self
.
out_conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
mid_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
),
# nn.BatchNorm3d(out_channels),
# nn.ReLU(inplace=True),
)
@
autocast
(
False
)
def
forward
(
self
,
x
):
x
=
checkpoint
(
self
.
reduce_conv
,
x
)
short_cut
=
x
x
=
checkpoint
(
self
.
conv
,
x
)
x
=
short_cut
+
x
x
=
self
.
out_conv
(
x
)
return
x
\ No newline at end of file
Prev
1
…
3
4
5
6
7
8
9
10
11
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment