Commit cbc25585 authored by limm's avatar limm
Browse files

add mmpretrain/ part

parent 1baf0566
Pipeline #2801 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import math
import numbers
import re
import string
from enum import EnumMeta
from numbers import Number
from typing import Dict, List, Optional, Sequence, Tuple, Union
import mmcv
import mmengine
import numpy as np
import torch
import torchvision
import torchvision.transforms.functional as F
from mmcv.transforms import BaseTransform
from mmcv.transforms.utils import cache_randomness
from PIL import Image
from torchvision import transforms
from torchvision.transforms.transforms import InterpolationMode
from mmpretrain.registry import TRANSFORMS
try:
import albumentations
except ImportError:
albumentations = None
def _str_to_torch_dtype(t: str):
"""mapping str format dtype to torch.dtype."""
import torch # noqa: F401,F403
return eval(f'torch.{t}')
def _interpolation_modes_from_str(t: str):
"""mapping str format to Interpolation."""
t = t.lower()
inverse_modes_mapping = {
'nearest': InterpolationMode.NEAREST,
'bilinear': InterpolationMode.BILINEAR,
'bicubic': InterpolationMode.BICUBIC,
'box': InterpolationMode.BOX,
'hammimg': InterpolationMode.HAMMING,
'lanczos': InterpolationMode.LANCZOS,
}
return inverse_modes_mapping[t]
class TorchVisonTransformWrapper:
def __init__(self, transform, *args, **kwargs):
if 'interpolation' in kwargs and isinstance(kwargs['interpolation'],
str):
kwargs['interpolation'] = _interpolation_modes_from_str(
kwargs['interpolation'])
if 'dtype' in kwargs and isinstance(kwargs['dtype'], str):
kwargs['dtype'] = _str_to_torch_dtype(kwargs['dtype'])
self.t = transform(*args, **kwargs)
def __call__(self, results):
results['img'] = self.t(results['img'])
return results
def __repr__(self) -> str:
return f'TorchVision{repr(self.t)}'
def register_vision_transforms() -> List[str]:
"""Register transforms in ``torchvision.transforms`` to the ``TRANSFORMS``
registry.
Returns:
List[str]: A list of registered transforms' name.
"""
vision_transforms = []
for module_name in dir(torchvision.transforms):
if not re.match('[A-Z]', module_name):
# must startswith a capital letter
continue
_transform = getattr(torchvision.transforms, module_name)
if inspect.isclass(_transform) and callable(
_transform) and not isinstance(_transform, (EnumMeta)):
from functools import partial
TRANSFORMS.register_module(
module=partial(
TorchVisonTransformWrapper, transform=_transform),
name=f'torchvision/{module_name}')
vision_transforms.append(f'torchvision/{module_name}')
return vision_transforms
# register all the transforms in torchvision by using a transform wrapper
VISION_TRANSFORMS = register_vision_transforms()
@TRANSFORMS.register_module()
class RandomCrop(BaseTransform):
"""Crop the given Image at a random location.
**Required Keys:**
- img
**Modified Keys:**
- img
- img_shape
Args:
crop_size (int | Sequence): Desired output size of the crop. If
crop_size is an int instead of sequence like (h, w), a square crop
(crop_size, crop_size) is made.
padding (int | Sequence, optional): Optional padding on each border
of the image. If a sequence of length 4 is provided, it is used to
pad left, top, right, bottom borders respectively. If a sequence
of length 2 is provided, it is used to pad left/right, top/bottom
borders, respectively. Default: None, which means no padding.
pad_if_needed (bool): It will pad the image if smaller than the
desired size to avoid raising an exception. Since cropping is done
after padding, the padding seems to be done at a random offset.
Default: False.
pad_val (Number | Sequence[Number]): Pixel pad_val value for constant
fill. If a tuple of length 3, it is used to pad_val R, G, B
channels respectively. Default: 0.
padding_mode (str): Type of padding. Defaults to "constant". Should
be one of the following:
- ``constant``: Pads with a constant value, this value is specified
with pad_val.
- ``edge``: pads with the last value at the edge of the image.
- ``reflect``: Pads with reflection of image without repeating the
last value on the edge. For example, padding [1, 2, 3, 4]
with 2 elements on both sides in reflect mode will result
in [3, 2, 1, 2, 3, 4, 3, 2].
- ``symmetric``: Pads with reflection of image repeating the last
value on the edge. For example, padding [1, 2, 3, 4] with
2 elements on both sides in symmetric mode will result in
[2, 1, 1, 2, 3, 4, 4, 3].
"""
def __init__(self,
crop_size: Union[Sequence, int],
padding: Optional[Union[Sequence, int]] = None,
pad_if_needed: bool = False,
pad_val: Union[Number, Sequence[Number]] = 0,
padding_mode: str = 'constant'):
if isinstance(crop_size, Sequence):
assert len(crop_size) == 2
assert crop_size[0] > 0 and crop_size[1] > 0
self.crop_size = crop_size
else:
assert crop_size > 0
self.crop_size = (crop_size, crop_size)
# check padding mode
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
self.padding = padding
self.pad_if_needed = pad_if_needed
self.pad_val = pad_val
self.padding_mode = padding_mode
@cache_randomness
def rand_crop_params(self, img: np.ndarray):
"""Get parameters for ``crop`` for a random crop.
Args:
img (ndarray): Image to be cropped.
Returns:
tuple: Params (offset_h, offset_w, target_h, target_w) to be
passed to ``crop`` for random crop.
"""
h, w = img.shape[:2]
target_h, target_w = self.crop_size
if w == target_w and h == target_h:
return 0, 0, h, w
elif w < target_w or h < target_h:
target_w = min(w, target_w)
target_h = min(h, target_h)
offset_h = np.random.randint(0, h - target_h + 1)
offset_w = np.random.randint(0, w - target_w + 1)
return offset_h, offset_w, target_h, target_w
def transform(self, results: dict) -> dict:
"""Transform function to randomly crop images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Randomly cropped results, 'img_shape'
key in result dict is updated according to crop size.
"""
img = results['img']
if self.padding is not None:
img = mmcv.impad(img, padding=self.padding, pad_val=self.pad_val)
# pad img if needed
if self.pad_if_needed:
h_pad = math.ceil(max(0, self.crop_size[0] - img.shape[0]) / 2)
w_pad = math.ceil(max(0, self.crop_size[1] - img.shape[1]) / 2)
img = mmcv.impad(
img,
padding=(w_pad, h_pad, w_pad, h_pad),
pad_val=self.pad_val,
padding_mode=self.padding_mode)
offset_h, offset_w, target_h, target_w = self.rand_crop_params(img)
img = mmcv.imcrop(
img,
np.array([
offset_w,
offset_h,
offset_w + target_w - 1,
offset_h + target_h - 1,
]))
results['img'] = img
results['img_shape'] = img.shape
return results
def __repr__(self):
"""Print the basic information of the transform.
Returns:
str: Formatted string.
"""
repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size}'
repr_str += f', padding={self.padding}'
repr_str += f', pad_if_needed={self.pad_if_needed}'
repr_str += f', pad_val={self.pad_val}'
repr_str += f', padding_mode={self.padding_mode})'
return repr_str
@TRANSFORMS.register_module()
class RandomResizedCrop(BaseTransform):
"""Crop the given image to random scale and aspect ratio.
A crop of random size (default: of 0.08 to 1.0) of the original size and a
random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio
is made. This crop is finally resized to given size.
**Required Keys:**
- img
**Modified Keys:**
- img
- img_shape
Args:
scale (sequence | int): Desired output scale of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
crop_ratio_range (tuple): Range of the random size of the cropped
image compared to the original image. Defaults to (0.08, 1.0).
aspect_ratio_range (tuple): Range of the random aspect ratio of the
cropped image compared to the original image.
Defaults to (3. / 4., 4. / 3.).
max_attempts (int): Maximum number of attempts before falling back to
Central Crop. Defaults to 10.
interpolation (str): Interpolation method, accepted values are
'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to
'bilinear'.
backend (str): The image resize backend type, accepted values are
'cv2' and 'pillow'. Defaults to 'cv2'.
"""
def __init__(self,
scale: Union[Sequence, int],
crop_ratio_range: Tuple[float, float] = (0.08, 1.0),
aspect_ratio_range: Tuple[float, float] = (3. / 4., 4. / 3.),
max_attempts: int = 10,
interpolation: str = 'bilinear',
backend: str = 'cv2') -> None:
if isinstance(scale, Sequence):
assert len(scale) == 2
assert scale[0] > 0 and scale[1] > 0
self.scale = scale
else:
assert scale > 0
self.scale = (scale, scale)
if (crop_ratio_range[0] > crop_ratio_range[1]) or (
aspect_ratio_range[0] > aspect_ratio_range[1]):
raise ValueError(
'range should be of kind (min, max). '
f'But received crop_ratio_range {crop_ratio_range} '
f'and aspect_ratio_range {aspect_ratio_range}.')
assert isinstance(max_attempts, int) and max_attempts >= 0, \
'max_attempts mush be int and no less than 0.'
assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area',
'lanczos')
self.crop_ratio_range = crop_ratio_range
self.aspect_ratio_range = aspect_ratio_range
self.max_attempts = max_attempts
self.interpolation = interpolation
self.backend = backend
@cache_randomness
def rand_crop_params(self, img: np.ndarray) -> Tuple[int, int, int, int]:
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (ndarray): Image to be cropped.
Returns:
tuple: Params (offset_h, offset_w, target_h, target_w) to be
passed to `crop` for a random sized crop.
"""
h, w = img.shape[:2]
area = h * w
for _ in range(self.max_attempts):
target_area = np.random.uniform(*self.crop_ratio_range) * area
log_ratio = (math.log(self.aspect_ratio_range[0]),
math.log(self.aspect_ratio_range[1]))
aspect_ratio = math.exp(np.random.uniform(*log_ratio))
target_w = int(round(math.sqrt(target_area * aspect_ratio)))
target_h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < target_w <= w and 0 < target_h <= h:
offset_h = np.random.randint(0, h - target_h + 1)
offset_w = np.random.randint(0, w - target_w + 1)
return offset_h, offset_w, target_h, target_w
# Fallback to central crop
in_ratio = float(w) / float(h)
if in_ratio < min(self.aspect_ratio_range):
target_w = w
target_h = int(round(target_w / min(self.aspect_ratio_range)))
elif in_ratio > max(self.aspect_ratio_range):
target_h = h
target_w = int(round(target_h * max(self.aspect_ratio_range)))
else: # whole image
target_w = w
target_h = h
offset_h = (h - target_h) // 2
offset_w = (w - target_w) // 2
return offset_h, offset_w, target_h, target_w
def transform(self, results: dict) -> dict:
"""Transform function to randomly resized crop images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Randomly resized cropped results, 'img_shape'
key in result dict is updated according to crop size.
"""
img = results['img']
offset_h, offset_w, target_h, target_w = self.rand_crop_params(img)
img = mmcv.imcrop(
img,
bboxes=np.array([
offset_w, offset_h, offset_w + target_w - 1,
offset_h + target_h - 1
]))
img = mmcv.imresize(
img,
tuple(self.scale[::-1]),
interpolation=self.interpolation,
backend=self.backend)
results['img'] = img
results['img_shape'] = img.shape
return results
def __repr__(self):
"""Print the basic information of the transform.
Returns:
str: Formatted string.
"""
repr_str = self.__class__.__name__ + f'(scale={self.scale}'
repr_str += ', crop_ratio_range='
repr_str += f'{tuple(round(s, 4) for s in self.crop_ratio_range)}'
repr_str += ', aspect_ratio_range='
repr_str += f'{tuple(round(r, 4) for r in self.aspect_ratio_range)}'
repr_str += f', max_attempts={self.max_attempts}'
repr_str += f', interpolation={self.interpolation}'
repr_str += f', backend={self.backend})'
return repr_str
@TRANSFORMS.register_module()
class EfficientNetRandomCrop(RandomResizedCrop):
"""EfficientNet style RandomResizedCrop.
**Required Keys:**
- img
**Modified Keys:**
- img
- img_shape
Args:
scale (int): Desired output scale of the crop. Only int size is
accepted, a square crop (size, size) is made.
min_covered (Number): Minimum ratio of the cropped area to the original
area. Defaults to 0.1.
crop_padding (int): The crop padding parameter in efficientnet style
center crop. Defaults to 32.
crop_ratio_range (tuple): Range of the random size of the cropped
image compared to the original image. Defaults to (0.08, 1.0).
aspect_ratio_range (tuple): Range of the random aspect ratio of the
cropped image compared to the original image.
Defaults to (3. / 4., 4. / 3.).
max_attempts (int): Maximum number of attempts before falling back to
Central Crop. Defaults to 10.
interpolation (str): Interpolation method, accepted values are
'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to
'bicubic'.
backend (str): The image resize backend type, accepted values are
'cv2' and 'pillow'. Defaults to 'cv2'.
"""
def __init__(self,
scale: int,
min_covered: float = 0.1,
crop_padding: int = 32,
interpolation: str = 'bicubic',
**kwarg):
assert isinstance(scale, int)
super().__init__(scale, interpolation=interpolation, **kwarg)
assert min_covered >= 0, 'min_covered should be no less than 0.'
assert crop_padding >= 0, 'crop_padding should be no less than 0.'
self.min_covered = min_covered
self.crop_padding = crop_padding
# https://github.com/kakaobrain/fast-autoaugment/blob/master/FastAutoAugment/data.py # noqa
@cache_randomness
def rand_crop_params(self, img: np.ndarray) -> Tuple[int, int, int, int]:
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (ndarray): Image to be cropped.
Returns:
tuple: Params (offset_h, offset_w, target_h, target_w) to be
passed to `crop` for a random sized crop.
"""
h, w = img.shape[:2]
area = h * w
min_target_area = self.crop_ratio_range[0] * area
max_target_area = self.crop_ratio_range[1] * area
for _ in range(self.max_attempts):
aspect_ratio = np.random.uniform(*self.aspect_ratio_range)
min_target_h = int(
round(math.sqrt(min_target_area / aspect_ratio)))
max_target_h = int(
round(math.sqrt(max_target_area / aspect_ratio)))
if max_target_h * aspect_ratio > w:
max_target_h = int((w + 0.5 - 1e-7) / aspect_ratio)
if max_target_h * aspect_ratio > w:
max_target_h -= 1
max_target_h = min(max_target_h, h)
min_target_h = min(max_target_h, min_target_h)
# slightly differs from tf implementation
target_h = int(
round(np.random.uniform(min_target_h, max_target_h)))
target_w = int(round(target_h * aspect_ratio))
target_area = target_h * target_w
# slight differs from tf. In tf, if target_area > max_target_area,
# area will be recalculated
if (target_area < min_target_area or target_area > max_target_area
or target_w > w or target_h > h
or target_area < self.min_covered * area):
continue
offset_h = np.random.randint(0, h - target_h + 1)
offset_w = np.random.randint(0, w - target_w + 1)
return offset_h, offset_w, target_h, target_w
# Fallback to central crop
img_short = min(h, w)
crop_size = self.scale[0] / (self.scale[0] +
self.crop_padding) * img_short
offset_h = max(0, int(round((h - crop_size) / 2.)))
offset_w = max(0, int(round((w - crop_size) / 2.)))
return offset_h, offset_w, crop_size, crop_size
def __repr__(self):
"""Print the basic information of the transform.
Returns:
str: Formatted string.
"""
repr_str = super().__repr__()[:-1]
repr_str += f', min_covered={self.min_covered}'
repr_str += f', crop_padding={self.crop_padding})'
return repr_str
@TRANSFORMS.register_module()
class RandomErasing(BaseTransform):
"""Randomly selects a rectangle region in an image and erase pixels.
**Required Keys:**
- img
**Modified Keys:**
- img
Args:
erase_prob (float): Probability that image will be randomly erased.
Default: 0.5
min_area_ratio (float): Minimum erased area / input image area
Default: 0.02
max_area_ratio (float): Maximum erased area / input image area
Default: 0.4
aspect_range (sequence | float): Aspect ratio range of erased area.
if float, it will be converted to (aspect_ratio, 1/aspect_ratio)
Default: (3/10, 10/3)
mode (str): Fill method in erased area, can be:
- const (default): All pixels are assign with the same value.
- rand: each pixel is assigned with a random value in [0, 255]
fill_color (sequence | Number): Base color filled in erased area.
Defaults to (128, 128, 128).
fill_std (sequence | Number, optional): If set and ``mode`` is 'rand',
fill erased area with random color from normal distribution
(mean=fill_color, std=fill_std); If not set, fill erased area with
random color from uniform distribution (0~255). Defaults to None.
Note:
See `Random Erasing Data Augmentation
<https://arxiv.org/pdf/1708.04896.pdf>`_
This paper provided 4 modes: RE-R, RE-M, RE-0, RE-255, and use RE-M as
default. The config of these 4 modes are:
- RE-R: RandomErasing(mode='rand')
- RE-M: RandomErasing(mode='const', fill_color=(123.67, 116.3, 103.5))
- RE-0: RandomErasing(mode='const', fill_color=0)
- RE-255: RandomErasing(mode='const', fill_color=255)
"""
def __init__(self,
erase_prob=0.5,
min_area_ratio=0.02,
max_area_ratio=0.4,
aspect_range=(3 / 10, 10 / 3),
mode='const',
fill_color=(128, 128, 128),
fill_std=None):
assert isinstance(erase_prob, float) and 0. <= erase_prob <= 1.
assert isinstance(min_area_ratio, float) and 0. <= min_area_ratio <= 1.
assert isinstance(max_area_ratio, float) and 0. <= max_area_ratio <= 1.
assert min_area_ratio <= max_area_ratio, \
'min_area_ratio should be smaller than max_area_ratio'
if isinstance(aspect_range, float):
aspect_range = min(aspect_range, 1 / aspect_range)
aspect_range = (aspect_range, 1 / aspect_range)
assert isinstance(aspect_range, Sequence) and len(aspect_range) == 2 \
and all(isinstance(x, float) for x in aspect_range), \
'aspect_range should be a float or Sequence with two float.'
assert all(x > 0 for x in aspect_range), \
'aspect_range should be positive.'
assert aspect_range[0] <= aspect_range[1], \
'In aspect_range (min, max), min should be smaller than max.'
assert mode in ['const', 'rand'], \
'Please select `mode` from ["const", "rand"].'
if isinstance(fill_color, Number):
fill_color = [fill_color] * 3
assert isinstance(fill_color, Sequence) and len(fill_color) == 3 \
and all(isinstance(x, Number) for x in fill_color), \
'fill_color should be a float or Sequence with three int.'
if fill_std is not None:
if isinstance(fill_std, Number):
fill_std = [fill_std] * 3
assert isinstance(fill_std, Sequence) and len(fill_std) == 3 \
and all(isinstance(x, Number) for x in fill_std), \
'fill_std should be a float or Sequence with three int.'
self.erase_prob = erase_prob
self.min_area_ratio = min_area_ratio
self.max_area_ratio = max_area_ratio
self.aspect_range = aspect_range
self.mode = mode
self.fill_color = fill_color
self.fill_std = fill_std
def _fill_pixels(self, img, top, left, h, w):
"""Fill pixels to the patch of image."""
if self.mode == 'const':
patch = np.empty((h, w, 3), dtype=np.uint8)
patch[:, :] = np.array(self.fill_color, dtype=np.uint8)
elif self.fill_std is None:
# Uniform distribution
patch = np.random.uniform(0, 256, (h, w, 3)).astype(np.uint8)
else:
# Normal distribution
patch = np.random.normal(self.fill_color, self.fill_std, (h, w, 3))
patch = np.clip(patch.astype(np.int32), 0, 255).astype(np.uint8)
img[top:top + h, left:left + w] = patch
return img
@cache_randomness
def random_disable(self):
"""Randomly disable the transform."""
return np.random.rand() > self.erase_prob
@cache_randomness
def random_patch(self, img_h, img_w):
"""Randomly generate patch the erase."""
# convert the aspect ratio to log space to equally handle width and
# height.
log_aspect_range = np.log(
np.array(self.aspect_range, dtype=np.float32))
aspect_ratio = np.exp(np.random.uniform(*log_aspect_range))
area = img_h * img_w
area *= np.random.uniform(self.min_area_ratio, self.max_area_ratio)
h = min(int(round(np.sqrt(area * aspect_ratio))), img_h)
w = min(int(round(np.sqrt(area / aspect_ratio))), img_w)
top = np.random.randint(0, img_h - h) if img_h > h else 0
left = np.random.randint(0, img_w - w) if img_w > w else 0
return top, left, h, w
def transform(self, results):
"""
Args:
results (dict): Results dict from pipeline
Returns:
dict: Results after the transformation.
"""
if self.random_disable():
return results
img = results['img']
img_h, img_w = img.shape[:2]
img = self._fill_pixels(img, *self.random_patch(img_h, img_w))
results['img'] = img
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(erase_prob={self.erase_prob}, '
repr_str += f'min_area_ratio={self.min_area_ratio}, '
repr_str += f'max_area_ratio={self.max_area_ratio}, '
repr_str += f'aspect_range={self.aspect_range}, '
repr_str += f'mode={self.mode}, '
repr_str += f'fill_color={self.fill_color}, '
repr_str += f'fill_std={self.fill_std})'
return repr_str
@TRANSFORMS.register_module()
class EfficientNetCenterCrop(BaseTransform):
r"""EfficientNet style center crop.
**Required Keys:**
- img
**Modified Keys:**
- img
- img_shape
Args:
crop_size (int): Expected size after cropping with the format
of (h, w).
crop_padding (int): The crop padding parameter in efficientnet style
center crop. Defaults to 32.
interpolation (str): Interpolation method, accepted values are
'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Only valid if
``efficientnet_style`` is True. Defaults to 'bicubic'.
backend (str): The image resize backend type, accepted values are
`cv2` and `pillow`. Only valid if efficientnet style is True.
Defaults to `cv2`.
Notes:
- If the image is smaller than the crop size, return the original
image.
- The pipeline will be to first
to perform the center crop with the ``crop_size_`` as:
.. math::
\text{crop_size_} = \frac{\text{crop_size}}{\text{crop_size} +
\text{crop_padding}} \times \text{short_edge}
And then the pipeline resizes the img to the input crop size.
"""
def __init__(self,
crop_size: int,
crop_padding: int = 32,
interpolation: str = 'bicubic',
backend: str = 'cv2'):
assert isinstance(crop_size, int)
assert crop_size > 0
assert crop_padding >= 0
assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area',
'lanczos')
self.crop_size = crop_size
self.crop_padding = crop_padding
self.interpolation = interpolation
self.backend = backend
def transform(self, results: dict) -> dict:
"""Transform function to randomly resized crop images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: EfficientNet style center cropped results, 'img_shape'
key in result dict is updated according to crop size.
"""
img = results['img']
h, w = img.shape[:2]
# https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/preprocessing.py#L118 # noqa
img_short = min(h, w)
crop_size = self.crop_size / (self.crop_size +
self.crop_padding) * img_short
offset_h = max(0, int(round((h - crop_size) / 2.)))
offset_w = max(0, int(round((w - crop_size) / 2.)))
# crop the image
img = mmcv.imcrop(
img,
bboxes=np.array([
offset_w, offset_h, offset_w + crop_size - 1,
offset_h + crop_size - 1
]))
# resize image
img = mmcv.imresize(
img, (self.crop_size, self.crop_size),
interpolation=self.interpolation,
backend=self.backend)
results['img'] = img
results['img_shape'] = img.shape
return results
def __repr__(self):
"""Print the basic information of the transform.
Returns:
str: Formatted string.
"""
repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size}'
repr_str += f', crop_padding={self.crop_padding}'
repr_str += f', interpolation={self.interpolation}'
repr_str += f', backend={self.backend})'
return repr_str
@TRANSFORMS.register_module()
class ResizeEdge(BaseTransform):
"""Resize images along the specified edge.
**Required Keys:**
- img
**Modified Keys:**
- img
- img_shape
**Added Keys:**
- scale
- scale_factor
Args:
scale (int): The edge scale to resizing.
edge (str): The edge to resize. Defaults to 'short'.
backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
These two backends generates slightly different results.
Defaults to 'cv2'.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear" for 'pillow' backend.
Defaults to 'bilinear'.
"""
def __init__(self,
scale: int,
edge: str = 'short',
backend: str = 'cv2',
interpolation: str = 'bilinear') -> None:
allow_edges = ['short', 'long', 'width', 'height']
assert edge in allow_edges, \
f'Invalid edge "{edge}", please specify from {allow_edges}.'
self.edge = edge
self.scale = scale
self.backend = backend
self.interpolation = interpolation
def _resize_img(self, results: dict) -> None:
"""Resize images with ``results['scale']``."""
img, w_scale, h_scale = mmcv.imresize(
results['img'],
results['scale'],
interpolation=self.interpolation,
return_scale=True,
backend=self.backend)
results['img'] = img
results['img_shape'] = img.shape[:2]
results['scale'] = img.shape[:2][::-1]
results['scale_factor'] = (w_scale, h_scale)
def transform(self, results: Dict) -> Dict:
"""Transform function to resize images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img', 'scale', 'scale_factor',
'img_shape' keys are updated in result dict.
"""
assert 'img' in results, 'No `img` field in the input.'
h, w = results['img'].shape[:2]
if any([
# conditions to resize the width
self.edge == 'short' and w < h,
self.edge == 'long' and w > h,
self.edge == 'width',
]):
width = self.scale
height = int(self.scale * h / w)
else:
height = self.scale
width = int(self.scale * w / h)
results['scale'] = (width, height)
self._resize_img(results)
return results
def __repr__(self):
"""Print the basic information of the transform.
Returns:
str: Formatted string.
"""
repr_str = self.__class__.__name__
repr_str += f'(scale={self.scale}, '
repr_str += f'edge={self.edge}, '
repr_str += f'backend={self.backend}, '
repr_str += f'interpolation={self.interpolation})'
return repr_str
@TRANSFORMS.register_module()
class ColorJitter(BaseTransform):
"""Randomly change the brightness, contrast and saturation of an image.
Modified from
https://github.com/pytorch/vision/blob/main/torchvision/transforms/transforms.py
Licensed under the BSD 3-Clause License.
**Required Keys:**
- img
**Modified Keys:**
- img
Args:
brightness (float | Sequence[float] (min, max)): How much to jitter
brightness. brightness_factor is chosen uniformly from
``[max(0, 1 - brightness), 1 + brightness]`` or the given
``[min, max]``. Should be non negative numbers. Defaults to 0.
contrast (float | Sequence[float] (min, max)): How much to jitter
contrast. contrast_factor is chosen uniformly from
``[max(0, 1 - contrast), 1 + contrast]`` or the given
``[min, max]``. Should be non negative numbers. Defaults to 0.
saturation (float | Sequence[float] (min, max)): How much to jitter
saturation. saturation_factor is chosen uniformly from
``[max(0, 1 - saturation), 1 + saturation]`` or the given
``[min, max]``. Should be non negative numbers. Defaults to 0.
hue (float | Sequence[float] (min, max)): How much to jitter hue.
hue_factor is chosen uniformly from ``[-hue, hue]`` (0 <= hue
<= 0.5) or the given ``[min, max]`` (-0.5 <= min <= max <= 0.5).
Defaults to 0.
backend (str): The backend to operate the image. Defaults to 'pillow'
"""
def __init__(self,
brightness: Union[float, Sequence[float]] = 0.,
contrast: Union[float, Sequence[float]] = 0.,
saturation: Union[float, Sequence[float]] = 0.,
hue: Union[float, Sequence[float]] = 0.,
backend='pillow'):
self.brightness = self._set_range(brightness, 'brightness')
self.contrast = self._set_range(contrast, 'contrast')
self.saturation = self._set_range(saturation, 'saturation')
self.hue = self._set_range(hue, 'hue', center=0, bound=(-0.5, 0.5))
self.backend = backend
def _set_range(self, value, name, center=1, bound=(0, float('inf'))):
"""Set the range of magnitudes."""
if isinstance(value, numbers.Number):
if value < 0:
raise ValueError(
f'If {name} is a single number, it must be non negative.')
value = (center - float(value), center + float(value))
if isinstance(value, (tuple, list)) and len(value) == 2:
if not bound[0] <= value[0] <= value[1] <= bound[1]:
value = np.clip(value, bound[0], bound[1])
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.warning(f'ColorJitter {name} values exceed the bound '
f'{bound}, clipped to the bound.')
else:
raise TypeError(f'{name} should be a single number '
'or a list/tuple with length 2.')
# if value is 0 or (1., 1.) for brightness/contrast/saturation
# or (0., 0.) for hue, do nothing
if value[0] == value[1] == center:
value = None
else:
value = tuple(value)
return value
@cache_randomness
def _rand_params(self):
"""Get random parameters including magnitudes and indices of
transforms."""
trans_inds = np.random.permutation(4)
b, c, s, h = (None, ) * 4
if self.brightness is not None:
b = np.random.uniform(self.brightness[0], self.brightness[1])
if self.contrast is not None:
c = np.random.uniform(self.contrast[0], self.contrast[1])
if self.saturation is not None:
s = np.random.uniform(self.saturation[0], self.saturation[1])
if self.hue is not None:
h = np.random.uniform(self.hue[0], self.hue[1])
return trans_inds, b, c, s, h
def transform(self, results: Dict) -> Dict:
"""Transform function to resize images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: ColorJitter results, 'img' key is updated in result dict.
"""
img = results['img']
trans_inds, brightness, contrast, saturation, hue = self._rand_params()
for index in trans_inds:
if index == 0 and brightness is not None:
img = mmcv.adjust_brightness(
img, brightness, backend=self.backend)
elif index == 1 and contrast is not None:
img = mmcv.adjust_contrast(img, contrast, backend=self.backend)
elif index == 2 and saturation is not None:
img = mmcv.adjust_color(
img, alpha=saturation, backend=self.backend)
elif index == 3 and hue is not None:
img = mmcv.adjust_hue(img, hue, backend=self.backend)
results['img'] = img
return results
def __repr__(self):
"""Print the basic information of the transform.
Returns:
str: Formatted string.
"""
repr_str = self.__class__.__name__
repr_str += f'(brightness={self.brightness}, '
repr_str += f'contrast={self.contrast}, '
repr_str += f'saturation={self.saturation}, '
repr_str += f'hue={self.hue})'
return repr_str
@TRANSFORMS.register_module()
class Lighting(BaseTransform):
"""Adjust images lighting using AlexNet-style PCA jitter.
**Required Keys:**
- img
**Modified Keys:**
- img
Args:
eigval (Sequence[float]): the eigenvalue of the convariance matrix
of pixel values, respectively.
eigvec (list[list]): the eigenvector of the convariance matrix of
pixel values, respectively.
alphastd (float): The standard deviation for distribution of alpha.
Defaults to 0.1.
to_rgb (bool): Whether to convert img to rgb. Defaults to False.
"""
def __init__(self,
eigval: Sequence[float],
eigvec: Sequence[float],
alphastd: float = 0.1,
to_rgb: bool = False):
assert isinstance(eigval, Sequence), \
f'eigval must be Sequence, got {type(eigval)} instead.'
assert isinstance(eigvec, Sequence), \
f'eigvec must be Sequence, got {type(eigvec)} instead.'
for vec in eigvec:
assert isinstance(vec, Sequence) and len(vec) == len(eigvec[0]), \
'eigvec must contains lists with equal length.'
assert isinstance(alphastd, float), 'alphastd should be of type ' \
f'float or int, got {type(alphastd)} instead.'
self.eigval = np.array(eigval)
self.eigvec = np.array(eigvec)
self.alphastd = alphastd
self.to_rgb = to_rgb
def transform(self, results: Dict) -> Dict:
"""Transform function to resize images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Lightinged results, 'img' key is updated in result dict.
"""
assert 'img' in results, 'No `img` field in the input.'
img = results['img']
img_lighting = mmcv.adjust_lighting(
img,
self.eigval,
self.eigvec,
alphastd=self.alphastd,
to_rgb=self.to_rgb)
results['img'] = img_lighting.astype(img.dtype)
return results
def __repr__(self):
"""Print the basic information of the transform.
Returns:
str: Formatted string.
"""
repr_str = self.__class__.__name__
repr_str += f'(eigval={self.eigval.tolist()}, '
repr_str += f'eigvec={self.eigvec.tolist()}, '
repr_str += f'alphastd={self.alphastd}, '
repr_str += f'to_rgb={self.to_rgb})'
return repr_str
# 'Albu' is used in previous versions of mmpretrain, here is for compatibility
# users can use both 'Albumentations' and 'Albu'.
@TRANSFORMS.register_module(['Albumentations', 'Albu'])
class Albumentations(BaseTransform):
"""Wrapper to use augmentation from albumentations library.
**Required Keys:**
- img
**Modified Keys:**
- img
- img_shape
Adds custom transformations from albumentations library.
More details can be found in
`Albumentations <https://albumentations.readthedocs.io>`_.
An example of ``transforms`` is as followed:
.. code-block::
[
dict(
type='ShiftScaleRotate',
shift_limit=0.0625,
scale_limit=0.0,
rotate_limit=0,
interpolation=1,
p=0.5),
dict(
type='RandomBrightnessContrast',
brightness_limit=[0.1, 0.3],
contrast_limit=[0.1, 0.3],
p=0.2),
dict(type='ChannelShuffle', p=0.1),
dict(
type='OneOf',
transforms=[
dict(type='Blur', blur_limit=3, p=1.0),
dict(type='MedianBlur', blur_limit=3, p=1.0)
],
p=0.1),
]
Args:
transforms (List[Dict]): List of albumentations transform configs.
keymap (Optional[Dict]): Mapping of mmpretrain to albumentations
fields, in format {'input key':'albumentation-style key'}.
Defaults to None.
Example:
>>> import mmcv
>>> from mmpretrain.datasets import Albumentations
>>> transforms = [
... dict(
... type='ShiftScaleRotate',
... shift_limit=0.0625,
... scale_limit=0.0,
... rotate_limit=0,
... interpolation=1,
... p=0.5),
... dict(
... type='RandomBrightnessContrast',
... brightness_limit=[0.1, 0.3],
... contrast_limit=[0.1, 0.3],
... p=0.2),
... dict(type='ChannelShuffle', p=0.1),
... dict(
... type='OneOf',
... transforms=[
... dict(type='Blur', blur_limit=3, p=1.0),
... dict(type='MedianBlur', blur_limit=3, p=1.0)
... ],
... p=0.1),
... ]
>>> albu = Albumentations(transforms)
>>> data = {'img': mmcv.imread('./demo/demo.JPEG')}
>>> data = albu(data)
>>> print(data['img'].shape)
(375, 500, 3)
"""
def __init__(self, transforms: List[Dict], keymap: Optional[Dict] = None):
if albumentations is None:
raise RuntimeError('albumentations is not installed')
else:
from albumentations import Compose as albu_Compose
assert isinstance(transforms, list), 'transforms must be a list.'
if keymap is not None:
assert isinstance(keymap, dict), 'keymap must be None or a dict. '
self.transforms = transforms
self.aug = albu_Compose(
[self.albu_builder(t) for t in self.transforms])
if not keymap:
self.keymap_to_albu = dict(img='image')
else:
self.keymap_to_albu = keymap
self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}
def albu_builder(self, cfg: Dict):
"""Import a module from albumentations.
It inherits some of :func:`build_from_cfg` logic.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
Returns:
obj: The constructed object.
"""
assert isinstance(cfg, dict) and 'type' in cfg, 'each item in ' \
"transforms must be a dict with keyword 'type'."
args = cfg.copy()
obj_type = args.pop('type')
if mmengine.is_str(obj_type):
obj_cls = getattr(albumentations, obj_type)
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
if 'transforms' in args:
args['transforms'] = [
self.albu_builder(transform)
for transform in args['transforms']
]
return obj_cls(**args)
@staticmethod
def mapper(d, keymap):
"""Dictionary mapper.
Renames keys according to keymap provided.
Args:
d (dict): old dict
keymap (dict): {'old_key':'new_key'}
Returns:
dict: new dict.
"""
updated_dict = {}
for k, v in zip(d.keys(), d.values()):
new_k = keymap.get(k, k)
updated_dict[new_k] = d[k]
return updated_dict
def transform(self, results: Dict) -> Dict:
"""Transform function to perform albumentations transforms.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Transformed results, 'img' and 'img_shape' keys are
updated in result dict.
"""
assert 'img' in results, 'No `img` field in the input.'
# dict to albumentations format
results = self.mapper(results, self.keymap_to_albu)
results = self.aug(**results)
# back to the original format
results = self.mapper(results, self.keymap_back)
results['img_shape'] = results['img'].shape[:2]
return results
def __repr__(self):
"""Print the basic information of the transform.
Returns:
str: Formatted string.
"""
repr_str = self.__class__.__name__
repr_str += f'(transforms={repr(self.transforms)})'
return repr_str
@TRANSFORMS.register_module()
class SimMIMMaskGenerator(BaseTransform):
"""Generate random block mask for each Image.
**Added Keys**:
- mask
This module is used in SimMIM to generate masks.
Args:
input_size (int): Size of input image. Defaults to 192.
mask_patch_size (int): Size of each block mask. Defaults to 32.
model_patch_size (int): Patch size of each token. Defaults to 4.
mask_ratio (float): The mask ratio of image. Defaults to 0.6.
"""
def __init__(self,
input_size: int = 192,
mask_patch_size: int = 32,
model_patch_size: int = 4,
mask_ratio: float = 0.6):
self.input_size = input_size
self.mask_patch_size = mask_patch_size
self.model_patch_size = model_patch_size
self.mask_ratio = mask_ratio
assert self.input_size % self.mask_patch_size == 0
assert self.mask_patch_size % self.model_patch_size == 0
self.rand_size = self.input_size // self.mask_patch_size
self.scale = self.mask_patch_size // self.model_patch_size
self.token_count = self.rand_size**2
self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))
def transform(self, results: dict) -> dict:
"""Method to generate random block mask for each Image in SimMIM.
Args:
results (dict): Result dict from previous pipeline.
Returns:
dict: Result dict with added key ``mask``.
"""
mask_idx = np.random.permutation(self.token_count)[:self.mask_count]
mask = np.zeros(self.token_count, dtype=int)
mask[mask_idx] = 1
mask = mask.reshape((self.rand_size, self.rand_size))
mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)
results.update({'mask': mask})
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(input_size={self.input_size}, '
repr_str += f'mask_patch_size={self.mask_patch_size}, '
repr_str += f'model_patch_size={self.model_patch_size}, '
repr_str += f'mask_ratio={self.mask_ratio})'
return repr_str
@TRANSFORMS.register_module()
class BEiTMaskGenerator(BaseTransform):
"""Generate mask for image.
**Added Keys**:
- mask
This module is borrowed from
https://github.com/microsoft/unilm/tree/master/beit
Args:
input_size (int): The size of input image.
num_masking_patches (int): The number of patches to be masked.
min_num_patches (int): The minimum number of patches to be masked
in the process of generating mask. Defaults to 4.
max_num_patches (int, optional): The maximum number of patches to be
masked in the process of generating mask. Defaults to None.
min_aspect (float): The minimum aspect ratio of mask blocks. Defaults
to 0.3.
min_aspect (float, optional): The minimum aspect ratio of mask blocks.
Defaults to None.
"""
def __init__(self,
input_size: int,
num_masking_patches: int,
min_num_patches: int = 4,
max_num_patches: Optional[int] = None,
min_aspect: float = 0.3,
max_aspect: Optional[float] = None) -> None:
if not isinstance(input_size, tuple):
input_size = (input_size, ) * 2
self.height, self.width = input_size
self.num_patches = self.height * self.width
self.num_masking_patches = num_masking_patches
self.min_num_patches = min_num_patches
self.max_num_patches = num_masking_patches if max_num_patches is None \
else max_num_patches
max_aspect = max_aspect or 1 / min_aspect
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
def _mask(self, mask: np.ndarray, max_mask_patches: int) -> int:
"""Generate mask recursively.
Args:
mask (np.ndarray): The mask to be generated.
max_mask_patches (int): The maximum number of patches to be masked.
Returns:
int: The number of patches masked.
"""
delta = 0
for _ in range(10):
target_area = np.random.uniform(self.min_num_patches,
max_mask_patches)
aspect_ratio = math.exp(np.random.uniform(*self.log_aspect_ratio))
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < self.width and h < self.height:
top = np.random.randint(0, self.height - h)
left = np.random.randint(0, self.width - w)
num_masked = mask[top:top + h, left:left + w].sum()
# Overlap
if 0 < h * w - num_masked <= max_mask_patches:
for i in range(top, top + h):
for j in range(left, left + w):
if mask[i, j] == 0:
mask[i, j] = 1
delta += 1
if delta > 0:
break
return delta
def transform(self, results: dict) -> dict:
"""Method to generate random block mask for each Image in BEiT.
Args:
results (dict): Result dict from previous pipeline.
Returns:
dict: Result dict with added key ``mask``.
"""
mask = np.zeros(shape=(self.height, self.width), dtype=int)
mask_count = 0
while mask_count != self.num_masking_patches:
max_mask_patches = self.num_masking_patches - mask_count
max_mask_patches = min(max_mask_patches, self.max_num_patches)
delta = self._mask(mask, max_mask_patches)
mask_count += delta
results.update({'mask': mask})
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(height={self.height}, '
repr_str += f'width={self.width}, '
repr_str += f'num_patches={self.num_patches}, '
repr_str += f'num_masking_patches={self.num_masking_patches}, '
repr_str += f'min_num_patches={self.min_num_patches}, '
repr_str += f'max_num_patches={self.max_num_patches}, '
repr_str += f'log_aspect_ratio={self.log_aspect_ratio})'
return repr_str
@TRANSFORMS.register_module()
class RandomResizedCropAndInterpolationWithTwoPic(BaseTransform):
"""Crop the given PIL Image to random size and aspect ratio with random
interpolation.
**Required Keys**:
- img
**Modified Keys**:
- img
**Added Keys**:
- target_img
This module is borrowed from
https://github.com/microsoft/unilm/tree/master/beit.
A crop of random size (default: of 0.08 to 1.0) of the original size and a
random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio
is made. This crop is finally resized to given size. This is popularly used
to train the Inception networks. This module first crops the image and
resizes the crop to two different sizes.
Args:
size (Union[tuple, int]): Expected output size of each edge of the
first image.
second_size (Union[tuple, int], optional): Expected output size of each
edge of the second image.
scale (tuple[float, float]): Range of size of the origin size cropped.
Defaults to (0.08, 1.0).
ratio (tuple[float, float]): Range of aspect ratio of the origin aspect
ratio cropped. Defaults to (3./4., 4./3.).
interpolation (str): The interpolation for the first image. Defaults
to ``bilinear``.
second_interpolation (str): The interpolation for the second image.
Defaults to ``lanczos``.
"""
def __init__(self,
size: Union[tuple, int],
second_size=None,
scale=(0.08, 1.0),
ratio=(3. / 4., 4. / 3.),
interpolation='bilinear',
second_interpolation='lanczos') -> None:
if isinstance(size, tuple):
self.size = size
else:
self.size = (size, size)
if second_size is not None:
if isinstance(second_size, tuple):
self.second_size = second_size
else:
self.second_size = (second_size, second_size)
else:
self.second_size = None
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
('range should be of kind (min, max)')
if interpolation == 'random':
self.interpolation = ('bilinear', 'bicubic')
else:
self.interpolation = interpolation
self.second_interpolation = second_interpolation
self.scale = scale
self.ratio = ratio
@staticmethod
def get_params(img: np.ndarray, scale: tuple,
ratio: tuple) -> Sequence[int]:
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (np.ndarray): Image to be cropped.
scale (tuple): range of size of the origin size cropped
ratio (tuple): range of aspect ratio of the origin aspect
ratio cropped
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
img_h, img_w = img.shape[:2]
area = img_h * img_w
for _ in range(10):
target_area = np.random.uniform(*scale) * area
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
aspect_ratio = math.exp(np.random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if w < img_w and h < img_h:
i = np.random.randint(0, img_h - h)
j = np.random.randint(0, img_w - w)
return i, j, h, w
# Fallback to central crop
in_ratio = img_w / img_h
if in_ratio < min(ratio):
w = img_w
h = int(round(w / min(ratio)))
elif in_ratio > max(ratio):
h = img_h
w = int(round(h * max(ratio)))
else: # whole image
w = img_w
h = img_h
i = (img_h - h) // 2
j = (img_w - w) // 2
return i, j, h, w
def transform(self, results: dict) -> dict:
"""Crop the given image and resize it to two different sizes.
This module crops the given image randomly and resize the crop to two
different sizes. This is popularly used in BEiT-style masked image
modeling, where an off-the-shelf model is used to provide the target.
Args:
results (dict): Results from previous pipeline.
Returns:
dict: Results after applying this transformation.
"""
img = results['img']
i, j, h, w = self.get_params(img, self.scale, self.ratio)
if isinstance(self.interpolation, (tuple, list)):
interpolation = np.random.choice(self.interpolation)
else:
interpolation = self.interpolation
if self.second_size is None:
img = img[i:i + h, j:j + w]
img = mmcv.imresize(img, self.size, interpolation=interpolation)
results.update({'img': img})
else:
img = img[i:i + h, j:j + w]
img_sample = mmcv.imresize(
img, self.size, interpolation=interpolation)
img_target = mmcv.imresize(
img, self.second_size, interpolation=self.second_interpolation)
results.update({'img': [img_sample, img_target]})
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(size={self.size}, '
repr_str += f'second_size={self.second_size}, '
repr_str += f'interpolation={self.interpolation}, '
repr_str += f'second_interpolation={self.second_interpolation}, '
repr_str += f'scale={self.scale}, '
repr_str += f'ratio={self.ratio})'
return repr_str
@TRANSFORMS.register_module()
class CleanCaption(BaseTransform):
"""Clean caption text.
Remove some useless punctuation for the caption task.
**Required Keys:**
- ``*keys``
**Modified Keys:**
- ``*keys``
Args:
keys (Sequence[str], optional): The keys of text to be cleaned.
Defaults to 'gt_caption'.
remove_chars (str): The characters to be removed. Defaults to
:py:attr:`string.punctuation`.
lowercase (bool): Whether to convert the text to lowercase.
Defaults to True.
remove_dup_space (bool): Whether to remove duplicated whitespaces.
Defaults to True.
strip (bool): Whether to remove leading and trailing whitespaces.
Defaults to True.
"""
def __init__(
self,
keys='gt_caption',
remove_chars=string.punctuation,
lowercase=True,
remove_dup_space=True,
strip=True,
):
if isinstance(keys, str):
keys = [keys]
self.keys = keys
self.transtab = str.maketrans({ch: None for ch in remove_chars})
self.lowercase = lowercase
self.remove_dup_space = remove_dup_space
self.strip = strip
def _clean(self, text):
"""Perform text cleaning before tokenizer."""
if self.strip:
text = text.strip()
text = text.translate(self.transtab)
if self.remove_dup_space:
text = re.sub(r'\s{2,}', ' ', text)
if self.lowercase:
text = text.lower()
return text
def clean(self, text):
"""Perform text cleaning before tokenizer."""
if isinstance(text, (list, tuple)):
return [self._clean(item) for item in text]
elif isinstance(text, str):
return self._clean(text)
else:
raise TypeError('text must be a string or a list of strings')
def transform(self, results: dict) -> dict:
"""Method to clean the input text data."""
for key in self.keys:
results[key] = self.clean(results[key])
return results
@TRANSFORMS.register_module()
class OFAAddObjects(BaseTransform):
def transform(self, results: dict) -> dict:
if 'objects' not in results:
raise ValueError(
'Some OFA fine-tuned models requires `objects` field in the '
'dataset, which is generated by VinVL. Or please use '
'zero-shot configs. See '
'https://github.com/OFA-Sys/OFA/issues/189')
if 'question' in results:
prompt = '{} object: {}'.format(
results['question'],
' '.join(results['objects']),
)
results['decoder_prompt'] = prompt
results['question'] = prompt
@TRANSFORMS.register_module()
class RandomTranslatePad(BaseTransform):
def __init__(self, size=640, aug_translate=False):
self.size = size
self.aug_translate = aug_translate
@cache_randomness
def rand_translate_params(self, dh, dw):
top = np.random.randint(0, dh)
left = np.random.randint(0, dw)
return top, left
def transform(self, results: dict) -> dict:
img = results['img']
h, w = img.shape[:-1]
dw = self.size - w
dh = self.size - h
if self.aug_translate:
top, left = self.rand_translate_params(dh, dw)
else:
top = round(dh / 2.0 - 0.1)
left = round(dw / 2.0 - 0.1)
out_img = np.zeros((self.size, self.size, 3), dtype=np.float32)
out_img[top:top + h, left:left + w, :] = img
results['img'] = out_img
results['img_shape'] = (self.size, self.size)
# translate box
if 'gt_bboxes' in results.keys():
for i in range(len(results['gt_bboxes'])):
box = results['gt_bboxes'][i]
box[0], box[2] = box[0] + left, box[2] + left
box[1], box[3] = box[1] + top, box[3] + top
results['gt_bboxes'][i] = box
return results
@TRANSFORMS.register_module()
class MAERandomResizedCrop(transforms.RandomResizedCrop):
"""RandomResizedCrop for matching TF/TPU implementation: no for-loop is
used.
This may lead to results different with torchvision's version.
Following BYOL's TF code:
https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 # noqa: E501
"""
@staticmethod
def get_params(img: Image.Image, scale: tuple, ratio: tuple) -> Tuple:
width, height = img.size
area = height * width
target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
log_ratio = torch.log(torch.tensor(ratio))
aspect_ratio = torch.exp(
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
w = min(w, width)
h = min(h, height)
i = torch.randint(0, height - h + 1, size=(1, )).item()
j = torch.randint(0, width - w + 1, size=(1, )).item()
return i, j, h, w
def forward(self, results: dict) -> dict:
"""The forward function of MAERandomResizedCrop.
Args:
results (dict): The results dict contains the image and all these
information related to the image.
Returns:
dict: The results dict contains the cropped image and all these
information related to the image.
"""
img = results['img']
i, j, h, w = self.get_params(img, self.scale, self.ratio)
img = F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
results['img'] = img
return results
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import List, Union
from mmcv.transforms import BaseTransform
PIPELINE_TYPE = List[Union[dict, BaseTransform]]
def get_transform_idx(pipeline: PIPELINE_TYPE, target: str) -> int:
"""Returns the index of the transform in a pipeline.
Args:
pipeline (List[dict] | List[BaseTransform]): The transforms list.
target (str): The target transform class name.
Returns:
int: The transform index. Returns -1 if not found.
"""
for i, transform in enumerate(pipeline):
if isinstance(transform, dict):
if isinstance(transform['type'], type):
if transform['type'].__name__ == target:
return i
else:
if transform['type'] == target:
return i
else:
if transform.__class__.__name__ == target:
return i
return -1
def remove_transform(pipeline: PIPELINE_TYPE, target: str, inplace=False):
"""Remove the target transform type from the pipeline.
Args:
pipeline (List[dict] | List[BaseTransform]): The transforms list.
target (str): The target transform class name.
inplace (bool): Whether to modify the pipeline inplace.
Returns:
The modified transform.
"""
idx = get_transform_idx(pipeline, target)
if not inplace:
pipeline = copy.deepcopy(pipeline)
while idx >= 0:
pipeline.pop(idx)
idx = get_transform_idx(pipeline, target)
return pipeline
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Callable, List, Union
from mmcv.transforms import BaseTransform, Compose
from mmpretrain.registry import TRANSFORMS
# Define type of transform or transform config
Transform = Union[dict, Callable[[dict], dict]]
@TRANSFORMS.register_module()
class MultiView(BaseTransform):
"""A transform wrapper for multiple views of an image.
Args:
transforms (list[dict | callable], optional): Sequence of transform
object or config dict to be wrapped.
mapping (dict): A dict that defines the input key mapping.
The keys corresponds to the inner key (i.e., kwargs of the
``transform`` method), and should be string type. The values
corresponds to the outer keys (i.e., the keys of the
data/results), and should have a type of string, list or dict.
None means not applying input mapping. Default: None.
allow_nonexist_keys (bool): If False, the outer keys in the mapping
must exist in the input data, or an exception will be raised.
Default: False.
Examples:
>>> # Example 1: MultiViews 1 pipeline with 2 views
>>> pipeline = [
>>> dict(type='MultiView',
>>> num_views=2,
>>> transforms=[
>>> [
>>> dict(type='Resize', scale=224))],
>>> ])
>>> ]
>>> # Example 2: MultiViews 2 pipelines, the first with 2 views,
>>> # the second with 6 views
>>> pipeline = [
>>> dict(type='MultiView',
>>> num_views=[2, 6],
>>> transforms=[
>>> [
>>> dict(type='Resize', scale=224)],
>>> [
>>> dict(type='Resize', scale=224),
>>> dict(type='RandomSolarize')],
>>> ])
>>> ]
"""
def __init__(self, transforms: List[List[Transform]],
num_views: Union[int, List[int]]) -> None:
if isinstance(num_views, int):
num_views = [num_views]
assert isinstance(num_views, List)
assert len(num_views) == len(transforms)
self.num_views = num_views
self.pipelines = []
for trans in transforms:
pipeline = Compose(trans)
self.pipelines.append(pipeline)
self.transforms = []
for i in range(len(num_views)):
self.transforms.extend([self.pipelines[i]] * num_views[i])
def transform(self, results: dict) -> dict:
"""Apply transformation to inputs.
Args:
results (dict): Result dict from previous pipelines.
Returns:
dict: Transformed results.
"""
multi_views_outputs = dict(img=[])
for trans in self.transforms:
inputs = copy.deepcopy(results)
outputs = trans(inputs)
multi_views_outputs['img'].append(outputs['img'])
results.update(multi_views_outputs)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__ + '('
for i, p in enumerate(self.pipelines):
repr_str += f'\nPipeline {i + 1} with {self.num_views[i]} views:\n'
repr_str += str(p)
repr_str += ')'
return repr_str
@TRANSFORMS.register_module()
class ApplyToList(BaseTransform):
"""A transform wrapper to apply the wrapped transforms to a list of items.
For example, to load and resize a list of images.
Args:
transforms (list[dict | callable]): Sequence of transform config dict
to be wrapped.
scatter_key (str): The key to scatter data dict. If the field is a
list, scatter the list to multiple data dicts to do transformation.
collate_keys (List[str]): The keys to collate from multiple data dicts.
The fields in ``collate_keys`` will be composed into a list after
transformation, and the other fields will be adopted from the
first data dict.
"""
def __init__(self, transforms, scatter_key, collate_keys):
super().__init__()
self.transforms = Compose([TRANSFORMS.build(t) for t in transforms])
self.scatter_key = scatter_key
self.collate_keys = set(collate_keys)
self.collate_keys.add(self.scatter_key)
def transform(self, results: dict):
scatter_field = results.get(self.scatter_key)
if isinstance(scatter_field, list):
scattered_results = []
for item in scatter_field:
single_results = copy.deepcopy(results)
single_results[self.scatter_key] = item
scattered_results.append(self.transforms(single_results))
final_output = scattered_results[0]
# merge output list to single output
for key in scattered_results[0].keys():
if key in self.collate_keys:
final_output[key] = [
single[key] for single in scattered_results
]
return final_output
else:
return self.transforms(results)
# Copyright (c) OpenMMLab. All rights reserved.
import gzip
import hashlib
import os
import os.path
import shutil
import tarfile
import tempfile
import urllib.error
import urllib.request
import zipfile
from mmengine.fileio import LocalBackend, get_file_backend
__all__ = [
'rm_suffix', 'check_integrity', 'download_and_extract_archive',
'open_maybe_compressed_file'
]
def rm_suffix(s, suffix=None):
if suffix is None:
return s[:s.rfind('.')]
else:
return s[:s.rfind(suffix)]
def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024):
md5 = hashlib.md5()
backend = get_file_backend(fpath, enable_singleton=True)
if isinstance(backend, LocalBackend):
# Enable chunk update for local file.
with open(fpath, 'rb') as f:
for chunk in iter(lambda: f.read(chunk_size), b''):
md5.update(chunk)
else:
md5.update(backend.get(fpath))
return md5.hexdigest()
def check_md5(fpath, md5, **kwargs):
return md5 == calculate_md5(fpath, **kwargs)
def check_integrity(fpath, md5=None):
if not os.path.isfile(fpath):
return False
if md5 is None:
return True
return check_md5(fpath, md5)
def download_url_to_file(url, dst, hash_prefix=None, progress=True):
"""Download object at the given URL to a local path.
Modified from
https://pytorch.org/docs/stable/hub.html#torch.hub.download_url_to_file
Args:
url (str): URL of the object to download
dst (str): Full path where object will be saved,
e.g. ``/tmp/temporary_file``
hash_prefix (string, optional): If not None, the SHA256 downloaded
file should start with ``hash_prefix``. Defaults to None.
progress (bool): whether or not to display a progress bar to stderr.
Defaults to True
"""
file_size = None
req = urllib.request.Request(url)
u = urllib.request.urlopen(req)
meta = u.info()
if hasattr(meta, 'getheaders'):
content_length = meta.getheaders('Content-Length')
else:
content_length = meta.get_all('Content-Length')
if content_length is not None and len(content_length) > 0:
file_size = int(content_length[0])
# We deliberately save it in a temp file and move it after download is
# complete. This prevents a local file being overridden by a broken
# download.
dst = os.path.expanduser(dst)
dst_dir = os.path.dirname(dst)
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
import rich.progress
columns = [
rich.progress.DownloadColumn(),
rich.progress.BarColumn(bar_width=None),
rich.progress.TimeRemainingColumn(),
]
try:
if hash_prefix is not None:
sha256 = hashlib.sha256()
with rich.progress.Progress(*columns) as pbar:
task = pbar.add_task('download', total=file_size, visible=progress)
while True:
buffer = u.read(8192)
if len(buffer) == 0:
break
f.write(buffer)
if hash_prefix is not None:
sha256.update(buffer)
pbar.update(task, advance=len(buffer))
f.close()
if hash_prefix is not None:
digest = sha256.hexdigest()
if digest[:len(hash_prefix)] != hash_prefix:
raise RuntimeError(
'invalid hash value (expected "{}", got "{}")'.format(
hash_prefix, digest))
shutil.move(f.name, dst)
finally:
f.close()
if os.path.exists(f.name):
os.remove(f.name)
def download_url(url, root, filename=None, md5=None):
"""Download a file from a url and place it in root.
Args:
url (str): URL to download file from.
root (str): Directory to place downloaded file in.
filename (str | None): Name to save the file under.
If filename is None, use the basename of the URL.
md5 (str | None): MD5 checksum of the download.
If md5 is None, download without md5 check.
"""
root = os.path.expanduser(root)
if not filename:
filename = os.path.basename(url)
fpath = os.path.join(root, filename)
os.makedirs(root, exist_ok=True)
if check_integrity(fpath, md5):
print(f'Using downloaded and verified file: {fpath}')
else:
try:
print(f'Downloading {url} to {fpath}')
download_url_to_file(url, fpath)
except (urllib.error.URLError, IOError) as e:
if url[:5] == 'https':
url = url.replace('https:', 'http:')
print('Failed download. Trying https -> http instead.'
f' Downloading {url} to {fpath}')
download_url_to_file(url, fpath)
else:
raise e
# check integrity of downloaded file
if not check_integrity(fpath, md5):
raise RuntimeError('File not found or corrupted.')
def _is_tarxz(filename):
return filename.endswith('.tar.xz')
def _is_tar(filename):
return filename.endswith('.tar')
def _is_targz(filename):
return filename.endswith('.tar.gz')
def _is_tgz(filename):
return filename.endswith('.tgz')
def _is_gzip(filename):
return filename.endswith('.gz') and not filename.endswith('.tar.gz')
def _is_zip(filename):
return filename.endswith('.zip')
def extract_archive(from_path, to_path=None, remove_finished=False):
if to_path is None:
to_path = os.path.dirname(from_path)
if _is_tar(from_path):
with tarfile.open(from_path, 'r') as tar:
tar.extractall(path=to_path)
elif _is_targz(from_path) or _is_tgz(from_path):
with tarfile.open(from_path, 'r:gz') as tar:
tar.extractall(path=to_path)
elif _is_tarxz(from_path):
with tarfile.open(from_path, 'r:xz') as tar:
tar.extractall(path=to_path)
elif _is_gzip(from_path):
to_path = os.path.join(
to_path,
os.path.splitext(os.path.basename(from_path))[0])
with open(to_path, 'wb') as out_f, gzip.GzipFile(from_path) as zip_f:
out_f.write(zip_f.read())
elif _is_zip(from_path):
with zipfile.ZipFile(from_path, 'r') as z:
z.extractall(to_path)
else:
raise ValueError(f'Extraction of {from_path} not supported')
if remove_finished:
os.remove(from_path)
def download_and_extract_archive(url,
download_root,
extract_root=None,
filename=None,
md5=None,
remove_finished=False):
download_root = os.path.expanduser(download_root)
if extract_root is None:
extract_root = download_root
if not filename:
filename = os.path.basename(url)
download_url(url, download_root, filename, md5)
archive = os.path.join(download_root, filename)
print(f'Extracting {archive} to {extract_root}')
extract_archive(archive, extract_root, remove_finished)
def open_maybe_compressed_file(path: str):
"""Return a file object that possibly decompresses 'path' on the fly.
Decompression occurs when argument `path` is a string and ends with '.gz'
or '.xz'.
"""
if not isinstance(path, str):
return path
if path.endswith('.gz'):
import gzip
return gzip.open(path, 'rb')
if path.endswith('.xz'):
import lzma
return lzma.open(path, 'rb')
return open(path, 'rb')
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
from mmengine.fileio import load
from mmpretrain.registry import DATASETS
from .base_dataset import BaseDataset
@DATASETS.register_module()
class VGVQA(BaseDataset):
"""Visual Genome VQA dataset."""
def load_data_list(self) -> List[dict]:
"""Load data list.
Compare to BaseDataset, the only difference is that coco_vqa annotation
file is already a list of data. There is no 'metainfo'.
"""
raw_data_list = load(self.ann_file)
if not isinstance(raw_data_list, list):
raise TypeError(
f'The VQA annotations loaded from annotation file '
f'should be a dict, but got {type(raw_data_list)}!')
# load and parse data_infos.
data_list = []
for raw_data_info in raw_data_list:
# parse raw data information to target format
data_info = self.parse_data_info(raw_data_info)
if isinstance(data_info, dict):
# For VQA tasks, each `data_info` looks like:
# {
# "question_id": 986769,
# "question": "How many people are there?",
# "answer": "two",
# "image": "image/1.jpg",
# "dataset": "vg"
# }
# change 'image' key to 'img_path'
# TODO: This process will be removed, after the annotation file
# is preprocess.
data_info['img_path'] = data_info['image']
del data_info['image']
if 'answer' in data_info:
# add answer_weight & answer_count, delete duplicate answer
if data_info['dataset'] == 'vqa':
answer_weight = {}
for answer in data_info['answer']:
if answer in answer_weight.keys():
answer_weight[answer] += 1 / len(
data_info['answer'])
else:
answer_weight[answer] = 1 / len(
data_info['answer'])
data_info['answer'] = list(answer_weight.keys())
data_info['answer_weight'] = list(
answer_weight.values())
data_info['answer_count'] = len(answer_weight)
elif data_info['dataset'] == 'vg':
data_info['answers'] = [data_info['answer']]
data_info['answer_weight'] = [0.2]
data_info['answer_count'] = 1
data_list.append(data_info)
else:
raise TypeError(
f'Each VQA data element loaded from annotation file '
f'should be a dict, but got {type(data_info)}!')
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
import re
from itertools import chain
from typing import List
import mmengine
from mmengine.dataset import BaseDataset
from mmpretrain.registry import DATASETS
@DATASETS.register_module()
class VisualGenomeQA(BaseDataset):
"""Visual Genome Question Answering dataset.
dataset structure: ::
data_root
├── image
│   ├── 1.jpg
│   ├── 2.jpg
│   └── ...
└── question_answers.json
Args:
data_root (str): The root directory for ``data_prefix``, ``ann_file``
and ``question_file``.
data_prefix (str): The directory of images. Defaults to ``"image"``.
ann_file (str, optional): Annotation file path for training and
validation. Defaults to ``"question_answers.json"``.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def __init__(self,
data_root: str,
data_prefix: str = 'image',
ann_file: str = 'question_answers.json',
**kwarg):
super().__init__(
data_root=data_root,
data_prefix=dict(img_path=data_prefix),
ann_file=ann_file,
**kwarg,
)
def _create_image_index(self):
img_prefix = self.data_prefix['img_path']
files = mmengine.list_dir_or_file(img_prefix, list_dir=False)
image_index = {}
for file in files:
image_id = re.findall(r'\d+', file)
if len(image_id) > 0:
image_id = int(image_id[-1])
image_index[image_id] = mmengine.join_path(img_prefix, file)
return image_index
def load_data_list(self) -> List[dict]:
"""Load data list."""
annotations = mmengine.load(self.ann_file)
# The original Visual Genome annotation file and question file includes
# only image id but no image file paths.
self.image_index = self._create_image_index()
data_list = []
for qas in chain.from_iterable(ann['qas'] for ann in annotations):
# ann example
# {
# 'id': 1,
# 'qas': [
# {
# 'a_objects': [],
# 'question': 'What color is the clock?',
# 'image_id': 1,
# 'qa_id': 986768,
# 'answer': 'Two.',
# 'q_objects': [],
# }
# ...
# ]
# }
data_info = {
'img_path': self.image_index[qas['image_id']],
'quesiton': qas['quesiton'],
'question_id': qas['question_id'],
'image_id': qas['image_id'],
'gt_answer': [qas['answer']],
}
data_list.append(data_info)
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
from collections import Counter
from typing import List
import mmengine
from mmengine.dataset import BaseDataset
from mmpretrain.registry import DATASETS
@DATASETS.register_module()
class VizWiz(BaseDataset):
"""VizWiz dataset.
Args:
data_root (str): The root directory for ``data_prefix``, ``ann_file``
and ``question_file``.
data_prefix (str): The directory of images.
ann_file (str, optional): Annotation file path for training and
validation. Defaults to an empty string.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def __init__(self,
data_root: str,
data_prefix: str,
ann_file: str = '',
**kwarg):
super().__init__(
data_root=data_root,
data_prefix=dict(img_path=data_prefix),
ann_file=ann_file,
**kwarg,
)
def load_data_list(self) -> List[dict]:
"""Load data list."""
annotations = mmengine.load(self.ann_file)
data_list = []
for ann in annotations:
# {
# "image": "VizWiz_val_00000001.jpg",
# "question": "Can you tell me what this medicine is please?",
# "answers": [
# {
# "answer": "no",
# "answer_confidence": "yes"
# },
# {
# "answer": "unanswerable",
# "answer_confidence": "yes"
# },
# {
# "answer": "night time",
# "answer_confidence": "maybe"
# },
# {
# "answer": "unanswerable",
# "answer_confidence": "yes"
# },
# {
# "answer": "night time",
# "answer_confidence": "maybe"
# },
# {
# "answer": "night time cold medicine",
# "answer_confidence": "maybe"
# },
# {
# "answer": "night time",
# "answer_confidence": "maybe"
# },
# {
# "answer": "night time",
# "answer_confidence": "maybe"
# },
# {
# "answer": "night time",
# "answer_confidence": "maybe"
# },
# {
# "answer": "night time medicine",
# "answer_confidence": "yes"
# }
# ],
# "answer_type": "other",
# "answerable": 1
# },
data_info = dict()
data_info['question'] = ann['question']
data_info['img_path'] = mmengine.join_path(
self.data_prefix['img_path'], ann['image'])
if 'answerable' not in ann:
data_list.append(data_info)
else:
if ann['answerable'] == 1:
# add answer_weight & answer_count, delete duplicate answer
answers = []
for item in ann.pop('answers'):
if item['answer_confidence'] == 'yes' and item[
'answer'] != 'unanswerable':
answers.append(item['answer'])
count = Counter(answers)
answer_weight = [i / len(answers) for i in count.values()]
data_info['gt_answer'] = list(count.keys())
data_info['gt_answer_weight'] = answer_weight
# data_info.update(ann)
data_list.append(data_info)
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
import xml.etree.ElementTree as ET
from typing import List, Optional, Union
from mmengine import get_file_backend, list_from_file
from mmengine.logging import MMLogger
from mmpretrain.registry import DATASETS
from .base_dataset import expanduser
from .categories import VOC2007_CATEGORIES
from .multi_label import MultiLabelDataset
@DATASETS.register_module()
class VOC(MultiLabelDataset):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Dataset.
After decompression, the dataset directory structure is as follows:
VOC dataset directory: ::
VOC2007
├── JPEGImages
│ ├── xxx.jpg
│ ├── xxy.jpg
│ └── ...
├── Annotations
│ ├── xxx.xml
│ ├── xxy.xml
│ └── ...
└── ImageSets
└── Main
├── train.txt
├── val.txt
├── trainval.txt
├── test.txt
└── ...
Extra difficult label is in VOC annotations, we will use
`gt_label_difficult` to record the difficult labels in each sample
and corresponding evaluation should take care of this field
to calculate metrics. Usually, difficult labels are reckoned as
negative in defaults.
Args:
data_root (str): The root directory for VOC dataset.
split (str, optional): The dataset split, supports "train",
"val", "trainval", and "test". Default to "trainval".
image_set_path (str, optional): The path of image set, The file which
lists image ids of the sub dataset, and this path is relative
to ``data_root``. Default to ''.
data_prefix (dict): Prefix for data and annotation, keyword
'img_path' and 'ann_path' can be set. Defaults to be
``dict(img_path='JPEGImages', ann_path='Annotations')``.
metainfo (dict, optional): Meta information for dataset, such as
categories information. Defaults to None.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
Examples:
>>> from mmpretrain.datasets import VOC
>>> train_dataset = VOC(data_root='data/VOC2007', split='trainval')
>>> train_dataset
Dataset VOC
Number of samples: 5011
Number of categories: 20
Prefix of dataset: data/VOC2007
Path of image set: data/VOC2007/ImageSets/Main/trainval.txt
Prefix of images: data/VOC2007/JPEGImages
Prefix of annotations: data/VOC2007/Annotations
>>> test_dataset = VOC(data_root='data/VOC2007', split='test')
>>> test_dataset
Dataset VOC
Number of samples: 4952
Number of categories: 20
Prefix of dataset: data/VOC2007
Path of image set: data/VOC2007/ImageSets/Main/test.txt
Prefix of images: data/VOC2007/JPEGImages
Prefix of annotations: data/VOC2007/Annotations
""" # noqa: E501
METAINFO = {'classes': VOC2007_CATEGORIES}
def __init__(self,
data_root: str,
split: str = 'trainval',
image_set_path: str = '',
data_prefix: Union[str, dict] = dict(
img_path='JPEGImages', ann_path='Annotations'),
test_mode: bool = False,
metainfo: Optional[dict] = None,
**kwargs):
self.backend = get_file_backend(data_root, enable_singleton=True)
if split:
splits = ['train', 'val', 'trainval', 'test']
assert split in splits, \
f"The split must be one of {splits}, but get '{split}'"
self.split = split
if not data_prefix:
data_prefix = dict(
img_path='JPEGImages', ann_path='Annotations')
if not image_set_path:
image_set_path = self.backend.join_path(
'ImageSets', 'Main', f'{split}.txt')
# To handle the BC-breaking
if (split == 'train' or split == 'trainval') and test_mode:
logger = MMLogger.get_current_instance()
logger.warning(f'split="{split}" but test_mode=True. '
f'The {split} set will be used.')
if isinstance(data_prefix, str):
data_prefix = dict(img_path=expanduser(data_prefix))
assert isinstance(data_prefix, dict) and 'img_path' in data_prefix, \
'`data_prefix` must be a dict with key img_path'
if (split and split not in ['val', 'test']) or not test_mode:
assert 'ann_path' in data_prefix and data_prefix[
'ann_path'] is not None, \
'"ann_path" must be set in `data_prefix`' \
'when validation or test set is used.'
self.data_root = data_root
self.image_set_path = self.backend.join_path(data_root, image_set_path)
super().__init__(
ann_file='',
metainfo=metainfo,
data_root=data_root,
data_prefix=data_prefix,
test_mode=test_mode,
**kwargs)
@property
def ann_prefix(self):
"""The prefix of images."""
if 'ann_path' in self.data_prefix:
return self.data_prefix['ann_path']
else:
return None
def _get_labels_from_xml(self, img_id):
"""Get gt_labels and labels_difficult from xml file."""
xml_path = self.backend.join_path(self.ann_prefix, f'{img_id}.xml')
content = self.backend.get(xml_path)
root = ET.fromstring(content)
labels, labels_difficult = set(), set()
for obj in root.findall('object'):
label_name = obj.find('name').text
# in case customized dataset has wrong labels
# or CLASSES has been override.
if label_name not in self.CLASSES:
continue
label = self.class_to_idx[label_name]
difficult = int(obj.find('difficult').text)
if difficult:
labels_difficult.add(label)
else:
labels.add(label)
return list(labels), list(labels_difficult)
def load_data_list(self):
"""Load images and ground truth labels."""
data_list = []
img_ids = list_from_file(self.image_set_path)
for img_id in img_ids:
img_path = self.backend.join_path(self.img_prefix, f'{img_id}.jpg')
labels, labels_difficult = None, None
if self.ann_prefix is not None:
labels, labels_difficult = self._get_labels_from_xml(img_id)
info = dict(
img_path=img_path,
gt_label=labels,
gt_label_difficult=labels_difficult)
data_list.append(info)
return data_list
def extra_repr(self) -> List[str]:
"""The extra repr information of the dataset."""
body = [
f'Prefix of dataset: \t{self.data_root}',
f'Path of image set: \t{self.image_set_path}',
f'Prefix of images: \t{self.img_prefix}',
f'Prefix of annotations: \t{self.ann_prefix}'
]
return body
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import mmengine
from mmengine.dataset import BaseDataset
from mmpretrain.registry import DATASETS
@DATASETS.register_module()
class VSR(BaseDataset):
"""VSR: Visual Spatial Reasoning dataset.
Args:
data_root (str): The root directory for ``data_prefix``, ``ann_file``
and ``question_file``.
data_prefix (str): The directory of images.
ann_file (str, optional): Annotation file path for training and
validation. Defaults to an empty string.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def __init__(self,
data_root: str,
data_prefix: str,
ann_file: str = '',
**kwarg):
super().__init__(
data_root=data_root,
data_prefix=dict(img_path=data_prefix),
ann_file=ann_file,
**kwarg,
)
def load_data_list(self) -> List[dict]:
"""Load data list."""
annotations = mmengine.load(self.ann_file)
data_list = []
for ann in annotations:
# ann example
# {
# "image": "train2017/000000372029.jpg",
# "question": "The dog is on the surfboard.",
# "answer": true
# }
data_info = dict()
data_info['img_path'] = mmengine.join_path(
self.data_prefix['img_path'], ann['image'])
data_info['question'] = ann['question']
data_info['gt_answer'] = 'yes' if ann['answer'] else 'no'
data_list.append(data_info)
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
from .hooks import * # noqa: F401, F403
from .optimizers import * # noqa: F401, F403
from .runners import * # noqa: F401, F403
from .schedulers import * # noqa: F401, F403
# Copyright (c) OpenMMLab. All rights reserved.
from .class_num_check_hook import ClassNumCheckHook
from .densecl_hook import DenseCLHook
from .ema_hook import EMAHook
from .margin_head_hooks import SetAdaptiveMarginsHook
from .precise_bn_hook import PreciseBNHook
from .retriever_hooks import PrepareProtoBeforeValLoopHook
from .simsiam_hook import SimSiamHook
from .swav_hook import SwAVHook
from .switch_recipe_hook import SwitchRecipeHook
from .visualization_hook import VisualizationHook
from .warmup_param_hook import WarmupParamHook
__all__ = [
'ClassNumCheckHook', 'PreciseBNHook', 'VisualizationHook',
'SwitchRecipeHook', 'PrepareProtoBeforeValLoopHook',
'SetAdaptiveMarginsHook', 'EMAHook', 'SimSiamHook', 'DenseCLHook',
'SwAVHook', 'WarmupParamHook'
]
# Copyright (c) OpenMMLab. All rights reserved
from mmengine.hooks import Hook
from mmengine.utils import is_seq_of
from mmpretrain.registry import HOOKS
@HOOKS.register_module()
class ClassNumCheckHook(Hook):
"""Class Number Check HOOK."""
def _check_head(self, runner, dataset):
"""Check whether the `num_classes` in head matches the length of
`CLASSES` in `dataset`.
Args:
runner (obj:`Runner`): runner object.
dataset (obj: `BaseDataset`): the dataset to check.
"""
model = runner.model
if dataset.CLASSES is None:
runner.logger.warning(
f'Please set class information in `metainfo` '
f'in the {dataset.__class__.__name__} and'
f'check if it is consistent with the `num_classes` '
f'of head')
else:
assert is_seq_of(dataset.CLASSES, str), \
(f'Class information in `metainfo` in '
f'{dataset.__class__.__name__} should be a tuple of str.')
for _, module in model.named_modules():
if hasattr(module, 'num_classes'):
assert module.num_classes == len(dataset.CLASSES), \
(f'The `num_classes` ({module.num_classes}) in '
f'{module.__class__.__name__} of '
f'{model.__class__.__name__} does not matches '
f'the length of class information in `metainfo` '
f'{len(dataset.CLASSES)}) in '
f'{dataset.__class__.__name__}')
def before_train(self, runner):
"""Check whether the training dataset is compatible with head.
Args:
runner (obj: `IterBasedRunner`): Iter based Runner.
"""
self._check_head(runner, runner.train_dataloader.dataset)
def before_val(self, runner):
"""Check whether the validation dataset is compatible with head.
Args:
runner (obj:`IterBasedRunner`): Iter based Runner.
"""
self._check_head(runner, runner.val_dataloader.dataset)
def before_test(self, runner):
"""Check whether the test dataset is compatible with head.
Args:
runner (obj:`IterBasedRunner`): Iter based Runner.
"""
self._check_head(runner, runner.test_dataloader.dataset)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence
from mmengine.hooks import Hook
from mmpretrain.registry import HOOKS
from mmpretrain.utils import get_ori_model
@HOOKS.register_module()
class DenseCLHook(Hook):
"""Hook for DenseCL.
This hook includes ``loss_lambda`` warmup in DenseCL.
Borrowed from the authors' code: `<https://github.com/WXinlong/DenseCL>`_.
Args:
start_iters (int): The number of warmup iterations to set
``loss_lambda=0``. Defaults to 1000.
"""
def __init__(self, start_iters: int = 1000) -> None:
self.start_iters = start_iters
def before_train(self, runner) -> None:
"""Obtain ``loss_lambda`` from algorithm."""
assert hasattr(get_ori_model(runner.model), 'loss_lambda'), \
"The runner must have attribute \"loss_lambda\" in DenseCL."
self.loss_lambda = get_ori_model(runner.model).loss_lambda
def before_train_iter(self,
runner,
batch_idx: int,
data_batch: Optional[Sequence[dict]] = None) -> None:
"""Adjust ``loss_lambda`` every train iter."""
assert hasattr(get_ori_model(runner.model), 'loss_lambda'), \
"The runner must have attribute \"loss_lambda\" in DenseCL."
cur_iter = runner.iter
if cur_iter >= self.start_iters:
get_ori_model(runner.model).loss_lambda = self.loss_lambda
else:
get_ori_model(runner.model).loss_lambda = 0.
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import itertools
import warnings
from typing import Dict, Optional
from mmengine.hooks import EMAHook as BaseEMAHook
from mmengine.logging import MMLogger
from mmengine.runner import Runner
from mmpretrain.registry import HOOKS
@HOOKS.register_module()
class EMAHook(BaseEMAHook):
"""A Hook to apply Exponential Moving Average (EMA) on the model during
training.
Comparing with :class:`mmengine.hooks.EMAHook`, this hook accepts
``evaluate_on_ema`` and ``evaluate_on_origin`` arguments. By default, the
``evaluate_on_ema`` is enabled, and if you want to do validation and
testing on both original and EMA models, please set both arguments
``True``.
Note:
- EMAHook takes priority over CheckpointHook.
- The original model parameters are actually saved in ema field after
train.
- ``begin_iter`` and ``begin_epoch`` cannot be set at the same time.
Args:
ema_type (str): The type of EMA strategy to use. You can find the
supported strategies in :mod:`mmengine.model.averaged_model`.
Defaults to 'ExponentialMovingAverage'.
strict_load (bool): Whether to strictly enforce that the keys of
``state_dict`` in checkpoint match the keys returned by
``self.module.state_dict``. Defaults to False.
Changed in v0.3.0.
begin_iter (int): The number of iteration to enable ``EMAHook``.
Defaults to 0.
begin_epoch (int): The number of epoch to enable ``EMAHook``.
Defaults to 0.
evaluate_on_ema (bool): Whether to evaluate (validate and test)
on EMA model during val-loop and test-loop. Defaults to True.
evaluate_on_origin (bool): Whether to evaluate (validate and test)
on the original model during val-loop and test-loop.
Defaults to False.
**kwargs: Keyword arguments passed to subclasses of
:obj:`BaseAveragedModel`
"""
priority = 'NORMAL'
def __init__(self,
ema_type: str = 'ExponentialMovingAverage',
strict_load: bool = False,
begin_iter: int = 0,
begin_epoch: int = 0,
evaluate_on_ema: bool = True,
evaluate_on_origin: bool = False,
**kwargs):
super().__init__(
ema_type=ema_type,
strict_load=strict_load,
begin_iter=begin_iter,
begin_epoch=begin_epoch,
**kwargs)
if not evaluate_on_ema and not evaluate_on_origin:
warnings.warn(
'Automatically set `evaluate_on_origin=True` since the '
'`evaluate_on_ema` is disabled. If you want to disable '
'all validation, please modify the `val_interval` of '
'the `train_cfg`.', UserWarning)
evaluate_on_origin = True
self.evaluate_on_ema = evaluate_on_ema
self.evaluate_on_origin = evaluate_on_origin
self.load_ema_from_ckpt = False
def before_train(self, runner) -> None:
super().before_train(runner)
if not runner._resume and self.load_ema_from_ckpt:
# If loaded EMA state dict but not want to resume training
# overwrite the EMA state dict with the source model.
MMLogger.get_current_instance().info(
'Load from a checkpoint with EMA parameters but not '
'resume training. Initialize the model parameters with '
'EMA parameters')
for p_ema, p_src in zip(self._ema_params, self._src_params):
p_src.data.copy_(p_ema.data)
def before_val_epoch(self, runner) -> None:
"""We load parameter values from ema model to source model before
validation.
Args:
runner (Runner): The runner of the training process.
"""
if self.evaluate_on_ema:
# Swap when evaluate on ema
self._swap_ema_parameters()
def after_val_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
"""We recover source model's parameter from ema model after validation.
Args:
runner (Runner): The runner of the validation process.
metrics (Dict[str, float], optional): Evaluation results of all
metrics on validation dataset. The keys are the names of the
metrics, and the values are corresponding results.
"""
if self.evaluate_on_ema:
# Swap when evaluate on ema
self._swap_ema_parameters()
if self.evaluate_on_ema and self.evaluate_on_origin:
# Re-evaluate if evaluate on both ema and origin.
val_loop = runner.val_loop
runner.model.eval()
for idx, data_batch in enumerate(val_loop.dataloader):
val_loop.run_iter(idx, data_batch)
# compute metrics
origin_metrics = val_loop.evaluator.evaluate(
len(val_loop.dataloader.dataset))
for k, v in origin_metrics.items():
runner.message_hub.update_scalar(f'val/{k}_origin', v)
def before_test_epoch(self, runner) -> None:
"""We load parameter values from ema model to source model before test.
Args:
runner (Runner): The runner of the training process.
"""
if self.evaluate_on_ema:
# Swap when evaluate on ema
self._swap_ema_parameters()
MMLogger.get_current_instance().info('Start testing on EMA model.')
else:
MMLogger.get_current_instance().info(
'Start testing on the original model.')
def after_test_epoch(self,
runner: Runner,
metrics: Optional[Dict[str, float]] = None) -> None:
"""We recover source model's parameter from ema model after test.
Args:
runner (Runner): The runner of the testing process.
metrics (Dict[str, float], optional): Evaluation results of all
metrics on test dataset. The keys are the names of the
metrics, and the values are corresponding results.
"""
if self.evaluate_on_ema:
# Swap when evaluate on ema
self._swap_ema_parameters()
if self.evaluate_on_ema and self.evaluate_on_origin:
# Re-evaluate if evaluate on both ema and origin.
MMLogger.get_current_instance().info(
'Start testing on the original model.')
test_loop = runner.test_loop
runner.model.eval()
for idx, data_batch in enumerate(test_loop.dataloader):
test_loop.run_iter(idx, data_batch)
# compute metrics
origin_metrics = test_loop.evaluator.evaluate(
len(test_loop.dataloader.dataset))
for k, v in origin_metrics.items():
runner.message_hub.update_scalar(f'test/{k}_origin', v)
def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
"""Resume ema parameters from checkpoint.
Args:
runner (Runner): The runner of the testing process.
"""
from mmengine.runner.checkpoint import load_state_dict
if 'ema_state_dict' in checkpoint:
# The original model parameters are actually saved in ema
# field swap the weights back to resume ema state.
self._swap_ema_state_dict(checkpoint)
self.ema_model.load_state_dict(
checkpoint['ema_state_dict'], strict=self.strict_load)
self.load_ema_from_ckpt = True
# Support load checkpoint without ema state dict.
else:
load_state_dict(
self.ema_model.module,
copy.deepcopy(checkpoint['state_dict']),
strict=self.strict_load)
@property
def _src_params(self):
if self.ema_model.update_buffers:
return itertools.chain(self.src_model.parameters(),
self.src_model.buffers())
else:
return self.src_model.parameters()
@property
def _ema_params(self):
if self.ema_model.update_buffers:
return itertools.chain(self.ema_model.module.parameters(),
self.ema_model.module.buffers())
else:
return self.ema_model.module.parameters()
# Copyright (c) OpenMMLab. All rights reserved
import numpy as np
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmpretrain.models.heads import ArcFaceClsHead
from mmpretrain.registry import HOOKS
@HOOKS.register_module()
class SetAdaptiveMarginsHook(Hook):
r"""Set adaptive-margins in ArcFaceClsHead based on the power of
category-wise count.
A PyTorch implementation of paper `Google Landmark Recognition 2020
Competition Third Place Solution <https://arxiv.org/abs/2010.05350>`_.
The margins will be
:math:`\text{f}(n) = (marginMax - marginMin) · norm(n^p) + marginMin`.
The `n` indicates the number of occurrences of a category.
Args:
margin_min (float): Lower bound of margins. Defaults to 0.05.
margin_max (float): Upper bound of margins. Defaults to 0.5.
power (float): The power of category freqercy. Defaults to -0.25.
"""
def __init__(self, margin_min=0.05, margin_max=0.5, power=-0.25) -> None:
self.margin_min = margin_min
self.margin_max = margin_max
self.margin_range = margin_max - margin_min
self.p = power
def before_train(self, runner):
"""change the margins in ArcFaceClsHead.
Args:
runner (obj: `Runner`): Runner.
"""
model = runner.model
if is_model_wrapper(model):
model = model.module
if (hasattr(model, 'head')
and not isinstance(model.head, ArcFaceClsHead)):
raise ValueError(
'Hook ``SetFreqPowAdvMarginsHook`` could only be used '
f'for ``ArcFaceClsHead``, but get {type(model.head)}')
# generate margins base on the dataset.
gt_labels = runner.train_dataloader.dataset.get_gt_labels()
label_count = np.bincount(gt_labels)
label_count[label_count == 0] = 1 # At least one occurrence
pow_freq = np.power(label_count, self.p)
min_f, max_f = pow_freq.min(), pow_freq.max()
normized_pow_freq = (pow_freq - min_f) / (max_f - min_f)
margins = normized_pow_freq * self.margin_range + self.margin_min
assert len(margins) == runner.model.head.num_classes
model.head.set_margins(margins)
# Copyright (c) OpenMMLab. All rights reserved.
# Adapted from https://github.com/facebookresearch/pycls/blob/f8cd962737e33ce9e19b3083a33551da95c2d9c0/pycls/core/net.py # noqa: E501
# Original licence: Copyright (c) 2019 Facebook, Inc under the Apache License 2.0 # noqa: E501
import itertools
import logging
from typing import List, Optional, Sequence, Union
import mmengine
import torch
import torch.nn as nn
from mmengine.hooks import Hook
from mmengine.logging import print_log
from mmengine.model import is_model_wrapper
from mmengine.runner import EpochBasedTrainLoop, IterBasedTrainLoop, Runner
from mmengine.utils import ProgressBar
from torch.functional import Tensor
from torch.nn import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.instancenorm import _InstanceNorm
from torch.utils.data import DataLoader
from mmpretrain.registry import HOOKS
DATA_BATCH = Optional[Sequence[dict]]
def scaled_all_reduce(tensors: List[Tensor], num_gpus: int) -> List[Tensor]:
"""Performs the scaled all_reduce operation on the provided tensors.
The input tensors are modified in-place. Currently supports only the sum
reduction operator. The reduced values are scaled by the inverse size of
the process group.
Args:
tensors (List[torch.Tensor]): The tensors to process.
num_gpus (int): The number of gpus to use
Returns:
List[torch.Tensor]: The processed tensors.
"""
# There is no need for reduction in the single-proc case
if num_gpus == 1:
return tensors
# Queue the reductions
reductions = []
for tensor in tensors:
reduction = torch.distributed.all_reduce(tensor, async_op=True)
reductions.append(reduction)
# Wait for reductions to finish
for reduction in reductions:
reduction.wait()
# Scale the results
for tensor in tensors:
tensor.mul_(1.0 / num_gpus)
return tensors
@torch.no_grad()
def update_bn_stats(
model: nn.Module,
loader: DataLoader,
num_samples: int = 8192,
logger: Optional[Union[logging.Logger, str]] = None) -> None:
"""Computes precise BN stats on training data.
Args:
model (nn.module): The model whose bn stats will be recomputed.
loader (DataLoader): PyTorch dataloader._dataloader
num_samples (int): The number of samples to update the bn stats.
Defaults to 8192.
logger (logging.Logger or str, optional): If the type of logger is
``logging.Logger``, we directly use logger to log messages.
Some special loggers are:
- "silent": No message will be printed.
- "current": Use latest created logger to log message.
- other str: Instance name of logger. The corresponding logger
will log message if it has been created, otherwise will raise a
`ValueError`.
- None: The `print()` method will be used to print log messages.
"""
if is_model_wrapper(model):
model = model.module
# get dist info
rank, world_size = mmengine.dist.get_dist_info()
# Compute the number of mini-batches to use, if the size of dataloader is
# less than num_iters, use all the samples in dataloader.
num_iter = num_samples // (loader.batch_size * world_size)
num_iter = min(num_iter, len(loader))
# Retrieve the BN layers
bn_layers = [
m for m in model.modules()
if m.training and isinstance(m, (_BatchNorm))
]
if len(bn_layers) == 0:
print_log('No BN found in model', logger=logger, level=logging.WARNING)
return
print_log(
f'{len(bn_layers)} BN found, run {num_iter} iters...', logger=logger)
# Finds all the other norm layers with training=True.
other_norm_layers = [
m for m in model.modules()
if m.training and isinstance(m, (_InstanceNorm, GroupNorm))
]
if len(other_norm_layers) > 0:
print_log(
'IN/GN stats will not be updated in PreciseHook.',
logger=logger,
level=logging.INFO)
# Initialize BN stats storage for computing
# mean(mean(batch)) and mean(var(batch))
running_means = [torch.zeros_like(bn.running_mean) for bn in bn_layers]
running_vars = [torch.zeros_like(bn.running_var) for bn in bn_layers]
# Remember momentum values
momentums = [bn.momentum for bn in bn_layers]
# Set momentum to 1.0 to compute BN stats that reflect the current batch
for bn in bn_layers:
bn.momentum = 1.0
# Average the BN stats for each BN layer over the batches
if rank == 0:
prog_bar = ProgressBar(num_iter)
for data in itertools.islice(loader, num_iter):
data = model.data_preprocessor(data, False)
model(**data)
for i, bn in enumerate(bn_layers):
running_means[i] += bn.running_mean / num_iter
running_vars[i] += bn.running_var / num_iter
if rank == 0:
prog_bar.update()
# Sync BN stats across GPUs (no reduction if 1 GPU used)
running_means = scaled_all_reduce(running_means, world_size)
running_vars = scaled_all_reduce(running_vars, world_size)
# Set BN stats and restore original momentum values
for i, bn in enumerate(bn_layers):
bn.running_mean = running_means[i]
bn.running_var = running_vars[i]
bn.momentum = momentums[i]
@HOOKS.register_module()
class PreciseBNHook(Hook):
"""Precise BN hook.
Recompute and update the batch norm stats to make them more precise. During
training both BN stats and the weight are changing after every iteration,
so the running average can not precisely reflect the actual stats of the
current model.
With this hook, the BN stats are recomputed with fixed weights, to make the
running average more precise. Specifically, it computes the true average of
per-batch mean/variance instead of the running average. See Sec. 3 of the
paper `Rethinking Batch in BatchNorm <https://arxiv.org/abs/2105.07576>`
for details.
This hook will update BN stats, so it should be executed before
``CheckpointHook`` and ``EMAHook``, generally set its priority to
"ABOVE_NORMAL".
Args:
num_samples (int): The number of samples to update the bn stats.
Defaults to 8192.
interval (int): Perform precise bn interval. If the train loop is
`EpochBasedTrainLoop` or `by_epoch=True`, its unit is 'epoch'; if the
train loop is `IterBasedTrainLoop` or `by_epoch=False`, its unit is
'iter'. Defaults to 1.
"""
def __init__(self, num_samples: int = 8192, interval: int = 1) -> None:
assert interval > 0 and num_samples > 0, "'interval' and " \
"'num_samples' must be bigger than 0."
self.interval = interval
self.num_samples = num_samples
def _perform_precise_bn(self, runner: Runner) -> None:
"""perform precise bn."""
print_log(
f'Running Precise BN for {self.num_samples} samples...',
logger=runner.logger)
update_bn_stats(
runner.model,
runner.train_loop.dataloader,
self.num_samples,
logger=runner.logger)
print_log('Finish Precise BN, BN stats updated.', logger=runner.logger)
def after_train_epoch(self, runner: Runner) -> None:
"""Calculate prcise BN and broadcast BN stats across GPUs.
Args:
runner (obj:`Runner`): The runner of the training process.
"""
# if use `EpochBasedTrainLoop``, do perform precise every
# `self.interval` epochs.
if isinstance(runner.train_loop,
EpochBasedTrainLoop) and self.every_n_epochs(
runner, self.interval):
self._perform_precise_bn(runner)
def after_train_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None:
"""Calculate prcise BN and broadcast BN stats across GPUs.
Args:
runner (obj:`Runner`): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
"""
# if use `IterBasedTrainLoop``, do perform precise every
# `self.interval` iters.
if isinstance(runner.train_loop,
IterBasedTrainLoop) and self.every_n_train_iters(
runner, self.interval):
self._perform_precise_bn(runner)
# Copyright (c) OpenMMLab. All rights reserved
import warnings
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmpretrain.models import BaseRetriever
from mmpretrain.registry import HOOKS
@HOOKS.register_module()
class PrepareProtoBeforeValLoopHook(Hook):
"""The hook to prepare the prototype in retrievers.
Since the encoders of the retriever changes during training, the prototype
changes accordingly. So the `prototype_vecs` needs to be regenerated before
validation loop.
"""
def before_val(self, runner) -> None:
model = runner.model
if is_model_wrapper(model):
model = model.module
if isinstance(model, BaseRetriever):
if hasattr(model, 'prepare_prototype'):
model.prepare_prototype()
else:
warnings.warn(
'Only the `mmpretrain.models.retrievers.BaseRetriever` '
'can execute `PrepareRetrieverPrototypeHook`, but got '
f'`{type(model)}`')
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence
from mmengine.hooks import Hook
from mmpretrain.registry import HOOKS
@HOOKS.register_module()
class SimSiamHook(Hook):
"""Hook for SimSiam.
This hook is for SimSiam to fix learning rate of predictor.
Args:
fix_pred_lr (bool): whether to fix the lr of predictor or not.
lr (float): the value of fixed lr.
adjust_by_epoch (bool, optional): whether to set lr by epoch or iter.
Defaults to True.
"""
def __init__(self,
fix_pred_lr: bool,
lr: float,
adjust_by_epoch: Optional[bool] = True) -> None:
self.fix_pred_lr = fix_pred_lr
self.lr = lr
self.adjust_by_epoch = adjust_by_epoch
def before_train_iter(self,
runner,
batch_idx: int,
data_batch: Optional[Sequence[dict]] = None) -> None:
"""fix lr of predictor by iter."""
if self.adjust_by_epoch:
return
else:
if self.fix_pred_lr:
for param_group in runner.optim_wrapper.optimizer.param_groups:
if 'fix_lr' in param_group and param_group['fix_lr']:
param_group['lr'] = self.lr
def before_train_epoch(self, runner) -> None:
"""fix lr of predictor by epoch."""
if self.fix_pred_lr:
for param_group in runner.optim_wrapper.optimizer.param_groups:
if 'fix_lr' in param_group and param_group['fix_lr']:
param_group['lr'] = self.lr
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Dict, List, Optional, Sequence
import torch
from mmengine.device import get_device
from mmengine.dist import get_rank, get_world_size, is_distributed
from mmengine.hooks import Hook
from mmengine.logging import MMLogger
from mmpretrain.registry import HOOKS
from mmpretrain.utils import get_ori_model
@HOOKS.register_module()
class SwAVHook(Hook):
"""Hook for SwAV.
This hook builds the queue in SwAV according to ``epoch_queue_starts``.
The queue will be saved in ``runner.work_dir`` or loaded at start epoch
if the path folder has queues saved before.
Args:
batch_size (int): the batch size per GPU for computing.
epoch_queue_starts (int, optional): from this epoch, starts to use the
queue. Defaults to 15.
crops_for_assign (list[int], optional): list of crops id used for
computing assignments. Defaults to [0, 1].
feat_dim (int, optional): feature dimension of output vector.
Defaults to 128.
queue_length (int, optional): length of the queue (0 for no queue).
Defaults to 0.
interval (int, optional): the interval to save the queue.
Defaults to 1.
frozen_layers_cfg (dict, optional): Dict to config frozen layers.
The key-value pair is layer name and its frozen iters. If frozen,
the layers don't need gradient. Defaults to dict().
"""
def __init__(
self,
batch_size: int,
epoch_queue_starts: Optional[int] = 15,
crops_for_assign: Optional[List[int]] = [0, 1],
feat_dim: Optional[int] = 128,
queue_length: Optional[int] = 0,
interval: Optional[int] = 1,
frozen_layers_cfg: Optional[Dict] = dict()
) -> None:
self.batch_size = batch_size * get_world_size()
self.epoch_queue_starts = epoch_queue_starts
self.crops_for_assign = crops_for_assign
self.feat_dim = feat_dim
self.queue_length = queue_length
self.interval = interval
self.frozen_layers_cfg = frozen_layers_cfg
self.requires_grad = True
self.queue = None
def before_run(self, runner) -> None:
"""Check whether the queues exist locally or not."""
if is_distributed():
self.queue_path = osp.join(runner.work_dir,
'queue' + str(get_rank()) + '.pth')
else:
self.queue_path = osp.join(runner.work_dir, 'queue.pth')
# load the queues if queues exist locally
if osp.isfile(self.queue_path):
self.queue = torch.load(self.queue_path)['queue']
get_ori_model(runner.model).head.loss_module.queue = self.queue
MMLogger.get_current_instance().info(
f'Load queue from file: {self.queue_path}')
# the queue needs to be divisible by the batch size
self.queue_length -= self.queue_length % self.batch_size
def before_train_iter(self,
runner,
batch_idx: int,
data_batch: Optional[Sequence[dict]] = None) -> None:
"""Freeze layers before specific iters according to the config."""
for layer, frozen_iters in self.frozen_layers_cfg.items():
if runner.iter < frozen_iters and self.requires_grad:
self.requires_grad = False
for name, p in get_ori_model(runner.model).named_parameters():
if layer in name:
p.requires_grad = False
elif runner.iter >= frozen_iters and not self.requires_grad:
self.requires_grad = True
for name, p in get_ori_model(runner.model).named_parameters():
if layer in name:
p.requires_grad = True
def before_train_epoch(self, runner) -> None:
"""Check the queues' state."""
# optionally starts a queue
if self.queue_length > 0 \
and runner.epoch >= self.epoch_queue_starts \
and self.queue is None:
self.queue = torch.zeros(
len(self.crops_for_assign),
self.queue_length // runner.world_size,
self.feat_dim,
device=get_device(),
)
# set the boolean type of use_the_queue
get_ori_model(runner.model).head.loss_module.queue = self.queue
get_ori_model(runner.model).head.loss_module.use_queue = False
def after_train_epoch(self, runner) -> None:
"""Save the queues locally."""
self.queue = get_ori_model(runner.model).head.loss_module.queue
if self.queue is not None and self.every_n_epochs(
runner, self.interval):
torch.save({'queue': self.queue}, self.queue_path)
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from copy import deepcopy
from mmcv.transforms import Compose
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmpretrain.models.utils import RandomBatchAugment
from mmpretrain.registry import HOOKS, MODEL_WRAPPERS, MODELS
@HOOKS.register_module()
class SwitchRecipeHook(Hook):
"""switch recipe during the training loop, including train pipeline, batch
augments and loss currently.
Args:
schedule (list): Every item of the schedule list should be a dict, and
the dict should have ``action_epoch`` and some of
``train_pipeline``, ``train_augments`` and ``loss`` keys:
- ``action_epoch`` (int): switch training recipe at which epoch.
- ``train_pipeline`` (list, optional): The new data pipeline of the
train dataset. If not specified, keep the original settings.
- ``batch_augments`` (dict | None, optional): The new batch
augmentations of during training. See :mod:`Batch Augmentations
<mmpretrain.models.utils.batch_augments>` for more details.
If None, disable batch augmentations. If not specified, keep the
original settings.
- ``loss`` (dict, optional): The new loss module config. If not
specified, keep the original settings.
Example:
To use this hook in config files.
.. code:: python
custom_hooks = [
dict(
type='SwitchRecipeHook',
schedule=[
dict(
action_epoch=30,
train_pipeline=pipeline_after_30e,
batch_augments=batch_augments_after_30e,
loss=loss_after_30e,
),
dict(
action_epoch=60,
# Disable batch augmentations after 60e
# and keep other settings.
batch_augments=None,
),
]
)
]
"""
priority = 'NORMAL'
def __init__(self, schedule):
recipes = {}
for recipe in schedule:
assert 'action_epoch' in recipe, \
'Please set `action_epoch` in every item ' \
'of the `schedule` in the SwitchRecipeHook.'
recipe = deepcopy(recipe)
if 'train_pipeline' in recipe:
recipe['train_pipeline'] = Compose(recipe['train_pipeline'])
if 'batch_augments' in recipe:
batch_augments = recipe['batch_augments']
if isinstance(batch_augments, dict):
batch_augments = RandomBatchAugment(**batch_augments)
recipe['batch_augments'] = batch_augments
if 'loss' in recipe:
loss = recipe['loss']
if isinstance(loss, dict):
loss = MODELS.build(loss)
recipe['loss'] = loss
action_epoch = recipe.pop('action_epoch')
assert action_epoch not in recipes, \
f'The `action_epoch` {action_epoch} is repeated ' \
'in the SwitchRecipeHook.'
recipes[action_epoch] = recipe
self.schedule = OrderedDict(sorted(recipes.items()))
def before_train(self, runner) -> None:
"""before run setting. If resume form a checkpoint, do all switch
before the current epoch.
Args:
runner (Runner): The runner of the training, validation or testing
process.
"""
if runner._resume:
for action_epoch, recipe in self.schedule.items():
if action_epoch >= runner.epoch + 1:
break
self._do_switch(runner, recipe,
f' (resume recipe of epoch {action_epoch})')
def before_train_epoch(self, runner):
"""do before train epoch."""
recipe = self.schedule.get(runner.epoch + 1, None)
if recipe is not None:
self._do_switch(runner, recipe, f' at epoch {runner.epoch + 1}')
def _do_switch(self, runner, recipe, extra_info=''):
"""do the switch aug process."""
if 'batch_augments' in recipe:
self._switch_batch_augments(runner, recipe['batch_augments'])
runner.logger.info(f'Switch batch augments{extra_info}.')
if 'train_pipeline' in recipe:
self._switch_train_pipeline(runner, recipe['train_pipeline'])
runner.logger.info(f'Switch train pipeline{extra_info}.')
if 'loss' in recipe:
self._switch_loss(runner, recipe['loss'])
runner.logger.info(f'Switch loss{extra_info}.')
@staticmethod
def _switch_batch_augments(runner, batch_augments):
"""switch the train augments."""
model = runner.model
if is_model_wrapper(model):
model = model.module
model.data_preprocessor.batch_augments = batch_augments
@staticmethod
def _switch_train_pipeline(runner, train_pipeline):
"""switch the train loader dataset pipeline."""
def switch_pipeline(dataset, pipeline):
if hasattr(dataset, 'pipeline'):
# for usual dataset
dataset.pipeline = pipeline
elif hasattr(dataset, 'datasets'):
# for concat dataset wrapper
for ds in dataset.datasets:
switch_pipeline(ds, pipeline)
elif hasattr(dataset, 'dataset'):
# for other dataset wrappers
switch_pipeline(dataset.dataset, pipeline)
else:
raise RuntimeError(
'Cannot access the `pipeline` of the dataset.')
train_loader = runner.train_loop.dataloader
switch_pipeline(train_loader.dataset, train_pipeline)
# To restart the iterator of dataloader when `persistent_workers=True`
train_loader._iterator = None
@staticmethod
def _switch_loss(runner, loss_module):
"""switch the loss module."""
model = runner.model
if is_model_wrapper(model, MODEL_WRAPPERS):
model = model.module
if hasattr(model, 'loss_module'):
model.loss_module = loss_module
elif hasattr(model, 'head') and hasattr(model.head, 'loss_module'):
model.head.loss_module = loss_module
else:
raise RuntimeError('Cannot access the `loss_module` of the model.')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment