"examples/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "e259f156ffe33e64eccd64d3aa2402d54d979384"
Unverified Commit fc301b98 authored by Robin Karlsson's avatar Robin Karlsson Committed by GitHub
Browse files

[Fix] Centerpoint head nested list transpose (#879)

* FIX Transpose nested lists without Numpy

* Removed unused Numpy import
parent 319b0e36
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
import numpy as np
import torch import torch
from mmcv.cnn import ConvModule, build_conv_layer from mmcv.cnn import ConvModule, build_conv_layer
from mmcv.runner import BaseModule, force_fp32 from mmcv.runner import BaseModule, force_fp32
...@@ -386,6 +385,17 @@ class CenterHead(BaseModule): ...@@ -386,6 +385,17 @@ class CenterHead(BaseModule):
def get_targets(self, gt_bboxes_3d, gt_labels_3d): def get_targets(self, gt_bboxes_3d, gt_labels_3d):
"""Generate targets. """Generate targets.
How each output is transformed:
Each nested list is transposed so that all same-index elements in
each sub-list (1, ..., N) become the new sub-lists.
[ [a0, a1, a2, ... ], [b0, b1, b2, ... ], ... ]
==> [ [a0, b0, ... ], [a1, b1, ... ], [a2, b2, ... ] ]
The new transposed nested list is converted into a list of N
tensors generated by concatenating tensors in the new sub-lists.
[ tensor0, tensor1, tensor2, ... ]
Args: Args:
gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground
truth gt boxes. truth gt boxes.
...@@ -405,18 +415,17 @@ class CenterHead(BaseModule): ...@@ -405,18 +415,17 @@ class CenterHead(BaseModule):
""" """
heatmaps, anno_boxes, inds, masks = multi_apply( heatmaps, anno_boxes, inds, masks = multi_apply(
self.get_targets_single, gt_bboxes_3d, gt_labels_3d) self.get_targets_single, gt_bboxes_3d, gt_labels_3d)
# transpose heatmaps, because the dimension of tensors in each task is # Transpose heatmaps
# different, we have to use numpy instead of torch to do the transpose. heatmaps = list(map(list, zip(*heatmaps)))
heatmaps = np.array(heatmaps).transpose(1, 0).tolist()
heatmaps = [torch.stack(hms_) for hms_ in heatmaps] heatmaps = [torch.stack(hms_) for hms_ in heatmaps]
# transpose anno_boxes # Transpose anno_boxes
anno_boxes = np.array(anno_boxes).transpose(1, 0).tolist() anno_boxes = list(map(list, zip(*anno_boxes)))
anno_boxes = [torch.stack(anno_boxes_) for anno_boxes_ in anno_boxes] anno_boxes = [torch.stack(anno_boxes_) for anno_boxes_ in anno_boxes]
# transpose inds # Transpose inds
inds = np.array(inds).transpose(1, 0).tolist() inds = list(map(list, zip(*inds)))
inds = [torch.stack(inds_) for inds_ in inds] inds = [torch.stack(inds_) for inds_ in inds]
# transpose inds # Transpose inds
masks = np.array(masks).transpose(1, 0).tolist() masks = list(map(list, zip(*masks)))
masks = [torch.stack(masks_) for masks_ in masks] masks = [torch.stack(masks_) for masks_ in masks]
return heatmaps, anno_boxes, inds, masks return heatmaps, anno_boxes, inds, masks
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment