"vscode:/vscode.git/clone" did not exist on "853f8459ed195928d9025f5b617d6f073e2bf6d7"
Unverified Commit 7fec1d53 authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Fix] Fix `LoadMultiViewImageFromFiles` to be compatible with `DefaultFormatBundle` (#611)

* unravel multi-view img to list

* add unit test
parent 07590418
...@@ -41,12 +41,15 @@ class LoadMultiViewImageFromFiles(object): ...@@ -41,12 +41,15 @@ class LoadMultiViewImageFromFiles(object):
- img_norm_cfg (dict): Normalization configuration of images. - img_norm_cfg (dict): Normalization configuration of images.
""" """
filename = results['img_filename'] filename = results['img_filename']
# img is of shape (h, w, c, num_views)
img = np.stack( img = np.stack(
[mmcv.imread(name, self.color_type) for name in filename], axis=-1) [mmcv.imread(name, self.color_type) for name in filename], axis=-1)
if self.to_float32: if self.to_float32:
img = img.astype(np.float32) img = img.astype(np.float32)
results['filename'] = filename results['filename'] = filename
results['img'] = img # unravel to list, see `DefaultFormatBundle` in formating.py
# which will transpose each image separately and then stack into array
results['img'] = [img[..., i] for i in range(img.shape[-1])]
results['img_shape'] = img.shape results['img_shape'] = img.shape
results['ori_shape'] = img.shape results['ori_shape'] = img.shape
# Set initial values for default meta_keys # Set initial values for default meta_keys
...@@ -61,8 +64,10 @@ class LoadMultiViewImageFromFiles(object): ...@@ -61,8 +64,10 @@ class LoadMultiViewImageFromFiles(object):
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
return f'{self.__class__.__name__} (to_float32={self.to_float32}, '\ repr_str = self.__class__.__name__
f"color_type='{self.color_type}')" repr_str += f'(to_float32={self.to_float32}, '
repr_str += f"color_type='{self.color_type}')"
return repr_str
@PIPELINES.register_module() @PIPELINES.register_module()
......
import numpy as np
import torch
from mmcv.parallel import DataContainer
from mmdet3d.datasets.pipelines import (DefaultFormatBundle,
LoadMultiViewImageFromFiles)
def test_load_multi_view_image_from_files():
multi_view_img_loader = LoadMultiViewImageFromFiles(to_float32=True)
num_views = 6
filename = 'tests/data/waymo/kitti_format/training/image_0/0000000.png'
filenames = [filename for _ in range(num_views)]
input_dict = dict(img_filename=filenames)
results = multi_view_img_loader(input_dict)
img = results['img']
img0 = img[0]
img_norm_cfg = results['img_norm_cfg']
assert isinstance(img, list)
assert len(img) == num_views
assert img0.dtype == np.float32
assert results['filename'] == filenames
assert results['img_shape'] == results['ori_shape'] == \
results['pad_shape'] == (1280, 1920, 3, num_views)
assert results['scale_factor'] == 1.0
assert np.all(img_norm_cfg['mean'] == np.zeros(3, dtype=np.float32))
assert np.all(img_norm_cfg['std'] == np.ones(3, dtype=np.float32))
assert not img_norm_cfg['to_rgb']
repr_str = repr(multi_view_img_loader)
expected_str = 'LoadMultiViewImageFromFiles(to_float32=True, ' \
"color_type='unchanged')"
assert repr_str == expected_str
# test LoadMultiViewImageFromFiles's compatibility with DefaultFormatBundle
# refer to https://github.com/open-mmlab/mmdetection3d/issues/227
default_format_bundle = DefaultFormatBundle()
results = default_format_bundle(results)
img = results['img']
assert isinstance(img, DataContainer)
assert img._data.shape == torch.Size((num_views, 3, 1280, 1920))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment