Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
e2ead7e9
Commit
e2ead7e9
authored
Jun 01, 2022
by
ZCMax
Committed by
ChaimZhu
Jul 20, 2022
Browse files
Refactor SMOKEHEAD
parent
5db1ead3
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
121 additions
and
115 deletions
+121
-115
mmdet3d/models/dense_heads/smoke_mono3d_head.py
mmdet3d/models/dense_heads/smoke_mono3d_head.py
+121
-115
No files found.
mmdet3d/models/dense_heads/smoke_mono3d_head.py
View file @
e2ead7e9
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
mmcv.runner
import
force_fp32
from
mmengine.config
import
ConfigDict
from
torch
import
Tensor
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
mmdet3d.registry
import
MODELS
from
mmdet3d.registry
import
MODELS
...
@@ -30,8 +35,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
...
@@ -30,8 +35,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
regression heatmap channels.
regression heatmap channels.
ori_channel (list[int]): indices of orientation offset pred in
ori_channel (list[int]): indices of orientation offset pred in
regression heatmap channels.
regression heatmap channels.
bbox_coder (:obj:`CameraInstance3DBoxes`): Bbox coder
bbox_coder (dict): Bbox coder for encoding and decoding boxes.
for encoding and decoding boxes.
loss_cls (dict, optional): Config of classification loss.
loss_cls (dict, optional): Config of classification loss.
Default: loss_cls=dict(type='GaussionFocalLoss', loss_weight=1.0).
Default: loss_cls=dict(type='GaussionFocalLoss', loss_weight=1.0).
loss_bbox (dict, optional): Config of localization loss.
loss_bbox (dict, optional): Config of localization loss.
...
@@ -47,18 +51,20 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
...
@@ -47,18 +51,20 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
"""
# noqa: E501
"""
# noqa: E501
def
__init__
(
self
,
def
__init__
(
self
,
num_classes
,
num_classes
:
int
,
in_channels
,
in_channels
:
int
,
dim_channel
,
dim_channel
:
List
[
int
],
ori_channel
,
ori_channel
:
List
[
int
],
bbox_coder
,
bbox_coder
:
dict
,
loss_cls
=
dict
(
type
=
'GaussionFocalLoss'
,
loss_weight
=
1.0
),
loss_cls
:
dict
=
dict
(
loss_bbox
=
dict
(
type
=
'L1Loss'
,
loss_weight
=
0.1
),
type
=
'GaussionFocalLoss'
,
loss_weight
=
1.0
),
loss_dir
=
None
,
loss_bbox
:
dict
=
dict
(
type
=
'L1Loss'
,
loss_weight
=
0.1
),
loss_attr
=
None
,
loss_dir
:
Optional
[
dict
]
=
None
,
norm_cfg
=
dict
(
type
=
'GN'
,
num_groups
=
32
,
requires_grad
=
True
),
loss_attr
:
Optional
[
dict
]
=
None
,
init_cfg
=
None
,
norm_cfg
:
dict
=
dict
(
**
kwargs
):
type
=
'GN'
,
num_groups
=
32
,
requires_grad
=
True
),
init_cfg
:
Optional
[
Union
[
ConfigDict
,
dict
]]
=
None
,
**
kwargs
)
->
None
:
super
().
__init__
(
super
().
__init__
(
num_classes
,
num_classes
,
in_channels
,
in_channels
,
...
@@ -73,7 +79,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
...
@@ -73,7 +79,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
self
.
ori_channel
=
ori_channel
self
.
ori_channel
=
ori_channel
self
.
bbox_coder
=
build_bbox_coder
(
bbox_coder
)
self
.
bbox_coder
=
build_bbox_coder
(
bbox_coder
)
def
forward
(
self
,
feats
):
def
forward
(
self
,
feats
:
Tuple
[
Tensor
]
):
"""Forward features from the upstream network.
"""Forward features from the upstream network.
Args:
Args:
...
@@ -91,7 +97,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
...
@@ -91,7 +97,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
"""
"""
return
multi_apply
(
self
.
forward_single
,
feats
)
return
multi_apply
(
self
.
forward_single
,
feats
)
def
forward_single
(
self
,
x
)
:
def
forward_single
(
self
,
x
:
Tensor
)
->
Union
[
Tensor
,
Tensor
]
:
"""Forward features of a single scale level.
"""Forward features of a single scale level.
Args:
Args:
...
@@ -112,13 +118,18 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
...
@@ -112,13 +118,18 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
bbox_pred
[:,
self
.
ori_channel
,
...]
=
F
.
normalize
(
vector_ori
)
bbox_pred
[:,
self
.
ori_channel
,
...]
=
F
.
normalize
(
vector_ori
)
return
cls_score
,
bbox_pred
return
cls_score
,
bbox_pred
def
get_bboxes
(
self
,
cls_scores
,
bbox_preds
,
img_metas
,
rescale
=
None
):
@
force_fp32
(
apply_to
=
(
'cls_scores'
,
'bbox_preds'
))
def
get_results
(
self
,
cls_scores
,
bbox_preds
,
batch_img_metas
,
rescale
=
None
):
"""Generate bboxes from bbox head predictions.
"""Generate bboxes from bbox head predictions.
Args:
Args:
cls_scores (list[Tensor]): Box scores for each scale level.
cls_scores (list[Tensor]): Box scores for each scale level.
bbox_preds (list[Tensor]): Box regression for each scale.
bbox_preds (list[Tensor]): Box regression for each scale.
img_metas (list[dict]): Meta information of each image, e.g.,
batch_
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
image size, scaling factor, etc.
rescale (bool): If True, return boxes in original image space.
rescale (bool): If True, return boxes in original image space.
...
@@ -128,24 +139,24 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
...
@@ -128,24 +139,24 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
"""
"""
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
==
1
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
==
1
cam2imgs
=
torch
.
stack
([
cam2imgs
=
torch
.
stack
([
cls_scores
[
0
].
new_tensor
(
img_meta
[
'cam2img'
])
cls_scores
[
0
].
new_tensor
(
img_meta
s
[
'cam2img'
])
for
img_meta
in
img_metas
for
img_meta
s
in
batch_
img_metas
])
])
trans_mats
=
torch
.
stack
([
trans_mats
=
torch
.
stack
([
cls_scores
[
0
].
new_tensor
(
img_meta
[
'trans_mat'
])
cls_scores
[
0
].
new_tensor
(
img_meta
s
[
'trans_mat'
])
for
img_meta
in
img_metas
for
img_meta
s
in
batch_
img_metas
])
])
batch_bboxes
,
batch_scores
,
batch_topk_labels
=
self
.
decode_heatmap
(
batch_bboxes
,
batch_scores
,
batch_topk_labels
=
self
.
decode_heatmap
(
cls_scores
[
0
],
cls_scores
[
0
],
bbox_preds
[
0
],
bbox_preds
[
0
],
img_metas
,
batch_
img_metas
,
cam2imgs
=
cam2imgs
,
cam2imgs
=
cam2imgs
,
trans_mats
=
trans_mats
,
trans_mats
=
trans_mats
,
topk
=
100
,
topk
=
100
,
kernel
=
3
)
kernel
=
3
)
result_list
=
[]
result_list
=
[]
for
img_id
in
range
(
len
(
img_metas
)):
for
img_id
in
range
(
len
(
batch_
img_metas
)):
bboxes
=
batch_bboxes
[
img_id
]
bboxes
=
batch_bboxes
[
img_id
]
scores
=
batch_scores
[
img_id
]
scores
=
batch_scores
[
img_id
]
...
@@ -156,7 +167,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
...
@@ -156,7 +167,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
scores
=
scores
[
keep_idx
]
scores
=
scores
[
keep_idx
]
labels
=
labels
[
keep_idx
]
labels
=
labels
[
keep_idx
]
bboxes
=
img_metas
[
img_id
][
'box_type_3d'
](
bboxes
=
batch_
img_metas
[
img_id
][
'box_type_3d'
](
bboxes
,
box_dim
=
self
.
bbox_code_size
,
origin
=
(
0.5
,
0.5
,
0.5
))
bboxes
,
box_dim
=
self
.
bbox_code_size
,
origin
=
(
0.5
,
0.5
,
0.5
))
attrs
=
None
attrs
=
None
result_list
.
append
((
bboxes
,
scores
,
labels
,
attrs
))
result_list
.
append
((
bboxes
,
scores
,
labels
,
attrs
))
...
@@ -166,7 +177,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
...
@@ -166,7 +177,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
def
decode_heatmap
(
self
,
def
decode_heatmap
(
self
,
cls_score
,
cls_score
,
reg_pred
,
reg_pred
,
img_metas
,
batch_
img_metas
,
cam2imgs
,
cam2imgs
,
trans_mats
,
trans_mats
,
topk
=
100
,
topk
=
100
,
...
@@ -178,7 +189,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
...
@@ -178,7 +189,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
shape (B, num_classes, H, W).
shape (B, num_classes, H, W).
reg_pred (Tensor): Box regression map.
reg_pred (Tensor): Box regression map.
shape (B, channel, H , W).
shape (B, channel, H , W).
img_metas (
L
ist[dict]): Meta information of each image, e.g.,
batch_
img_metas (
l
ist[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
image size, scaling factor, etc.
cam2imgs (Tensor): Camera intrinsic matrixs.
cam2imgs (Tensor): Camera intrinsic matrixs.
shape (B, 4, 4)
shape (B, 4, 4)
...
@@ -199,7 +210,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
...
@@ -199,7 +210,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
- batch_topk_labels (Tensor): Categories of each 3D box.
- batch_topk_labels (Tensor): Categories of each 3D box.
shape (B, k)
shape (B, k)
"""
"""
img_h
,
img_w
=
img_metas
[
0
][
'pad_shape'
][:
2
]
img_h
,
img_w
=
batch_
img_metas
[
0
][
'pad_shape'
][:
2
]
bs
,
_
,
feat_h
,
feat_w
=
cls_score
.
shape
bs
,
_
,
feat_h
,
feat_w
=
cls_score
.
shape
center_heatmap_pred
=
get_local_maximum
(
cls_score
,
kernel
=
kernel
)
center_heatmap_pred
=
get_local_maximum
(
cls_score
,
kernel
=
kernel
)
...
@@ -221,14 +232,15 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
...
@@ -221,14 +232,15 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
batch_bboxes
=
batch_bboxes
.
view
(
bs
,
-
1
,
self
.
bbox_code_size
)
batch_bboxes
=
batch_bboxes
.
view
(
bs
,
-
1
,
self
.
bbox_code_size
)
return
batch_bboxes
,
batch_scores
,
batch_topk_labels
return
batch_bboxes
,
batch_scores
,
batch_topk_labels
def
get_predictions
(
self
,
labels3d
,
centers2d
,
gt_locations
,
gt_dimensions
,
def
get_predictions
(
self
,
labels_3d
,
centers_2d
,
gt_locations
,
gt_orientations
,
indices
,
img_metas
,
pred_reg
):
gt_dimensions
,
gt_orientations
,
indices
,
batch_img_metas
,
pred_reg
):
"""Prepare predictions for computing loss.
"""Prepare predictions for computing loss.
Args:
Args:
labels3d (Tensor): Labels of each 3D box.
labels
_
3d (Tensor): Labels of each 3D box.
shape (B, max_objs, )
shape (B, max_objs, )
centers2d (Tensor): Coords of each projected 3D box
centers
_
2d (Tensor): Coords of each projected 3D box
center on image. shape (B * max_objs, 2)
center on image. shape (B * max_objs, 2)
gt_locations (Tensor): Coords of each 3D box's location.
gt_locations (Tensor): Coords of each 3D box's location.
shape (B * max_objs, 3)
shape (B * max_objs, 3)
...
@@ -238,8 +250,8 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
...
@@ -238,8 +250,8 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
shape (N, 1)
shape (N, 1)
indices (Tensor): Indices of the existence of the 3D box.
indices (Tensor): Indices of the existence of the 3D box.
shape (B * max_objs, )
shape (B * max_objs, )
img_metas (list[dict]): Meta information of each image,
batch_
img_metas (list[dict]): Meta information of each image,
e.g.,
e.g.,
image size, scaling factor, etc.
image size, scaling factor, etc.
pre_reg (Tensor): Box regression map.
pre_reg (Tensor): Box regression map.
shape (B, channel, H , W).
shape (B, channel, H , W).
...
@@ -255,19 +267,19 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
...
@@ -255,19 +267,19 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
batch
,
channel
=
pred_reg
.
shape
[
0
],
pred_reg
.
shape
[
1
]
batch
,
channel
=
pred_reg
.
shape
[
0
],
pred_reg
.
shape
[
1
]
w
=
pred_reg
.
shape
[
3
]
w
=
pred_reg
.
shape
[
3
]
cam2imgs
=
torch
.
stack
([
cam2imgs
=
torch
.
stack
([
gt_locations
.
new_tensor
(
img_meta
[
'cam2img'
])
gt_locations
.
new_tensor
(
img_meta
s
[
'cam2img'
])
for
img_meta
in
img_metas
for
img_meta
s
in
batch_
img_metas
])
])
trans_mats
=
torch
.
stack
([
trans_mats
=
torch
.
stack
([
gt_locations
.
new_tensor
(
img_meta
[
'trans_mat'
])
gt_locations
.
new_tensor
(
img_meta
s
[
'trans_mat'
])
for
img_meta
in
img_metas
for
img_meta
s
in
batch_
img_metas
])
])
centers2d_inds
=
centers2d
[:,
1
]
*
w
+
centers2d
[:,
0
]
centers
_
2d_inds
=
centers
_
2d
[:,
1
]
*
w
+
centers
_
2d
[:,
0
]
centers2d_inds
=
centers2d_inds
.
view
(
batch
,
-
1
)
centers
_
2d_inds
=
centers
_
2d_inds
.
view
(
batch
,
-
1
)
pred_regression
=
transpose_and_gather_feat
(
pred_reg
,
centers2d_inds
)
pred_regression
=
transpose_and_gather_feat
(
pred_reg
,
centers
_
2d_inds
)
pred_regression_pois
=
pred_regression
.
view
(
-
1
,
channel
)
pred_regression_pois
=
pred_regression
.
view
(
-
1
,
channel
)
locations
,
dimensions
,
orientations
=
self
.
bbox_coder
.
decode
(
locations
,
dimensions
,
orientations
=
self
.
bbox_coder
.
decode
(
pred_regression_pois
,
centers2d
,
labels3d
,
cam2imgs
,
trans_mats
,
pred_regression_pois
,
centers
_
2d
,
labels
_
3d
,
cam2imgs
,
trans_mats
,
gt_locations
)
gt_locations
)
locations
,
dimensions
,
orientations
=
locations
[
indices
],
dimensions
[
locations
,
dimensions
,
orientations
=
locations
[
indices
],
dimensions
[
...
@@ -281,44 +293,35 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
...
@@ -281,44 +293,35 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
assert
len
(
dimensions
)
==
len
(
gt_dimensions
)
assert
len
(
dimensions
)
==
len
(
gt_dimensions
)
assert
len
(
orientations
)
==
len
(
gt_orientations
)
assert
len
(
orientations
)
==
len
(
gt_orientations
)
bbox3d_yaws
=
self
.
bbox_coder
.
encode
(
gt_locations
,
gt_dimensions
,
bbox3d_yaws
=
self
.
bbox_coder
.
encode
(
gt_locations
,
gt_dimensions
,
orientations
,
img_metas
)
orientations
,
batch_
img_metas
)
bbox3d_dims
=
self
.
bbox_coder
.
encode
(
gt_locations
,
dimensions
,
bbox3d_dims
=
self
.
bbox_coder
.
encode
(
gt_locations
,
dimensions
,
gt_orientations
,
img_metas
)
gt_orientations
,
batch_
img_metas
)
bbox3d_locs
=
self
.
bbox_coder
.
encode
(
locations
,
gt_dimensions
,
bbox3d_locs
=
self
.
bbox_coder
.
encode
(
locations
,
gt_dimensions
,
gt_orientations
,
img_metas
)
gt_orientations
,
batch_
img_metas
)
pred_bboxes
=
dict
(
ori
=
bbox3d_yaws
,
dim
=
bbox3d_dims
,
loc
=
bbox3d_locs
)
pred_bboxes
=
dict
(
ori
=
bbox3d_yaws
,
dim
=
bbox3d_dims
,
loc
=
bbox3d_locs
)
return
pred_bboxes
return
pred_bboxes
def
get_targets
(
self
,
gt_bboxes
,
gt_labels
,
gt_bboxes_3d
,
gt_labels_3d
,
def
get_targets
(
self
,
batch_gt_instances_3d
,
feat_shape
,
batch_img_metas
):
centers2d
,
feat_shape
,
img_shape
,
img_metas
):
"""Get training targets for batch images.
"""Get training targets for batch images.
Args:
Args:
gt_bboxes (list[Tensor]): Ground truth bboxes of each image,
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
shape (num_gt, 4).
gt_instance_3d. It usually includes ``bboxes``、``labels``
gt_labels (list[Tensor]): Ground truth labels of each box,
、``bboxes_3d``、``labels_3d``、``depths``、``centers_2d`` and
shape (num_gt,).
attributes.
gt_bboxes_3d (list[:obj:`CameraInstance3DBoxes`]): 3D Ground
truth bboxes of each image,
shape (num_gt, bbox_code_size).
gt_labels_3d (list[Tensor]): 3D Ground truth labels of each
box, shape (num_gt,).
centers2d (list[Tensor]): Projected 3D centers onto 2D image,
shape (num_gt, 2).
feat_shape (tuple[int]): Feature map shape with value,
feat_shape (tuple[int]): Feature map shape with value,
shape (B, _, H, W).
shape (B, _, H, W).
img_shape (tuple[int]): Image shape in [h, w] format.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
image size, scaling factor, etc.
Returns:
Returns:
tuple[Tensor, dict]: The Tensor value is the targets of
tuple[Tensor, dict]: The Tensor value is the targets of
center heatmap, the dict has components below:
center heatmap, the dict has components below:
- gt_centers2d (Tensor): Coords of each projected 3D box
- gt_centers
_
2d (Tensor): Coords of each projected 3D box
center on image. shape (B * max_objs, 2)
center on image. shape (B * max_objs, 2)
- gt_labels3d (Tensor): Labels of each 3D box.
- gt_labels
_
3d (Tensor): Labels of each 3D box.
shape (B, max_objs, )
shape (B, max_objs, )
- indices (Tensor): Indices of the existence of the 3D box.
- indices (Tensor): Indices of the existence of the 3D box.
shape (B * max_objs, )
shape (B * max_objs, )
...
@@ -334,10 +337,30 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
...
@@ -334,10 +337,30 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
shape (N, 8, 3)
shape (N, 8, 3)
"""
"""
gt_bboxes
=
[
gt_instances_3d
.
bboxes
for
gt_instances_3d
in
batch_gt_instances_3d
]
gt_labels
=
[
gt_instances_3d
.
labels
for
gt_instances_3d
in
batch_gt_instances_3d
]
gt_bboxes_3d
=
[
gt_instances_3d
.
bboxes_3d
for
gt_instances_3d
in
batch_gt_instances_3d
]
gt_labels_3d
=
[
gt_instances_3d
.
labels_3d
for
gt_instances_3d
in
batch_gt_instances_3d
]
centers_2d
=
[
gt_instances_3d
.
centers_2d
for
gt_instances_3d
in
batch_gt_instances_3d
]
img_shape
=
batch_img_metas
[
0
][
'pad_shape'
]
reg_mask
=
torch
.
stack
([
reg_mask
=
torch
.
stack
([
gt_bboxes
[
0
].
new_tensor
(
gt_bboxes
[
0
].
new_tensor
(
not
img_meta
[
'affine_aug'
],
dtype
=
torch
.
bool
)
not
img_meta
s
[
'affine_aug'
],
dtype
=
torch
.
bool
)
for
img_meta
in
img_metas
for
img_meta
s
in
batch_
img_metas
])
])
img_h
,
img_w
=
img_shape
[:
2
]
img_h
,
img_w
=
img_shape
[:
2
]
...
@@ -351,15 +374,15 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
...
@@ -351,15 +374,15 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
center_heatmap_target
=
gt_bboxes
[
-
1
].
new_zeros
(
center_heatmap_target
=
gt_bboxes
[
-
1
].
new_zeros
(
[
bs
,
self
.
num_classes
,
feat_h
,
feat_w
])
[
bs
,
self
.
num_classes
,
feat_h
,
feat_w
])
gt_centers2d
=
centers2d
.
copy
()
gt_centers
_
2d
=
centers
_
2d
.
copy
()
for
batch_id
in
range
(
bs
):
for
batch_id
in
range
(
bs
):
gt_bbox
=
gt_bboxes
[
batch_id
]
gt_bbox
=
gt_bboxes
[
batch_id
]
gt_label
=
gt_labels
[
batch_id
]
gt_label
=
gt_labels
[
batch_id
]
# project centers2d from input image to feat map
# project centers
_
2d from input image to feat map
gt_center2d
=
gt_centers2d
[
batch_id
]
*
width_ratio
gt_center
_
2d
=
gt_centers
_
2d
[
batch_id
]
*
width_ratio
for
j
,
center
in
enumerate
(
gt_center2d
):
for
j
,
center
in
enumerate
(
gt_center
_
2d
):
center_x_int
,
center_y_int
=
center
.
int
()
center_x_int
,
center_y_int
=
center
.
int
()
scale_box_h
=
(
gt_bbox
[
j
][
3
]
-
gt_bbox
[
j
][
1
])
*
height_ratio
scale_box_h
=
(
gt_bbox
[
j
][
3
]
-
gt_bbox
[
j
][
1
])
*
height_ratio
scale_box_w
=
(
gt_bbox
[
j
][
2
]
-
gt_bbox
[
j
][
0
])
*
width_ratio
scale_box_w
=
(
gt_bbox
[
j
][
2
]
-
gt_bbox
[
j
][
0
])
*
width_ratio
...
@@ -371,33 +394,33 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
...
@@ -371,33 +394,33 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
[
center_x_int
,
center_y_int
],
radius
)
[
center_x_int
,
center_y_int
],
radius
)
avg_factor
=
max
(
1
,
center_heatmap_target
.
eq
(
1
).
sum
())
avg_factor
=
max
(
1
,
center_heatmap_target
.
eq
(
1
).
sum
())
num_ctrs
=
[
center2d
.
shape
[
0
]
for
center2d
in
centers2d
]
num_ctrs
=
[
center
_
2d
.
shape
[
0
]
for
center
_
2d
in
centers
_
2d
]
max_objs
=
max
(
num_ctrs
)
max_objs
=
max
(
num_ctrs
)
reg_inds
=
torch
.
cat
(
reg_inds
=
torch
.
cat
(
[
reg_mask
[
i
].
repeat
(
num_ctrs
[
i
])
for
i
in
range
(
bs
)])
[
reg_mask
[
i
].
repeat
(
num_ctrs
[
i
])
for
i
in
range
(
bs
)])
inds
=
torch
.
zeros
((
bs
,
max_objs
),
inds
=
torch
.
zeros
((
bs
,
max_objs
),
dtype
=
torch
.
bool
).
to
(
centers2d
[
0
].
device
)
dtype
=
torch
.
bool
).
to
(
centers
_
2d
[
0
].
device
)
# put gt 3d bboxes to gpu
# put gt 3d bboxes to gpu
gt_bboxes_3d
=
[
gt_bboxes_3d
=
[
gt_bbox_3d
.
to
(
centers2d
[
0
].
device
)
for
gt_bbox_3d
in
gt_bboxes_3d
gt_bbox_3d
.
to
(
centers
_
2d
[
0
].
device
)
for
gt_bbox_3d
in
gt_bboxes_3d
]
]
batch_centers2d
=
centers2d
[
0
].
new_zeros
((
bs
,
max_objs
,
2
))
batch_centers
_
2d
=
centers
_
2d
[
0
].
new_zeros
((
bs
,
max_objs
,
2
))
batch_labels_3d
=
gt_labels_3d
[
0
].
new_zeros
((
bs
,
max_objs
))
batch_labels_3d
=
gt_labels_3d
[
0
].
new_zeros
((
bs
,
max_objs
))
batch_gt_locations
=
\
batch_gt_locations
=
\
gt_bboxes_3d
[
0
].
tensor
.
new_zeros
((
bs
,
max_objs
,
3
))
gt_bboxes_3d
[
0
].
tensor
.
new_zeros
((
bs
,
max_objs
,
3
))
for
i
in
range
(
bs
):
for
i
in
range
(
bs
):
inds
[
i
,
:
num_ctrs
[
i
]]
=
1
inds
[
i
,
:
num_ctrs
[
i
]]
=
1
batch_centers2d
[
i
,
:
num_ctrs
[
i
]]
=
centers2d
[
i
]
batch_centers
_
2d
[
i
,
:
num_ctrs
[
i
]]
=
centers
_
2d
[
i
]
batch_labels_3d
[
i
,
:
num_ctrs
[
i
]]
=
gt_labels_3d
[
i
]
batch_labels_3d
[
i
,
:
num_ctrs
[
i
]]
=
gt_labels_3d
[
i
]
batch_gt_locations
[
i
,
:
num_ctrs
[
i
]]
=
\
batch_gt_locations
[
i
,
:
num_ctrs
[
i
]]
=
\
gt_bboxes_3d
[
i
].
tensor
[:,
:
3
]
gt_bboxes_3d
[
i
].
tensor
[:,
:
3
]
inds
=
inds
.
flatten
()
inds
=
inds
.
flatten
()
batch_centers2d
=
batch_centers2d
.
view
(
-
1
,
2
)
*
width_ratio
batch_centers
_
2d
=
batch_centers
_
2d
.
view
(
-
1
,
2
)
*
width_ratio
batch_gt_locations
=
batch_gt_locations
.
view
(
-
1
,
3
)
batch_gt_locations
=
batch_gt_locations
.
view
(
-
1
,
3
)
# filter the empty image, without gt_bboxes_3d
# filter the empty image, without gt_bboxes_3d
...
@@ -416,8 +439,8 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
...
@@ -416,8 +439,8 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
[
gt_bbox_3d
.
corners
for
gt_bbox_3d
in
gt_bboxes_3d
])
[
gt_bbox_3d
.
corners
for
gt_bbox_3d
in
gt_bboxes_3d
])
target_labels
=
dict
(
target_labels
=
dict
(
gt_centers2d
=
batch_centers2d
.
long
(),
gt_centers
_
2d
=
batch_centers
_
2d
.
long
(),
gt_labels3d
=
batch_labels_3d
,
gt_labels
_
3d
=
batch_labels_3d
,
indices
=
inds
,
indices
=
inds
,
reg_indices
=
reg_inds
,
reg_indices
=
reg_inds
,
gt_locs
=
batch_gt_locations
,
gt_locs
=
batch_gt_locations
,
...
@@ -430,15 +453,9 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
...
@@ -430,15 +453,9 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
def
loss
(
self
,
def
loss
(
self
,
cls_scores
,
cls_scores
,
bbox_preds
,
bbox_preds
,
gt_bboxes
,
batch_gt_instances_3d
,
gt_labels
,
batch_img_metas
,
gt_bboxes_3d
,
batch_gt_instances_ignore
=
None
):
gt_labels_3d
,
centers2d
,
depths
,
attr_labels
,
img_metas
,
gt_bboxes_ignore
=
None
):
"""Compute loss of the head.
"""Compute loss of the head.
Args:
Args:
...
@@ -447,53 +464,42 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
...
@@ -447,53 +464,42 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
bbox_preds (list[Tensor]): Box dims is a 4D-tensor, the channel
bbox_preds (list[Tensor]): Box dims is a 4D-tensor, the channel
number is bbox_code_size.
number is bbox_code_size.
shape (B, 7, H, W).
shape (B, 7, H, W).
gt_bboxes (list[Tensor]): Ground truth bboxes for each image.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_instance_3d. It usually includes ``bboxes``、``labels``
gt_labels (list[Tensor]): Class indices corresponding to each box.
、``bboxes_3d``、``labels_3d``、``depths``、``centers_2d`` and
shape (num_gts, ).
attributes.
gt_bboxes_3d (list[:obj:`CameraInstance3DBoxes`]): 3D boxes ground
batch_img_metas (list[dict]): Meta information of each image, e.g.,
truth. it is the flipped gt_bboxes
gt_labels_3d (list[Tensor]): Same as gt_labels.
centers2d (list[Tensor]): 2D centers on the image.
shape (num_gts, 2).
depths (list[Tensor]): Depth ground truth.
shape (num_gts, ).
attr_labels (list[Tensor]): Attributes indices of each box.
In kitti it's None.
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
image size, scaling factor, etc.
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
boxes can be ignored when computing the loss.
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
Default: None.
data that is ignored during training and testing.
Defaults to None.
Returns:
Returns:
dict[str, Tensor]: A dictionary of loss components.
dict[str, Tensor]: A dictionary of loss components.
"""
"""
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
==
1
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
==
1
assert
attr_labels
is
None
assert
batch_gt_instances_ignore
is
None
assert
gt_bboxes_ignore
is
None
center_2d_heatmap
=
cls_scores
[
0
]
center2d_heatmap
=
cls_scores
[
0
]
pred_reg
=
bbox_preds
[
0
]
pred_reg
=
bbox_preds
[
0
]
center2d_heatmap_target
,
avg_factor
,
target_labels
=
\
center_2d_heatmap_target
,
avg_factor
,
target_labels
=
\
self
.
get_targets
(
gt_bboxes
,
gt_labels
,
gt_bboxes_3d
,
self
.
get_targets
(
batch_gt_instances_3d
,
gt_labels_3d
,
centers2d
,
center_2d_heatmap
.
shape
,
center2d_heatmap
.
shape
,
batch_img_metas
)
img_metas
[
0
][
'pad_shape'
],
img_metas
)
pred_bboxes
=
self
.
get_predictions
(
pred_bboxes
=
self
.
get_predictions
(
labels3d
=
target_labels
[
'gt_labels3d'
],
labels
_
3d
=
target_labels
[
'gt_labels
_
3d'
],
centers2d
=
target_labels
[
'gt_centers2d'
],
centers
_
2d
=
target_labels
[
'gt_centers
_
2d'
],
gt_locations
=
target_labels
[
'gt_locs'
],
gt_locations
=
target_labels
[
'gt_locs'
],
gt_dimensions
=
target_labels
[
'gt_dims'
],
gt_dimensions
=
target_labels
[
'gt_dims'
],
gt_orientations
=
target_labels
[
'gt_yaws'
],
gt_orientations
=
target_labels
[
'gt_yaws'
],
indices
=
target_labels
[
'indices'
],
indices
=
target_labels
[
'indices'
],
img_metas
=
img_metas
,
batch_
img_metas
=
batch_
img_metas
,
pred_reg
=
pred_reg
)
pred_reg
=
pred_reg
)
loss_cls
=
self
.
loss_cls
(
loss_cls
=
self
.
loss_cls
(
center2d_heatmap
,
center2d_heatmap_target
,
avg_factor
=
avg_factor
)
center
_
2d_heatmap
,
center
_
2d_heatmap_target
,
avg_factor
=
avg_factor
)
reg_inds
=
target_labels
[
'reg_indices'
]
reg_inds
=
target_labels
[
'reg_indices'
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment