Commit f90567a0 authored by liukuikun's avatar liukuikun Committed by zhouzaida
Browse files

[Fix] LoadImageFromFile

parent 864942be
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import numpy as np import numpy as np
import mmcv import mmcv
...@@ -33,22 +35,24 @@ class LoadImageFromFile(BaseTransform): ...@@ -33,22 +35,24 @@ class LoadImageFromFile(BaseTransform):
file_client_args (dict): Arguments to instantiate a FileClient. file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details. See :class:`mmcv.fileio.FileClient` for details.
Defaults to ``dict(backend='disk')``. Defaults to ``dict(backend='disk')``.
ignore_empty (bool): Whether to allow loading empty image or file path
not existent. Defaults to False.
""" """
def __init__( def __init__(self,
self,
to_float32: bool = False, to_float32: bool = False,
color_type: str = 'color', color_type: str = 'color',
imdecode_backend: str = 'cv2', imdecode_backend: str = 'cv2',
file_client_args: dict = dict(backend='disk') file_client_args: dict = dict(backend='disk'),
) -> None: ignore_empty: bool = False) -> None:
self.ignore_empty = ignore_empty
self.to_float32 = to_float32 self.to_float32 = to_float32
self.color_type = color_type self.color_type = color_type
self.imdecode_backend = imdecode_backend self.imdecode_backend = imdecode_backend
self.file_client_args = file_client_args.copy() self.file_client_args = file_client_args.copy()
self.file_client = mmcv.FileClient(**self.file_client_args) self.file_client = mmcv.FileClient(**self.file_client_args)
def transform(self, results: dict) -> dict: def transform(self, results: dict) -> Optional[dict]:
"""Functions to load image. """Functions to load image.
Args: Args:
...@@ -59,9 +63,15 @@ class LoadImageFromFile(BaseTransform): ...@@ -59,9 +63,15 @@ class LoadImageFromFile(BaseTransform):
""" """
filename = results['img_path'] filename = results['img_path']
try:
img_bytes = self.file_client.get(filename) img_bytes = self.file_client.get(filename)
img = mmcv.imfrombytes( img = mmcv.imfrombytes(
img_bytes, flag=self.color_type, backend=self.imdecode_backend) img_bytes, flag=self.color_type, backend=self.imdecode_backend)
except Exception as e:
if self.ignore_empty:
return None
else:
raise e
if self.to_float32: if self.to_float32:
img = img.astype(np.float32) img = img.astype(np.float32)
...@@ -72,6 +82,7 @@ class LoadImageFromFile(BaseTransform): ...@@ -72,6 +82,7 @@ class LoadImageFromFile(BaseTransform):
def __repr__(self): def __repr__(self):
repr_str = (f'{self.__class__.__name__}(' repr_str = (f'{self.__class__.__name__}('
f'ignore_empty={self.ignore_empty}, '
f'to_float32={self.to_float32}, ' f'to_float32={self.to_float32}, '
f"color_type='{self.color_type}', " f"color_type='{self.color_type}', "
f"imdecode_backend='{self.imdecode_backend}', " f"imdecode_backend='{self.imdecode_backend}', "
......
...@@ -3,6 +3,7 @@ import copy ...@@ -3,6 +3,7 @@ import copy
import os.path as osp import os.path as osp
import numpy as np import numpy as np
import pytest
from mmcv.transforms import LoadAnnotations, LoadImageFromFile from mmcv.transforms import LoadAnnotations, LoadImageFromFile
...@@ -21,7 +22,7 @@ class TestLoadImageFromFile: ...@@ -21,7 +22,7 @@ class TestLoadImageFromFile:
assert results['img_shape'] == (300, 400) assert results['img_shape'] == (300, 400)
assert results['ori_shape'] == (300, 400) assert results['ori_shape'] == (300, 400)
assert repr(transform) == transform.__class__.__name__ + \ assert repr(transform) == transform.__class__.__name__ + \
"(to_float32=False, color_type='color', " + \ "(ignore_empty=False, to_float32=False, color_type='color', " + \
"imdecode_backend='cv2', file_client_args={'backend': 'disk'})" "imdecode_backend='cv2', file_client_args={'backend': 'disk'})"
# to_float32 # to_float32
...@@ -41,6 +42,15 @@ class TestLoadImageFromFile: ...@@ -41,6 +42,15 @@ class TestLoadImageFromFile:
assert results['img'].shape == (300, 400) assert results['img'].shape == (300, 400)
assert results['img'].dtype == np.uint8 assert results['img'].dtype == np.uint8
# test load empty
fake_img_path = osp.join(data_prefix, 'fake.jpg')
results['img_path'] = fake_img_path
transform = LoadImageFromFile(ignore_empty=False)
with pytest.raises(FileNotFoundError):
transform(copy.deepcopy(results))
transform = LoadImageFromFile(ignore_empty=True)
assert transform(copy.deepcopy(results)) is None
class TestLoadAnnotations: class TestLoadAnnotations:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment