Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
e2ead7e9
Commit
e2ead7e9
authored
Jun 01, 2022
by
ZCMax
Committed by
ChaimZhu
Jul 20, 2022
Browse files
Refactor SMOKEHEAD
parent
5db1ead3
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
121 additions
and
115 deletions
+121
-115
mmdet3d/models/dense_heads/smoke_mono3d_head.py
mmdet3d/models/dense_heads/smoke_mono3d_head.py
+121
-115
No files found.
mmdet3d/models/dense_heads/smoke_mono3d_head.py
View file @
e2ead7e9
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
from
mmcv.runner
import
force_fp32
from
mmengine.config
import
ConfigDict
from
torch
import
Tensor
from
torch.nn
import
functional
as
F
from
mmdet3d.registry
import
MODELS
...
...
@@ -30,8 +35,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
regression heatmap channels.
ori_channel (list[int]): indices of orientation offset pred in
regression heatmap channels.
bbox_coder (:obj:`CameraInstance3DBoxes`): Bbox coder
for encoding and decoding boxes.
bbox_coder (dict): Bbox coder for encoding and decoding boxes.
loss_cls (dict, optional): Config of classification loss.
Default: loss_cls=dict(type='GaussionFocalLoss', loss_weight=1.0).
loss_bbox (dict, optional): Config of localization loss.
...
...
@@ -47,18 +51,20 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
"""
# noqa: E501
def
__init__
(
self
,
num_classes
,
in_channels
,
dim_channel
,
ori_channel
,
bbox_coder
,
loss_cls
=
dict
(
type
=
'GaussionFocalLoss'
,
loss_weight
=
1.0
),
loss_bbox
=
dict
(
type
=
'L1Loss'
,
loss_weight
=
0.1
),
loss_dir
=
None
,
loss_attr
=
None
,
norm_cfg
=
dict
(
type
=
'GN'
,
num_groups
=
32
,
requires_grad
=
True
),
init_cfg
=
None
,
**
kwargs
):
num_classes
:
int
,
in_channels
:
int
,
dim_channel
:
List
[
int
],
ori_channel
:
List
[
int
],
bbox_coder
:
dict
,
loss_cls
:
dict
=
dict
(
type
=
'GaussionFocalLoss'
,
loss_weight
=
1.0
),
loss_bbox
:
dict
=
dict
(
type
=
'L1Loss'
,
loss_weight
=
0.1
),
loss_dir
:
Optional
[
dict
]
=
None
,
loss_attr
:
Optional
[
dict
]
=
None
,
norm_cfg
:
dict
=
dict
(
type
=
'GN'
,
num_groups
=
32
,
requires_grad
=
True
),
init_cfg
:
Optional
[
Union
[
ConfigDict
,
dict
]]
=
None
,
**
kwargs
)
->
None
:
super
().
__init__
(
num_classes
,
in_channels
,
...
...
@@ -73,7 +79,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
self
.
ori_channel
=
ori_channel
self
.
bbox_coder
=
build_bbox_coder
(
bbox_coder
)
def
forward
(
self
,
feats
):
def
forward
(
self
,
feats
:
Tuple
[
Tensor
]
):
"""Forward features from the upstream network.
Args:
...
...
@@ -91,7 +97,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
"""
return
multi_apply
(
self
.
forward_single
,
feats
)
def
forward_single
(
self
,
x
)
:
def
forward_single
(
self
,
x
:
Tensor
)
->
Union
[
Tensor
,
Tensor
]
:
"""Forward features of a single scale level.
Args:
...
...
@@ -112,13 +118,18 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
bbox_pred
[:,
self
.
ori_channel
,
...]
=
F
.
normalize
(
vector_ori
)
return
cls_score
,
bbox_pred
def
get_bboxes
(
self
,
cls_scores
,
bbox_preds
,
img_metas
,
rescale
=
None
):
@
force_fp32
(
apply_to
=
(
'cls_scores'
,
'bbox_preds'
))
def
get_results
(
self
,
cls_scores
,
bbox_preds
,
batch_img_metas
,
rescale
=
None
):
"""Generate bboxes from bbox head predictions.
Args:
cls_scores (list[Tensor]): Box scores for each scale level.
bbox_preds (list[Tensor]): Box regression for each scale.
img_metas (list[dict]): Meta information of each image, e.g.,
batch_
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
rescale (bool): If True, return boxes in original image space.
...
...
@@ -128,24 +139,24 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
"""
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
==
1
cam2imgs
=
torch
.
stack
([
cls_scores
[
0
].
new_tensor
(
img_meta
[
'cam2img'
])
for
img_meta
in
img_metas
cls_scores
[
0
].
new_tensor
(
img_meta
s
[
'cam2img'
])
for
img_meta
s
in
batch_
img_metas
])
trans_mats
=
torch
.
stack
([
cls_scores
[
0
].
new_tensor
(
img_meta
[
'trans_mat'
])
for
img_meta
in
img_metas
cls_scores
[
0
].
new_tensor
(
img_meta
s
[
'trans_mat'
])
for
img_meta
s
in
batch_
img_metas
])
batch_bboxes
,
batch_scores
,
batch_topk_labels
=
self
.
decode_heatmap
(
cls_scores
[
0
],
bbox_preds
[
0
],
img_metas
,
batch_
img_metas
,
cam2imgs
=
cam2imgs
,
trans_mats
=
trans_mats
,
topk
=
100
,
kernel
=
3
)
result_list
=
[]
for
img_id
in
range
(
len
(
img_metas
)):
for
img_id
in
range
(
len
(
batch_
img_metas
)):
bboxes
=
batch_bboxes
[
img_id
]
scores
=
batch_scores
[
img_id
]
...
...
@@ -156,7 +167,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
scores
=
scores
[
keep_idx
]
labels
=
labels
[
keep_idx
]
bboxes
=
img_metas
[
img_id
][
'box_type_3d'
](
bboxes
=
batch_
img_metas
[
img_id
][
'box_type_3d'
](
bboxes
,
box_dim
=
self
.
bbox_code_size
,
origin
=
(
0.5
,
0.5
,
0.5
))
attrs
=
None
result_list
.
append
((
bboxes
,
scores
,
labels
,
attrs
))
...
...
@@ -166,7 +177,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
def
decode_heatmap
(
self
,
cls_score
,
reg_pred
,
img_metas
,
batch_
img_metas
,
cam2imgs
,
trans_mats
,
topk
=
100
,
...
...
@@ -178,7 +189,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
shape (B, num_classes, H, W).
reg_pred (Tensor): Box regression map.
shape (B, channel, H , W).
img_metas (
L
ist[dict]): Meta information of each image, e.g.,
batch_
img_metas (
l
ist[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
cam2imgs (Tensor): Camera intrinsic matrixs.
shape (B, 4, 4)
...
...
@@ -199,7 +210,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
- batch_topk_labels (Tensor): Categories of each 3D box.
shape (B, k)
"""
img_h
,
img_w
=
img_metas
[
0
][
'pad_shape'
][:
2
]
img_h
,
img_w
=
batch_
img_metas
[
0
][
'pad_shape'
][:
2
]
bs
,
_
,
feat_h
,
feat_w
=
cls_score
.
shape
center_heatmap_pred
=
get_local_maximum
(
cls_score
,
kernel
=
kernel
)
...
...
@@ -221,14 +232,15 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
batch_bboxes
=
batch_bboxes
.
view
(
bs
,
-
1
,
self
.
bbox_code_size
)
return
batch_bboxes
,
batch_scores
,
batch_topk_labels
def
get_predictions
(
self
,
labels3d
,
centers2d
,
gt_locations
,
gt_dimensions
,
gt_orientations
,
indices
,
img_metas
,
pred_reg
):
def
get_predictions
(
self
,
labels_3d
,
centers_2d
,
gt_locations
,
gt_dimensions
,
gt_orientations
,
indices
,
batch_img_metas
,
pred_reg
):
"""Prepare predictions for computing loss.
Args:
labels3d (Tensor): Labels of each 3D box.
labels
_
3d (Tensor): Labels of each 3D box.
shape (B, max_objs, )
centers2d (Tensor): Coords of each projected 3D box
centers
_
2d (Tensor): Coords of each projected 3D box
center on image. shape (B * max_objs, 2)
gt_locations (Tensor): Coords of each 3D box's location.
shape (B * max_objs, 3)
...
...
@@ -238,8 +250,8 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
shape (N, 1)
indices (Tensor): Indices of the existence of the 3D box.
shape (B * max_objs, )
img_metas (list[dict]): Meta information of each image,
e.g.,
image size, scaling factor, etc.
batch_
img_metas (list[dict]): Meta information of each image,
e.g.,
image size, scaling factor, etc.
pre_reg (Tensor): Box regression map.
shape (B, channel, H , W).
...
...
@@ -255,19 +267,19 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
batch
,
channel
=
pred_reg
.
shape
[
0
],
pred_reg
.
shape
[
1
]
w
=
pred_reg
.
shape
[
3
]
cam2imgs
=
torch
.
stack
([
gt_locations
.
new_tensor
(
img_meta
[
'cam2img'
])
for
img_meta
in
img_metas
gt_locations
.
new_tensor
(
img_meta
s
[
'cam2img'
])
for
img_meta
s
in
batch_
img_metas
])
trans_mats
=
torch
.
stack
([
gt_locations
.
new_tensor
(
img_meta
[
'trans_mat'
])
for
img_meta
in
img_metas
gt_locations
.
new_tensor
(
img_meta
s
[
'trans_mat'
])
for
img_meta
s
in
batch_
img_metas
])
centers2d_inds
=
centers2d
[:,
1
]
*
w
+
centers2d
[:,
0
]
centers2d_inds
=
centers2d_inds
.
view
(
batch
,
-
1
)
pred_regression
=
transpose_and_gather_feat
(
pred_reg
,
centers2d_inds
)
centers
_
2d_inds
=
centers
_
2d
[:,
1
]
*
w
+
centers
_
2d
[:,
0
]
centers
_
2d_inds
=
centers
_
2d_inds
.
view
(
batch
,
-
1
)
pred_regression
=
transpose_and_gather_feat
(
pred_reg
,
centers
_
2d_inds
)
pred_regression_pois
=
pred_regression
.
view
(
-
1
,
channel
)
locations
,
dimensions
,
orientations
=
self
.
bbox_coder
.
decode
(
pred_regression_pois
,
centers2d
,
labels3d
,
cam2imgs
,
trans_mats
,
pred_regression_pois
,
centers
_
2d
,
labels
_
3d
,
cam2imgs
,
trans_mats
,
gt_locations
)
locations
,
dimensions
,
orientations
=
locations
[
indices
],
dimensions
[
...
...
@@ -281,44 +293,35 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
assert
len
(
dimensions
)
==
len
(
gt_dimensions
)
assert
len
(
orientations
)
==
len
(
gt_orientations
)
bbox3d_yaws
=
self
.
bbox_coder
.
encode
(
gt_locations
,
gt_dimensions
,
orientations
,
img_metas
)
orientations
,
batch_
img_metas
)
bbox3d_dims
=
self
.
bbox_coder
.
encode
(
gt_locations
,
dimensions
,
gt_orientations
,
img_metas
)
gt_orientations
,
batch_
img_metas
)
bbox3d_locs
=
self
.
bbox_coder
.
encode
(
locations
,
gt_dimensions
,
gt_orientations
,
img_metas
)
gt_orientations
,
batch_
img_metas
)
pred_bboxes
=
dict
(
ori
=
bbox3d_yaws
,
dim
=
bbox3d_dims
,
loc
=
bbox3d_locs
)
return
pred_bboxes
def
get_targets
(
self
,
gt_bboxes
,
gt_labels
,
gt_bboxes_3d
,
gt_labels_3d
,
centers2d
,
feat_shape
,
img_shape
,
img_metas
):
def
get_targets
(
self
,
batch_gt_instances_3d
,
feat_shape
,
batch_img_metas
):
"""Get training targets for batch images.
Args:
gt_bboxes (list[Tensor]): Ground truth bboxes of each image,
shape (num_gt, 4).
gt_labels (list[Tensor]): Ground truth labels of each box,
shape (num_gt,).
gt_bboxes_3d (list[:obj:`CameraInstance3DBoxes`]): 3D Ground
truth bboxes of each image,
shape (num_gt, bbox_code_size).
gt_labels_3d (list[Tensor]): 3D Ground truth labels of each
box, shape (num_gt,).
centers2d (list[Tensor]): Projected 3D centers onto 2D image,
shape (num_gt, 2).
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instance_3d. It usually includes ``bboxes``、``labels``
、``bboxes_3d``、``labels_3d``、``depths``、``centers_2d`` and
attributes.
feat_shape (tuple[int]): Feature map shape with value,
shape (B, _, H, W).
img_shape (tuple[int]): Image shape in [h, w] format.
img_metas (list[dict]): Meta information of each image, e.g.,
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
Returns:
tuple[Tensor, dict]: The Tensor value is the targets of
center heatmap, the dict has components below:
- gt_centers2d (Tensor): Coords of each projected 3D box
- gt_centers
_
2d (Tensor): Coords of each projected 3D box
center on image. shape (B * max_objs, 2)
- gt_labels3d (Tensor): Labels of each 3D box.
- gt_labels
_
3d (Tensor): Labels of each 3D box.
shape (B, max_objs, )
- indices (Tensor): Indices of the existence of the 3D box.
shape (B * max_objs, )
...
...
@@ -334,10 +337,30 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
shape (N, 8, 3)
"""
gt_bboxes
=
[
gt_instances_3d
.
bboxes
for
gt_instances_3d
in
batch_gt_instances_3d
]
gt_labels
=
[
gt_instances_3d
.
labels
for
gt_instances_3d
in
batch_gt_instances_3d
]
gt_bboxes_3d
=
[
gt_instances_3d
.
bboxes_3d
for
gt_instances_3d
in
batch_gt_instances_3d
]
gt_labels_3d
=
[
gt_instances_3d
.
labels_3d
for
gt_instances_3d
in
batch_gt_instances_3d
]
centers_2d
=
[
gt_instances_3d
.
centers_2d
for
gt_instances_3d
in
batch_gt_instances_3d
]
img_shape
=
batch_img_metas
[
0
][
'pad_shape'
]
reg_mask
=
torch
.
stack
([
gt_bboxes
[
0
].
new_tensor
(
not
img_meta
[
'affine_aug'
],
dtype
=
torch
.
bool
)
for
img_meta
in
img_metas
not
img_meta
s
[
'affine_aug'
],
dtype
=
torch
.
bool
)
for
img_meta
s
in
batch_
img_metas
])
img_h
,
img_w
=
img_shape
[:
2
]
...
...
@@ -351,15 +374,15 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
center_heatmap_target
=
gt_bboxes
[
-
1
].
new_zeros
(
[
bs
,
self
.
num_classes
,
feat_h
,
feat_w
])
gt_centers2d
=
centers2d
.
copy
()
gt_centers
_
2d
=
centers
_
2d
.
copy
()
for
batch_id
in
range
(
bs
):
gt_bbox
=
gt_bboxes
[
batch_id
]
gt_label
=
gt_labels
[
batch_id
]
# project centers2d from input image to feat map
gt_center2d
=
gt_centers2d
[
batch_id
]
*
width_ratio
# project centers
_
2d from input image to feat map
gt_center
_
2d
=
gt_centers
_
2d
[
batch_id
]
*
width_ratio
for
j
,
center
in
enumerate
(
gt_center2d
):
for
j
,
center
in
enumerate
(
gt_center
_
2d
):
center_x_int
,
center_y_int
=
center
.
int
()
scale_box_h
=
(
gt_bbox
[
j
][
3
]
-
gt_bbox
[
j
][
1
])
*
height_ratio
scale_box_w
=
(
gt_bbox
[
j
][
2
]
-
gt_bbox
[
j
][
0
])
*
width_ratio
...
...
@@ -371,33 +394,33 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
[
center_x_int
,
center_y_int
],
radius
)
avg_factor
=
max
(
1
,
center_heatmap_target
.
eq
(
1
).
sum
())
num_ctrs
=
[
center2d
.
shape
[
0
]
for
center2d
in
centers2d
]
num_ctrs
=
[
center
_
2d
.
shape
[
0
]
for
center
_
2d
in
centers
_
2d
]
max_objs
=
max
(
num_ctrs
)
reg_inds
=
torch
.
cat
(
[
reg_mask
[
i
].
repeat
(
num_ctrs
[
i
])
for
i
in
range
(
bs
)])
inds
=
torch
.
zeros
((
bs
,
max_objs
),
dtype
=
torch
.
bool
).
to
(
centers2d
[
0
].
device
)
dtype
=
torch
.
bool
).
to
(
centers
_
2d
[
0
].
device
)
# put gt 3d bboxes to gpu
gt_bboxes_3d
=
[
gt_bbox_3d
.
to
(
centers2d
[
0
].
device
)
for
gt_bbox_3d
in
gt_bboxes_3d
gt_bbox_3d
.
to
(
centers
_
2d
[
0
].
device
)
for
gt_bbox_3d
in
gt_bboxes_3d
]
batch_centers2d
=
centers2d
[
0
].
new_zeros
((
bs
,
max_objs
,
2
))
batch_centers
_
2d
=
centers
_
2d
[
0
].
new_zeros
((
bs
,
max_objs
,
2
))
batch_labels_3d
=
gt_labels_3d
[
0
].
new_zeros
((
bs
,
max_objs
))
batch_gt_locations
=
\
gt_bboxes_3d
[
0
].
tensor
.
new_zeros
((
bs
,
max_objs
,
3
))
for
i
in
range
(
bs
):
inds
[
i
,
:
num_ctrs
[
i
]]
=
1
batch_centers2d
[
i
,
:
num_ctrs
[
i
]]
=
centers2d
[
i
]
batch_centers
_
2d
[
i
,
:
num_ctrs
[
i
]]
=
centers
_
2d
[
i
]
batch_labels_3d
[
i
,
:
num_ctrs
[
i
]]
=
gt_labels_3d
[
i
]
batch_gt_locations
[
i
,
:
num_ctrs
[
i
]]
=
\
gt_bboxes_3d
[
i
].
tensor
[:,
:
3
]
inds
=
inds
.
flatten
()
batch_centers2d
=
batch_centers2d
.
view
(
-
1
,
2
)
*
width_ratio
batch_centers
_
2d
=
batch_centers
_
2d
.
view
(
-
1
,
2
)
*
width_ratio
batch_gt_locations
=
batch_gt_locations
.
view
(
-
1
,
3
)
# filter the empty image, without gt_bboxes_3d
...
...
@@ -416,8 +439,8 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
[
gt_bbox_3d
.
corners
for
gt_bbox_3d
in
gt_bboxes_3d
])
target_labels
=
dict
(
gt_centers2d
=
batch_centers2d
.
long
(),
gt_labels3d
=
batch_labels_3d
,
gt_centers
_
2d
=
batch_centers
_
2d
.
long
(),
gt_labels
_
3d
=
batch_labels_3d
,
indices
=
inds
,
reg_indices
=
reg_inds
,
gt_locs
=
batch_gt_locations
,
...
...
@@ -430,15 +453,9 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
def
loss
(
self
,
cls_scores
,
bbox_preds
,
gt_bboxes
,
gt_labels
,
gt_bboxes_3d
,
gt_labels_3d
,
centers2d
,
depths
,
attr_labels
,
img_metas
,
gt_bboxes_ignore
=
None
):
batch_gt_instances_3d
,
batch_img_metas
,
batch_gt_instances_ignore
=
None
):
"""Compute loss of the head.
Args:
...
...
@@ -447,53 +464,42 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
bbox_preds (list[Tensor]): Box dims is a 4D-tensor, the channel
number is bbox_code_size.
shape (B, 7, H, W).
gt_bboxes (list[Tensor]): Ground truth bboxes for each image.
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): Class indices corresponding to each box.
shape (num_gts, ).
gt_bboxes_3d (list[:obj:`CameraInstance3DBoxes`]): 3D boxes ground
truth. it is the flipped gt_bboxes
gt_labels_3d (list[Tensor]): Same as gt_labels.
centers2d (list[Tensor]): 2D centers on the image.
shape (num_gts, 2).
depths (list[Tensor]): Depth ground truth.
shape (num_gts, ).
attr_labels (list[Tensor]): Attributes indices of each box.
In kitti it's None.
img_metas (list[dict]): Meta information of each image, e.g.,
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instance_3d. It usually includes ``bboxes``、``labels``
、``bboxes_3d``、``labels_3d``、``depths``、``centers_2d`` and
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
boxes can be ignored when computing the loss.
Default: None.
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
==
1
assert
attr_labels
is
None
assert
gt_bboxes_ignore
is
None
center2d_heatmap
=
cls_scores
[
0
]
assert
batch_gt_instances_ignore
is
None
center_2d_heatmap
=
cls_scores
[
0
]
pred_reg
=
bbox_preds
[
0
]
center2d_heatmap_target
,
avg_factor
,
target_labels
=
\
self
.
get_targets
(
gt_bboxes
,
gt_labels
,
gt_bboxes_3d
,
gt_labels_3d
,
centers2d
,
center2d_heatmap
.
shape
,
img_metas
[
0
][
'pad_shape'
],
img_metas
)
center_2d_heatmap_target
,
avg_factor
,
target_labels
=
\
self
.
get_targets
(
batch_gt_instances_3d
,
center_2d_heatmap
.
shape
,
batch_img_metas
)
pred_bboxes
=
self
.
get_predictions
(
labels3d
=
target_labels
[
'gt_labels3d'
],
centers2d
=
target_labels
[
'gt_centers2d'
],
labels
_
3d
=
target_labels
[
'gt_labels
_
3d'
],
centers
_
2d
=
target_labels
[
'gt_centers
_
2d'
],
gt_locations
=
target_labels
[
'gt_locs'
],
gt_dimensions
=
target_labels
[
'gt_dims'
],
gt_orientations
=
target_labels
[
'gt_yaws'
],
indices
=
target_labels
[
'indices'
],
img_metas
=
img_metas
,
batch_
img_metas
=
batch_
img_metas
,
pred_reg
=
pred_reg
)
loss_cls
=
self
.
loss_cls
(
center2d_heatmap
,
center2d_heatmap_target
,
avg_factor
=
avg_factor
)
center
_
2d_heatmap
,
center
_
2d_heatmap_target
,
avg_factor
=
avg_factor
)
reg_inds
=
target_labels
[
'reg_indices'
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment