"vscode:/vscode.git/clone" did not exist on "c267b1a02c952b68a897c96201f32ad57e0b955e"
Commit fdeee889 authored by limm's avatar limm
Browse files

release v1.6.1 of mmcv

parent df465820
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.parallel import MMDataParallel
from ..scatter_gather import scatter_kwargs
class MPSDataParallel(MMDataParallel):
"""The MPSDataParallel module that supports DataContainer.
MPSDataParallel is a class inherited from MMDataParall, which supports
MPS training and inference only.
The main differences with MMDataParallel:
- It only supports single-card of MPS, and only use first card to
run training and inference.
- It uses direct host-to-device copy instead of stream-background
scatter.
Args:
module (:class:`nn.Module`): Module to be encapsulated.
dim (int): Dimension used to scatter the data. Defaults to 0.
"""
def __init__(self, *args, dim=0, **kwargs):
super().__init__(*args, dim=dim, **kwargs)
self.device_ids = [0]
self.src_device_obj = torch.device('mps:0')
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.parallel.data_container import DataContainer
from mmcv.utils import deprecated_api_warning
from ._functions import Scatter
from .utils import get_device
@deprecated_api_warning({'target_mlus': 'target_devices'})
def scatter(inputs, target_devices, dim=0):
"""Scatter inputs to target devices.
The only difference from original :func:`scatter` is to add support for
:type:`~mmcv.parallel.DataContainer`.
"""
current_device = get_device()
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
if target_devices != [-1]:
obj = obj.to(current_device)
return [obj]
else:
# for CPU inference we use self-implemented scatter
return Scatter.forward(target_devices, obj)
if isinstance(obj, DataContainer):
if obj.cpu_only:
return obj.data
else:
return Scatter.forward(target_devices, obj.data)
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
out = list(map(list, zip(*map(scatter_map, obj))))
return out
if isinstance(obj, dict) and len(obj) > 0:
out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return out
return [obj for _ in target_devices]
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
return scatter_map(inputs)
finally:
scatter_map = None
@deprecated_api_warning({'target_mlus': 'target_devices'})
def scatter_kwargs(inputs, kwargs, target_devices, dim=0):
"""Scatter with support for kwargs dictionary."""
inputs = scatter(inputs, target_devices, dim) if inputs else []
kwargs = scatter(kwargs, target_devices, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE
def get_device() -> str:
"""Returns the currently existing device type.
Returns:
str: cuda | mlu | mps | cpu.
"""
if IS_CUDA_AVAILABLE:
return 'cuda'
elif IS_MLU_AVAILABLE:
return 'mlu'
elif IS_MPS_AVAILABLE:
return 'mps'
else:
return 'cpu'
......@@ -4,15 +4,18 @@ import pickle
import shutil
import tempfile
import time
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.utils.data import DataLoader
import mmcv
from mmcv.runner import get_dist_info
def single_gpu_test(model, data_loader):
def single_gpu_test(model: nn.Module, data_loader: DataLoader) -> list:
"""Test model with a single gpu.
This method tests model with a single gpu and displays test progress bar.
......@@ -41,7 +44,10 @@ def single_gpu_test(model, data_loader):
return results
def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
def multi_gpu_test(model: nn.Module,
data_loader: DataLoader,
tmpdir: Optional[str] = None,
gpu_collect: bool = False) -> Optional[list]:
"""Test model with multiple gpus.
This method tests model with multiple gpus and collects the results
......@@ -82,13 +88,15 @@ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
# collect results from all ranks
if gpu_collect:
results = collect_results_gpu(results, len(dataset))
result_from_ranks = collect_results_gpu(results, len(dataset))
else:
results = collect_results_cpu(results, len(dataset), tmpdir)
return results
result_from_ranks = collect_results_cpu(results, len(dataset), tmpdir)
return result_from_ranks
def collect_results_cpu(result_part, size, tmpdir=None):
def collect_results_cpu(result_part: list,
size: int,
tmpdir: Optional[str] = None) -> Optional[list]:
"""Collect results under cpu mode.
On cpu mode, this function will save the results on different gpus to
......@@ -126,7 +134,8 @@ def collect_results_cpu(result_part, size, tmpdir=None):
else:
mmcv.mkdir_or_exist(tmpdir)
# dump the part result to the dir
mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
part_file = osp.join(tmpdir, f'part_{rank}.pkl') # type: ignore
mmcv.dump(result_part, part_file)
dist.barrier()
# collect all parts
if rank != 0:
......@@ -135,7 +144,7 @@ def collect_results_cpu(result_part, size, tmpdir=None):
# load results of all parts from tmp dir
part_list = []
for i in range(world_size):
part_file = osp.join(tmpdir, f'part_{i}.pkl')
part_file = osp.join(tmpdir, f'part_{i}.pkl') # type: ignore
part_result = mmcv.load(part_file)
# When data is severely insufficient, an empty part_result
# on a certain gpu could makes the overall outputs empty.
......@@ -148,11 +157,11 @@ def collect_results_cpu(result_part, size, tmpdir=None):
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
# remove tmp dir
shutil.rmtree(tmpdir)
shutil.rmtree(tmpdir) # type: ignore
return ordered_results
def collect_results_gpu(result_part, size):
def collect_results_gpu(result_part: list, size: int) -> Optional[list]:
"""Collect results under gpu mode.
On gpu mode, this function will encode results to gpu tensors and use gpu
......@@ -200,3 +209,5 @@ def collect_results_gpu(result_part, size):
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
return ordered_results
else:
return None
......@@ -8,7 +8,7 @@ import warnings
from abc import ABCMeta, abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import Iterable, Iterator, Optional, Tuple, Union
from typing import Any, Generator, Iterator, Optional, Tuple, Union
from urllib.request import urlopen
import mmcv
......@@ -64,7 +64,8 @@ class CephBackend(BaseStorageBackend):
raise ImportError('Please install ceph to enable CephBackend.')
warnings.warn(
'CephBackend will be deprecated, please use PetrelBackend instead')
'CephBackend will be deprecated, please use PetrelBackend instead',
DeprecationWarning)
self._client = ceph.S3Client()
assert isinstance(path_mapping, dict) or path_mapping is None
self.path_mapping = path_mapping
......@@ -209,9 +210,9 @@ class PetrelBackend(BaseStorageBackend):
"""
if not has_method(self._client, 'delete'):
raise NotImplementedError(
('Current version of Petrel Python SDK has not supported '
'the `delete` method, please use a higher version or dev'
' branch instead.'))
'Current version of Petrel Python SDK has not supported '
'the `delete` method, please use a higher version or dev'
' branch instead.')
filepath = self._map_path(filepath)
filepath = self._format_path(filepath)
......@@ -229,9 +230,9 @@ class PetrelBackend(BaseStorageBackend):
if not (has_method(self._client, 'contains')
and has_method(self._client, 'isdir')):
raise NotImplementedError(
('Current version of Petrel Python SDK has not supported '
'the `contains` and `isdir` methods, please use a higher'
'version or dev branch instead.'))
'Current version of Petrel Python SDK has not supported '
'the `contains` and `isdir` methods, please use a higher'
'version or dev branch instead.')
filepath = self._map_path(filepath)
filepath = self._format_path(filepath)
......@@ -246,13 +247,13 @@ class PetrelBackend(BaseStorageBackend):
Returns:
bool: Return ``True`` if ``filepath`` points to a directory,
``False`` otherwise.
``False`` otherwise.
"""
if not has_method(self._client, 'isdir'):
raise NotImplementedError(
('Current version of Petrel Python SDK has not supported '
'the `isdir` method, please use a higher version or dev'
' branch instead.'))
'Current version of Petrel Python SDK has not supported '
'the `isdir` method, please use a higher version or dev'
' branch instead.')
filepath = self._map_path(filepath)
filepath = self._format_path(filepath)
......@@ -266,13 +267,13 @@ class PetrelBackend(BaseStorageBackend):
Returns:
bool: Return ``True`` if ``filepath`` points to a file, ``False``
otherwise.
otherwise.
"""
if not has_method(self._client, 'contains'):
raise NotImplementedError(
('Current version of Petrel Python SDK has not supported '
'the `contains` method, please use a higher version or '
'dev branch instead.'))
'Current version of Petrel Python SDK has not supported '
'the `contains` method, please use a higher version or '
'dev branch instead.')
filepath = self._map_path(filepath)
filepath = self._format_path(filepath)
......@@ -297,7 +298,10 @@ class PetrelBackend(BaseStorageBackend):
return '/'.join(formatted_paths)
@contextmanager
def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]:
def get_local_path(
self,
filepath: Union[str,
Path]) -> Generator[Union[str, Path], None, None]:
"""Download a file from ``filepath`` and return a temporary path.
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
......@@ -362,9 +366,9 @@ class PetrelBackend(BaseStorageBackend):
"""
if not has_method(self._client, 'list'):
raise NotImplementedError(
('Current version of Petrel Python SDK has not supported '
'the `list` method, please use a higher version or dev'
' branch instead.'))
'Current version of Petrel Python SDK has not supported '
'the `list` method, please use a higher version or dev'
' branch instead.')
dir_path = self._map_path(dir_path)
dir_path = self._format_path(dir_path)
......@@ -473,17 +477,16 @@ class LmdbBackend(BaseStorageBackend):
readahead=False,
**kwargs):
try:
import lmdb
import lmdb # NOQA
except ImportError:
raise ImportError('Please install lmdb to enable LmdbBackend.')
self.db_path = str(db_path)
self._client = lmdb.open(
self.db_path,
readonly=readonly,
lock=lock,
readahead=readahead,
**kwargs)
self.readonly = readonly
self.lock = lock
self.readahead = readahead
self.kwargs = kwargs
self._client = None
def get(self, filepath):
"""Get values according to the filepath.
......@@ -491,14 +494,29 @@ class LmdbBackend(BaseStorageBackend):
Args:
filepath (str | obj:`Path`): Here, filepath is the lmdb key.
"""
filepath = str(filepath)
if self._client is None:
self._client = self._get_client()
with self._client.begin(write=False) as txn:
value_buf = txn.get(filepath.encode('ascii'))
value_buf = txn.get(str(filepath).encode('utf-8'))
return value_buf
def get_text(self, filepath, encoding=None):
raise NotImplementedError
def _get_client(self):
import lmdb
return lmdb.open(
self.db_path,
readonly=self.readonly,
lock=self.lock,
readahead=self.readahead,
**self.kwargs)
def __del__(self):
self._client.close()
class HardDiskBackend(BaseStorageBackend):
"""Raw hard disks storage backend."""
......@@ -531,7 +549,7 @@ class HardDiskBackend(BaseStorageBackend):
Returns:
str: Expected text reading from ``filepath``.
"""
with open(filepath, 'r', encoding=encoding) as f:
with open(filepath, encoding=encoding) as f:
value_buf = f.read()
return value_buf
......@@ -598,7 +616,7 @@ class HardDiskBackend(BaseStorageBackend):
Returns:
bool: Return ``True`` if ``filepath`` points to a directory,
``False`` otherwise.
``False`` otherwise.
"""
return osp.isdir(filepath)
......@@ -610,7 +628,7 @@ class HardDiskBackend(BaseStorageBackend):
Returns:
bool: Return ``True`` if ``filepath`` points to a file, ``False``
otherwise.
otherwise.
"""
return osp.isfile(filepath)
......@@ -631,7 +649,9 @@ class HardDiskBackend(BaseStorageBackend):
@contextmanager
def get_local_path(
self, filepath: Union[str, Path]) -> Iterable[Union[str, Path]]:
self,
filepath: Union[str,
Path]) -> Generator[Union[str, Path], None, None]:
"""Only for unified API and do nothing."""
yield filepath
......@@ -700,7 +720,8 @@ class HTTPBackend(BaseStorageBackend):
return value_buf.decode(encoding)
@contextmanager
def get_local_path(self, filepath: str) -> Iterable[str]:
def get_local_path(
self, filepath: str) -> Generator[Union[str, Path], None, None]:
"""Download a file from ``filepath``.
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
......@@ -770,19 +791,16 @@ class FileClient:
'petrel': PetrelBackend,
'http': HTTPBackend,
}
# This collection is used to record the overridden backends, and when a
# backend appears in the collection, the singleton pattern is disabled for
# that backend, because if the singleton pattern is used, then the object
# returned will be the backend before overwriting
_overridden_backends = set()
_prefix_to_backends = {
's3': PetrelBackend,
'http': HTTPBackend,
'https': HTTPBackend,
}
_overridden_prefixes = set()
_instances = {}
_instances: dict = {}
client: Any
def __new__(cls, backend=None, prefix=None, **kwargs):
if backend is None and prefix is None:
......@@ -802,10 +820,7 @@ class FileClient:
for key, value in kwargs.items():
arg_key += f':{key}:{value}'
# if a backend was overridden, it will create a new object
if (arg_key in cls._instances
and backend not in cls._overridden_backends
and prefix not in cls._overridden_prefixes):
if arg_key in cls._instances:
_instance = cls._instances[arg_key]
else:
# create a new object and put it to _instance
......@@ -839,8 +854,8 @@ class FileClient:
's3'
Returns:
str | None: Return the prefix of uri if the uri contains '://'
else ``None``.
str | None: Return the prefix of uri if the uri contains '://' else
``None``.
"""
assert is_filepath(uri)
uri = str(uri)
......@@ -899,7 +914,9 @@ class FileClient:
'add "force=True" if you want to override it')
if name in cls._backends and force:
cls._overridden_backends.add(name)
for arg_key, instance in list(cls._instances.items()):
if isinstance(instance.client, cls._backends[name]):
cls._instances.pop(arg_key)
cls._backends[name] = backend
if prefixes is not None:
......@@ -911,7 +928,12 @@ class FileClient:
if prefix not in cls._prefix_to_backends:
cls._prefix_to_backends[prefix] = backend
elif (prefix in cls._prefix_to_backends) and force:
cls._overridden_prefixes.add(prefix)
overridden_backend = cls._prefix_to_backends[prefix]
if isinstance(overridden_backend, list):
overridden_backend = tuple(overridden_backend)
for arg_key, instance in list(cls._instances.items()):
if isinstance(instance.client, overridden_backend):
cls._instances.pop(arg_key)
cls._prefix_to_backends[prefix] = backend
else:
raise KeyError(
......@@ -987,7 +1009,7 @@ class FileClient:
Returns:
bytes | memoryview: Expected bytes object or a memory view of the
bytes object.
bytes object.
"""
return self.client.get(filepath)
......@@ -1060,7 +1082,7 @@ class FileClient:
Returns:
bool: Return ``True`` if ``filepath`` points to a directory,
``False`` otherwise.
``False`` otherwise.
"""
return self.client.isdir(filepath)
......@@ -1072,7 +1094,7 @@ class FileClient:
Returns:
bool: Return ``True`` if ``filepath`` points to a file, ``False``
otherwise.
otherwise.
"""
return self.client.isfile(filepath)
......@@ -1092,7 +1114,10 @@ class FileClient:
return self.client.join_path(filepath, *filepaths)
@contextmanager
def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]:
def get_local_path(
self,
filepath: Union[str,
Path]) -> Generator[Union[str, Path], None, None]:
"""Download data from ``filepath`` and write the data to local path.
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
......
......@@ -21,10 +21,10 @@ class BaseFileHandler(metaclass=ABCMeta):
def dump_to_str(self, obj, **kwargs):
pass
def load_from_path(self, filepath, mode='r', **kwargs):
def load_from_path(self, filepath: str, mode: str = 'r', **kwargs):
with open(filepath, mode) as f:
return self.load_from_fileobj(f, **kwargs)
def dump_to_path(self, obj, filepath, mode='w', **kwargs):
def dump_to_path(self, obj, filepath: str, mode: str = 'w', **kwargs):
with open(filepath, mode) as f:
self.dump_to_fileobj(obj, f, **kwargs)
......@@ -12,8 +12,7 @@ class PickleHandler(BaseFileHandler):
return pickle.load(file, **kwargs)
def load_from_path(self, filepath, **kwargs):
return super(PickleHandler, self).load_from_path(
filepath, mode='rb', **kwargs)
return super().load_from_path(filepath, mode='rb', **kwargs)
def dump_to_str(self, obj, **kwargs):
kwargs.setdefault('protocol', 2)
......@@ -24,5 +23,4 @@ class PickleHandler(BaseFileHandler):
pickle.dump(obj, file, **kwargs)
def dump_to_path(self, obj, filepath, **kwargs):
super(PickleHandler, self).dump_to_path(
obj, filepath, mode='wb', **kwargs)
super().dump_to_path(obj, filepath, mode='wb', **kwargs)
......@@ -2,9 +2,10 @@
import yaml
try:
from yaml import CLoader as Loader, CDumper as Dumper
from yaml import CDumper as Dumper
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader, Dumper
from yaml import Loader, Dumper # type: ignore
from .base import BaseFileHandler # isort:skip
......
# Copyright (c) OpenMMLab. All rights reserved.
from io import BytesIO, StringIO
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, TextIO, Union
from ..utils import is_list_of, is_str
from ..utils import is_list_of
from .file_client import FileClient
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
FileLikeObject = Union[TextIO, StringIO, BytesIO]
file_handlers = {
'json': JsonHandler(),
'yaml': YamlHandler(),
......@@ -15,7 +18,10 @@ file_handlers = {
}
def load(file, file_format=None, file_client_args=None, **kwargs):
def load(file: Union[str, Path, FileLikeObject],
file_format: Optional[str] = None,
file_client_args: Optional[Dict] = None,
**kwargs):
"""Load data from json/yaml/pickle files.
This method provides a unified api for loading data from serialized files.
......@@ -45,13 +51,14 @@ def load(file, file_format=None, file_client_args=None, **kwargs):
"""
if isinstance(file, Path):
file = str(file)
if file_format is None and is_str(file):
if file_format is None and isinstance(file, str):
file_format = file.split('.')[-1]
if file_format not in file_handlers:
raise TypeError(f'Unsupported format: {file_format}')
handler = file_handlers[file_format]
if is_str(file):
f: FileLikeObject
if isinstance(file, str):
file_client = FileClient.infer_client(file_client_args, file)
if handler.str_like:
with StringIO(file_client.get_text(file)) as f:
......@@ -66,7 +73,11 @@ def load(file, file_format=None, file_client_args=None, **kwargs):
return obj
def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
def dump(obj: Any,
file: Optional[Union[str, Path, FileLikeObject]] = None,
file_format: Optional[str] = None,
file_client_args: Optional[Dict] = None,
**kwargs):
"""Dump data to json/yaml/pickle strings or files.
This method provides a unified api for dumping data as strings or to files,
......@@ -96,18 +107,18 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
if isinstance(file, Path):
file = str(file)
if file_format is None:
if is_str(file):
if isinstance(file, str):
file_format = file.split('.')[-1]
elif file is None:
raise ValueError(
'file_format must be specified since file is None')
if file_format not in file_handlers:
raise TypeError(f'Unsupported format: {file_format}')
f: FileLikeObject
handler = file_handlers[file_format]
if file is None:
return handler.dump_to_str(obj, **kwargs)
elif is_str(file):
elif isinstance(file, str):
file_client = FileClient.infer_client(file_client_args, file)
if handler.str_like:
with StringIO() as f:
......@@ -123,7 +134,8 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
raise TypeError('"file" must be a filename str or a file-object')
def _register_handler(handler, file_formats):
def _register_handler(handler: BaseFileHandler,
file_formats: Union[str, List[str]]) -> None:
"""Register a handler for some file extensions.
Args:
......@@ -142,7 +154,7 @@ def _register_handler(handler, file_formats):
file_handlers[ext] = handler
def register_handler(file_formats, **kwargs):
def register_handler(file_formats: Union[str, list], **kwargs) -> Callable:
def wrap(cls):
_register_handler(cls(**kwargs), file_formats)
......
# Copyright (c) OpenMMLab. All rights reserved.
from io import StringIO
from pathlib import Path
from typing import Dict, List, Optional, Union
from .file_client import FileClient
def list_from_file(filename,
prefix='',
offset=0,
max_num=0,
encoding='utf-8',
file_client_args=None):
def list_from_file(filename: Union[str, Path],
prefix: str = '',
offset: int = 0,
max_num: int = 0,
encoding: str = 'utf-8',
file_client_args: Optional[Dict] = None) -> List:
"""Load a text file and parse the content as a list of strings.
Note:
......@@ -52,10 +54,10 @@ def list_from_file(filename,
return item_list
def dict_from_file(filename,
key_type=str,
encoding='utf-8',
file_client_args=None):
def dict_from_file(filename: Union[str, Path],
key_type: type = str,
encoding: str = 'utf-8',
file_client_args: Optional[Dict] = None) -> Dict:
"""Load a text file and parse the content as a dict.
Each line of the text file will be two or more columns split by
......
......@@ -9,10 +9,10 @@ from .geometric import (cutout, imcrop, imflip, imflip_, impad,
from .io import imfrombytes, imread, imwrite, supported_backends, use_backend
from .misc import tensor2imgs
from .photometric import (adjust_brightness, adjust_color, adjust_contrast,
adjust_lighting, adjust_sharpness, auto_contrast,
clahe, imdenormalize, imequalize, iminvert,
imnormalize, imnormalize_, lut_transform, posterize,
solarize)
adjust_hue, adjust_lighting, adjust_sharpness,
auto_contrast, clahe, imdenormalize, imequalize,
iminvert, imnormalize, imnormalize_, lut_transform,
posterize, solarize)
__all__ = [
'bgr2gray', 'bgr2hls', 'bgr2hsv', 'bgr2rgb', 'gray2bgr', 'gray2rgb',
......@@ -24,5 +24,6 @@ __all__ = [
'solarize', 'rgb2ycbcr', 'bgr2ycbcr', 'ycbcr2rgb', 'ycbcr2bgr',
'tensor2imgs', 'imshear', 'imtranslate', 'adjust_color', 'imequalize',
'adjust_brightness', 'adjust_contrast', 'lut_transform', 'clahe',
'adjust_sharpness', 'auto_contrast', 'cutout', 'adjust_lighting'
'adjust_sharpness', 'auto_contrast', 'cutout', 'adjust_lighting',
'adjust_hue'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, Union
import cv2
import numpy as np
def imconvert(img, src, dst):
def imconvert(img: np.ndarray, src: str, dst: str) -> np.ndarray:
"""Convert an image from the src colorspace to dst colorspace.
Args:
......@@ -19,7 +21,7 @@ def imconvert(img, src, dst):
return out_img
def bgr2gray(img, keepdim=False):
def bgr2gray(img: np.ndarray, keepdim: bool = False) -> np.ndarray:
"""Convert a BGR image to grayscale image.
Args:
......@@ -36,7 +38,7 @@ def bgr2gray(img, keepdim=False):
return out_img
def rgb2gray(img, keepdim=False):
def rgb2gray(img: np.ndarray, keepdim: bool = False) -> np.ndarray:
"""Convert a RGB image to grayscale image.
Args:
......@@ -53,7 +55,7 @@ def rgb2gray(img, keepdim=False):
return out_img
def gray2bgr(img):
def gray2bgr(img: np.ndarray) -> np.ndarray:
"""Convert a grayscale image to BGR image.
Args:
......@@ -67,7 +69,7 @@ def gray2bgr(img):
return out_img
def gray2rgb(img):
def gray2rgb(img: np.ndarray) -> np.ndarray:
"""Convert a grayscale image to RGB image.
Args:
......@@ -81,7 +83,7 @@ def gray2rgb(img):
return out_img
def _convert_input_type_range(img):
def _convert_input_type_range(img: np.ndarray) -> np.ndarray:
"""Convert the type and range of the input image.
It converts the input image to np.float32 type and range of [0, 1].
......@@ -109,7 +111,8 @@ def _convert_input_type_range(img):
return img
def _convert_output_type_range(img, dst_type):
def _convert_output_type_range(
img: np.ndarray, dst_type: Union[np.uint8, np.float32]) -> np.ndarray:
"""Convert the type and range of the image according to dst_type.
It converts the image to desired type and range. If `dst_type` is np.uint8,
......@@ -140,7 +143,7 @@ def _convert_output_type_range(img, dst_type):
return img.astype(dst_type)
def rgb2ycbcr(img, y_only=False):
def rgb2ycbcr(img: np.ndarray, y_only: bool = False) -> np.ndarray:
"""Convert a RGB image to YCbCr image.
This function produces the same results as Matlab's `rgb2ycbcr` function.
......@@ -160,7 +163,7 @@ def rgb2ycbcr(img, y_only=False):
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img)
......@@ -174,7 +177,7 @@ def rgb2ycbcr(img, y_only=False):
return out_img
def bgr2ycbcr(img, y_only=False):
def bgr2ycbcr(img: np.ndarray, y_only: bool = False) -> np.ndarray:
"""Convert a BGR image to YCbCr image.
The bgr version of rgb2ycbcr.
......@@ -194,7 +197,7 @@ def bgr2ycbcr(img, y_only=False):
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img)
......@@ -208,7 +211,7 @@ def bgr2ycbcr(img, y_only=False):
return out_img
def ycbcr2rgb(img):
def ycbcr2rgb(img: np.ndarray) -> np.ndarray:
"""Convert a YCbCr image to RGB image.
This function produces the same results as Matlab's ycbcr2rgb function.
......@@ -227,7 +230,7 @@ def ycbcr2rgb(img):
Returns:
ndarray: The converted RGB image. The output image has the same type
and range as input image.
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img) * 255
......@@ -240,7 +243,7 @@ def ycbcr2rgb(img):
return out_img
def ycbcr2bgr(img):
def ycbcr2bgr(img: np.ndarray) -> np.ndarray:
"""Convert a YCbCr image to BGR image.
The bgr version of ycbcr2rgb.
......@@ -259,7 +262,7 @@ def ycbcr2bgr(img):
Returns:
ndarray: The converted BGR image. The output image has the same type
and range as input image.
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img) * 255
......@@ -272,11 +275,11 @@ def ycbcr2bgr(img):
return out_img
def convert_color_factory(src, dst):
def convert_color_factory(src: str, dst: str) -> Callable:
code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
def convert_color(img):
def convert_color(img: np.ndarray) -> np.ndarray:
out_img = cv2.cvtColor(img, code)
return out_img
......
......@@ -37,15 +37,27 @@ cv2_interp_codes = {
'lanczos': cv2.INTER_LANCZOS4
}
# Pillow >=v9.1.0 use a slightly different naming scheme for filters.
# Set pillow_interp_codes according to the naming scheme used.
if Image is not None:
pillow_interp_codes = {
'nearest': Image.NEAREST,
'bilinear': Image.BILINEAR,
'bicubic': Image.BICUBIC,
'box': Image.BOX,
'lanczos': Image.LANCZOS,
'hamming': Image.HAMMING
}
if hasattr(Image, 'Resampling'):
pillow_interp_codes = {
'nearest': Image.Resampling.NEAREST,
'bilinear': Image.Resampling.BILINEAR,
'bicubic': Image.Resampling.BICUBIC,
'box': Image.Resampling.BOX,
'lanczos': Image.Resampling.LANCZOS,
'hamming': Image.Resampling.HAMMING
}
else:
pillow_interp_codes = {
'nearest': Image.NEAREST,
'bilinear': Image.BILINEAR,
'bicubic': Image.BICUBIC,
'box': Image.BOX,
'lanczos': Image.LANCZOS,
'hamming': Image.HAMMING
}
def imresize(img,
......@@ -70,7 +82,7 @@ def imresize(img,
Returns:
tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
`resized_img`.
`resized_img`.
"""
h, w = img.shape[:2]
if backend is None:
......@@ -130,7 +142,7 @@ def imresize_to_multiple(img,
Returns:
tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
`resized_img`.
`resized_img`.
"""
h, w = img.shape[:2]
if size is not None and scale_factor is not None:
......@@ -145,7 +157,7 @@ def imresize_to_multiple(img,
size = _scale_size((w, h), scale_factor)
divisor = to_2tuple(divisor)
size = tuple([int(np.ceil(s / d)) * d for s, d in zip(size, divisor)])
size = tuple(int(np.ceil(s / d)) * d for s, d in zip(size, divisor))
resized_img, w_scale, h_scale = imresize(
img,
size,
......@@ -175,7 +187,7 @@ def imresize_like(img,
Returns:
tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or
`resized_img`.
`resized_img`.
"""
h, w = dst_img.shape[:2]
return imresize(img, (w, h), return_scale, interpolation, backend=backend)
......@@ -460,18 +472,17 @@ def impad(img,
areas when padding_mode is 'constant'. Default: 0.
padding_mode (str): Type of padding. Should be: constant, edge,
reflect or symmetric. Default: constant.
- constant: pads with a constant value, this value is specified
with pad_val.
with pad_val.
- edge: pads with the last value at the edge of the image.
- reflect: pads with reflection of image without repeating the
last value on the edge. For example, padding [1, 2, 3, 4]
with 2 elements on both sides in reflect mode will result
in [3, 2, 1, 2, 3, 4, 3, 2].
- symmetric: pads with reflection of image repeating the last
value on the edge. For example, padding [1, 2, 3, 4] with
2 elements on both sides in symmetric mode will result in
[2, 1, 1, 2, 3, 4, 4, 3]
- reflect: pads with reflection of image without repeating the last
value on the edge. For example, padding [1, 2, 3, 4] with 2
elements on both sides in reflect mode will result in
[3, 2, 1, 2, 3, 4, 3, 2].
- symmetric: pads with reflection of image repeating the last value
on the edge. For example, padding [1, 2, 3, 4] with 2 elements on
both sides in symmetric mode will result in
[2, 1, 1, 2, 3, 4, 4, 3]
Returns:
ndarray: The padded image.
......@@ -479,7 +490,9 @@ def impad(img,
assert (shape is not None) ^ (padding is not None)
if shape is not None:
padding = (0, 0, shape[1] - img.shape[1], shape[0] - img.shape[0])
width = max(shape[1] - img.shape[1], 0)
height = max(shape[0] - img.shape[0], 0)
padding = (0, 0, width, height)
# check pad_val
if isinstance(pad_val, tuple):
......
# Copyright (c) OpenMMLab. All rights reserved.
import io
import os.path as osp
import warnings
from pathlib import Path
import cv2
......@@ -8,7 +9,8 @@ import numpy as np
from cv2 import (IMREAD_COLOR, IMREAD_GRAYSCALE, IMREAD_IGNORE_ORIENTATION,
IMREAD_UNCHANGED)
from mmcv.utils import check_file_exist, is_str, mkdir_or_exist
from mmcv.fileio import FileClient
from mmcv.utils import is_filepath, is_str
try:
from turbojpeg import TJCS_RGB, TJPF_BGR, TJPF_GRAY, TurboJPEG
......@@ -137,9 +139,16 @@ def _pillow2array(img, flag='color', channel_order='bgr'):
return array
def imread(img_or_path, flag='color', channel_order='bgr', backend=None):
def imread(img_or_path,
flag='color',
channel_order='bgr',
backend=None,
file_client_args=None):
"""Read an image.
Note:
In v1.4.1 and later, add `file_client_args` parameters.
Args:
img_or_path (ndarray or str or Path): Either a numpy array or str or
pathlib.Path. If it is a numpy array (loaded image), then
......@@ -157,44 +166,42 @@ def imread(img_or_path, flag='color', channel_order='bgr', backend=None):
`cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`.
If backend is None, the global imread_backend specified by
``mmcv.use_backend()`` will be used. Default: None.
file_client_args (dict | None): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
Returns:
ndarray: Loaded image array.
Examples:
>>> import mmcv
>>> img_path = '/path/to/img.jpg'
>>> img = mmcv.imread(img_path)
>>> img = mmcv.imread(img_path, flag='color', channel_order='rgb',
... backend='cv2')
>>> img = mmcv.imread(img_path, flag='color', channel_order='bgr',
... backend='pillow')
>>> s3_img_path = 's3://bucket/img.jpg'
>>> # infer the file backend by the prefix s3
>>> img = mmcv.imread(s3_img_path)
>>> # manually set the file backend petrel
>>> img = mmcv.imread(s3_img_path, file_client_args={
... 'backend': 'petrel'})
>>> http_img_path = 'http://path/to/img.jpg'
>>> img = mmcv.imread(http_img_path)
>>> img = mmcv.imread(http_img_path, file_client_args={
... 'backend': 'http'})
"""
if backend is None:
backend = imread_backend
if backend not in supported_backends:
raise ValueError(f'backend: {backend} is not supported. Supported '
"backends are 'cv2', 'turbojpeg', 'pillow'")
if isinstance(img_or_path, Path):
img_or_path = str(img_or_path)
if isinstance(img_or_path, np.ndarray):
return img_or_path
elif is_str(img_or_path):
check_file_exist(img_or_path,
f'img file does not exist: {img_or_path}')
if backend == 'turbojpeg':
with open(img_or_path, 'rb') as in_file:
img = jpeg.decode(in_file.read(),
_jpegflag(flag, channel_order))
if img.shape[-1] == 1:
img = img[:, :, 0]
return img
elif backend == 'pillow':
img = Image.open(img_or_path)
img = _pillow2array(img, flag, channel_order)
return img
elif backend == 'tifffile':
img = tifffile.imread(img_or_path)
return img
else:
flag = imread_flags[flag] if is_str(flag) else flag
img = cv2.imread(img_or_path, flag)
if flag == IMREAD_COLOR and channel_order == 'rgb':
cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
return img
file_client = FileClient.infer_client(file_client_args, img_or_path)
img_bytes = file_client.get(img_or_path)
return imfrombytes(img_bytes, flag, channel_order, backend)
else:
raise TypeError('"img" must be a numpy array or a str or '
'a pathlib.Path object')
......@@ -206,29 +213,45 @@ def imfrombytes(content, flag='color', channel_order='bgr', backend=None):
Args:
content (bytes): Image bytes got from files or other streams.
flag (str): Same as :func:`imread`.
channel_order (str): The channel order of the output, candidates
are 'bgr' and 'rgb'. Default to 'bgr'.
backend (str | None): The image decoding backend type. Options are
`cv2`, `pillow`, `turbojpeg`, `None`. If backend is None, the
global imread_backend specified by ``mmcv.use_backend()`` will be
used. Default: None.
`cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`. If backend is
None, the global imread_backend specified by ``mmcv.use_backend()``
will be used. Default: None.
Returns:
ndarray: Loaded image array.
Examples:
>>> img_path = '/path/to/img.jpg'
>>> with open(img_path, 'rb') as f:
>>> img_buff = f.read()
>>> img = mmcv.imfrombytes(img_buff)
>>> img = mmcv.imfrombytes(img_buff, flag='color', channel_order='rgb')
>>> img = mmcv.imfrombytes(img_buff, backend='pillow')
>>> img = mmcv.imfrombytes(img_buff, backend='cv2')
"""
if backend is None:
backend = imread_backend
if backend not in supported_backends:
raise ValueError(f'backend: {backend} is not supported. Supported '
"backends are 'cv2', 'turbojpeg', 'pillow'")
raise ValueError(
f'backend: {backend} is not supported. Supported '
"backends are 'cv2', 'turbojpeg', 'pillow', 'tifffile'")
if backend == 'turbojpeg':
img = jpeg.decode(content, _jpegflag(flag, channel_order))
if img.shape[-1] == 1:
img = img[:, :, 0]
return img
elif backend == 'pillow':
buff = io.BytesIO(content)
img = Image.open(buff)
img = _pillow2array(img, flag, channel_order)
with io.BytesIO(content) as buff:
img = Image.open(buff)
img = _pillow2array(img, flag, channel_order)
return img
elif backend == 'tifffile':
with io.BytesIO(content) as buff:
img = tifffile.imread(buff)
return img
else:
img_np = np.frombuffer(content, np.uint8)
......@@ -239,20 +262,53 @@ def imfrombytes(content, flag='color', channel_order='bgr', backend=None):
return img
def imwrite(img, file_path, params=None, auto_mkdir=True):
def imwrite(img,
file_path,
params=None,
auto_mkdir=None,
file_client_args=None):
"""Write image to file.
Note:
In v1.4.1 and later, add `file_client_args` parameters.
Warning:
The parameter `auto_mkdir` will be deprecated in the future and every
file clients will make directory automatically.
Args:
img (ndarray): Image array to be written.
file_path (str): Image file path.
params (None or list): Same as opencv :func:`imwrite` interface.
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
whether to create it automatically.
whether to create it automatically. It will be deprecated.
file_client_args (dict | None): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
Returns:
bool: Successful or not.
Examples:
>>> # write to hard disk client
>>> ret = mmcv.imwrite(img, '/path/to/img.jpg')
>>> # infer the file backend by the prefix s3
>>> ret = mmcv.imwrite(img, 's3://bucket/img.jpg')
>>> # manually set the file backend petrel
>>> ret = mmcv.imwrite(img, 's3://bucket/img.jpg', file_client_args={
... 'backend': 'petrel'})
"""
if auto_mkdir:
dir_name = osp.abspath(osp.dirname(file_path))
mkdir_or_exist(dir_name)
return cv2.imwrite(file_path, img, params)
assert is_filepath(file_path)
file_path = str(file_path)
if auto_mkdir is not None:
warnings.warn(
'The parameter `auto_mkdir` will be deprecated in the future and '
'every file clients will make directory automatically.')
file_client = FileClient.infer_client(file_client_args, file_path)
img_ext = osp.splitext(file_path)[-1]
# Encode image according to image suffix.
# For example, if image path is '/path/your/img.jpg', the encode
# format is '.jpg'.
flag, img_buff = cv2.imencode(img_ext, img, params)
file_client.put(img_buff.tobytes(), file_path)
return flag
......@@ -9,18 +9,21 @@ except ImportError:
torch = None
def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True):
"""Convert tensor to 3-channel images.
def tensor2imgs(tensor, mean=None, std=None, to_rgb=True):
"""Convert tensor to 3-channel images or 1-channel gray images.
Args:
tensor (torch.Tensor): Tensor that contains multiple images, shape (
N, C, H, W).
mean (tuple[float], optional): Mean of images. Defaults to (0, 0, 0).
std (tuple[float], optional): Standard deviation of images.
Defaults to (1, 1, 1).
N, C, H, W). :math:`C` can be either 3 or 1.
mean (tuple[float], optional): Mean of images. If None,
(0, 0, 0) will be used for tensor with 3-channel,
while (0, ) for tensor with 1-channel. Defaults to None.
std (tuple[float], optional): Standard deviation of images. If None,
(1, 1, 1) will be used for tensor with 3-channel,
while (1, ) for tensor with 1-channel. Defaults to None.
to_rgb (bool, optional): Whether the tensor was converted to RGB
format in the first place. If so, convert it back to BGR.
Defaults to True.
For the tensor with 1 channel, it must be False. Defaults to True.
Returns:
list[np.ndarray]: A list that contains multiple images.
......@@ -29,8 +32,14 @@ def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True):
if torch is None:
raise RuntimeError('pytorch is not installed')
assert torch.is_tensor(tensor) and tensor.ndim == 4
assert len(mean) == 3
assert len(std) == 3
channels = tensor.size(1)
assert channels in [1, 3]
if mean is None:
mean = (0, ) * channels
if std is None:
std = (1, ) * channels
assert (channels == len(mean) == len(std) == 3) or \
(channels == len(mean) == len(std) == 1 and not to_rgb)
num_imgs = tensor.size(0)
mean = np.array(mean, dtype=np.float32)
......
......@@ -426,3 +426,46 @@ def clahe(img, clip_limit=40.0, tile_grid_size=(8, 8)):
clahe = cv2.createCLAHE(clip_limit, tile_grid_size)
return clahe.apply(np.array(img, dtype=np.uint8))
def adjust_hue(img: np.ndarray, hue_factor: float) -> np.ndarray:
"""Adjust hue of an image.
The image hue is adjusted by converting the image to HSV and cyclically
shifting the intensities in the hue channel (H). The image is then
converted back to original image mode.
`hue_factor` is the amount of shift in H channel and must be in the
interval `[-0.5, 0.5]`.
Modified from
https://github.com/pytorch/vision/blob/main/torchvision/
transforms/functional.py
Args:
img (ndarray): Image to be adjusted.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
with complementary colors while 0 gives the original image.
Returns:
ndarray: Hue adjusted image.
"""
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f'hue_factor:{hue_factor} is not in [-0.5, 0.5].')
if not (isinstance(img, np.ndarray) and (img.ndim in {2, 3})):
raise TypeError('img should be ndarray with dim=[2 or 3].')
dtype = img.dtype
img = img.astype(np.uint8)
hsv_img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV_FULL)
h, s, v = cv2.split(hsv_img)
h = h.astype(np.uint8)
# uint8 addition take cares of rotation across boundaries
with np.errstate(over='ignore'):
h += np.uint8(hue_factor * 255)
hsv_img = cv2.merge([h, s, v])
return cv2.cvtColor(hsv_img, cv2.COLOR_HSV2RGB_FULL).astype(dtype)
{
"alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
"densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth",
"densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth",
"densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth",
"densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth",
"efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
"efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
"efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
"efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
"efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
"efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
"efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
"efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
"googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth",
"inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
"mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
"mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
"mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
"regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
"regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
"regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
"regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
"regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
"regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
"regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
"regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
"regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
"regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
"regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
"regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
"regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
"regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
"resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
"resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
"resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
"resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
"resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
"resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
"resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
"wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
"wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
"shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
"shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
"shufflenetv2_x1.5": null,
"shufflenetv2_x2.0": null,
"squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
"squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
"vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth",
"vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth",
"vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
"vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
"vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth"
}
# Copyright (c) OpenMMLab. All rights reserved.
import os
import warnings
import torch
def is_custom_op_loaded():
def is_custom_op_loaded() -> bool:
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
msg = white_background + bright_style + red_text
msg += 'DeprecationWarning: This function will be deprecated in future. '
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg += reset_style
warnings.warn(msg)
flag = False
try:
from ..tensorrt import is_tensorrt_plugin_loaded
......
......@@ -59,7 +59,7 @@ def _parse_arg(value, desc):
raise RuntimeError(
"ONNX symbolic doesn't know to interpret ListConstruct node")
raise RuntimeError('Unexpected node type: {}'.format(value.node().kind()))
raise RuntimeError(f'Unexpected node type: {value.node().kind()}')
def _maybe_get_const(value, desc):
......@@ -328,4 +328,4 @@ cast_pytorch_to_onnx = {
# Global set to store the list of quantized operators in the network.
# This is currently only used in the conversion of quantized ops from PT
# -> C2 via ONNX.
_quantized_ops = set()
_quantized_ops: set = set()
# Copyright (c) OpenMMLab. All rights reserved.
"""Modified from https://github.com/pytorch/pytorch."""
import os
import warnings
import numpy as np
import torch
......@@ -409,8 +410,8 @@ def cummin(g, input, dim):
@parse_args('v', 'v', 'is')
def roll(g, input, shifts, dims):
from torch.onnx.symbolic_opset9 import squeeze
from packaging import version
from torch.onnx.symbolic_opset9 import squeeze
input_shape = g.op('Shape', input)
need_flatten = len(dims) == 0
......@@ -467,6 +468,18 @@ def roll(g, input, shifts, dims):
def register_extra_symbolics(opset=11):
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
msg = white_background + bright_style + red_text
msg += 'DeprecationWarning: This function will be deprecated in future. '
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg += reset_style
warnings.warn(msg)
register_op('one_hot', one_hot, '', opset)
register_op('im2col', im2col, '', opset)
register_op('topk', topk, '', opset)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment