Commit fdeee889 authored by limm's avatar limm
Browse files

release v1.6.1 of mmcv

parent df465820
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.parallel import MMDataParallel
from ..scatter_gather import scatter_kwargs
class MPSDataParallel(MMDataParallel):
"""The MPSDataParallel module that supports DataContainer.
MPSDataParallel is a class inherited from MMDataParall, which supports
MPS training and inference only.
The main differences with MMDataParallel:
- It only supports single-card of MPS, and only use first card to
run training and inference.
- It uses direct host-to-device copy instead of stream-background
scatter.
Args:
module (:class:`nn.Module`): Module to be encapsulated.
dim (int): Dimension used to scatter the data. Defaults to 0.
"""
def __init__(self, *args, dim=0, **kwargs):
super().__init__(*args, dim=dim, **kwargs)
self.device_ids = [0]
self.src_device_obj = torch.device('mps:0')
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.parallel.data_container import DataContainer
from mmcv.utils import deprecated_api_warning
from ._functions import Scatter
from .utils import get_device
@deprecated_api_warning({'target_mlus': 'target_devices'})
def scatter(inputs, target_devices, dim=0):
"""Scatter inputs to target devices.
The only difference from original :func:`scatter` is to add support for
:type:`~mmcv.parallel.DataContainer`.
"""
current_device = get_device()
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
if target_devices != [-1]:
obj = obj.to(current_device)
return [obj]
else:
# for CPU inference we use self-implemented scatter
return Scatter.forward(target_devices, obj)
if isinstance(obj, DataContainer):
if obj.cpu_only:
return obj.data
else:
return Scatter.forward(target_devices, obj.data)
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
out = list(map(list, zip(*map(scatter_map, obj))))
return out
if isinstance(obj, dict) and len(obj) > 0:
out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return out
return [obj for _ in target_devices]
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
return scatter_map(inputs)
finally:
scatter_map = None
@deprecated_api_warning({'target_mlus': 'target_devices'})
def scatter_kwargs(inputs, kwargs, target_devices, dim=0):
"""Scatter with support for kwargs dictionary."""
inputs = scatter(inputs, target_devices, dim) if inputs else []
kwargs = scatter(kwargs, target_devices, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE
def get_device() -> str:
"""Returns the currently existing device type.
Returns:
str: cuda | mlu | mps | cpu.
"""
if IS_CUDA_AVAILABLE:
return 'cuda'
elif IS_MLU_AVAILABLE:
return 'mlu'
elif IS_MPS_AVAILABLE:
return 'mps'
else:
return 'cpu'
...@@ -4,15 +4,18 @@ import pickle ...@@ -4,15 +4,18 @@ import pickle
import shutil import shutil
import tempfile import tempfile
import time import time
from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn
from torch.utils.data import DataLoader
import mmcv import mmcv
from mmcv.runner import get_dist_info from mmcv.runner import get_dist_info
def single_gpu_test(model, data_loader): def single_gpu_test(model: nn.Module, data_loader: DataLoader) -> list:
"""Test model with a single gpu. """Test model with a single gpu.
This method tests model with a single gpu and displays test progress bar. This method tests model with a single gpu and displays test progress bar.
...@@ -41,7 +44,10 @@ def single_gpu_test(model, data_loader): ...@@ -41,7 +44,10 @@ def single_gpu_test(model, data_loader):
return results return results
def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False): def multi_gpu_test(model: nn.Module,
data_loader: DataLoader,
tmpdir: Optional[str] = None,
gpu_collect: bool = False) -> Optional[list]:
"""Test model with multiple gpus. """Test model with multiple gpus.
This method tests model with multiple gpus and collects the results This method tests model with multiple gpus and collects the results
...@@ -82,13 +88,15 @@ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False): ...@@ -82,13 +88,15 @@ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
# collect results from all ranks # collect results from all ranks
if gpu_collect: if gpu_collect:
results = collect_results_gpu(results, len(dataset)) result_from_ranks = collect_results_gpu(results, len(dataset))
else: else:
results = collect_results_cpu(results, len(dataset), tmpdir) result_from_ranks = collect_results_cpu(results, len(dataset), tmpdir)
return results return result_from_ranks
def collect_results_cpu(result_part, size, tmpdir=None): def collect_results_cpu(result_part: list,
size: int,
tmpdir: Optional[str] = None) -> Optional[list]:
"""Collect results under cpu mode. """Collect results under cpu mode.
On cpu mode, this function will save the results on different gpus to On cpu mode, this function will save the results on different gpus to
...@@ -126,7 +134,8 @@ def collect_results_cpu(result_part, size, tmpdir=None): ...@@ -126,7 +134,8 @@ def collect_results_cpu(result_part, size, tmpdir=None):
else: else:
mmcv.mkdir_or_exist(tmpdir) mmcv.mkdir_or_exist(tmpdir)
# dump the part result to the dir # dump the part result to the dir
mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl')) part_file = osp.join(tmpdir, f'part_{rank}.pkl') # type: ignore
mmcv.dump(result_part, part_file)
dist.barrier() dist.barrier()
# collect all parts # collect all parts
if rank != 0: if rank != 0:
...@@ -135,7 +144,7 @@ def collect_results_cpu(result_part, size, tmpdir=None): ...@@ -135,7 +144,7 @@ def collect_results_cpu(result_part, size, tmpdir=None):
# load results of all parts from tmp dir # load results of all parts from tmp dir
part_list = [] part_list = []
for i in range(world_size): for i in range(world_size):
part_file = osp.join(tmpdir, f'part_{i}.pkl') part_file = osp.join(tmpdir, f'part_{i}.pkl') # type: ignore
part_result = mmcv.load(part_file) part_result = mmcv.load(part_file)
# When data is severely insufficient, an empty part_result # When data is severely insufficient, an empty part_result
# on a certain gpu could makes the overall outputs empty. # on a certain gpu could makes the overall outputs empty.
...@@ -148,11 +157,11 @@ def collect_results_cpu(result_part, size, tmpdir=None): ...@@ -148,11 +157,11 @@ def collect_results_cpu(result_part, size, tmpdir=None):
# the dataloader may pad some samples # the dataloader may pad some samples
ordered_results = ordered_results[:size] ordered_results = ordered_results[:size]
# remove tmp dir # remove tmp dir
shutil.rmtree(tmpdir) shutil.rmtree(tmpdir) # type: ignore
return ordered_results return ordered_results
def collect_results_gpu(result_part, size): def collect_results_gpu(result_part: list, size: int) -> Optional[list]:
"""Collect results under gpu mode. """Collect results under gpu mode.
On gpu mode, this function will encode results to gpu tensors and use gpu On gpu mode, this function will encode results to gpu tensors and use gpu
...@@ -200,3 +209,5 @@ def collect_results_gpu(result_part, size): ...@@ -200,3 +209,5 @@ def collect_results_gpu(result_part, size):
# the dataloader may pad some samples # the dataloader may pad some samples
ordered_results = ordered_results[:size] ordered_results = ordered_results[:size]
return ordered_results return ordered_results
else:
return None
...@@ -8,7 +8,7 @@ import warnings ...@@ -8,7 +8,7 @@ import warnings
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Iterable, Iterator, Optional, Tuple, Union from typing import Any, Generator, Iterator, Optional, Tuple, Union
from urllib.request import urlopen from urllib.request import urlopen
import mmcv import mmcv
...@@ -64,7 +64,8 @@ class CephBackend(BaseStorageBackend): ...@@ -64,7 +64,8 @@ class CephBackend(BaseStorageBackend):
raise ImportError('Please install ceph to enable CephBackend.') raise ImportError('Please install ceph to enable CephBackend.')
warnings.warn( warnings.warn(
'CephBackend will be deprecated, please use PetrelBackend instead') 'CephBackend will be deprecated, please use PetrelBackend instead',
DeprecationWarning)
self._client = ceph.S3Client() self._client = ceph.S3Client()
assert isinstance(path_mapping, dict) or path_mapping is None assert isinstance(path_mapping, dict) or path_mapping is None
self.path_mapping = path_mapping self.path_mapping = path_mapping
...@@ -209,9 +210,9 @@ class PetrelBackend(BaseStorageBackend): ...@@ -209,9 +210,9 @@ class PetrelBackend(BaseStorageBackend):
""" """
if not has_method(self._client, 'delete'): if not has_method(self._client, 'delete'):
raise NotImplementedError( raise NotImplementedError(
('Current version of Petrel Python SDK has not supported ' 'Current version of Petrel Python SDK has not supported '
'the `delete` method, please use a higher version or dev' 'the `delete` method, please use a higher version or dev'
' branch instead.')) ' branch instead.')
filepath = self._map_path(filepath) filepath = self._map_path(filepath)
filepath = self._format_path(filepath) filepath = self._format_path(filepath)
...@@ -229,9 +230,9 @@ class PetrelBackend(BaseStorageBackend): ...@@ -229,9 +230,9 @@ class PetrelBackend(BaseStorageBackend):
if not (has_method(self._client, 'contains') if not (has_method(self._client, 'contains')
and has_method(self._client, 'isdir')): and has_method(self._client, 'isdir')):
raise NotImplementedError( raise NotImplementedError(
('Current version of Petrel Python SDK has not supported ' 'Current version of Petrel Python SDK has not supported '
'the `contains` and `isdir` methods, please use a higher' 'the `contains` and `isdir` methods, please use a higher'
'version or dev branch instead.')) 'version or dev branch instead.')
filepath = self._map_path(filepath) filepath = self._map_path(filepath)
filepath = self._format_path(filepath) filepath = self._format_path(filepath)
...@@ -246,13 +247,13 @@ class PetrelBackend(BaseStorageBackend): ...@@ -246,13 +247,13 @@ class PetrelBackend(BaseStorageBackend):
Returns: Returns:
bool: Return ``True`` if ``filepath`` points to a directory, bool: Return ``True`` if ``filepath`` points to a directory,
``False`` otherwise. ``False`` otherwise.
""" """
if not has_method(self._client, 'isdir'): if not has_method(self._client, 'isdir'):
raise NotImplementedError( raise NotImplementedError(
('Current version of Petrel Python SDK has not supported ' 'Current version of Petrel Python SDK has not supported '
'the `isdir` method, please use a higher version or dev' 'the `isdir` method, please use a higher version or dev'
' branch instead.')) ' branch instead.')
filepath = self._map_path(filepath) filepath = self._map_path(filepath)
filepath = self._format_path(filepath) filepath = self._format_path(filepath)
...@@ -266,13 +267,13 @@ class PetrelBackend(BaseStorageBackend): ...@@ -266,13 +267,13 @@ class PetrelBackend(BaseStorageBackend):
Returns: Returns:
bool: Return ``True`` if ``filepath`` points to a file, ``False`` bool: Return ``True`` if ``filepath`` points to a file, ``False``
otherwise. otherwise.
""" """
if not has_method(self._client, 'contains'): if not has_method(self._client, 'contains'):
raise NotImplementedError( raise NotImplementedError(
('Current version of Petrel Python SDK has not supported ' 'Current version of Petrel Python SDK has not supported '
'the `contains` method, please use a higher version or ' 'the `contains` method, please use a higher version or '
'dev branch instead.')) 'dev branch instead.')
filepath = self._map_path(filepath) filepath = self._map_path(filepath)
filepath = self._format_path(filepath) filepath = self._format_path(filepath)
...@@ -297,7 +298,10 @@ class PetrelBackend(BaseStorageBackend): ...@@ -297,7 +298,10 @@ class PetrelBackend(BaseStorageBackend):
return '/'.join(formatted_paths) return '/'.join(formatted_paths)
@contextmanager @contextmanager
def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]: def get_local_path(
self,
filepath: Union[str,
Path]) -> Generator[Union[str, Path], None, None]:
"""Download a file from ``filepath`` and return a temporary path. """Download a file from ``filepath`` and return a temporary path.
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
...@@ -362,9 +366,9 @@ class PetrelBackend(BaseStorageBackend): ...@@ -362,9 +366,9 @@ class PetrelBackend(BaseStorageBackend):
""" """
if not has_method(self._client, 'list'): if not has_method(self._client, 'list'):
raise NotImplementedError( raise NotImplementedError(
('Current version of Petrel Python SDK has not supported ' 'Current version of Petrel Python SDK has not supported '
'the `list` method, please use a higher version or dev' 'the `list` method, please use a higher version or dev'
' branch instead.')) ' branch instead.')
dir_path = self._map_path(dir_path) dir_path = self._map_path(dir_path)
dir_path = self._format_path(dir_path) dir_path = self._format_path(dir_path)
...@@ -473,17 +477,16 @@ class LmdbBackend(BaseStorageBackend): ...@@ -473,17 +477,16 @@ class LmdbBackend(BaseStorageBackend):
readahead=False, readahead=False,
**kwargs): **kwargs):
try: try:
import lmdb import lmdb # NOQA
except ImportError: except ImportError:
raise ImportError('Please install lmdb to enable LmdbBackend.') raise ImportError('Please install lmdb to enable LmdbBackend.')
self.db_path = str(db_path) self.db_path = str(db_path)
self._client = lmdb.open( self.readonly = readonly
self.db_path, self.lock = lock
readonly=readonly, self.readahead = readahead
lock=lock, self.kwargs = kwargs
readahead=readahead, self._client = None
**kwargs)
def get(self, filepath): def get(self, filepath):
"""Get values according to the filepath. """Get values according to the filepath.
...@@ -491,14 +494,29 @@ class LmdbBackend(BaseStorageBackend): ...@@ -491,14 +494,29 @@ class LmdbBackend(BaseStorageBackend):
Args: Args:
filepath (str | obj:`Path`): Here, filepath is the lmdb key. filepath (str | obj:`Path`): Here, filepath is the lmdb key.
""" """
filepath = str(filepath) if self._client is None:
self._client = self._get_client()
with self._client.begin(write=False) as txn: with self._client.begin(write=False) as txn:
value_buf = txn.get(filepath.encode('ascii')) value_buf = txn.get(str(filepath).encode('utf-8'))
return value_buf return value_buf
def get_text(self, filepath, encoding=None): def get_text(self, filepath, encoding=None):
raise NotImplementedError raise NotImplementedError
def _get_client(self):
import lmdb
return lmdb.open(
self.db_path,
readonly=self.readonly,
lock=self.lock,
readahead=self.readahead,
**self.kwargs)
def __del__(self):
self._client.close()
class HardDiskBackend(BaseStorageBackend): class HardDiskBackend(BaseStorageBackend):
"""Raw hard disks storage backend.""" """Raw hard disks storage backend."""
...@@ -531,7 +549,7 @@ class HardDiskBackend(BaseStorageBackend): ...@@ -531,7 +549,7 @@ class HardDiskBackend(BaseStorageBackend):
Returns: Returns:
str: Expected text reading from ``filepath``. str: Expected text reading from ``filepath``.
""" """
with open(filepath, 'r', encoding=encoding) as f: with open(filepath, encoding=encoding) as f:
value_buf = f.read() value_buf = f.read()
return value_buf return value_buf
...@@ -598,7 +616,7 @@ class HardDiskBackend(BaseStorageBackend): ...@@ -598,7 +616,7 @@ class HardDiskBackend(BaseStorageBackend):
Returns: Returns:
bool: Return ``True`` if ``filepath`` points to a directory, bool: Return ``True`` if ``filepath`` points to a directory,
``False`` otherwise. ``False`` otherwise.
""" """
return osp.isdir(filepath) return osp.isdir(filepath)
...@@ -610,7 +628,7 @@ class HardDiskBackend(BaseStorageBackend): ...@@ -610,7 +628,7 @@ class HardDiskBackend(BaseStorageBackend):
Returns: Returns:
bool: Return ``True`` if ``filepath`` points to a file, ``False`` bool: Return ``True`` if ``filepath`` points to a file, ``False``
otherwise. otherwise.
""" """
return osp.isfile(filepath) return osp.isfile(filepath)
...@@ -631,7 +649,9 @@ class HardDiskBackend(BaseStorageBackend): ...@@ -631,7 +649,9 @@ class HardDiskBackend(BaseStorageBackend):
@contextmanager @contextmanager
def get_local_path( def get_local_path(
self, filepath: Union[str, Path]) -> Iterable[Union[str, Path]]: self,
filepath: Union[str,
Path]) -> Generator[Union[str, Path], None, None]:
"""Only for unified API and do nothing.""" """Only for unified API and do nothing."""
yield filepath yield filepath
...@@ -700,7 +720,8 @@ class HTTPBackend(BaseStorageBackend): ...@@ -700,7 +720,8 @@ class HTTPBackend(BaseStorageBackend):
return value_buf.decode(encoding) return value_buf.decode(encoding)
@contextmanager @contextmanager
def get_local_path(self, filepath: str) -> Iterable[str]: def get_local_path(
self, filepath: str) -> Generator[Union[str, Path], None, None]:
"""Download a file from ``filepath``. """Download a file from ``filepath``.
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
...@@ -770,19 +791,16 @@ class FileClient: ...@@ -770,19 +791,16 @@ class FileClient:
'petrel': PetrelBackend, 'petrel': PetrelBackend,
'http': HTTPBackend, 'http': HTTPBackend,
} }
# This collection is used to record the overridden backends, and when a
# backend appears in the collection, the singleton pattern is disabled for
# that backend, because if the singleton pattern is used, then the object
# returned will be the backend before overwriting
_overridden_backends = set()
_prefix_to_backends = { _prefix_to_backends = {
's3': PetrelBackend, 's3': PetrelBackend,
'http': HTTPBackend, 'http': HTTPBackend,
'https': HTTPBackend, 'https': HTTPBackend,
} }
_overridden_prefixes = set()
_instances = {} _instances: dict = {}
client: Any
def __new__(cls, backend=None, prefix=None, **kwargs): def __new__(cls, backend=None, prefix=None, **kwargs):
if backend is None and prefix is None: if backend is None and prefix is None:
...@@ -802,10 +820,7 @@ class FileClient: ...@@ -802,10 +820,7 @@ class FileClient:
for key, value in kwargs.items(): for key, value in kwargs.items():
arg_key += f':{key}:{value}' arg_key += f':{key}:{value}'
# if a backend was overridden, it will create a new object if arg_key in cls._instances:
if (arg_key in cls._instances
and backend not in cls._overridden_backends
and prefix not in cls._overridden_prefixes):
_instance = cls._instances[arg_key] _instance = cls._instances[arg_key]
else: else:
# create a new object and put it to _instance # create a new object and put it to _instance
...@@ -839,8 +854,8 @@ class FileClient: ...@@ -839,8 +854,8 @@ class FileClient:
's3' 's3'
Returns: Returns:
str | None: Return the prefix of uri if the uri contains '://' str | None: Return the prefix of uri if the uri contains '://' else
else ``None``. ``None``.
""" """
assert is_filepath(uri) assert is_filepath(uri)
uri = str(uri) uri = str(uri)
...@@ -899,7 +914,9 @@ class FileClient: ...@@ -899,7 +914,9 @@ class FileClient:
'add "force=True" if you want to override it') 'add "force=True" if you want to override it')
if name in cls._backends and force: if name in cls._backends and force:
cls._overridden_backends.add(name) for arg_key, instance in list(cls._instances.items()):
if isinstance(instance.client, cls._backends[name]):
cls._instances.pop(arg_key)
cls._backends[name] = backend cls._backends[name] = backend
if prefixes is not None: if prefixes is not None:
...@@ -911,7 +928,12 @@ class FileClient: ...@@ -911,7 +928,12 @@ class FileClient:
if prefix not in cls._prefix_to_backends: if prefix not in cls._prefix_to_backends:
cls._prefix_to_backends[prefix] = backend cls._prefix_to_backends[prefix] = backend
elif (prefix in cls._prefix_to_backends) and force: elif (prefix in cls._prefix_to_backends) and force:
cls._overridden_prefixes.add(prefix) overridden_backend = cls._prefix_to_backends[prefix]
if isinstance(overridden_backend, list):
overridden_backend = tuple(overridden_backend)
for arg_key, instance in list(cls._instances.items()):
if isinstance(instance.client, overridden_backend):
cls._instances.pop(arg_key)
cls._prefix_to_backends[prefix] = backend cls._prefix_to_backends[prefix] = backend
else: else:
raise KeyError( raise KeyError(
...@@ -987,7 +1009,7 @@ class FileClient: ...@@ -987,7 +1009,7 @@ class FileClient:
Returns: Returns:
bytes | memoryview: Expected bytes object or a memory view of the bytes | memoryview: Expected bytes object or a memory view of the
bytes object. bytes object.
""" """
return self.client.get(filepath) return self.client.get(filepath)
...@@ -1060,7 +1082,7 @@ class FileClient: ...@@ -1060,7 +1082,7 @@ class FileClient:
Returns: Returns:
bool: Return ``True`` if ``filepath`` points to a directory, bool: Return ``True`` if ``filepath`` points to a directory,
``False`` otherwise. ``False`` otherwise.
""" """
return self.client.isdir(filepath) return self.client.isdir(filepath)
...@@ -1072,7 +1094,7 @@ class FileClient: ...@@ -1072,7 +1094,7 @@ class FileClient:
Returns: Returns:
bool: Return ``True`` if ``filepath`` points to a file, ``False`` bool: Return ``True`` if ``filepath`` points to a file, ``False``
otherwise. otherwise.
""" """
return self.client.isfile(filepath) return self.client.isfile(filepath)
...@@ -1092,7 +1114,10 @@ class FileClient: ...@@ -1092,7 +1114,10 @@ class FileClient:
return self.client.join_path(filepath, *filepaths) return self.client.join_path(filepath, *filepaths)
@contextmanager @contextmanager
def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]: def get_local_path(
self,
filepath: Union[str,
Path]) -> Generator[Union[str, Path], None, None]:
"""Download data from ``filepath`` and write the data to local path. """Download data from ``filepath`` and write the data to local path.
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
......
...@@ -21,10 +21,10 @@ class BaseFileHandler(metaclass=ABCMeta): ...@@ -21,10 +21,10 @@ class BaseFileHandler(metaclass=ABCMeta):
def dump_to_str(self, obj, **kwargs): def dump_to_str(self, obj, **kwargs):
pass pass
def load_from_path(self, filepath, mode='r', **kwargs): def load_from_path(self, filepath: str, mode: str = 'r', **kwargs):
with open(filepath, mode) as f: with open(filepath, mode) as f:
return self.load_from_fileobj(f, **kwargs) return self.load_from_fileobj(f, **kwargs)
def dump_to_path(self, obj, filepath, mode='w', **kwargs): def dump_to_path(self, obj, filepath: str, mode: str = 'w', **kwargs):
with open(filepath, mode) as f: with open(filepath, mode) as f:
self.dump_to_fileobj(obj, f, **kwargs) self.dump_to_fileobj(obj, f, **kwargs)
...@@ -12,8 +12,7 @@ class PickleHandler(BaseFileHandler): ...@@ -12,8 +12,7 @@ class PickleHandler(BaseFileHandler):
return pickle.load(file, **kwargs) return pickle.load(file, **kwargs)
def load_from_path(self, filepath, **kwargs): def load_from_path(self, filepath, **kwargs):
return super(PickleHandler, self).load_from_path( return super().load_from_path(filepath, mode='rb', **kwargs)
filepath, mode='rb', **kwargs)
def dump_to_str(self, obj, **kwargs): def dump_to_str(self, obj, **kwargs):
kwargs.setdefault('protocol', 2) kwargs.setdefault('protocol', 2)
...@@ -24,5 +23,4 @@ class PickleHandler(BaseFileHandler): ...@@ -24,5 +23,4 @@ class PickleHandler(BaseFileHandler):
pickle.dump(obj, file, **kwargs) pickle.dump(obj, file, **kwargs)
def dump_to_path(self, obj, filepath, **kwargs): def dump_to_path(self, obj, filepath, **kwargs):
super(PickleHandler, self).dump_to_path( super().dump_to_path(obj, filepath, mode='wb', **kwargs)
obj, filepath, mode='wb', **kwargs)
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
import yaml import yaml
try: try:
from yaml import CLoader as Loader, CDumper as Dumper from yaml import CDumper as Dumper
from yaml import CLoader as Loader
except ImportError: except ImportError:
from yaml import Loader, Dumper from yaml import Loader, Dumper # type: ignore
from .base import BaseFileHandler # isort:skip from .base import BaseFileHandler # isort:skip
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from io import BytesIO, StringIO from io import BytesIO, StringIO
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, TextIO, Union
from ..utils import is_list_of, is_str from ..utils import is_list_of
from .file_client import FileClient from .file_client import FileClient
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
FileLikeObject = Union[TextIO, StringIO, BytesIO]
file_handlers = { file_handlers = {
'json': JsonHandler(), 'json': JsonHandler(),
'yaml': YamlHandler(), 'yaml': YamlHandler(),
...@@ -15,7 +18,10 @@ file_handlers = { ...@@ -15,7 +18,10 @@ file_handlers = {
} }
def load(file, file_format=None, file_client_args=None, **kwargs): def load(file: Union[str, Path, FileLikeObject],
file_format: Optional[str] = None,
file_client_args: Optional[Dict] = None,
**kwargs):
"""Load data from json/yaml/pickle files. """Load data from json/yaml/pickle files.
This method provides a unified api for loading data from serialized files. This method provides a unified api for loading data from serialized files.
...@@ -45,13 +51,14 @@ def load(file, file_format=None, file_client_args=None, **kwargs): ...@@ -45,13 +51,14 @@ def load(file, file_format=None, file_client_args=None, **kwargs):
""" """
if isinstance(file, Path): if isinstance(file, Path):
file = str(file) file = str(file)
if file_format is None and is_str(file): if file_format is None and isinstance(file, str):
file_format = file.split('.')[-1] file_format = file.split('.')[-1]
if file_format not in file_handlers: if file_format not in file_handlers:
raise TypeError(f'Unsupported format: {file_format}') raise TypeError(f'Unsupported format: {file_format}')
handler = file_handlers[file_format] handler = file_handlers[file_format]
if is_str(file): f: FileLikeObject
if isinstance(file, str):
file_client = FileClient.infer_client(file_client_args, file) file_client = FileClient.infer_client(file_client_args, file)
if handler.str_like: if handler.str_like:
with StringIO(file_client.get_text(file)) as f: with StringIO(file_client.get_text(file)) as f:
...@@ -66,7 +73,11 @@ def load(file, file_format=None, file_client_args=None, **kwargs): ...@@ -66,7 +73,11 @@ def load(file, file_format=None, file_client_args=None, **kwargs):
return obj return obj
def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): def dump(obj: Any,
file: Optional[Union[str, Path, FileLikeObject]] = None,
file_format: Optional[str] = None,
file_client_args: Optional[Dict] = None,
**kwargs):
"""Dump data to json/yaml/pickle strings or files. """Dump data to json/yaml/pickle strings or files.
This method provides a unified api for dumping data as strings or to files, This method provides a unified api for dumping data as strings or to files,
...@@ -96,18 +107,18 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): ...@@ -96,18 +107,18 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
if isinstance(file, Path): if isinstance(file, Path):
file = str(file) file = str(file)
if file_format is None: if file_format is None:
if is_str(file): if isinstance(file, str):
file_format = file.split('.')[-1] file_format = file.split('.')[-1]
elif file is None: elif file is None:
raise ValueError( raise ValueError(
'file_format must be specified since file is None') 'file_format must be specified since file is None')
if file_format not in file_handlers: if file_format not in file_handlers:
raise TypeError(f'Unsupported format: {file_format}') raise TypeError(f'Unsupported format: {file_format}')
f: FileLikeObject
handler = file_handlers[file_format] handler = file_handlers[file_format]
if file is None: if file is None:
return handler.dump_to_str(obj, **kwargs) return handler.dump_to_str(obj, **kwargs)
elif is_str(file): elif isinstance(file, str):
file_client = FileClient.infer_client(file_client_args, file) file_client = FileClient.infer_client(file_client_args, file)
if handler.str_like: if handler.str_like:
with StringIO() as f: with StringIO() as f:
...@@ -123,7 +134,8 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): ...@@ -123,7 +134,8 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
raise TypeError('"file" must be a filename str or a file-object') raise TypeError('"file" must be a filename str or a file-object')
def _register_handler(handler, file_formats): def _register_handler(handler: BaseFileHandler,
file_formats: Union[str, List[str]]) -> None:
"""Register a handler for some file extensions. """Register a handler for some file extensions.
Args: Args:
...@@ -142,7 +154,7 @@ def _register_handler(handler, file_formats): ...@@ -142,7 +154,7 @@ def _register_handler(handler, file_formats):
file_handlers[ext] = handler file_handlers[ext] = handler
def register_handler(file_formats, **kwargs): def register_handler(file_formats: Union[str, list], **kwargs) -> Callable:
def wrap(cls): def wrap(cls):
_register_handler(cls(**kwargs), file_formats) _register_handler(cls(**kwargs), file_formats)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from io import StringIO from io import StringIO
from pathlib import Path
from typing import Dict, List, Optional, Union
from .file_client import FileClient from .file_client import FileClient
def list_from_file(filename, def list_from_file(filename: Union[str, Path],
prefix='', prefix: str = '',
offset=0, offset: int = 0,
max_num=0, max_num: int = 0,
encoding='utf-8', encoding: str = 'utf-8',
file_client_args=None): file_client_args: Optional[Dict] = None) -> List:
"""Load a text file and parse the content as a list of strings. """Load a text file and parse the content as a list of strings.
Note: Note:
...@@ -52,10 +54,10 @@ def list_from_file(filename, ...@@ -52,10 +54,10 @@ def list_from_file(filename,
return item_list return item_list
def dict_from_file(filename, def dict_from_file(filename: Union[str, Path],
key_type=str, key_type: type = str,
encoding='utf-8', encoding: str = 'utf-8',
file_client_args=None): file_client_args: Optional[Dict] = None) -> Dict:
"""Load a text file and parse the content as a dict. """Load a text file and parse the content as a dict.
Each line of the text file will be two or more columns split by Each line of the text file will be two or more columns split by
......
...@@ -9,10 +9,10 @@ from .geometric import (cutout, imcrop, imflip, imflip_, impad, ...@@ -9,10 +9,10 @@ from .geometric import (cutout, imcrop, imflip, imflip_, impad,
from .io import imfrombytes, imread, imwrite, supported_backends, use_backend from .io import imfrombytes, imread, imwrite, supported_backends, use_backend
from .misc import tensor2imgs from .misc import tensor2imgs
from .photometric import (adjust_brightness, adjust_color, adjust_contrast, from .photometric import (adjust_brightness, adjust_color, adjust_contrast,
adjust_lighting, adjust_sharpness, auto_contrast, adjust_hue, adjust_lighting, adjust_sharpness,
clahe, imdenormalize, imequalize, iminvert, auto_contrast, clahe, imdenormalize, imequalize,
imnormalize, imnormalize_, lut_transform, posterize, iminvert, imnormalize, imnormalize_, lut_transform,
solarize) posterize, solarize)
__all__ = [ __all__ = [
'bgr2gray', 'bgr2hls', 'bgr2hsv', 'bgr2rgb', 'gray2bgr', 'gray2rgb', 'bgr2gray', 'bgr2hls', 'bgr2hsv', 'bgr2rgb', 'gray2bgr', 'gray2rgb',
...@@ -24,5 +24,6 @@ __all__ = [ ...@@ -24,5 +24,6 @@ __all__ = [
'solarize', 'rgb2ycbcr', 'bgr2ycbcr', 'ycbcr2rgb', 'ycbcr2bgr', 'solarize', 'rgb2ycbcr', 'bgr2ycbcr', 'ycbcr2rgb', 'ycbcr2bgr',
'tensor2imgs', 'imshear', 'imtranslate', 'adjust_color', 'imequalize', 'tensor2imgs', 'imshear', 'imtranslate', 'adjust_color', 'imequalize',
'adjust_brightness', 'adjust_contrast', 'lut_transform', 'clahe', 'adjust_brightness', 'adjust_contrast', 'lut_transform', 'clahe',
'adjust_sharpness', 'auto_contrast', 'cutout', 'adjust_lighting' 'adjust_sharpness', 'auto_contrast', 'cutout', 'adjust_lighting',
'adjust_hue'
] ]
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, Union
import cv2 import cv2
import numpy as np import numpy as np
def imconvert(img, src, dst): def imconvert(img: np.ndarray, src: str, dst: str) -> np.ndarray:
"""Convert an image from the src colorspace to dst colorspace. """Convert an image from the src colorspace to dst colorspace.
Args: Args:
...@@ -19,7 +21,7 @@ def imconvert(img, src, dst): ...@@ -19,7 +21,7 @@ def imconvert(img, src, dst):
return out_img return out_img
def bgr2gray(img, keepdim=False): def bgr2gray(img: np.ndarray, keepdim: bool = False) -> np.ndarray:
"""Convert a BGR image to grayscale image. """Convert a BGR image to grayscale image.
Args: Args:
...@@ -36,7 +38,7 @@ def bgr2gray(img, keepdim=False): ...@@ -36,7 +38,7 @@ def bgr2gray(img, keepdim=False):
return out_img return out_img
def rgb2gray(img, keepdim=False): def rgb2gray(img: np.ndarray, keepdim: bool = False) -> np.ndarray:
"""Convert a RGB image to grayscale image. """Convert a RGB image to grayscale image.
Args: Args:
...@@ -53,7 +55,7 @@ def rgb2gray(img, keepdim=False): ...@@ -53,7 +55,7 @@ def rgb2gray(img, keepdim=False):
return out_img return out_img
def gray2bgr(img): def gray2bgr(img: np.ndarray) -> np.ndarray:
"""Convert a grayscale image to BGR image. """Convert a grayscale image to BGR image.
Args: Args:
...@@ -67,7 +69,7 @@ def gray2bgr(img): ...@@ -67,7 +69,7 @@ def gray2bgr(img):
return out_img return out_img
def gray2rgb(img): def gray2rgb(img: np.ndarray) -> np.ndarray:
"""Convert a grayscale image to RGB image. """Convert a grayscale image to RGB image.
Args: Args:
...@@ -81,7 +83,7 @@ def gray2rgb(img): ...@@ -81,7 +83,7 @@ def gray2rgb(img):
return out_img return out_img
def _convert_input_type_range(img): def _convert_input_type_range(img: np.ndarray) -> np.ndarray:
"""Convert the type and range of the input image. """Convert the type and range of the input image.
It converts the input image to np.float32 type and range of [0, 1]. It converts the input image to np.float32 type and range of [0, 1].
...@@ -109,7 +111,8 @@ def _convert_input_type_range(img): ...@@ -109,7 +111,8 @@ def _convert_input_type_range(img):
return img return img
def _convert_output_type_range(img, dst_type): def _convert_output_type_range(
img: np.ndarray, dst_type: Union[np.uint8, np.float32]) -> np.ndarray:
"""Convert the type and range of the image according to dst_type. """Convert the type and range of the image according to dst_type.
It converts the image to desired type and range. If `dst_type` is np.uint8, It converts the image to desired type and range. If `dst_type` is np.uint8,
...@@ -140,7 +143,7 @@ def _convert_output_type_range(img, dst_type): ...@@ -140,7 +143,7 @@ def _convert_output_type_range(img, dst_type):
return img.astype(dst_type) return img.astype(dst_type)
def rgb2ycbcr(img, y_only=False): def rgb2ycbcr(img: np.ndarray, y_only: bool = False) -> np.ndarray:
"""Convert a RGB image to YCbCr image. """Convert a RGB image to YCbCr image.
This function produces the same results as Matlab's `rgb2ycbcr` function. This function produces the same results as Matlab's `rgb2ycbcr` function.
...@@ -160,7 +163,7 @@ def rgb2ycbcr(img, y_only=False): ...@@ -160,7 +163,7 @@ def rgb2ycbcr(img, y_only=False):
Returns: Returns:
ndarray: The converted YCbCr image. The output image has the same type ndarray: The converted YCbCr image. The output image has the same type
and range as input image. and range as input image.
""" """
img_type = img.dtype img_type = img.dtype
img = _convert_input_type_range(img) img = _convert_input_type_range(img)
...@@ -174,7 +177,7 @@ def rgb2ycbcr(img, y_only=False): ...@@ -174,7 +177,7 @@ def rgb2ycbcr(img, y_only=False):
return out_img return out_img
def bgr2ycbcr(img, y_only=False): def bgr2ycbcr(img: np.ndarray, y_only: bool = False) -> np.ndarray:
"""Convert a BGR image to YCbCr image. """Convert a BGR image to YCbCr image.
The bgr version of rgb2ycbcr. The bgr version of rgb2ycbcr.
...@@ -194,7 +197,7 @@ def bgr2ycbcr(img, y_only=False): ...@@ -194,7 +197,7 @@ def bgr2ycbcr(img, y_only=False):
Returns: Returns:
ndarray: The converted YCbCr image. The output image has the same type ndarray: The converted YCbCr image. The output image has the same type
and range as input image. and range as input image.
""" """
img_type = img.dtype img_type = img.dtype
img = _convert_input_type_range(img) img = _convert_input_type_range(img)
...@@ -208,7 +211,7 @@ def bgr2ycbcr(img, y_only=False): ...@@ -208,7 +211,7 @@ def bgr2ycbcr(img, y_only=False):
return out_img return out_img
def ycbcr2rgb(img): def ycbcr2rgb(img: np.ndarray) -> np.ndarray:
"""Convert a YCbCr image to RGB image. """Convert a YCbCr image to RGB image.
This function produces the same results as Matlab's ycbcr2rgb function. This function produces the same results as Matlab's ycbcr2rgb function.
...@@ -227,7 +230,7 @@ def ycbcr2rgb(img): ...@@ -227,7 +230,7 @@ def ycbcr2rgb(img):
Returns: Returns:
ndarray: The converted RGB image. The output image has the same type ndarray: The converted RGB image. The output image has the same type
and range as input image. and range as input image.
""" """
img_type = img.dtype img_type = img.dtype
img = _convert_input_type_range(img) * 255 img = _convert_input_type_range(img) * 255
...@@ -240,7 +243,7 @@ def ycbcr2rgb(img): ...@@ -240,7 +243,7 @@ def ycbcr2rgb(img):
return out_img return out_img
def ycbcr2bgr(img): def ycbcr2bgr(img: np.ndarray) -> np.ndarray:
"""Convert a YCbCr image to BGR image. """Convert a YCbCr image to BGR image.
The bgr version of ycbcr2rgb. The bgr version of ycbcr2rgb.
...@@ -259,7 +262,7 @@ def ycbcr2bgr(img): ...@@ -259,7 +262,7 @@ def ycbcr2bgr(img):
Returns: Returns:
ndarray: The converted BGR image. The output image has the same type ndarray: The converted BGR image. The output image has the same type
and range as input image. and range as input image.
""" """
img_type = img.dtype img_type = img.dtype
img = _convert_input_type_range(img) * 255 img = _convert_input_type_range(img) * 255
...@@ -272,11 +275,11 @@ def ycbcr2bgr(img): ...@@ -272,11 +275,11 @@ def ycbcr2bgr(img):
return out_img return out_img
def convert_color_factory(src, dst): def convert_color_factory(src: str, dst: str) -> Callable:
code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}') code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
def convert_color(img): def convert_color(img: np.ndarray) -> np.ndarray:
out_img = cv2.cvtColor(img, code) out_img = cv2.cvtColor(img, code)
return out_img return out_img
......
...@@ -37,15 +37,27 @@ cv2_interp_codes = { ...@@ -37,15 +37,27 @@ cv2_interp_codes = {
'lanczos': cv2.INTER_LANCZOS4 'lanczos': cv2.INTER_LANCZOS4
} }
# Pillow >=v9.1.0 use a slightly different naming scheme for filters.
# Set pillow_interp_codes according to the naming scheme used.
if Image is not None: if Image is not None:
pillow_interp_codes = { if hasattr(Image, 'Resampling'):
'nearest': Image.NEAREST, pillow_interp_codes = {
'bilinear': Image.BILINEAR, 'nearest': Image.Resampling.NEAREST,
'bicubic': Image.BICUBIC, 'bilinear': Image.Resampling.BILINEAR,
'box': Image.BOX, 'bicubic': Image.Resampling.BICUBIC,
'lanczos': Image.LANCZOS, 'box': Image.Resampling.BOX,
'hamming': Image.HAMMING 'lanczos': Image.Resampling.LANCZOS,
} 'hamming': Image.Resampling.HAMMING
}
else:
pillow_interp_codes = {
'nearest': Image.NEAREST,
'bilinear': Image.BILINEAR,
'bicubic': Image.BICUBIC,
'box': Image.BOX,
'lanczos': Image.LANCZOS,
'hamming': Image.HAMMING
}
def imresize(img, def imresize(img,
...@@ -70,7 +82,7 @@ def imresize(img, ...@@ -70,7 +82,7 @@ def imresize(img,
Returns: Returns:
tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
`resized_img`. `resized_img`.
""" """
h, w = img.shape[:2] h, w = img.shape[:2]
if backend is None: if backend is None:
...@@ -130,7 +142,7 @@ def imresize_to_multiple(img, ...@@ -130,7 +142,7 @@ def imresize_to_multiple(img,
Returns: Returns:
tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
`resized_img`. `resized_img`.
""" """
h, w = img.shape[:2] h, w = img.shape[:2]
if size is not None and scale_factor is not None: if size is not None and scale_factor is not None:
...@@ -145,7 +157,7 @@ def imresize_to_multiple(img, ...@@ -145,7 +157,7 @@ def imresize_to_multiple(img,
size = _scale_size((w, h), scale_factor) size = _scale_size((w, h), scale_factor)
divisor = to_2tuple(divisor) divisor = to_2tuple(divisor)
size = tuple([int(np.ceil(s / d)) * d for s, d in zip(size, divisor)]) size = tuple(int(np.ceil(s / d)) * d for s, d in zip(size, divisor))
resized_img, w_scale, h_scale = imresize( resized_img, w_scale, h_scale = imresize(
img, img,
size, size,
...@@ -175,7 +187,7 @@ def imresize_like(img, ...@@ -175,7 +187,7 @@ def imresize_like(img,
Returns: Returns:
tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or
`resized_img`. `resized_img`.
""" """
h, w = dst_img.shape[:2] h, w = dst_img.shape[:2]
return imresize(img, (w, h), return_scale, interpolation, backend=backend) return imresize(img, (w, h), return_scale, interpolation, backend=backend)
...@@ -460,18 +472,17 @@ def impad(img, ...@@ -460,18 +472,17 @@ def impad(img,
areas when padding_mode is 'constant'. Default: 0. areas when padding_mode is 'constant'. Default: 0.
padding_mode (str): Type of padding. Should be: constant, edge, padding_mode (str): Type of padding. Should be: constant, edge,
reflect or symmetric. Default: constant. reflect or symmetric. Default: constant.
- constant: pads with a constant value, this value is specified - constant: pads with a constant value, this value is specified
with pad_val. with pad_val.
- edge: pads with the last value at the edge of the image. - edge: pads with the last value at the edge of the image.
- reflect: pads with reflection of image without repeating the - reflect: pads with reflection of image without repeating the last
last value on the edge. For example, padding [1, 2, 3, 4] value on the edge. For example, padding [1, 2, 3, 4] with 2
with 2 elements on both sides in reflect mode will result elements on both sides in reflect mode will result in
in [3, 2, 1, 2, 3, 4, 3, 2]. [3, 2, 1, 2, 3, 4, 3, 2].
- symmetric: pads with reflection of image repeating the last - symmetric: pads with reflection of image repeating the last value
value on the edge. For example, padding [1, 2, 3, 4] with on the edge. For example, padding [1, 2, 3, 4] with 2 elements on
2 elements on both sides in symmetric mode will result in both sides in symmetric mode will result in
[2, 1, 1, 2, 3, 4, 4, 3] [2, 1, 1, 2, 3, 4, 4, 3]
Returns: Returns:
ndarray: The padded image. ndarray: The padded image.
...@@ -479,7 +490,9 @@ def impad(img, ...@@ -479,7 +490,9 @@ def impad(img,
assert (shape is not None) ^ (padding is not None) assert (shape is not None) ^ (padding is not None)
if shape is not None: if shape is not None:
padding = (0, 0, shape[1] - img.shape[1], shape[0] - img.shape[0]) width = max(shape[1] - img.shape[1], 0)
height = max(shape[0] - img.shape[0], 0)
padding = (0, 0, width, height)
# check pad_val # check pad_val
if isinstance(pad_val, tuple): if isinstance(pad_val, tuple):
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import io import io
import os.path as osp import os.path as osp
import warnings
from pathlib import Path from pathlib import Path
import cv2 import cv2
...@@ -8,7 +9,8 @@ import numpy as np ...@@ -8,7 +9,8 @@ import numpy as np
from cv2 import (IMREAD_COLOR, IMREAD_GRAYSCALE, IMREAD_IGNORE_ORIENTATION, from cv2 import (IMREAD_COLOR, IMREAD_GRAYSCALE, IMREAD_IGNORE_ORIENTATION,
IMREAD_UNCHANGED) IMREAD_UNCHANGED)
from mmcv.utils import check_file_exist, is_str, mkdir_or_exist from mmcv.fileio import FileClient
from mmcv.utils import is_filepath, is_str
try: try:
from turbojpeg import TJCS_RGB, TJPF_BGR, TJPF_GRAY, TurboJPEG from turbojpeg import TJCS_RGB, TJPF_BGR, TJPF_GRAY, TurboJPEG
...@@ -137,9 +139,16 @@ def _pillow2array(img, flag='color', channel_order='bgr'): ...@@ -137,9 +139,16 @@ def _pillow2array(img, flag='color', channel_order='bgr'):
return array return array
def imread(img_or_path, flag='color', channel_order='bgr', backend=None): def imread(img_or_path,
flag='color',
channel_order='bgr',
backend=None,
file_client_args=None):
"""Read an image. """Read an image.
Note:
In v1.4.1 and later, add `file_client_args` parameters.
Args: Args:
img_or_path (ndarray or str or Path): Either a numpy array or str or img_or_path (ndarray or str or Path): Either a numpy array or str or
pathlib.Path. If it is a numpy array (loaded image), then pathlib.Path. If it is a numpy array (loaded image), then
...@@ -157,44 +166,42 @@ def imread(img_or_path, flag='color', channel_order='bgr', backend=None): ...@@ -157,44 +166,42 @@ def imread(img_or_path, flag='color', channel_order='bgr', backend=None):
`cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`. `cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`.
If backend is None, the global imread_backend specified by If backend is None, the global imread_backend specified by
``mmcv.use_backend()`` will be used. Default: None. ``mmcv.use_backend()`` will be used. Default: None.
file_client_args (dict | None): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
Returns: Returns:
ndarray: Loaded image array. ndarray: Loaded image array.
Examples:
>>> import mmcv
>>> img_path = '/path/to/img.jpg'
>>> img = mmcv.imread(img_path)
>>> img = mmcv.imread(img_path, flag='color', channel_order='rgb',
... backend='cv2')
>>> img = mmcv.imread(img_path, flag='color', channel_order='bgr',
... backend='pillow')
>>> s3_img_path = 's3://bucket/img.jpg'
>>> # infer the file backend by the prefix s3
>>> img = mmcv.imread(s3_img_path)
>>> # manually set the file backend petrel
>>> img = mmcv.imread(s3_img_path, file_client_args={
... 'backend': 'petrel'})
>>> http_img_path = 'http://path/to/img.jpg'
>>> img = mmcv.imread(http_img_path)
>>> img = mmcv.imread(http_img_path, file_client_args={
... 'backend': 'http'})
""" """
if backend is None:
backend = imread_backend
if backend not in supported_backends:
raise ValueError(f'backend: {backend} is not supported. Supported '
"backends are 'cv2', 'turbojpeg', 'pillow'")
if isinstance(img_or_path, Path): if isinstance(img_or_path, Path):
img_or_path = str(img_or_path) img_or_path = str(img_or_path)
if isinstance(img_or_path, np.ndarray): if isinstance(img_or_path, np.ndarray):
return img_or_path return img_or_path
elif is_str(img_or_path): elif is_str(img_or_path):
check_file_exist(img_or_path, file_client = FileClient.infer_client(file_client_args, img_or_path)
f'img file does not exist: {img_or_path}') img_bytes = file_client.get(img_or_path)
if backend == 'turbojpeg': return imfrombytes(img_bytes, flag, channel_order, backend)
with open(img_or_path, 'rb') as in_file:
img = jpeg.decode(in_file.read(),
_jpegflag(flag, channel_order))
if img.shape[-1] == 1:
img = img[:, :, 0]
return img
elif backend == 'pillow':
img = Image.open(img_or_path)
img = _pillow2array(img, flag, channel_order)
return img
elif backend == 'tifffile':
img = tifffile.imread(img_or_path)
return img
else:
flag = imread_flags[flag] if is_str(flag) else flag
img = cv2.imread(img_or_path, flag)
if flag == IMREAD_COLOR and channel_order == 'rgb':
cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
return img
else: else:
raise TypeError('"img" must be a numpy array or a str or ' raise TypeError('"img" must be a numpy array or a str or '
'a pathlib.Path object') 'a pathlib.Path object')
...@@ -206,29 +213,45 @@ def imfrombytes(content, flag='color', channel_order='bgr', backend=None): ...@@ -206,29 +213,45 @@ def imfrombytes(content, flag='color', channel_order='bgr', backend=None):
Args: Args:
content (bytes): Image bytes got from files or other streams. content (bytes): Image bytes got from files or other streams.
flag (str): Same as :func:`imread`. flag (str): Same as :func:`imread`.
channel_order (str): The channel order of the output, candidates
are 'bgr' and 'rgb'. Default to 'bgr'.
backend (str | None): The image decoding backend type. Options are backend (str | None): The image decoding backend type. Options are
`cv2`, `pillow`, `turbojpeg`, `None`. If backend is None, the `cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`. If backend is
global imread_backend specified by ``mmcv.use_backend()`` will be None, the global imread_backend specified by ``mmcv.use_backend()``
used. Default: None. will be used. Default: None.
Returns: Returns:
ndarray: Loaded image array. ndarray: Loaded image array.
Examples:
>>> img_path = '/path/to/img.jpg'
>>> with open(img_path, 'rb') as f:
>>> img_buff = f.read()
>>> img = mmcv.imfrombytes(img_buff)
>>> img = mmcv.imfrombytes(img_buff, flag='color', channel_order='rgb')
>>> img = mmcv.imfrombytes(img_buff, backend='pillow')
>>> img = mmcv.imfrombytes(img_buff, backend='cv2')
""" """
if backend is None: if backend is None:
backend = imread_backend backend = imread_backend
if backend not in supported_backends: if backend not in supported_backends:
raise ValueError(f'backend: {backend} is not supported. Supported ' raise ValueError(
"backends are 'cv2', 'turbojpeg', 'pillow'") f'backend: {backend} is not supported. Supported '
"backends are 'cv2', 'turbojpeg', 'pillow', 'tifffile'")
if backend == 'turbojpeg': if backend == 'turbojpeg':
img = jpeg.decode(content, _jpegflag(flag, channel_order)) img = jpeg.decode(content, _jpegflag(flag, channel_order))
if img.shape[-1] == 1: if img.shape[-1] == 1:
img = img[:, :, 0] img = img[:, :, 0]
return img return img
elif backend == 'pillow': elif backend == 'pillow':
buff = io.BytesIO(content) with io.BytesIO(content) as buff:
img = Image.open(buff) img = Image.open(buff)
img = _pillow2array(img, flag, channel_order) img = _pillow2array(img, flag, channel_order)
return img
elif backend == 'tifffile':
with io.BytesIO(content) as buff:
img = tifffile.imread(buff)
return img return img
else: else:
img_np = np.frombuffer(content, np.uint8) img_np = np.frombuffer(content, np.uint8)
...@@ -239,20 +262,53 @@ def imfrombytes(content, flag='color', channel_order='bgr', backend=None): ...@@ -239,20 +262,53 @@ def imfrombytes(content, flag='color', channel_order='bgr', backend=None):
return img return img
def imwrite(img, file_path, params=None, auto_mkdir=True): def imwrite(img,
file_path,
params=None,
auto_mkdir=None,
file_client_args=None):
"""Write image to file. """Write image to file.
Note:
In v1.4.1 and later, add `file_client_args` parameters.
Warning:
The parameter `auto_mkdir` will be deprecated in the future and every
file clients will make directory automatically.
Args: Args:
img (ndarray): Image array to be written. img (ndarray): Image array to be written.
file_path (str): Image file path. file_path (str): Image file path.
params (None or list): Same as opencv :func:`imwrite` interface. params (None or list): Same as opencv :func:`imwrite` interface.
auto_mkdir (bool): If the parent folder of `file_path` does not exist, auto_mkdir (bool): If the parent folder of `file_path` does not exist,
whether to create it automatically. whether to create it automatically. It will be deprecated.
file_client_args (dict | None): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
Returns: Returns:
bool: Successful or not. bool: Successful or not.
Examples:
>>> # write to hard disk client
>>> ret = mmcv.imwrite(img, '/path/to/img.jpg')
>>> # infer the file backend by the prefix s3
>>> ret = mmcv.imwrite(img, 's3://bucket/img.jpg')
>>> # manually set the file backend petrel
>>> ret = mmcv.imwrite(img, 's3://bucket/img.jpg', file_client_args={
... 'backend': 'petrel'})
""" """
if auto_mkdir: assert is_filepath(file_path)
dir_name = osp.abspath(osp.dirname(file_path)) file_path = str(file_path)
mkdir_or_exist(dir_name) if auto_mkdir is not None:
return cv2.imwrite(file_path, img, params) warnings.warn(
'The parameter `auto_mkdir` will be deprecated in the future and '
'every file clients will make directory automatically.')
file_client = FileClient.infer_client(file_client_args, file_path)
img_ext = osp.splitext(file_path)[-1]
# Encode image according to image suffix.
# For example, if image path is '/path/your/img.jpg', the encode
# format is '.jpg'.
flag, img_buff = cv2.imencode(img_ext, img, params)
file_client.put(img_buff.tobytes(), file_path)
return flag
...@@ -9,18 +9,21 @@ except ImportError: ...@@ -9,18 +9,21 @@ except ImportError:
torch = None torch = None
def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True): def tensor2imgs(tensor, mean=None, std=None, to_rgb=True):
"""Convert tensor to 3-channel images. """Convert tensor to 3-channel images or 1-channel gray images.
Args: Args:
tensor (torch.Tensor): Tensor that contains multiple images, shape ( tensor (torch.Tensor): Tensor that contains multiple images, shape (
N, C, H, W). N, C, H, W). :math:`C` can be either 3 or 1.
mean (tuple[float], optional): Mean of images. Defaults to (0, 0, 0). mean (tuple[float], optional): Mean of images. If None,
std (tuple[float], optional): Standard deviation of images. (0, 0, 0) will be used for tensor with 3-channel,
Defaults to (1, 1, 1). while (0, ) for tensor with 1-channel. Defaults to None.
std (tuple[float], optional): Standard deviation of images. If None,
(1, 1, 1) will be used for tensor with 3-channel,
while (1, ) for tensor with 1-channel. Defaults to None.
to_rgb (bool, optional): Whether the tensor was converted to RGB to_rgb (bool, optional): Whether the tensor was converted to RGB
format in the first place. If so, convert it back to BGR. format in the first place. If so, convert it back to BGR.
Defaults to True. For the tensor with 1 channel, it must be False. Defaults to True.
Returns: Returns:
list[np.ndarray]: A list that contains multiple images. list[np.ndarray]: A list that contains multiple images.
...@@ -29,8 +32,14 @@ def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True): ...@@ -29,8 +32,14 @@ def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True):
if torch is None: if torch is None:
raise RuntimeError('pytorch is not installed') raise RuntimeError('pytorch is not installed')
assert torch.is_tensor(tensor) and tensor.ndim == 4 assert torch.is_tensor(tensor) and tensor.ndim == 4
assert len(mean) == 3 channels = tensor.size(1)
assert len(std) == 3 assert channels in [1, 3]
if mean is None:
mean = (0, ) * channels
if std is None:
std = (1, ) * channels
assert (channels == len(mean) == len(std) == 3) or \
(channels == len(mean) == len(std) == 1 and not to_rgb)
num_imgs = tensor.size(0) num_imgs = tensor.size(0)
mean = np.array(mean, dtype=np.float32) mean = np.array(mean, dtype=np.float32)
......
...@@ -426,3 +426,46 @@ def clahe(img, clip_limit=40.0, tile_grid_size=(8, 8)): ...@@ -426,3 +426,46 @@ def clahe(img, clip_limit=40.0, tile_grid_size=(8, 8)):
clahe = cv2.createCLAHE(clip_limit, tile_grid_size) clahe = cv2.createCLAHE(clip_limit, tile_grid_size)
return clahe.apply(np.array(img, dtype=np.uint8)) return clahe.apply(np.array(img, dtype=np.uint8))
def adjust_hue(img: np.ndarray, hue_factor: float) -> np.ndarray:
"""Adjust hue of an image.
The image hue is adjusted by converting the image to HSV and cyclically
shifting the intensities in the hue channel (H). The image is then
converted back to original image mode.
`hue_factor` is the amount of shift in H channel and must be in the
interval `[-0.5, 0.5]`.
Modified from
https://github.com/pytorch/vision/blob/main/torchvision/
transforms/functional.py
Args:
img (ndarray): Image to be adjusted.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
with complementary colors while 0 gives the original image.
Returns:
ndarray: Hue adjusted image.
"""
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f'hue_factor:{hue_factor} is not in [-0.5, 0.5].')
if not (isinstance(img, np.ndarray) and (img.ndim in {2, 3})):
raise TypeError('img should be ndarray with dim=[2 or 3].')
dtype = img.dtype
img = img.astype(np.uint8)
hsv_img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV_FULL)
h, s, v = cv2.split(hsv_img)
h = h.astype(np.uint8)
# uint8 addition take cares of rotation across boundaries
with np.errstate(over='ignore'):
h += np.uint8(hue_factor * 255)
hsv_img = cv2.merge([h, s, v])
return cv2.cvtColor(hsv_img, cv2.COLOR_HSV2RGB_FULL).astype(dtype)
{
"alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
"densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth",
"densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth",
"densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth",
"densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth",
"efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
"efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
"efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
"efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
"efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
"efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
"efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
"efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
"googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth",
"inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
"mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
"mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
"mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
"regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
"regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
"regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
"regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
"regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
"regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
"regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
"regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
"regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
"regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
"regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
"regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
"regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
"regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
"resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
"resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
"resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
"resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
"resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
"resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
"resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
"wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
"wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
"shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
"shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
"shufflenetv2_x1.5": null,
"shufflenetv2_x2.0": null,
"squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
"squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
"vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth",
"vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth",
"vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
"vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
"vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth"
}
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os import os
import warnings
import torch import torch
def is_custom_op_loaded(): def is_custom_op_loaded() -> bool:
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
msg = white_background + bright_style + red_text
msg += 'DeprecationWarning: This function will be deprecated in future. '
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg += reset_style
warnings.warn(msg)
flag = False flag = False
try: try:
from ..tensorrt import is_tensorrt_plugin_loaded from ..tensorrt import is_tensorrt_plugin_loaded
......
...@@ -59,7 +59,7 @@ def _parse_arg(value, desc): ...@@ -59,7 +59,7 @@ def _parse_arg(value, desc):
raise RuntimeError( raise RuntimeError(
"ONNX symbolic doesn't know to interpret ListConstruct node") "ONNX symbolic doesn't know to interpret ListConstruct node")
raise RuntimeError('Unexpected node type: {}'.format(value.node().kind())) raise RuntimeError(f'Unexpected node type: {value.node().kind()}')
def _maybe_get_const(value, desc): def _maybe_get_const(value, desc):
...@@ -328,4 +328,4 @@ cast_pytorch_to_onnx = { ...@@ -328,4 +328,4 @@ cast_pytorch_to_onnx = {
# Global set to store the list of quantized operators in the network. # Global set to store the list of quantized operators in the network.
# This is currently only used in the conversion of quantized ops from PT # This is currently only used in the conversion of quantized ops from PT
# -> C2 via ONNX. # -> C2 via ONNX.
_quantized_ops = set() _quantized_ops: set = set()
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
"""Modified from https://github.com/pytorch/pytorch.""" """Modified from https://github.com/pytorch/pytorch."""
import os import os
import warnings
import numpy as np import numpy as np
import torch import torch
...@@ -409,8 +410,8 @@ def cummin(g, input, dim): ...@@ -409,8 +410,8 @@ def cummin(g, input, dim):
@parse_args('v', 'v', 'is') @parse_args('v', 'v', 'is')
def roll(g, input, shifts, dims): def roll(g, input, shifts, dims):
from torch.onnx.symbolic_opset9 import squeeze
from packaging import version from packaging import version
from torch.onnx.symbolic_opset9 import squeeze
input_shape = g.op('Shape', input) input_shape = g.op('Shape', input)
need_flatten = len(dims) == 0 need_flatten = len(dims) == 0
...@@ -467,6 +468,18 @@ def roll(g, input, shifts, dims): ...@@ -467,6 +468,18 @@ def roll(g, input, shifts, dims):
def register_extra_symbolics(opset=11): def register_extra_symbolics(opset=11):
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
msg = white_background + bright_style + red_text
msg += 'DeprecationWarning: This function will be deprecated in future. '
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg += reset_style
warnings.warn(msg)
register_op('one_hot', one_hot, '', opset) register_op('one_hot', one_hot, '', opset)
register_op('im2col', im2col, '', opset) register_op('im2col', im2col, '', opset)
register_op('topk', topk, '', opset) register_op('topk', topk, '', opset)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment