Commit 0fd8347d authored by unknown's avatar unknown
Browse files

添加mmclassification-0.24.1代码,删除mmclassification-speed-benchmark

parent cc567e9e
import os # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Union
import numpy as np
from .base_dataset import BaseDataset
from .builder import DATASETS from .builder import DATASETS
from .custom import CustomDataset
def has_file_allowed_extension(filename, extensions): @DATASETS.register_module()
"""Checks if a file is an allowed extension. class ImageNet(CustomDataset):
"""`ImageNet <http://www.image-net.org>`_ Dataset.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
filename_lower = filename.lower()
return any(filename_lower.endswith(ext) for ext in extensions)
def find_folders(root):
"""Find classes by folders under a root.
Args:
root (string): root directory of folders
Returns:
folder_to_idx (dict): the map from folder name to class idx
"""
folders = [
d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))
]
folders.sort()
folder_to_idx = {folders[i]: i for i in range(len(folders))}
return folder_to_idx
def get_samples(root, folder_to_idx, extensions): The dataset supports two kinds of annotation format. More details can be
"""Make dataset by walking all images under a root. found in :class:`CustomDataset`.
Args: Args:
root (string): root directory of folders data_prefix (str): The path of data directory.
folder_to_idx (dict): the map from class name to class idx pipeline (Sequence[dict]): A list of dict, where each element
extensions (tuple): allowed extensions represents a operation defined in :mod:`mmcls.datasets.pipelines`.
Defaults to an empty tuple.
Returns: classes (str | Sequence[str], optional): Specify names of classes.
samples (list): a list of tuple where each element is (image, label)
"""
samples = []
root = os.path.expanduser(root)
for folder_name in sorted(os.listdir(root)):
_dir = os.path.join(root, folder_name)
if not os.path.isdir(_dir):
continue
for _, _, fns in sorted(os.walk(_dir)):
for fn in sorted(fns):
if has_file_allowed_extension(fn, extensions):
path = os.path.join(folder_name, fn)
item = (path, folder_to_idx[folder_name])
samples.append(item)
return samples
- If is string, it should be a file path, and the every line of
the file is a name of a class.
- If is a sequence of string, every item is a name of class.
- If is None, use the default ImageNet-1k classes names.
@DATASETS.register_module() Defaults to None.
class ImageNet(BaseDataset): ann_file (str, optional): The annotation file. If is string, read
"""`ImageNet <http://www.image-net.org>`_ Dataset. samples paths from the ann_file. If is None, find samples in
``data_prefix``. Defaults to None.
This implementation is modified from extensions (Sequence[str]): A sequence of allowed extensions. Defaults
https://github.com/pytorch/vision/blob/master/torchvision/datasets/imagenet.py # noqa: E501 to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif').
""" test_mode (bool): In train mode or test mode. It's only a mark and
won't be used in this class. Defaults to False.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
If None, automatically inference from the specified path.
Defaults to None.
""" # noqa: E501
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif') IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
CLASSES = [ CLASSES = [
...@@ -1075,31 +1042,18 @@ class ImageNet(BaseDataset): ...@@ -1075,31 +1042,18 @@ class ImageNet(BaseDataset):
'toilet tissue, toilet paper, bathroom tissue' 'toilet tissue, toilet paper, bathroom tissue'
] ]
def load_annotations(self): def __init__(self,
if self.ann_file is None: data_prefix: str,
folder_to_idx = find_folders(self.data_prefix) pipeline: Sequence = (),
samples = get_samples( classes: Union[str, Sequence[str], None] = None,
self.data_prefix, ann_file: Optional[str] = None,
folder_to_idx, test_mode: bool = False,
extensions=self.IMG_EXTENSIONS) file_client_args: Optional[dict] = None):
if len(samples) == 0: super().__init__(
raise (RuntimeError('Found 0 files in subfolders of: ' data_prefix=data_prefix,
f'{self.data_prefix}. ' pipeline=pipeline,
'Supported extensions are: ' classes=classes,
f'{",".join(self.IMG_EXTENSIONS)}')) ann_file=ann_file,
extensions=self.IMG_EXTENSIONS,
self.folder_to_idx = folder_to_idx test_mode=test_mode,
elif isinstance(self.ann_file, str): file_client_args=file_client_args)
with open(self.ann_file) as f:
samples = [x.strip().split(' ') for x in f.readlines()]
else:
raise TypeError('ann_file must be a str or None')
self.samples = samples
data_infos = []
for filename, gt_label in self.samples:
info = {'img_prefix': self.data_prefix}
info['img_info'] = {'filename': filename}
info['gt_label'] = np.array(gt_label, dtype=np.int64)
data_infos.append(info)
return data_infos
# Copyright (c) OpenMMLab. All rights reserved.
import gc
import pickle
import warnings
from typing import List, Optional, Sequence, Tuple, Union
import numpy as np
from .builder import DATASETS
from .custom import CustomDataset
@DATASETS.register_module()
class ImageNet21k(CustomDataset):
"""ImageNet21k Dataset.
Since the dataset ImageNet21k is extremely big, cantains 21k+ classes
and 1.4B files. This class has improved the following points on the
basis of the class ``ImageNet``, in order to save memory, we enable the
``serialize_data`` optional by default. With this option, the annotation
won't be stored in the list ``data_infos``, but be serialized as an
array.
Args:
data_prefix (str): The path of data directory.
pipeline (Sequence[dict]): A list of dict, where each element
represents a operation defined in :mod:`mmcls.datasets.pipelines`.
Defaults to an empty tuple.
classes (str | Sequence[str], optional): Specify names of classes.
- If is string, it should be a file path, and the every line of
the file is a name of a class.
- If is a sequence of string, every item is a name of class.
- If is None, the object won't have category information.
(Not recommended)
Defaults to None.
ann_file (str, optional): The annotation file. If is string, read
samples paths from the ann_file. If is None, find samples in
``data_prefix``. Defaults to None.
serialize_data (bool): Whether to hold memory using serialized objects,
when enabled, data loader workers can use shared RAM from master
process instead of making a copy. Defaults to True.
multi_label (bool): Not implement by now. Use multi label or not.
Defaults to False.
recursion_subdir(bool): Deprecated, and the dataset will recursively
get all images now.
test_mode (bool): In train mode or test mode. It's only a mark and
won't be used in this class. Defaults to False.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
If None, automatically inference from the specified path.
Defaults to None.
"""
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
CLASSES = None
def __init__(self,
data_prefix: str,
pipeline: Sequence = (),
classes: Union[str, Sequence[str], None] = None,
ann_file: Optional[str] = None,
serialize_data: bool = True,
multi_label: bool = False,
recursion_subdir: bool = True,
test_mode=False,
file_client_args: Optional[dict] = None):
assert recursion_subdir, 'The `recursion_subdir` option is ' \
'deprecated. Now the dataset will recursively get all images.'
if multi_label:
raise NotImplementedError(
'The `multi_label` option is not supported by now.')
self.multi_label = multi_label
self.serialize_data = serialize_data
if ann_file is None:
warnings.warn(
'The ImageNet21k dataset is large, and scanning directory may '
'consume long time. Considering to specify the `ann_file` to '
'accelerate the initialization.', UserWarning)
if classes is None:
warnings.warn(
'The CLASSES is not stored in the `ImageNet21k` class. '
'Considering to specify the `classes` argument if you need '
'do inference on the ImageNet-21k dataset', UserWarning)
super().__init__(
data_prefix=data_prefix,
pipeline=pipeline,
classes=classes,
ann_file=ann_file,
extensions=self.IMG_EXTENSIONS,
test_mode=test_mode,
file_client_args=file_client_args)
if self.serialize_data:
self.data_infos_bytes, self.data_address = self._serialize_data()
# Empty cache for preventing making multiple copies of
# `self.data_infos` when loading data multi-processes.
self.data_infos.clear()
gc.collect()
def get_cat_ids(self, idx: int) -> List[int]:
"""Get category id by index.
Args:
idx (int): Index of data.
Returns:
cat_ids (List[int]): Image category of specified index.
"""
return [int(self.get_data_info(idx)['gt_label'])]
def get_data_info(self, idx: int) -> dict:
"""Get annotation by index.
Args:
idx (int): The index of data.
Returns:
dict: The idx-th annotation of the dataset.
"""
if self.serialize_data:
start_addr = 0 if idx == 0 else self.data_address[idx - 1].item()
end_addr = self.data_address[idx].item()
bytes = memoryview(self.data_infos_bytes[start_addr:end_addr])
data_info = pickle.loads(bytes)
else:
data_info = self.data_infos[idx]
return data_info
def prepare_data(self, idx):
data_info = self.get_data_info(idx)
return self.pipeline(data_info)
def _serialize_data(self) -> Tuple[np.ndarray, np.ndarray]:
"""Serialize ``self.data_infos`` to save memory when launching multiple
workers in data loading. This function will be called in ``full_init``.
Hold memory using serialized objects, and data loader workers can use
shared RAM from master process instead of making a copy.
Returns:
Tuple[np.ndarray, np.ndarray]: serialize result and corresponding
address.
"""
def _serialize(data):
buffer = pickle.dumps(data, protocol=4)
return np.frombuffer(buffer, dtype=np.uint8)
serialized_data_infos_list = [_serialize(x) for x in self.data_infos]
address_list = np.asarray([len(x) for x in serialized_data_infos_list],
dtype=np.int64)
data_address: np.ndarray = np.cumsum(address_list)
serialized_data_infos = np.concatenate(serialized_data_infos_list)
return serialized_data_infos, data_address
def __len__(self) -> int:
"""Get the length of filtered dataset and automatically call
``full_init`` if the dataset has not been fully init.
Returns:
int: The length of filtered dataset.
"""
if self.serialize_data:
return len(self.data_address)
else:
return len(self.data_infos)
# Copyright (c) OpenMMLab. All rights reserved.
import codecs import codecs
import os import os
import os.path as osp import os.path as osp
...@@ -17,8 +18,8 @@ class MNIST(BaseDataset): ...@@ -17,8 +18,8 @@ class MNIST(BaseDataset):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset. """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
This implementation is modified from This implementation is modified from
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py # noqa: E501 https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
""" """ # noqa: E501
resource_prefix = 'http://yann.lecun.com/exdb/mnist/' resource_prefix = 'http://yann.lecun.com/exdb/mnist/'
resources = { resources = {
......
import warnings # Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import numpy as np import numpy as np
...@@ -9,25 +10,25 @@ from .base_dataset import BaseDataset ...@@ -9,25 +10,25 @@ from .base_dataset import BaseDataset
class MultiLabelDataset(BaseDataset): class MultiLabelDataset(BaseDataset):
"""Multi-label Dataset.""" """Multi-label Dataset."""
def get_cat_ids(self, idx): def get_cat_ids(self, idx: int) -> List[int]:
"""Get category ids by index. """Get category ids by index.
Args: Args:
idx (int): Index of data. idx (int): Index of data.
Returns: Returns:
np.ndarray: Image categories of specified index. cat_ids (List[int]): Image categories of specified index.
""" """
gt_labels = self.data_infos[idx]['gt_label'] gt_labels = self.data_infos[idx]['gt_label']
cat_ids = np.where(gt_labels == 1)[0] cat_ids = np.where(gt_labels == 1)[0].tolist()
return cat_ids return cat_ids
def evaluate(self, def evaluate(self,
results, results,
metric='mAP', metric='mAP',
metric_options=None, metric_options=None,
logger=None, indices=None,
**deprecated_kwargs): logger=None):
"""Evaluate the dataset. """Evaluate the dataset.
Args: Args:
...@@ -39,19 +40,13 @@ class MultiLabelDataset(BaseDataset): ...@@ -39,19 +40,13 @@ class MultiLabelDataset(BaseDataset):
Allowed keys are 'k' and 'thr'. Defaults to None Allowed keys are 'k' and 'thr'. Defaults to None
logger (logging.Logger | str, optional): Logger used for printing logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Defaults to None. related information during evaluation. Defaults to None.
deprecated_kwargs (dict): Used for containing deprecated arguments.
Returns: Returns:
dict: evaluation results dict: evaluation results
""" """
if metric_options is None: if metric_options is None or metric_options == {}:
metric_options = {'thr': 0.5} metric_options = {'thr': 0.5}
if deprecated_kwargs != {}:
warnings.warn('Option arguments for metrics has been changed to '
'`metric_options`.')
metric_options = {**deprecated_kwargs}
if isinstance(metric, str): if isinstance(metric, str):
metrics = [metric] metrics = [metric]
else: else:
...@@ -60,6 +55,8 @@ class MultiLabelDataset(BaseDataset): ...@@ -60,6 +55,8 @@ class MultiLabelDataset(BaseDataset):
eval_results = {} eval_results = {}
results = np.vstack(results) results = np.vstack(results)
gt_labels = self.get_gt_labels() gt_labels = self.get_gt_labels()
if indices is not None:
gt_labels = gt_labels[indices]
num_imgs = len(results) num_imgs = len(results)
assert len(gt_labels) == num_imgs, 'dataset testing results should '\ assert len(gt_labels) == num_imgs, 'dataset testing results should '\
'be of the same length as gt_labels.' 'be of the same length as gt_labels.'
......
# Copyright (c) OpenMMLab. All rights reserved.
from .auto_augment import (AutoAugment, AutoContrast, Brightness,
ColorTransform, Contrast, Cutout, Equalize, Invert,
Posterize, RandAugment, Rotate, Sharpness, Shear,
Solarize, SolarizeAdd, Translate)
from .compose import Compose
from .formatting import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor,
Transpose, to_tensor)
from .loading import LoadImageFromFile
from .transforms import (CenterCrop, ColorJitter, Lighting, Normalize, Pad,
RandomCrop, RandomErasing, RandomFlip,
RandomGrayscale, RandomResizedCrop, Resize)
__all__ = [
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy',
'Transpose', 'Collect', 'LoadImageFromFile', 'Resize', 'CenterCrop',
'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop',
'RandomGrayscale', 'Shear', 'Translate', 'Rotate', 'Invert',
'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize',
'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd',
'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing', 'Pad'
]
# Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
import inspect
import random import random
from math import ceil
from numbers import Number from numbers import Number
from typing import Sequence from typing import Sequence
...@@ -9,18 +12,43 @@ import numpy as np ...@@ -9,18 +12,43 @@ import numpy as np
from ..builder import PIPELINES from ..builder import PIPELINES
from .compose import Compose from .compose import Compose
# Default hyperparameters for all Ops
_HPARAMS_DEFAULT = dict(pad_val=128)
def random_negative(value, random_negative_prob): def random_negative(value, random_negative_prob):
"""Randomly negate value based on random_negative_prob.""" """Randomly negate value based on random_negative_prob."""
return -value if np.random.rand() < random_negative_prob else value return -value if np.random.rand() < random_negative_prob else value
def merge_hparams(policy: dict, hparams: dict):
"""Merge hyperparameters into policy config.
Only merge partial hyperparameters required of the policy.
Args:
policy (dict): Original policy config dict.
hparams (dict): Hyperparameters need to be merged.
Returns:
dict: Policy config dict after adding ``hparams``.
"""
op = PIPELINES.get(policy['type'])
assert op is not None, f'Invalid policy type "{policy["type"]}".'
for key, value in hparams.items():
if policy.get(key, None) is not None:
continue
if key in inspect.getfullargspec(op.__init__).args:
policy[key] = value
return policy
@PIPELINES.register_module() @PIPELINES.register_module()
class AutoAugment(object): class AutoAugment(object):
"""Auto augmentation. This data augmentation is proposed in `AutoAugment: """Auto augmentation.
Learning Augmentation Policies from Data.
<https://arxiv.org/abs/1805.09501>`_. This data augmentation is proposed in `AutoAugment: Learning Augmentation
Policies from Data <https://arxiv.org/abs/1805.09501>`_.
Args: Args:
policies (list[list[dict]]): The policies of auto augmentation. Each policies (list[list[dict]]): The policies of auto augmentation. Each
...@@ -28,9 +56,12 @@ class AutoAugment(object): ...@@ -28,9 +56,12 @@ class AutoAugment(object):
composed by several augmentations (dict). When AutoAugment is composed by several augmentations (dict). When AutoAugment is
called, a random policy in ``policies`` will be selected to called, a random policy in ``policies`` will be selected to
augment images. augment images.
hparams (dict): Configs of hyperparameters. Hyperparameters will be
used in policies that require these arguments if these arguments
are not set in policy dicts. Defaults to use _HPARAMS_DEFAULT.
""" """
def __init__(self, policies): def __init__(self, policies, hparams=_HPARAMS_DEFAULT):
assert isinstance(policies, list) and len(policies) > 0, \ assert isinstance(policies, list) and len(policies) > 0, \
'Policies must be a non-empty list.' 'Policies must be a non-empty list.'
for policy in policies: for policy in policies:
...@@ -41,7 +72,13 @@ class AutoAugment(object): ...@@ -41,7 +72,13 @@ class AutoAugment(object):
'Each specific augmentation must be a dict with key' \ 'Each specific augmentation must be a dict with key' \
' "type".' ' "type".'
self.policies = copy.deepcopy(policies) self.hparams = hparams
policies = copy.deepcopy(policies)
self.policies = []
for sub in policies:
merged_sub = [merge_hparams(policy, hparams) for policy in sub]
self.policies.append(merged_sub)
self.sub_policy = [Compose(policy) for policy in self.policies] self.sub_policy = [Compose(policy) for policy in self.policies]
def __call__(self, results): def __call__(self, results):
...@@ -56,9 +93,10 @@ class AutoAugment(object): ...@@ -56,9 +93,10 @@ class AutoAugment(object):
@PIPELINES.register_module() @PIPELINES.register_module()
class RandAugment(object): class RandAugment(object):
"""Random augmentation. This data augmentation is proposed in `RandAugment: r"""Random augmentation.
Practical automated data augmentation with a reduced search space.
This data augmentation is proposed in `RandAugment: Practical automated
data augmentation with a reduced search space
<https://arxiv.org/abs/1909.13719>`_. <https://arxiv.org/abs/1909.13719>`_.
Args: Args:
...@@ -78,19 +116,26 @@ class RandAugment(object): ...@@ -78,19 +116,26 @@ class RandAugment(object):
total_level (int | float): Total level for the magnitude. Defaults to total_level (int | float): Total level for the magnitude. Defaults to
30. 30.
magnitude_std (Number | str): Deviation of magnitude noise applied. magnitude_std (Number | str): Deviation of magnitude noise applied.
If positive number, magnitude is sampled from normal distribution
(mean=magnitude, std=magnitude_std). - If positive number, magnitude is sampled from normal distribution
If 0 or negative number, magnitude remains unchanged. (mean=magnitude, std=magnitude_std).
If str "inf", magnitude is sampled from uniform distribution - If 0 or negative number, magnitude remains unchanged.
(range=[min, magnitude]). - If str "inf", magnitude is sampled from uniform distribution
(range=[min, magnitude]).
hparams (dict): Configs of hyperparameters. Hyperparameters will be
used in policies that require these arguments if these arguments
are not set in policy dicts. Defaults to use _HPARAMS_DEFAULT.
Note: Note:
`magnitude_std` will introduce some randomness to policy, modified by `magnitude_std` will introduce some randomness to policy, modified by
https://github.com/rwightman/pytorch-image-models https://github.com/rwightman/pytorch-image-models.
When magnitude_std=0, we calculate the magnitude as follows: When magnitude_std=0, we calculate the magnitude as follows:
.. math:: .. math::
magnitude = magnitude_level / total_level * (val2 - val1) + val1 \text{magnitude} = \frac{\text{magnitude_level}}
{\text{totallevel}} \times (\text{val2} - \text{val1})
+ \text{val1}
""" """
def __init__(self, def __init__(self,
...@@ -98,7 +143,8 @@ class RandAugment(object): ...@@ -98,7 +143,8 @@ class RandAugment(object):
num_policies, num_policies,
magnitude_level, magnitude_level,
magnitude_std=0., magnitude_std=0.,
total_level=30): total_level=30,
hparams=_HPARAMS_DEFAULT):
assert isinstance(num_policies, int), 'Number of policies must be ' \ assert isinstance(num_policies, int), 'Number of policies must be ' \
f'of int type, got {type(num_policies)} instead.' f'of int type, got {type(num_policies)} instead.'
assert isinstance(magnitude_level, (int, float)), \ assert isinstance(magnitude_level, (int, float)), \
...@@ -125,8 +171,10 @@ class RandAugment(object): ...@@ -125,8 +171,10 @@ class RandAugment(object):
self.magnitude_level = magnitude_level self.magnitude_level = magnitude_level
self.magnitude_std = magnitude_std self.magnitude_std = magnitude_std
self.total_level = total_level self.total_level = total_level
self.policies = policies self.hparams = hparams
self._check_policies(self.policies) policies = copy.deepcopy(policies)
self._check_policies(policies)
self.policies = [merge_hparams(policy, hparams) for policy in policies]
def _check_policies(self, policies): def _check_policies(self, policies):
for policy in policies: for policy in policies:
...@@ -190,8 +238,8 @@ class Shear(object): ...@@ -190,8 +238,8 @@ class Shear(object):
Args: Args:
magnitude (int | float): The magnitude used for shear. magnitude (int | float): The magnitude used for shear.
pad_val (int, tuple[int]): Pixel pad_val value for constant fill. If a pad_val (int, Sequence[int]): Pixel pad_val value for constant fill.
tuple of length 3, it is used to pad_val R, G, B channels If a sequence of length 3, it is used to pad_val R, G, B channels
respectively. Defaults to 128. respectively. Defaults to 128.
prob (float): The probability for performing Shear therefore should be prob (float): The probability for performing Shear therefore should be
in range [0, 1]. Defaults to 0.5. in range [0, 1]. Defaults to 0.5.
...@@ -214,7 +262,7 @@ class Shear(object): ...@@ -214,7 +262,7 @@ class Shear(object):
f'be int or float, but got {type(magnitude)} instead.' f'be int or float, but got {type(magnitude)} instead.'
if isinstance(pad_val, int): if isinstance(pad_val, int):
pad_val = tuple([pad_val] * 3) pad_val = tuple([pad_val] * 3)
elif isinstance(pad_val, tuple): elif isinstance(pad_val, Sequence):
assert len(pad_val) == 3, 'pad_val as a tuple must have 3 ' \ assert len(pad_val) == 3, 'pad_val as a tuple must have 3 ' \
f'elements, got {len(pad_val)} instead.' f'elements, got {len(pad_val)} instead.'
assert all(isinstance(i, int) for i in pad_val), 'pad_val as a '\ assert all(isinstance(i, int) for i in pad_val), 'pad_val as a '\
...@@ -229,7 +277,7 @@ class Shear(object): ...@@ -229,7 +277,7 @@ class Shear(object):
f'should be in range [0,1], got {random_negative_prob} instead.' f'should be in range [0,1], got {random_negative_prob} instead.'
self.magnitude = magnitude self.magnitude = magnitude
self.pad_val = pad_val self.pad_val = tuple(pad_val)
self.prob = prob self.prob = prob
self.direction = direction self.direction = direction
self.random_negative_prob = random_negative_prob self.random_negative_prob = random_negative_prob
...@@ -269,9 +317,9 @@ class Translate(object): ...@@ -269,9 +317,9 @@ class Translate(object):
magnitude (int | float): The magnitude used for translate. Note that magnitude (int | float): The magnitude used for translate. Note that
the offset is calculated by magnitude * size in the corresponding the offset is calculated by magnitude * size in the corresponding
direction. With a magnitude of 1, the whole image will be moved out direction. With a magnitude of 1, the whole image will be moved out
of the range. of the range.
pad_val (int, tuple[int]): Pixel pad_val value for constant fill. If a pad_val (int, Sequence[int]): Pixel pad_val value for constant fill.
tuple of length 3, it is used to pad_val R, G, B channels If a sequence of length 3, it is used to pad_val R, G, B channels
respectively. Defaults to 128. respectively. Defaults to 128.
prob (float): The probability for performing translate therefore should prob (float): The probability for performing translate therefore should
be in range [0, 1]. Defaults to 0.5. be in range [0, 1]. Defaults to 0.5.
...@@ -294,7 +342,7 @@ class Translate(object): ...@@ -294,7 +342,7 @@ class Translate(object):
f'be int or float, but got {type(magnitude)} instead.' f'be int or float, but got {type(magnitude)} instead.'
if isinstance(pad_val, int): if isinstance(pad_val, int):
pad_val = tuple([pad_val] * 3) pad_val = tuple([pad_val] * 3)
elif isinstance(pad_val, tuple): elif isinstance(pad_val, Sequence):
assert len(pad_val) == 3, 'pad_val as a tuple must have 3 ' \ assert len(pad_val) == 3, 'pad_val as a tuple must have 3 ' \
f'elements, got {len(pad_val)} instead.' f'elements, got {len(pad_val)} instead.'
assert all(isinstance(i, int) for i in pad_val), 'pad_val as a '\ assert all(isinstance(i, int) for i in pad_val), 'pad_val as a '\
...@@ -309,7 +357,7 @@ class Translate(object): ...@@ -309,7 +357,7 @@ class Translate(object):
f'should be in range [0,1], got {random_negative_prob} instead.' f'should be in range [0,1], got {random_negative_prob} instead.'
self.magnitude = magnitude self.magnitude = magnitude
self.pad_val = pad_val self.pad_val = tuple(pad_val)
self.prob = prob self.prob = prob
self.direction = direction self.direction = direction
self.random_negative_prob = random_negative_prob self.random_negative_prob = random_negative_prob
...@@ -354,11 +402,11 @@ class Rotate(object): ...@@ -354,11 +402,11 @@ class Rotate(object):
angle (float): The angle used for rotate. Positive values stand for angle (float): The angle used for rotate. Positive values stand for
clockwise rotation. clockwise rotation.
center (tuple[float], optional): Center point (w, h) of the rotation in center (tuple[float], optional): Center point (w, h) of the rotation in
the source image. If None, the center of the image will be used. the source image. If None, the center of the image will be used.
defaults to None. Defaults to None.
scale (float): Isotropic scale factor. Defaults to 1.0. scale (float): Isotropic scale factor. Defaults to 1.0.
pad_val (int, tuple[int]): Pixel pad_val value for constant fill. If a pad_val (int, Sequence[int]): Pixel pad_val value for constant fill.
tuple of length 3, it is used to pad_val R, G, B channels If a sequence of length 3, it is used to pad_val R, G, B channels
respectively. Defaults to 128. respectively. Defaults to 128.
prob (float): The probability for performing Rotate therefore should be prob (float): The probability for performing Rotate therefore should be
in range [0, 1]. Defaults to 0.5. in range [0, 1]. Defaults to 0.5.
...@@ -388,7 +436,7 @@ class Rotate(object): ...@@ -388,7 +436,7 @@ class Rotate(object):
f'got {type(scale)} instead.' f'got {type(scale)} instead.'
if isinstance(pad_val, int): if isinstance(pad_val, int):
pad_val = tuple([pad_val] * 3) pad_val = tuple([pad_val] * 3)
elif isinstance(pad_val, tuple): elif isinstance(pad_val, Sequence):
assert len(pad_val) == 3, 'pad_val as a tuple must have 3 ' \ assert len(pad_val) == 3, 'pad_val as a tuple must have 3 ' \
f'elements, got {len(pad_val)} instead.' f'elements, got {len(pad_val)} instead.'
assert all(isinstance(i, int) for i in pad_val), 'pad_val as a '\ assert all(isinstance(i, int) for i in pad_val), 'pad_val as a '\
...@@ -403,7 +451,7 @@ class Rotate(object): ...@@ -403,7 +451,7 @@ class Rotate(object):
self.angle = angle self.angle = angle
self.center = center self.center = center
self.scale = scale self.scale = scale
self.pad_val = pad_val self.pad_val = tuple(pad_val)
self.prob = prob self.prob = prob
self.random_negative_prob = random_negative_prob self.random_negative_prob = random_negative_prob
self.interpolation = interpolation self.interpolation = interpolation
...@@ -621,7 +669,8 @@ class Posterize(object): ...@@ -621,7 +669,8 @@ class Posterize(object):
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \ assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
f'got {prob} instead.' f'got {prob} instead.'
self.bits = int(bits) # To align timm version, we need to round up to integer here.
self.bits = ceil(bits)
self.prob = prob self.prob = prob
def __call__(self, results): def __call__(self, results):
...@@ -692,7 +741,7 @@ class ColorTransform(object): ...@@ -692,7 +741,7 @@ class ColorTransform(object):
Args: Args:
magnitude (int | float): The magnitude used for color transform. A magnitude (int | float): The magnitude used for color transform. A
positive magnitude would enhance the color and a negative magnitude positive magnitude would enhance the color and a negative magnitude
would make the image grayer. A magnitude=0 gives the origin img. would make the image grayer. A magnitude=0 gives the origin img.
prob (float): The probability for performing ColorTransform therefore prob (float): The probability for performing ColorTransform therefore
should be in range [0, 1]. Defaults to 0.5. should be in range [0, 1]. Defaults to 0.5.
random_negative_prob (float): The probability that turns the magnitude random_negative_prob (float): The probability that turns the magnitude
...@@ -827,8 +876,8 @@ class Cutout(object): ...@@ -827,8 +876,8 @@ class Cutout(object):
shape (int | float | tuple(int | float)): Expected cutout shape (h, w). shape (int | float | tuple(int | float)): Expected cutout shape (h, w).
If given as a single value, the value will be used for If given as a single value, the value will be used for
both h and w. both h and w.
pad_val (int, tuple[int]): Pixel pad_val value for constant fill. If pad_val (int, Sequence[int]): Pixel pad_val value for constant fill.
it is a tuple, it must have the same length with the image If it is a sequence, it must have the same length with the image
channels. Defaults to 128. channels. Defaults to 128.
prob (float): The probability for performing cutout therefore should prob (float): The probability for performing cutout therefore should
be in range [0, 1]. Defaults to 0.5. be in range [0, 1]. Defaults to 0.5.
...@@ -843,11 +892,16 @@ class Cutout(object): ...@@ -843,11 +892,16 @@ class Cutout(object):
raise TypeError( raise TypeError(
'shape must be of ' 'shape must be of '
f'type int, float or tuple, got {type(shape)} instead') f'type int, float or tuple, got {type(shape)} instead')
if isinstance(pad_val, int):
pad_val = tuple([pad_val] * 3)
elif isinstance(pad_val, Sequence):
assert len(pad_val) == 3, 'pad_val as a tuple must have 3 ' \
f'elements, got {len(pad_val)} instead.'
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \ assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
f'got {prob} instead.' f'got {prob} instead.'
self.shape = shape self.shape = shape
self.pad_val = pad_val self.pad_val = tuple(pad_val)
self.prob = prob self.prob = prob
def __call__(self, results): def __call__(self, results):
......
# Copyright (c) OpenMMLab. All rights reserved.
from collections.abc import Sequence from collections.abc import Sequence
from mmcv.utils import build_from_cfg from mmcv.utils import build_from_cfg
......
# Copyright (c) OpenMMLab. All rights reserved.
from collections.abc import Sequence
import mmcv
import numpy as np
import torch
from mmcv.parallel import DataContainer as DC
from PIL import Image
from ..builder import PIPELINES
def to_tensor(data):
"""Convert objects of various python types to :obj:`torch.Tensor`.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
"""
if isinstance(data, torch.Tensor):
return data
elif isinstance(data, np.ndarray):
return torch.from_numpy(data)
elif isinstance(data, Sequence) and not mmcv.is_str(data):
return torch.tensor(data)
elif isinstance(data, int):
return torch.LongTensor([data])
elif isinstance(data, float):
return torch.FloatTensor([data])
else:
raise TypeError(
f'Type {type(data)} cannot be converted to tensor.'
'Supported types are: `numpy.ndarray`, `torch.Tensor`, '
'`Sequence`, `int` and `float`')
@PIPELINES.register_module()
class ToTensor(object):
def __init__(self, keys):
self.keys = keys
def __call__(self, results):
for key in self.keys:
results[key] = to_tensor(results[key])
return results
def __repr__(self):
return self.__class__.__name__ + f'(keys={self.keys})'
@PIPELINES.register_module()
class ImageToTensor(object):
def __init__(self, keys):
self.keys = keys
def __call__(self, results):
for key in self.keys:
img = results[key]
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
results[key] = to_tensor(img.transpose(2, 0, 1))
return results
def __repr__(self):
return self.__class__.__name__ + f'(keys={self.keys})'
@PIPELINES.register_module()
class Transpose(object):
def __init__(self, keys, order):
self.keys = keys
self.order = order
def __call__(self, results):
for key in self.keys:
results[key] = results[key].transpose(self.order)
return results
def __repr__(self):
return self.__class__.__name__ + \
f'(keys={self.keys}, order={self.order})'
@PIPELINES.register_module()
class ToPIL(object):
def __init__(self):
pass
def __call__(self, results):
results['img'] = Image.fromarray(results['img'])
return results
@PIPELINES.register_module()
class ToNumpy(object):
def __init__(self):
pass
def __call__(self, results):
results['img'] = np.array(results['img'], dtype=np.float32)
return results
@PIPELINES.register_module()
class Collect(object):
"""Collect data from the loader relevant to the specific task.
This is usually the last stage of the data loader pipeline. Typically keys
is set to some subset of "img" and "gt_label".
Args:
keys (Sequence[str]): Keys of results to be collected in ``data``.
meta_keys (Sequence[str], optional): Meta keys to be converted to
``mmcv.DataContainer`` and collected in ``data[img_metas]``.
Default: ('filename', 'ori_shape', 'img_shape', 'flip',
'flip_direction', 'img_norm_cfg')
Returns:
dict: The result dict contains the following keys
- keys in ``self.keys``
- ``img_metas`` if available
"""
def __init__(self,
keys,
meta_keys=('filename', 'ori_filename', 'ori_shape',
'img_shape', 'flip', 'flip_direction',
'img_norm_cfg')):
self.keys = keys
self.meta_keys = meta_keys
def __call__(self, results):
data = {}
img_meta = {}
for key in self.meta_keys:
if key in results:
img_meta[key] = results[key]
data['img_metas'] = DC(img_meta, cpu_only=True)
for key in self.keys:
data[key] = results[key]
return data
def __repr__(self):
return self.__class__.__name__ + \
f'(keys={self.keys}, meta_keys={self.meta_keys})'
@PIPELINES.register_module()
class WrapFieldsToLists(object):
"""Wrap fields of the data dictionary into lists for evaluation.
This class can be used as a last step of a test or validation
pipeline for single image evaluation or inference.
Example:
>>> test_pipeline = [
>>> dict(type='LoadImageFromFile'),
>>> dict(type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
>>> dict(type='ImageToTensor', keys=['img']),
>>> dict(type='Collect', keys=['img']),
>>> dict(type='WrapIntoLists')
>>> ]
"""
def __call__(self, results):
# Wrap dict fields into lists
for key, val in results.items():
results[key] = [val]
return results
def __repr__(self):
return f'{self.__class__.__name__}()'
@PIPELINES.register_module()
class ToHalf(object):
def __init__(self, keys):
self.keys = keys
def __call__(self, results):
for k in self.keys:
if isinstance(results[k], torch.Tensor):
results[k] = results[k].to(torch.half)
else:
results[k] = results[k].astype(np.float16)
return results
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import inspect import inspect
import math import math
import random import random
...@@ -36,18 +38,19 @@ class RandomCrop(object): ...@@ -36,18 +38,19 @@ class RandomCrop(object):
pad_val (Number | Sequence[Number]): Pixel pad_val value for constant pad_val (Number | Sequence[Number]): Pixel pad_val value for constant
fill. If a tuple of length 3, it is used to pad_val R, G, B fill. If a tuple of length 3, it is used to pad_val R, G, B
channels respectively. Default: 0. channels respectively. Default: 0.
padding_mode (str): Type of padding. Should be: constant, edge, padding_mode (str): Type of padding. Defaults to "constant". Should
reflect or symmetric. Default: constant. be one of the following:
-constant: Pads with a constant value, this value is specified
- constant: Pads with a constant value, this value is specified \
with pad_val. with pad_val.
-edge: pads with the last value at the edge of the image. - edge: pads with the last value at the edge of the image.
-reflect: Pads with reflection of image without repeating the - reflect: Pads with reflection of image without repeating the \
last value on the edge. For example, padding [1, 2, 3, 4] last value on the edge. For example, padding [1, 2, 3, 4] \
with 2 elements on both sides in reflect mode will result with 2 elements on both sides in reflect mode will result \
in [3, 2, 1, 2, 3, 4, 3, 2]. in [3, 2, 1, 2, 3, 4, 3, 2].
-symmetric: Pads with reflection of image repeating the last - symmetric: Pads with reflection of image repeating the last \
value on the edge. For example, padding [1, 2, 3, 4] with value on the edge. For example, padding [1, 2, 3, 4] with \
2 elements on both sides in symmetric mode will result in 2 elements on both sides in symmetric mode will result in \
[2, 1, 1, 2, 3, 4, 4, 3]. [2, 1, 1, 2, 3, 4, 4, 3].
""" """
...@@ -151,7 +154,7 @@ class RandomResizedCrop(object): ...@@ -151,7 +154,7 @@ class RandomResizedCrop(object):
to the original image. Defaults to (0.08, 1.0). to the original image. Defaults to (0.08, 1.0).
ratio (tuple): Range of the random aspect ratio of the cropped image ratio (tuple): Range of the random aspect ratio of the cropped image
compared to the original image. Defaults to (3. / 4., 4. / 3.). compared to the original image. Defaults to (3. / 4., 4. / 3.).
max_attempts (int): Maxinum number of attempts before falling back to max_attempts (int): Maximum number of attempts before falling back to
Central Crop. Defaults to 10. Central Crop. Defaults to 10.
efficientnet_style (bool): Whether to use efficientnet style Random efficientnet_style (bool): Whether to use efficientnet style Random
ResizedCrop. Defaults to False. ResizedCrop. Defaults to False.
...@@ -163,7 +166,7 @@ class RandomResizedCrop(object): ...@@ -163,7 +166,7 @@ class RandomResizedCrop(object):
interpolation (str): Interpolation method, accepted values are interpolation (str): Interpolation method, accepted values are
'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to 'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to
'bilinear'. 'bilinear'.
backend (str): The image resize backend type, accpeted values are backend (str): The image resize backend type, accepted values are
`cv2` and `pillow`. Defaults to `cv2`. `cv2` and `pillow`. Defaults to `cv2`.
""" """
...@@ -191,7 +194,7 @@ class RandomResizedCrop(object): ...@@ -191,7 +194,7 @@ class RandomResizedCrop(object):
f'But received scale {scale} and rato {ratio}.') f'But received scale {scale} and rato {ratio}.')
assert min_covered >= 0, 'min_covered should be no less than 0.' assert min_covered >= 0, 'min_covered should be no less than 0.'
assert isinstance(max_attempts, int) and max_attempts >= 0, \ assert isinstance(max_attempts, int) and max_attempts >= 0, \
'max_attempts mush be of typle int and no less than 0.' 'max_attempts mush be int and no less than 0.'
assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area', assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area',
'lanczos') 'lanczos')
if backend not in ['cv2', 'pillow']: if backend not in ['cv2', 'pillow']:
...@@ -217,7 +220,7 @@ class RandomResizedCrop(object): ...@@ -217,7 +220,7 @@ class RandomResizedCrop(object):
compared to the original image size. compared to the original image size.
ratio (tuple): Range of the random aspect ratio of the cropped ratio (tuple): Range of the random aspect ratio of the cropped
image compared to the original image area. image compared to the original image area.
max_attempts (int): Maxinum number of attempts before falling back max_attempts (int): Maximum number of attempts before falling back
to central crop. Defaults to 10. to central crop. Defaults to 10.
Returns: Returns:
...@@ -279,7 +282,7 @@ class RandomResizedCrop(object): ...@@ -279,7 +282,7 @@ class RandomResizedCrop(object):
compared to the original image size. compared to the original image size.
ratio (tuple): Range of the random aspect ratio of the cropped ratio (tuple): Range of the random aspect ratio of the cropped
image compared to the original image area. image compared to the original image area.
max_attempts (int): Maxinum number of attempts before falling back max_attempts (int): Maximum number of attempts before falling back
to central crop. Defaults to 10. to central crop. Defaults to 10.
min_covered (Number): Minimum ratio of the cropped area to the min_covered (Number): Minimum ratio of the cropped area to the
original area. Only valid if efficientnet_style is true. original area. Only valid if efficientnet_style is true.
...@@ -311,7 +314,7 @@ class RandomResizedCrop(object): ...@@ -311,7 +314,7 @@ class RandomResizedCrop(object):
max_target_height = min(max_target_height, height) max_target_height = min(max_target_height, height)
min_target_height = min(max_target_height, min_target_height) min_target_height = min(max_target_height, min_target_height)
# slightly differs from tf inplementation # slightly differs from tf implementation
target_height = int( target_height = int(
round(random.uniform(min_target_height, max_target_height))) round(random.uniform(min_target_height, max_target_height)))
target_width = int(round(target_height * aspect_ratio)) target_width = int(round(target_height * aspect_ratio))
...@@ -393,11 +396,12 @@ class RandomGrayscale(object): ...@@ -393,11 +396,12 @@ class RandomGrayscale(object):
grayscale. Default: 0.1. grayscale. Default: 0.1.
Returns: Returns:
ndarray: Grayscale version of the input image with probability ndarray: Image after randomly grayscale transform.
gray_prob and unchanged with probability (1-gray_prob).
- If input image is 1 channel: grayscale version is 1 channel. Notes:
- If input image is 3 channel: grayscale version is 3 channel - If input image is 1 channel: grayscale version is 1 channel.
with r == g == b. - If input image is 3 channel: grayscale version is 3 channel
with r == g == b.
""" """
def __init__(self, gray_prob=0.1): def __init__(self, gray_prob=0.1):
...@@ -484,20 +488,24 @@ class RandomErasing(object): ...@@ -484,20 +488,24 @@ class RandomErasing(object):
if float, it will be converted to (aspect_ratio, 1/aspect_ratio) if float, it will be converted to (aspect_ratio, 1/aspect_ratio)
Default: (3/10, 10/3) Default: (3/10, 10/3)
mode (str): Fill method in erased area, can be: mode (str): Fill method in erased area, can be:
- 'const' (default): All pixels are assign with the same value.
- 'rand': each pixel is assigned with a random value in [0, 255] - const (default): All pixels are assign with the same value.
- rand: each pixel is assigned with a random value in [0, 255]
fill_color (sequence | Number): Base color filled in erased area. fill_color (sequence | Number): Base color filled in erased area.
Default: (128, 128, 128) Defaults to (128, 128, 128).
fill_std (sequence | Number, optional): If set and mode='rand', fill fill_std (sequence | Number, optional): If set and ``mode`` is 'rand',
erased area with random color from normal distribution fill erased area with random color from normal distribution
(mean=fill_color, std=fill_std); If not set, fill erased area with (mean=fill_color, std=fill_std); If not set, fill erased area with
random color from uniform distribution (0~255) random color from uniform distribution (0~255). Defaults to None.
Default: None
Note: Note:
See https://arxiv.org/pdf/1708.04896.pdf See `Random Erasing Data Augmentation
<https://arxiv.org/pdf/1708.04896.pdf>`_
This paper provided 4 modes: RE-R, RE-M, RE-0, RE-255, and use RE-M as This paper provided 4 modes: RE-R, RE-M, RE-0, RE-255, and use RE-M as
default. default. The config of these 4 modes are:
- RE-R: RandomErasing(mode='rand') - RE-R: RandomErasing(mode='rand')
- RE-M: RandomErasing(mode='const', fill_color=(123.67, 116.3, 103.5)) - RE-M: RandomErasing(mode='const', fill_color=(123.67, 116.3, 103.5))
- RE-0: RandomErasing(mode='const', fill_color=0) - RE-0: RandomErasing(mode='const', fill_color=0)
...@@ -605,6 +613,58 @@ class RandomErasing(object): ...@@ -605,6 +613,58 @@ class RandomErasing(object):
return repr_str return repr_str
@PIPELINES.register_module()
class Pad(object):
"""Pad images.
Args:
size (tuple[int] | None): Expected padding size (h, w). Conflicts with
pad_to_square. Defaults to None.
pad_to_square (bool): Pad any image to square shape. Defaults to False.
pad_val (Number | Sequence[Number]): Values to be filled in padding
areas when padding_mode is 'constant'. Default to 0.
padding_mode (str): Type of padding. Should be: constant, edge,
reflect or symmetric. Default to "constant".
"""
def __init__(self,
size=None,
pad_to_square=False,
pad_val=0,
padding_mode='constant'):
assert (size is None) ^ (pad_to_square is False), \
'Only one of [size, pad_to_square] should be given, ' \
f'but get {(size is not None) + (pad_to_square is not False)}'
self.size = size
self.pad_to_square = pad_to_square
self.pad_val = pad_val
self.padding_mode = padding_mode
def __call__(self, results):
for key in results.get('img_fields', ['img']):
img = results[key]
if self.pad_to_square:
target_size = tuple(
max(img.shape[0], img.shape[1]) for _ in range(2))
else:
target_size = self.size
img = mmcv.impad(
img,
shape=target_size,
pad_val=self.pad_val,
padding_mode=self.padding_mode)
results[key] = img
results['img_shape'] = img.shape
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(size={self.size}, '
repr_str += f'(pad_val={self.pad_val}, '
repr_str += f'padding_mode={self.padding_mode})'
return repr_str
@PIPELINES.register_module() @PIPELINES.register_module()
class Resize(object): class Resize(object):
"""Resize images. """Resize images.
...@@ -613,35 +673,49 @@ class Resize(object): ...@@ -613,35 +673,49 @@ class Resize(object):
size (int | tuple): Images scales for resizing (h, w). size (int | tuple): Images scales for resizing (h, w).
When size is int, the default behavior is to resize an image When size is int, the default behavior is to resize an image
to (size, size). When size is tuple and the second value is -1, to (size, size). When size is tuple and the second value is -1,
the short edge of an image is resized to its first value. the image will be resized according to adaptive_side. For example,
For example, when size is 224, the image is resized to 224x224. when size is 224, the image is resized to 224x224. When size is
When size is (224, -1), the short side is resized to 224 and the (224, -1) and adaptive_size is "short", the short side is resized
other side is computed based on the short side, maintaining the to 224 and the other side is computed based on the short side,
aspect ratio. maintaining the aspect ratio.
interpolation (str): Interpolation method, accepted values are interpolation (str): Interpolation method. For "cv2" backend, accepted
"nearest", "bilinear", "bicubic", "area", "lanczos". values are "nearest", "bilinear", "bicubic", "area", "lanczos". For
"pillow" backend, accepted values are "nearest", "bilinear",
"bicubic", "box", "lanczos", "hamming".
More details can be found in `mmcv.image.geometric`. More details can be found in `mmcv.image.geometric`.
backend (str): The image resize backend type, accpeted values are adaptive_side(str): Adaptive resize policy, accepted values are
"short", "long", "height", "width". Default to "short".
backend (str): The image resize backend type, accepted values are
`cv2` and `pillow`. Default: `cv2`. `cv2` and `pillow`. Default: `cv2`.
""" """
def __init__(self, size, interpolation='bilinear', backend='cv2'): def __init__(self,
size,
interpolation='bilinear',
adaptive_side='short',
backend='cv2'):
assert isinstance(size, int) or (isinstance(size, tuple) assert isinstance(size, int) or (isinstance(size, tuple)
and len(size) == 2) and len(size) == 2)
self.resize_w_short_side = False assert adaptive_side in {'short', 'long', 'height', 'width'}
self.adaptive_side = adaptive_side
self.adaptive_resize = False
if isinstance(size, int): if isinstance(size, int):
assert size > 0 assert size > 0
size = (size, size) size = (size, size)
else: else:
assert size[0] > 0 and (size[1] > 0 or size[1] == -1) assert size[0] > 0 and (size[1] > 0 or size[1] == -1)
if size[1] == -1: if size[1] == -1:
self.resize_w_short_side = True self.adaptive_resize = True
assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area',
'lanczos')
if backend not in ['cv2', 'pillow']: if backend not in ['cv2', 'pillow']:
raise ValueError(f'backend: {backend} is not supported for resize.' raise ValueError(f'backend: {backend} is not supported for resize.'
'Supported backends are "cv2", "pillow"') 'Supported backends are "cv2", "pillow"')
if backend == 'cv2':
assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area',
'lanczos')
else:
assert interpolation in ('nearest', 'bilinear', 'bicubic', 'box',
'lanczos', 'hamming')
self.size = size self.size = size
self.interpolation = interpolation self.interpolation = interpolation
self.backend = backend self.backend = backend
...@@ -650,19 +724,29 @@ class Resize(object): ...@@ -650,19 +724,29 @@ class Resize(object):
for key in results.get('img_fields', ['img']): for key in results.get('img_fields', ['img']):
img = results[key] img = results[key]
ignore_resize = False ignore_resize = False
if self.resize_w_short_side: if self.adaptive_resize:
h, w = img.shape[:2] h, w = img.shape[:2]
short_side = self.size[0] target_size = self.size[0]
if (w <= h and w == short_side) or (h <= w
and h == short_side): condition_ignore_resize = {
'short': min(h, w) == target_size,
'long': max(h, w) == target_size,
'height': h == target_size,
'width': w == target_size
}
if condition_ignore_resize[self.adaptive_side]:
ignore_resize = True ignore_resize = True
elif any([
self.adaptive_side == 'short' and w < h,
self.adaptive_side == 'long' and w > h,
self.adaptive_side == 'width',
]):
width = target_size
height = int(target_size * h / w)
else: else:
if w < h: height = target_size
width = short_side width = int(target_size * w / h)
height = int(short_side * h / w)
else:
height = short_side
width = int(short_side * w / h)
else: else:
height, width = self.size height, width = self.size
if not ignore_resize: if not ignore_resize:
...@@ -700,21 +784,23 @@ class CenterCrop(object): ...@@ -700,21 +784,23 @@ class CenterCrop(object):
32. 32.
interpolation (str): Interpolation method, accepted values are interpolation (str): Interpolation method, accepted values are
'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Only valid if 'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Only valid if
efficientnet style is True. Defaults to 'bilinear'. ``efficientnet_style`` is True. Defaults to 'bilinear'.
backend (str): The image resize backend type, accpeted values are backend (str): The image resize backend type, accepted values are
`cv2` and `pillow`. Only valid if efficientnet style is True. `cv2` and `pillow`. Only valid if efficientnet style is True.
Defaults to `cv2`. Defaults to `cv2`.
Notes: Notes:
If the image is smaller than the crop size, return the original image. - If the image is smaller than the crop size, return the original
If efficientnet_style is set to False, the pipeline would be a simple image.
center crop using the crop_size. - If efficientnet_style is set to False, the pipeline would be a simple
If efficientnet_style is set to True, the pipeline will be to first to center crop using the crop_size.
perform the center crop with the crop_size_ as: - If efficientnet_style is set to True, the pipeline will be to first
to perform the center crop with the ``crop_size_`` as:
.. math:: .. math::
crop\_size\_ = crop\_size / (crop\_size + crop\_padding) * short\_edge \text{crop_size_} = \frac{\text{crop_size}}{\text{crop_size} +
\text{crop_padding}} \times \text{short_edge}
And then the pipeline resizes the img to the input crop size. And then the pipeline resizes the img to the input crop size.
""" """
...@@ -886,7 +972,7 @@ class Lighting(object): ...@@ -886,7 +972,7 @@ class Lighting(object):
eigvec (list[list]): the eigenvector of the convariance matrix of pixel eigvec (list[list]): the eigenvector of the convariance matrix of pixel
values, respectively. values, respectively.
alphastd (float): The standard deviation for distribution of alpha. alphastd (float): The standard deviation for distribution of alpha.
Dafaults to 0.1 Defaults to 0.1
to_rgb (bool): Whether to convert img to rgb. to_rgb (bool): Whether to convert img to rgb.
""" """
...@@ -1032,19 +1118,23 @@ class Albu(object): ...@@ -1032,19 +1118,23 @@ class Albu(object):
return updated_dict return updated_dict
def __call__(self, results): def __call__(self, results):
# backup gt_label in case Albu modify it.
_gt_label = copy.deepcopy(results.get('gt_label', None))
# dict to albumentations format # dict to albumentations format
results = self.mapper(results, self.keymap_to_albu) results = self.mapper(results, self.keymap_to_albu)
# process aug
results = self.aug(**results) results = self.aug(**results)
if 'gt_labels' in results:
if isinstance(results['gt_labels'], list):
results['gt_labels'] = np.array(results['gt_labels'])
results['gt_labels'] = results['gt_labels'].astype(np.int64)
# back to the original format # back to the original format
results = self.mapper(results, self.keymap_back) results = self.mapper(results, self.keymap_back)
if _gt_label is not None:
# recover backup gt_label
results.update({'gt_label': _gt_label})
# update final shape # update final shape
if self.update_pad_shape: if self.update_pad_shape:
results['pad_shape'] = results['img'].shape results['pad_shape'] = results['img'].shape
......
# Copyright (c) OpenMMLab. All rights reserved.
from .distributed_sampler import DistributedSampler
from .repeat_aug import RepeatAugSampler
__all__ = ('DistributedSampler', 'RepeatAugSampler')
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch.utils.data import DistributedSampler as _DistributedSampler
from mmcls.core.utils import sync_random_seed
from mmcls.datasets import SAMPLERS
from mmcls.utils import auto_select_device
@SAMPLERS.register_module()
class DistributedSampler(_DistributedSampler):
def __init__(self,
dataset,
num_replicas=None,
rank=None,
shuffle=True,
round_up=True,
seed=0):
super().__init__(dataset, num_replicas=num_replicas, rank=rank)
self.shuffle = shuffle
self.round_up = round_up
if self.round_up:
self.total_size = self.num_samples * self.num_replicas
else:
self.total_size = len(self.dataset)
# In distributed sampling, different ranks should sample
# non-overlapped data in the dataset. Therefore, this function
# is used to make sure that each rank shuffles the data indices
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed, device=auto_select_device())
def __iter__(self):
# deterministically shuffle based on epoch
if self.shuffle:
g = torch.Generator()
# When :attr:`shuffle=True`, this ensures all replicas
# use a different random ordering for each epoch.
# Otherwise, the next iteration of this sampler will
# yield the same ordering.
g.manual_seed(self.epoch + self.seed)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = torch.arange(len(self.dataset)).tolist()
# add extra samples to make it evenly divisible
if self.round_up:
indices = (
indices *
int(self.total_size / len(indices) + 1))[:self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
if self.round_up:
assert len(indices) == self.num_samples
return iter(indices)
import math
import torch
from mmcv.runner import get_dist_info
from torch.utils.data import Sampler
from mmcls.core.utils import sync_random_seed
from mmcls.datasets import SAMPLERS
@SAMPLERS.register_module()
class RepeatAugSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset for
distributed, with repeated augmentation. It ensures that different each
augmented version of a sample will be visible to a different process (GPU).
Heavily based on torch.utils.data.DistributedSampler.
This sampler was taken from
https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py
Used in
Copyright (c) 2015-present, Facebook, Inc.
"""
def __init__(self,
dataset,
num_replicas=None,
rank=None,
shuffle=True,
num_repeats=3,
selected_round=256,
selected_ratio=0,
seed=0):
default_rank, default_world_size = get_dist_info()
rank = default_rank if rank is None else rank
num_replicas = (
default_world_size if num_replicas is None else num_replicas)
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.shuffle = shuffle
self.num_repeats = num_repeats
self.epoch = 0
self.num_samples = int(
math.ceil(len(self.dataset) * num_repeats / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
# Determine the number of samples to select per epoch for each rank.
# num_selected logic defaults to be the same as original RASampler
# impl, but this one can be tweaked
# via selected_ratio and selected_round args.
selected_ratio = selected_ratio or num_replicas # ratio to reduce
# selected samples by, num_replicas if 0
if selected_round:
self.num_selected_samples = int(
math.floor(
len(self.dataset) // selected_round * selected_round /
selected_ratio))
else:
self.num_selected_samples = int(
math.ceil(len(self.dataset) / selected_ratio))
# In distributed sampling, different ranks should sample
# non-overlapped data in the dataset. Therefore, this function
# is used to make sure that each rank shuffles the data indices
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed)
def __iter__(self):
# deterministically shuffle based on epoch
if self.shuffle:
if self.num_replicas > 1: # In distributed environment
# deterministically shuffle based on epoch
g = torch.Generator()
# When :attr:`shuffle=True`, this ensures all replicas
# use a different random ordering for each epoch.
# Otherwise, the next iteration of this sampler will
# yield the same ordering.
g.manual_seed(self.epoch + self.seed)
indices = torch.randperm(
len(self.dataset), generator=g).tolist()
else:
indices = torch.randperm(len(self.dataset)).tolist()
else:
indices = list(range(len(self.dataset)))
# produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....]
indices = [x for x in indices for _ in range(self.num_repeats)]
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
indices += indices[:padding_size]
assert len(indices) == self.total_size
# subsample per rank
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
# return up to num selected samples
return iter(indices[:self.num_selected_samples])
def __len__(self):
return self.num_selected_samples
def set_epoch(self, epoch):
self.epoch = epoch
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Optional
import numpy as np
from .base_dataset import BaseDataset
from .builder import DATASETS
@DATASETS.register_module()
class StanfordCars(BaseDataset):
"""`Stanford Cars`_ Dataset.
After downloading and decompression, the dataset
directory structure is as follows.
Stanford Cars dataset directory::
Stanford Cars
├── cars_train
│ ├── 00001.jpg
│ ├── 00002.jpg
│ └── ...
├── cars_test
│ ├── 00001.jpg
│ ├── 00002.jpg
│ └── ...
└── devkit
├── cars_meta.mat
├── cars_train_annos.mat
├── cars_test_annos.mat
├── cars_test_annoswithlabels.mat
├── eval_train.m
└── train_perfect_preds.txt
.. _Stanford Cars: https://ai.stanford.edu/~jkrause/cars/car_dataset.html
Args:
data_prefix (str): the prefix of data path
test_mode (bool): ``test_mode=True`` means in test phase. It determines
to use the training set or test set.
ann_file (str, optional): The annotation file. If is string, read
samples paths from the ann_file. If is None, read samples path
from cars_{train|test}_annos.mat file. Defaults to None.
""" # noqa: E501
CLASSES = [
'AM General Hummer SUV 2000', 'Acura RL Sedan 2012',
'Acura TL Sedan 2012', 'Acura TL Type-S 2008', 'Acura TSX Sedan 2012',
'Acura Integra Type R 2001', 'Acura ZDX Hatchback 2012',
'Aston Martin V8 Vantage Convertible 2012',
'Aston Martin V8 Vantage Coupe 2012',
'Aston Martin Virage Convertible 2012',
'Aston Martin Virage Coupe 2012', 'Audi RS 4 Convertible 2008',
'Audi A5 Coupe 2012', 'Audi TTS Coupe 2012', 'Audi R8 Coupe 2012',
'Audi V8 Sedan 1994', 'Audi 100 Sedan 1994', 'Audi 100 Wagon 1994',
'Audi TT Hatchback 2011', 'Audi S6 Sedan 2011',
'Audi S5 Convertible 2012', 'Audi S5 Coupe 2012', 'Audi S4 Sedan 2012',
'Audi S4 Sedan 2007', 'Audi TT RS Coupe 2012',
'BMW ActiveHybrid 5 Sedan 2012', 'BMW 1 Series Convertible 2012',
'BMW 1 Series Coupe 2012', 'BMW 3 Series Sedan 2012',
'BMW 3 Series Wagon 2012', 'BMW 6 Series Convertible 2007',
'BMW X5 SUV 2007', 'BMW X6 SUV 2012', 'BMW M3 Coupe 2012',
'BMW M5 Sedan 2010', 'BMW M6 Convertible 2010', 'BMW X3 SUV 2012',
'BMW Z4 Convertible 2012',
'Bentley Continental Supersports Conv. Convertible 2012',
'Bentley Arnage Sedan 2009', 'Bentley Mulsanne Sedan 2011',
'Bentley Continental GT Coupe 2012',
'Bentley Continental GT Coupe 2007',
'Bentley Continental Flying Spur Sedan 2007',
'Bugatti Veyron 16.4 Convertible 2009',
'Bugatti Veyron 16.4 Coupe 2009', 'Buick Regal GS 2012',
'Buick Rainier SUV 2007', 'Buick Verano Sedan 2012',
'Buick Enclave SUV 2012', 'Cadillac CTS-V Sedan 2012',
'Cadillac SRX SUV 2012', 'Cadillac Escalade EXT Crew Cab 2007',
'Chevrolet Silverado 1500 Hybrid Crew Cab 2012',
'Chevrolet Corvette Convertible 2012', 'Chevrolet Corvette ZR1 2012',
'Chevrolet Corvette Ron Fellows Edition Z06 2007',
'Chevrolet Traverse SUV 2012', 'Chevrolet Camaro Convertible 2012',
'Chevrolet HHR SS 2010', 'Chevrolet Impala Sedan 2007',
'Chevrolet Tahoe Hybrid SUV 2012', 'Chevrolet Sonic Sedan 2012',
'Chevrolet Express Cargo Van 2007',
'Chevrolet Avalanche Crew Cab 2012', 'Chevrolet Cobalt SS 2010',
'Chevrolet Malibu Hybrid Sedan 2010', 'Chevrolet TrailBlazer SS 2009',
'Chevrolet Silverado 2500HD Regular Cab 2012',
'Chevrolet Silverado 1500 Classic Extended Cab 2007',
'Chevrolet Express Van 2007', 'Chevrolet Monte Carlo Coupe 2007',
'Chevrolet Malibu Sedan 2007',
'Chevrolet Silverado 1500 Extended Cab 2012',
'Chevrolet Silverado 1500 Regular Cab 2012', 'Chrysler Aspen SUV 2009',
'Chrysler Sebring Convertible 2010',
'Chrysler Town and Country Minivan 2012', 'Chrysler 300 SRT-8 2010',
'Chrysler Crossfire Convertible 2008',
'Chrysler PT Cruiser Convertible 2008', 'Daewoo Nubira Wagon 2002',
'Dodge Caliber Wagon 2012', 'Dodge Caliber Wagon 2007',
'Dodge Caravan Minivan 1997', 'Dodge Ram Pickup 3500 Crew Cab 2010',
'Dodge Ram Pickup 3500 Quad Cab 2009', 'Dodge Sprinter Cargo Van 2009',
'Dodge Journey SUV 2012', 'Dodge Dakota Crew Cab 2010',
'Dodge Dakota Club Cab 2007', 'Dodge Magnum Wagon 2008',
'Dodge Challenger SRT8 2011', 'Dodge Durango SUV 2012',
'Dodge Durango SUV 2007', 'Dodge Charger Sedan 2012',
'Dodge Charger SRT-8 2009', 'Eagle Talon Hatchback 1998',
'FIAT 500 Abarth 2012', 'FIAT 500 Convertible 2012',
'Ferrari FF Coupe 2012', 'Ferrari California Convertible 2012',
'Ferrari 458 Italia Convertible 2012', 'Ferrari 458 Italia Coupe 2012',
'Fisker Karma Sedan 2012', 'Ford F-450 Super Duty Crew Cab 2012',
'Ford Mustang Convertible 2007', 'Ford Freestar Minivan 2007',
'Ford Expedition EL SUV 2009', 'Ford Edge SUV 2012',
'Ford Ranger SuperCab 2011', 'Ford GT Coupe 2006',
'Ford F-150 Regular Cab 2012', 'Ford F-150 Regular Cab 2007',
'Ford Focus Sedan 2007', 'Ford E-Series Wagon Van 2012',
'Ford Fiesta Sedan 2012', 'GMC Terrain SUV 2012',
'GMC Savana Van 2012', 'GMC Yukon Hybrid SUV 2012',
'GMC Acadia SUV 2012', 'GMC Canyon Extended Cab 2012',
'Geo Metro Convertible 1993', 'HUMMER H3T Crew Cab 2010',
'HUMMER H2 SUT Crew Cab 2009', 'Honda Odyssey Minivan 2012',
'Honda Odyssey Minivan 2007', 'Honda Accord Coupe 2012',
'Honda Accord Sedan 2012', 'Hyundai Veloster Hatchback 2012',
'Hyundai Santa Fe SUV 2012', 'Hyundai Tucson SUV 2012',
'Hyundai Veracruz SUV 2012', 'Hyundai Sonata Hybrid Sedan 2012',
'Hyundai Elantra Sedan 2007', 'Hyundai Accent Sedan 2012',
'Hyundai Genesis Sedan 2012', 'Hyundai Sonata Sedan 2012',
'Hyundai Elantra Touring Hatchback 2012', 'Hyundai Azera Sedan 2012',
'Infiniti G Coupe IPL 2012', 'Infiniti QX56 SUV 2011',
'Isuzu Ascender SUV 2008', 'Jaguar XK XKR 2012',
'Jeep Patriot SUV 2012', 'Jeep Wrangler SUV 2012',
'Jeep Liberty SUV 2012', 'Jeep Grand Cherokee SUV 2012',
'Jeep Compass SUV 2012', 'Lamborghini Reventon Coupe 2008',
'Lamborghini Aventador Coupe 2012',
'Lamborghini Gallardo LP 570-4 Superleggera 2012',
'Lamborghini Diablo Coupe 2001', 'Land Rover Range Rover SUV 2012',
'Land Rover LR2 SUV 2012', 'Lincoln Town Car Sedan 2011',
'MINI Cooper Roadster Convertible 2012',
'Maybach Landaulet Convertible 2012', 'Mazda Tribute SUV 2011',
'McLaren MP4-12C Coupe 2012',
'Mercedes-Benz 300-Class Convertible 1993',
'Mercedes-Benz C-Class Sedan 2012',
'Mercedes-Benz SL-Class Coupe 2009',
'Mercedes-Benz E-Class Sedan 2012', 'Mercedes-Benz S-Class Sedan 2012',
'Mercedes-Benz Sprinter Van 2012', 'Mitsubishi Lancer Sedan 2012',
'Nissan Leaf Hatchback 2012', 'Nissan NV Passenger Van 2012',
'Nissan Juke Hatchback 2012', 'Nissan 240SX Coupe 1998',
'Plymouth Neon Coupe 1999', 'Porsche Panamera Sedan 2012',
'Ram C/V Cargo Van Minivan 2012',
'Rolls-Royce Phantom Drophead Coupe Convertible 2012',
'Rolls-Royce Ghost Sedan 2012', 'Rolls-Royce Phantom Sedan 2012',
'Scion xD Hatchback 2012', 'Spyker C8 Convertible 2009',
'Spyker C8 Coupe 2009', 'Suzuki Aerio Sedan 2007',
'Suzuki Kizashi Sedan 2012', 'Suzuki SX4 Hatchback 2012',
'Suzuki SX4 Sedan 2012', 'Tesla Model S Sedan 2012',
'Toyota Sequoia SUV 2012', 'Toyota Camry Sedan 2012',
'Toyota Corolla Sedan 2012', 'Toyota 4Runner SUV 2012',
'Volkswagen Golf Hatchback 2012', 'Volkswagen Golf Hatchback 1991',
'Volkswagen Beetle Hatchback 2012', 'Volvo C30 Hatchback 2012',
'Volvo 240 Sedan 1993', 'Volvo XC90 SUV 2007',
'smart fortwo Convertible 2012'
]
def __init__(self,
data_prefix: str,
test_mode: bool,
ann_file: Optional[str] = None,
**kwargs):
if test_mode:
if ann_file is not None:
self.test_ann_file = ann_file
else:
self.test_ann_file = osp.join(
data_prefix, 'devkit/cars_test_annos_withlabels.mat')
data_prefix = osp.join(data_prefix, 'cars_test')
else:
if ann_file is not None:
self.train_ann_file = ann_file
else:
self.train_ann_file = osp.join(data_prefix,
'devkit/cars_train_annos.mat')
data_prefix = osp.join(data_prefix, 'cars_train')
super(StanfordCars, self).__init__(
ann_file=ann_file,
data_prefix=data_prefix,
test_mode=test_mode,
**kwargs)
def load_annotations(self):
try:
import scipy.io as sio
except ImportError:
raise ImportError(
'please run `pip install scipy` to install package `scipy`.')
data_infos = []
if self.test_mode:
data = sio.loadmat(self.test_ann_file)
else:
data = sio.loadmat(self.train_ann_file)
for img in data['annotations'][0]:
info = {'img_prefix': self.data_prefix}
# The organization of each record is as follows,
# 0: bbox_x1 of each image
# 1: bbox_y1 of each image
# 2: bbox_x2 of each image
# 3: bbox_y2 of each image
# 4: class_id, start from 0, so
# here we need to '- 1' to let them start from 0
# 5: file name of each image
info['img_info'] = {'filename': img[5][0]}
info['gt_label'] = np.array(img[4][0][0] - 1, dtype=np.int64)
data_infos.append(info)
return data_infos
# Copyright (c) OpenMMLab. All rights reserved.
import gzip
import hashlib
import os
import os.path
import shutil
import tarfile
import urllib.error
import urllib.request
import zipfile
__all__ = ['rm_suffix', 'check_integrity', 'download_and_extract_archive']
def rm_suffix(s, suffix=None):
if suffix is None:
return s[:s.rfind('.')]
else:
return s[:s.rfind(suffix)]
def calculate_md5(fpath, chunk_size=1024 * 1024):
md5 = hashlib.md5()
with open(fpath, 'rb') as f:
for chunk in iter(lambda: f.read(chunk_size), b''):
md5.update(chunk)
return md5.hexdigest()
def check_md5(fpath, md5, **kwargs):
return md5 == calculate_md5(fpath, **kwargs)
def check_integrity(fpath, md5=None):
if not os.path.isfile(fpath):
return False
if md5 is None:
return True
return check_md5(fpath, md5)
def download_url_to_file(url, fpath):
with urllib.request.urlopen(url) as resp, open(fpath, 'wb') as of:
shutil.copyfileobj(resp, of)
def download_url(url, root, filename=None, md5=None):
"""Download a file from a url and place it in root.
Args:
url (str): URL to download file from.
root (str): Directory to place downloaded file in.
filename (str | None): Name to save the file under.
If filename is None, use the basename of the URL.
md5 (str | None): MD5 checksum of the download.
If md5 is None, download without md5 check.
"""
root = os.path.expanduser(root)
if not filename:
filename = os.path.basename(url)
fpath = os.path.join(root, filename)
os.makedirs(root, exist_ok=True)
if check_integrity(fpath, md5):
print(f'Using downloaded and verified file: {fpath}')
else:
try:
print(f'Downloading {url} to {fpath}')
download_url_to_file(url, fpath)
except (urllib.error.URLError, IOError) as e:
if url[:5] == 'https':
url = url.replace('https:', 'http:')
print('Failed download. Trying https -> http instead.'
f' Downloading {url} to {fpath}')
download_url_to_file(url, fpath)
else:
raise e
# check integrity of downloaded file
if not check_integrity(fpath, md5):
raise RuntimeError('File not found or corrupted.')
def _is_tarxz(filename):
return filename.endswith('.tar.xz')
def _is_tar(filename):
return filename.endswith('.tar')
def _is_targz(filename):
return filename.endswith('.tar.gz')
def _is_tgz(filename):
return filename.endswith('.tgz')
def _is_gzip(filename):
return filename.endswith('.gz') and not filename.endswith('.tar.gz')
def _is_zip(filename):
return filename.endswith('.zip')
def extract_archive(from_path, to_path=None, remove_finished=False):
if to_path is None:
to_path = os.path.dirname(from_path)
if _is_tar(from_path):
with tarfile.open(from_path, 'r') as tar:
tar.extractall(path=to_path)
elif _is_targz(from_path) or _is_tgz(from_path):
with tarfile.open(from_path, 'r:gz') as tar:
tar.extractall(path=to_path)
elif _is_tarxz(from_path):
with tarfile.open(from_path, 'r:xz') as tar:
tar.extractall(path=to_path)
elif _is_gzip(from_path):
to_path = os.path.join(
to_path,
os.path.splitext(os.path.basename(from_path))[0])
with open(to_path, 'wb') as out_f, gzip.GzipFile(from_path) as zip_f:
out_f.write(zip_f.read())
elif _is_zip(from_path):
with zipfile.ZipFile(from_path, 'r') as z:
z.extractall(to_path)
else:
raise ValueError(f'Extraction of {from_path} not supported')
if remove_finished:
os.remove(from_path)
def download_and_extract_archive(url,
download_root,
extract_root=None,
filename=None,
md5=None,
remove_finished=False):
download_root = os.path.expanduser(download_root)
if extract_root is None:
extract_root = download_root
if not filename:
filename = os.path.basename(url)
download_url(url, download_root, filename, md5)
archive = os.path.join(download_root, filename)
print(f'Extracting {archive} to {extract_root}')
extract_archive(archive, extract_root, remove_finished)
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import xml.etree.ElementTree as ET
import mmcv
import numpy as np
from .builder import DATASETS
from .multi_label import MultiLabelDataset
@DATASETS.register_module()
class VOC(MultiLabelDataset):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Dataset.
Args:
data_prefix (str): the prefix of data path
pipeline (list): a list of dict, where each element represents
a operation defined in `mmcls.datasets.pipelines`
ann_file (str | None): the annotation file. When ann_file is str,
the subclass is expected to read from the ann_file. When ann_file
is None, the subclass is expected to read according to data_prefix
difficult_as_postive (Optional[bool]): Whether to map the difficult
labels as positive. If it set to True, map difficult examples to
positive ones(1), If it set to False, map difficult examples to
negative ones(0). Defaults to None, the difficult labels will be
set to '-1'.
"""
CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
'tvmonitor')
def __init__(self, difficult_as_postive=None, **kwargs):
self.difficult_as_postive = difficult_as_postive
super(VOC, self).__init__(**kwargs)
if 'VOC2007' in self.data_prefix:
self.year = 2007
else:
raise ValueError('Cannot infer dataset year from img_prefix.')
def load_annotations(self):
"""Load annotations.
Returns:
list[dict]: Annotation info from XML file.
"""
data_infos = []
img_ids = mmcv.list_from_file(self.ann_file)
for img_id in img_ids:
filename = f'JPEGImages/{img_id}.jpg'
xml_path = osp.join(self.data_prefix, 'Annotations',
f'{img_id}.xml')
tree = ET.parse(xml_path)
root = tree.getroot()
labels = []
labels_difficult = []
for obj in root.findall('object'):
label_name = obj.find('name').text
# in case customized dataset has wrong labels
# or CLASSES has been override.
if label_name not in self.CLASSES:
continue
label = self.class_to_idx[label_name]
difficult = int(obj.find('difficult').text)
if difficult:
labels_difficult.append(label)
else:
labels.append(label)
gt_label = np.zeros(len(self.CLASSES))
# set difficult example first, then set postivate examples.
# The order cannot be swapped for the case where multiple objects
# of the same kind exist and some are difficult.
if self.difficult_as_postive is None:
# map difficult examples to -1,
# it may be used in evaluation to ignore difficult targets.
gt_label[labels_difficult] = -1
elif self.difficult_as_postive:
# map difficult examples to positive ones(1).
gt_label[labels_difficult] = 1
else:
# map difficult examples to negative ones(0).
gt_label[labels_difficult] = 0
gt_label[labels] = 1
info = dict(
img_prefix=self.data_prefix,
img_info=dict(filename=filename),
gt_label=gt_label.astype(np.int8))
data_infos.append(info)
return data_infos
# Copyright (c) OpenMMLab. All rights reserved.
from .backbones import * # noqa: F401,F403
from .builder import (BACKBONES, CLASSIFIERS, HEADS, LOSSES, NECKS,
build_backbone, build_classifier, build_head, build_loss,
build_neck)
from .classifiers import * # noqa: F401,F403
from .heads import * # noqa: F401,F403
from .losses import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
__all__ = [
'BACKBONES', 'HEADS', 'NECKS', 'LOSSES', 'CLASSIFIERS', 'build_backbone',
'build_head', 'build_neck', 'build_loss', 'build_classifier'
]
# Copyright (c) OpenMMLab. All rights reserved.
from .alexnet import AlexNet
from .conformer import Conformer
from .convmixer import ConvMixer
from .convnext import ConvNeXt
from .cspnet import CSPDarkNet, CSPNet, CSPResNet, CSPResNeXt
from .deit import DistilledVisionTransformer
from .densenet import DenseNet
from .efficientformer import EfficientFormer
from .efficientnet import EfficientNet
from .hornet import HorNet
from .hrnet import HRNet
from .lenet import LeNet5
from .mlp_mixer import MlpMixer
from .mobilenet_v2 import MobileNetV2
from .mobilenet_v3 import MobileNetV3
from .mvit import MViT
from .poolformer import PoolFormer
from .regnet import RegNet
from .repmlp import RepMLPNet
from .repvgg import RepVGG
from .res2net import Res2Net
from .resnest import ResNeSt
from .resnet import ResNet, ResNetV1c, ResNetV1d
from .resnet_cifar import ResNet_CIFAR
from .resnext import ResNeXt
from .seresnet import SEResNet
from .seresnext import SEResNeXt
from .shufflenet_v1 import ShuffleNetV1
from .shufflenet_v2 import ShuffleNetV2
from .swin_transformer import SwinTransformer
from .swin_transformer_v2 import SwinTransformerV2
from .t2t_vit import T2T_ViT
from .timm_backbone import TIMMBackbone
from .tnt import TNT
from .twins import PCPVT, SVT
from .van import VAN
from .vgg import VGG
from .vision_transformer import VisionTransformer
__all__ = [
'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d',
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
'SwinTransformer', 'SwinTransformerV2', 'TNT', 'TIMMBackbone', 'T2T_ViT',
'Res2Net', 'RepVGG', 'Conformer', 'MlpMixer', 'DistilledVisionTransformer',
'PCPVT', 'SVT', 'EfficientNet', 'ConvNeXt', 'HRNet', 'ResNetV1c',
'ConvMixer', 'CSPDarkNet', 'CSPResNet', 'CSPResNeXt', 'CSPNet',
'RepMLPNet', 'PoolFormer', 'DenseNet', 'VAN', 'MViT', 'EfficientFormer',
'HorNet'
]
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn import torch.nn as nn
from ..builder import BACKBONES from ..builder import BACKBONES
...@@ -52,4 +53,4 @@ class AlexNet(BaseBackbone): ...@@ -52,4 +53,4 @@ class AlexNet(BaseBackbone):
x = x.view(x.size(0), 256 * 6 * 6) x = x.view(x.size(0), 256 * 6 * 6)
x = self.classifier(x) x = self.classifier(x)
return x return (x, )
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment