Commit f90567a0 authored by liukuikun's avatar liukuikun Committed by zhouzaida
Browse files

[Fix] LoadImageFromFile

parent 864942be
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import numpy as np
import mmcv
......@@ -33,22 +35,24 @@ class LoadImageFromFile(BaseTransform):
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details.
Defaults to ``dict(backend='disk')``.
ignore_empty (bool): Whether to allow loading empty image or file path
not existent. Defaults to False.
"""
def __init__(
self,
def __init__(self,
to_float32: bool = False,
color_type: str = 'color',
imdecode_backend: str = 'cv2',
file_client_args: dict = dict(backend='disk')
) -> None:
file_client_args: dict = dict(backend='disk'),
ignore_empty: bool = False) -> None:
self.ignore_empty = ignore_empty
self.to_float32 = to_float32
self.color_type = color_type
self.imdecode_backend = imdecode_backend
self.file_client_args = file_client_args.copy()
self.file_client = mmcv.FileClient(**self.file_client_args)
def transform(self, results: dict) -> dict:
def transform(self, results: dict) -> Optional[dict]:
"""Functions to load image.
Args:
......@@ -59,9 +63,15 @@ class LoadImageFromFile(BaseTransform):
"""
filename = results['img_path']
try:
img_bytes = self.file_client.get(filename)
img = mmcv.imfrombytes(
img_bytes, flag=self.color_type, backend=self.imdecode_backend)
except Exception as e:
if self.ignore_empty:
return None
else:
raise e
if self.to_float32:
img = img.astype(np.float32)
......@@ -72,6 +82,7 @@ class LoadImageFromFile(BaseTransform):
def __repr__(self):
repr_str = (f'{self.__class__.__name__}('
f'ignore_empty={self.ignore_empty}, '
f'to_float32={self.to_float32}, '
f"color_type='{self.color_type}', "
f"imdecode_backend='{self.imdecode_backend}', "
......
......@@ -3,6 +3,7 @@ import copy
import os.path as osp
import numpy as np
import pytest
from mmcv.transforms import LoadAnnotations, LoadImageFromFile
......@@ -21,7 +22,7 @@ class TestLoadImageFromFile:
assert results['img_shape'] == (300, 400)
assert results['ori_shape'] == (300, 400)
assert repr(transform) == transform.__class__.__name__ + \
"(to_float32=False, color_type='color', " + \
"(ignore_empty=False, to_float32=False, color_type='color', " + \
"imdecode_backend='cv2', file_client_args={'backend': 'disk'})"
# to_float32
......@@ -41,6 +42,15 @@ class TestLoadImageFromFile:
assert results['img'].shape == (300, 400)
assert results['img'].dtype == np.uint8
# test load empty
fake_img_path = osp.join(data_prefix, 'fake.jpg')
results['img_path'] = fake_img_path
transform = LoadImageFromFile(ignore_empty=False)
with pytest.raises(FileNotFoundError):
transform(copy.deepcopy(results))
transform = LoadImageFromFile(ignore_empty=True)
assert transform(copy.deepcopy(results)) is None
class TestLoadAnnotations:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment