Commit e5f6a32a authored by dengjb's avatar dengjb
Browse files

update code

parent 1189a8ad
Pipeline #718 failed with stages
in 0 seconds
# Copyright (c) OpenMMLab. All rights reserved.
import csv
import os.path as osp
from collections import defaultdict
from typing import Dict, List, Optional
import numpy as np
from mmengine.fileio import get_local_path, load
from mmengine.utils import is_abs
from mmdet.registry import DATASETS
from .base_det_dataset import BaseDetDataset
@DATASETS.register_module()
class OpenImagesDataset(BaseDetDataset):
"""Open Images dataset for detection.
Args:
ann_file (str): Annotation file path.
label_file (str): File path of the label description file that
maps the classes names in MID format to their short
descriptions.
meta_file (str): File path to get image metas.
hierarchy_file (str): The file path of the class hierarchy.
image_level_ann_file (str): Human-verified image level annotation,
which is used in evaluation.
backend_args (dict, optional): Arguments to instantiate the
corresponding backend. Defaults to None.
"""
METAINFO: dict = dict(dataset_type='oid_v6')
def __init__(self,
label_file: str,
meta_file: str,
hierarchy_file: str,
image_level_ann_file: Optional[str] = None,
**kwargs) -> None:
self.label_file = label_file
self.meta_file = meta_file
self.hierarchy_file = hierarchy_file
self.image_level_ann_file = image_level_ann_file
super().__init__(**kwargs)
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
"""
classes_names, label_id_mapping = self._parse_label_file(
self.label_file)
self._metainfo['classes'] = classes_names
self.label_id_mapping = label_id_mapping
if self.image_level_ann_file is not None:
img_level_anns = self._parse_img_level_ann(
self.image_level_ann_file)
else:
img_level_anns = None
# OpenImagesMetric can get the relation matrix from the dataset meta
relation_matrix = self._get_relation_matrix(self.hierarchy_file)
self._metainfo['RELATION_MATRIX'] = relation_matrix
data_list = []
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
last_img_id = None
instances = []
for i, line in enumerate(reader):
if i == 0:
continue
img_id = line[0]
if last_img_id is None:
last_img_id = img_id
label_id = line[2]
assert label_id in self.label_id_mapping
label = int(self.label_id_mapping[label_id])
bbox = [
float(line[4]), # xmin
float(line[6]), # ymin
float(line[5]), # xmax
float(line[7]) # ymax
]
is_occluded = True if int(line[8]) == 1 else False
is_truncated = True if int(line[9]) == 1 else False
is_group_of = True if int(line[10]) == 1 else False
is_depiction = True if int(line[11]) == 1 else False
is_inside = True if int(line[12]) == 1 else False
instance = dict(
bbox=bbox,
bbox_label=label,
ignore_flag=0,
is_occluded=is_occluded,
is_truncated=is_truncated,
is_group_of=is_group_of,
is_depiction=is_depiction,
is_inside=is_inside)
last_img_path = osp.join(self.data_prefix['img'],
f'{last_img_id}.jpg')
if img_id != last_img_id:
# switch to a new image, record previous image's data.
data_info = dict(
img_path=last_img_path,
img_id=last_img_id,
instances=instances,
)
data_list.append(data_info)
instances = []
instances.append(instance)
last_img_id = img_id
data_list.append(
dict(
img_path=last_img_path,
img_id=last_img_id,
instances=instances,
))
# add image metas to data list
img_metas = load(
self.meta_file, file_format='pkl', backend_args=self.backend_args)
assert len(img_metas) == len(data_list)
for i, meta in enumerate(img_metas):
img_id = data_list[i]['img_id']
assert f'{img_id}.jpg' == osp.split(meta['filename'])[-1]
h, w = meta['ori_shape'][:2]
data_list[i]['height'] = h
data_list[i]['width'] = w
# denormalize bboxes
for j in range(len(data_list[i]['instances'])):
data_list[i]['instances'][j]['bbox'][0] *= w
data_list[i]['instances'][j]['bbox'][2] *= w
data_list[i]['instances'][j]['bbox'][1] *= h
data_list[i]['instances'][j]['bbox'][3] *= h
# add image-level annotation
if img_level_anns is not None:
img_labels = []
confidences = []
img_ann_list = img_level_anns.get(img_id, [])
for ann in img_ann_list:
img_labels.append(int(ann['image_level_label']))
confidences.append(float(ann['confidence']))
data_list[i]['image_level_labels'] = np.array(
img_labels, dtype=np.int64)
data_list[i]['confidences'] = np.array(
confidences, dtype=np.float32)
return data_list
def _parse_label_file(self, label_file: str) -> tuple:
"""Get classes name and index mapping from cls-label-description file.
Args:
label_file (str): File path of the label description file that
maps the classes names in MID format to their short
descriptions.
Returns:
tuple: Class name of OpenImages.
"""
index_list = []
classes_names = []
with get_local_path(
label_file, backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
for line in reader:
# self.cat2label[line[0]] = line[1]
classes_names.append(line[1])
index_list.append(line[0])
index_mapping = {index: i for i, index in enumerate(index_list)}
return classes_names, index_mapping
def _parse_img_level_ann(self,
img_level_ann_file: str) -> Dict[str, List[dict]]:
"""Parse image level annotations from csv style ann_file.
Args:
img_level_ann_file (str): CSV style image level annotation
file path.
Returns:
Dict[str, List[dict]]: Annotations where item of the defaultdict
indicates an image, each of which has (n) dicts.
Keys of dicts are:
- `image_level_label` (int): Label id.
- `confidence` (float): Labels that are human-verified to be
present in an image have confidence = 1 (positive labels).
Labels that are human-verified to be absent from an image
have confidence = 0 (negative labels). Machine-generated
labels have fractional confidences, generally >= 0.5.
The higher the confidence, the smaller the chance for
the label to be a false positive.
"""
item_lists = defaultdict(list)
with get_local_path(
img_level_ann_file,
backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
for i, line in enumerate(reader):
if i == 0:
continue
img_id = line[0]
item_lists[img_id].append(
dict(
image_level_label=int(
self.label_id_mapping[line[2]]),
confidence=float(line[3])))
return item_lists
def _get_relation_matrix(self, hierarchy_file: str) -> np.ndarray:
"""Get the matrix of class hierarchy from the hierarchy file. Hierarchy
for 600 classes can be found at https://storage.googleapis.com/openimag
es/2018_04/bbox_labels_600_hierarchy_visualizer/circle.html.
Args:
hierarchy_file (str): File path to the hierarchy for classes.
Returns:
np.ndarray: The matrix of the corresponding relationship between
the parent class and the child class, of shape
(class_num, class_num).
""" # noqa
hierarchy = load(
hierarchy_file, file_format='json', backend_args=self.backend_args)
class_num = len(self._metainfo['classes'])
relation_matrix = np.eye(class_num, class_num)
relation_matrix = self._convert_hierarchy_tree(hierarchy,
relation_matrix)
return relation_matrix
def _convert_hierarchy_tree(self,
hierarchy_map: dict,
relation_matrix: np.ndarray,
parents: list = [],
get_all_parents: bool = True) -> np.ndarray:
"""Get matrix of the corresponding relationship between the parent
class and the child class.
Args:
hierarchy_map (dict): Including label name and corresponding
subcategory. Keys of dicts are:
- `LabeName` (str): Name of the label.
- `Subcategory` (dict | list): Corresponding subcategory(ies).
relation_matrix (ndarray): The matrix of the corresponding
relationship between the parent class and the child class,
of shape (class_num, class_num).
parents (list): Corresponding parent class.
get_all_parents (bool): Whether get all parent names.
Default: True
Returns:
ndarray: The matrix of the corresponding relationship between
the parent class and the child class, of shape
(class_num, class_num).
"""
if 'Subcategory' in hierarchy_map:
for node in hierarchy_map['Subcategory']:
if 'LabelName' in node:
children_name = node['LabelName']
children_index = self.label_id_mapping[children_name]
children = [children_index]
else:
continue
if len(parents) > 0:
for parent_index in parents:
if get_all_parents:
children.append(parent_index)
relation_matrix[children_index, parent_index] = 1
relation_matrix = self._convert_hierarchy_tree(
node, relation_matrix, parents=children)
return relation_matrix
def _join_prefix(self):
"""Join ``self.data_root`` with annotation path."""
super()._join_prefix()
if not is_abs(self.label_file) and self.label_file:
self.label_file = osp.join(self.data_root, self.label_file)
if not is_abs(self.meta_file) and self.meta_file:
self.meta_file = osp.join(self.data_root, self.meta_file)
if not is_abs(self.hierarchy_file) and self.hierarchy_file:
self.hierarchy_file = osp.join(self.data_root, self.hierarchy_file)
if self.image_level_ann_file and not is_abs(self.image_level_ann_file):
self.image_level_ann_file = osp.join(self.data_root,
self.image_level_ann_file)
@DATASETS.register_module()
class OpenImagesChallengeDataset(OpenImagesDataset):
"""Open Images Challenge dataset for detection.
Args:
ann_file (str): Open Images Challenge box annotation in txt format.
"""
METAINFO: dict = dict(dataset_type='oid_challenge')
def __init__(self, ann_file: str, **kwargs) -> None:
if not ann_file.endswith('txt'):
raise TypeError('The annotation file of Open Images Challenge '
'should be a txt file.')
super().__init__(ann_file=ann_file, **kwargs)
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
"""
classes_names, label_id_mapping = self._parse_label_file(
self.label_file)
self._metainfo['classes'] = classes_names
self.label_id_mapping = label_id_mapping
if self.image_level_ann_file is not None:
img_level_anns = self._parse_img_level_ann(
self.image_level_ann_file)
else:
img_level_anns = None
# OpenImagesMetric can get the relation matrix from the dataset meta
relation_matrix = self._get_relation_matrix(self.hierarchy_file)
self._metainfo['RELATION_MATRIX'] = relation_matrix
data_list = []
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
lines = f.readlines()
i = 0
while i < len(lines):
instances = []
filename = lines[i].rstrip()
i += 2
img_gt_size = int(lines[i])
i += 1
for j in range(img_gt_size):
sp = lines[i + j].split()
instances.append(
dict(
bbox=[
float(sp[1]),
float(sp[2]),
float(sp[3]),
float(sp[4])
],
bbox_label=int(sp[0]) - 1, # labels begin from 1
ignore_flag=0,
is_group_ofs=True if int(sp[5]) == 1 else False))
i += img_gt_size
data_list.append(
dict(
img_path=osp.join(self.data_prefix['img'], filename),
instances=instances,
))
# add image metas to data list
img_metas = load(
self.meta_file, file_format='pkl', backend_args=self.backend_args)
assert len(img_metas) == len(data_list)
for i, meta in enumerate(img_metas):
img_id = osp.split(data_list[i]['img_path'])[-1][:-4]
assert img_id == osp.split(meta['filename'])[-1][:-4]
h, w = meta['ori_shape'][:2]
data_list[i]['height'] = h
data_list[i]['width'] = w
data_list[i]['img_id'] = img_id
# denormalize bboxes
for j in range(len(data_list[i]['instances'])):
data_list[i]['instances'][j]['bbox'][0] *= w
data_list[i]['instances'][j]['bbox'][2] *= w
data_list[i]['instances'][j]['bbox'][1] *= h
data_list[i]['instances'][j]['bbox'][3] *= h
# add image-level annotation
if img_level_anns is not None:
img_labels = []
confidences = []
img_ann_list = img_level_anns.get(img_id, [])
for ann in img_ann_list:
img_labels.append(int(ann['image_level_label']))
confidences.append(float(ann['confidence']))
data_list[i]['image_level_labels'] = np.array(
img_labels, dtype=np.int64)
data_list[i]['confidences'] = np.array(
confidences, dtype=np.float32)
return data_list
def _parse_label_file(self, label_file: str) -> tuple:
"""Get classes name and index mapping from cls-label-description file.
Args:
label_file (str): File path of the label description file that
maps the classes names in MID format to their short
descriptions.
Returns:
tuple: Class name of OpenImages.
"""
label_list = []
id_list = []
index_mapping = {}
with get_local_path(
label_file, backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
for line in reader:
label_name = line[0]
label_id = int(line[2])
label_list.append(line[1])
id_list.append(label_id)
index_mapping[label_name] = label_id - 1
indexes = np.argsort(id_list)
classes_names = []
for index in indexes:
classes_names.append(label_list[index])
return classes_names, index_mapping
def _parse_img_level_ann(self, image_level_ann_file):
"""Parse image level annotations from csv style ann_file.
Args:
image_level_ann_file (str): CSV style image level annotation
file path.
Returns:
defaultdict[list[dict]]: Annotations where item of the defaultdict
indicates an image, each of which has (n) dicts.
Keys of dicts are:
- `image_level_label` (int): of shape 1.
- `confidence` (float): of shape 1.
"""
item_lists = defaultdict(list)
with get_local_path(
image_level_ann_file,
backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
i = -1
for line in reader:
i += 1
if i == 0:
continue
else:
img_id = line[0]
label_id = line[1]
assert label_id in self.label_id_mapping
image_level_label = int(
self.label_id_mapping[label_id])
confidence = float(line[2])
item_lists[img_id].append(
dict(
image_level_label=image_level_label,
confidence=confidence))
return item_lists
def _get_relation_matrix(self, hierarchy_file: str) -> np.ndarray:
"""Get the matrix of class hierarchy from the hierarchy file.
Args:
hierarchy_file (str): File path to the hierarchy for classes.
Returns:
np.ndarray: The matrix of the corresponding
relationship between the parent class and the child class,
of shape (class_num, class_num).
"""
with get_local_path(
hierarchy_file, backend_args=self.backend_args) as local_path:
class_label_tree = np.load(local_path, allow_pickle=True)
return class_label_tree[1:, 1:]
# Copyright (c) OpenMMLab. All rights reserved.
import collections
import os.path as osp
import random
from typing import Dict, List
import mmengine
from mmengine.dataset import BaseDataset
from mmdet.registry import DATASETS
@DATASETS.register_module()
class RefCocoDataset(BaseDataset):
"""RefCOCO dataset.
The `Refcoco` and `Refcoco+` dataset is based on
`ReferItGame: Referring to Objects in Photographs of Natural Scenes
<http://tamaraberg.com/papers/referit.pdf>`_.
The `Refcocog` dataset is based on
`Generation and Comprehension of Unambiguous Object Descriptions
<https://arxiv.org/abs/1511.02283>`_.
Args:
ann_file (str): Annotation file path.
data_root (str): The root directory for ``data_prefix`` and
``ann_file``. Defaults to ''.
data_prefix (str): Prefix for training data.
split_file (str): Split file path.
split (str): Split name. Defaults to 'train'.
text_mode (str): Text mode. Defaults to 'random'.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def __init__(self,
data_root: str,
ann_file: str,
split_file: str,
data_prefix: Dict,
split: str = 'train',
text_mode: str = 'random',
**kwargs):
self.split_file = split_file
self.split = split
assert text_mode in ['original', 'random', 'concat', 'select_first']
self.text_mode = text_mode
super().__init__(
data_root=data_root,
data_prefix=data_prefix,
ann_file=ann_file,
**kwargs,
)
def _join_prefix(self):
if not mmengine.is_abs(self.split_file) and self.split_file:
self.split_file = osp.join(self.data_root, self.split_file)
return super()._join_prefix()
def _init_refs(self):
"""Initialize the refs for RefCOCO."""
anns, imgs = {}, {}
for ann in self.instances['annotations']:
anns[ann['id']] = ann
for img in self.instances['images']:
imgs[img['id']] = img
refs, ref_to_ann = {}, {}
for ref in self.splits:
# ids
ref_id = ref['ref_id']
ann_id = ref['ann_id']
# add mapping related to ref
refs[ref_id] = ref
ref_to_ann[ref_id] = anns[ann_id]
self.refs = refs
self.ref_to_ann = ref_to_ann
def load_data_list(self) -> List[dict]:
"""Load data list."""
self.splits = mmengine.load(self.split_file, file_format='pkl')
self.instances = mmengine.load(self.ann_file, file_format='json')
self._init_refs()
img_prefix = self.data_prefix['img_path']
ref_ids = [
ref['ref_id'] for ref in self.splits if ref['split'] == self.split
]
full_anno = []
for ref_id in ref_ids:
ref = self.refs[ref_id]
ann = self.ref_to_ann[ref_id]
ann.update(ref)
full_anno.append(ann)
image_id_list = []
final_anno = {}
for anno in full_anno:
image_id_list.append(anno['image_id'])
final_anno[anno['ann_id']] = anno
annotations = [value for key, value in final_anno.items()]
coco_train_id = []
image_annot = {}
for i in range(len(self.instances['images'])):
coco_train_id.append(self.instances['images'][i]['id'])
image_annot[self.instances['images'][i]
['id']] = self.instances['images'][i]
images = []
for image_id in list(set(image_id_list)):
images += [image_annot[image_id]]
data_list = []
grounding_dict = collections.defaultdict(list)
for anno in annotations:
image_id = int(anno['image_id'])
grounding_dict[image_id].append(anno)
join_path = mmengine.fileio.get_file_backend(img_prefix).join_path
for image in images:
img_id = image['id']
instances = []
sentences = []
for grounding_anno in grounding_dict[img_id]:
texts = [x['raw'].lower() for x in grounding_anno['sentences']]
# random select one text
if self.text_mode == 'random':
idx = random.randint(0, len(texts) - 1)
text = [texts[idx]]
# concat all texts
elif self.text_mode == 'concat':
text = [''.join(texts)]
# select the first text
elif self.text_mode == 'select_first':
text = [texts[0]]
# use all texts
elif self.text_mode == 'original':
text = texts
else:
raise ValueError(f'Invalid text mode "{self.text_mode}".')
ins = [{
'mask': grounding_anno['segmentation'],
'ignore_flag': 0
}] * len(text)
instances.extend(ins)
sentences.extend(text)
data_info = {
'img_path': join_path(img_prefix, image['file_name']),
'img_id': img_id,
'instances': instances,
'text': sentences
}
data_list.append(data_info)
if len(data_list) == 0:
raise ValueError(f'No sample in split "{self.split}".')
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from collections import defaultdict
from typing import Any, Dict, List
import numpy as np
from mmengine.dataset import BaseDataset
from mmengine.utils import check_file_exist
from mmdet.registry import DATASETS
@DATASETS.register_module()
class ReIDDataset(BaseDataset):
"""Dataset for ReID.
Args:
triplet_sampler (dict, optional): The sampler for hard mining
triplet loss. Defaults to None.
keys: num_ids (int): The number of person ids.
ins_per_id (int): The number of image for each person.
"""
def __init__(self, triplet_sampler: dict = None, *args, **kwargs):
self.triplet_sampler = triplet_sampler
super().__init__(*args, **kwargs)
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ''self.ann_file''.
Returns:
list[dict]: A list of annotation.
"""
assert isinstance(self.ann_file, str)
check_file_exist(self.ann_file)
data_list = []
with open(self.ann_file) as f:
samples = [x.strip().split(' ') for x in f.readlines()]
for filename, gt_label in samples:
info = dict(img_prefix=self.data_prefix)
if self.data_prefix['img_path'] is not None:
info['img_path'] = osp.join(self.data_prefix['img_path'],
filename)
else:
info['img_path'] = filename
info['gt_label'] = np.array(gt_label, dtype=np.int64)
data_list.append(info)
self._parse_ann_info(data_list)
return data_list
def _parse_ann_info(self, data_list: List[dict]):
"""Parse person id annotations."""
index_tmp_dic = defaultdict(list) # pid->[idx1,...,idxN]
self.index_dic = dict() # pid->array([idx1,...,idxN])
for idx, info in enumerate(data_list):
pid = info['gt_label']
index_tmp_dic[int(pid)].append(idx)
for pid, idxs in index_tmp_dic.items():
self.index_dic[pid] = np.asarray(idxs, dtype=np.int64)
self.pids = np.asarray(list(self.index_dic.keys()), dtype=np.int64)
def prepare_data(self, idx: int) -> Any:
"""Get data processed by ''self.pipeline''.
Args:
idx (int): The index of ''data_info''
Returns:
Any: Depends on ''self.pipeline''
"""
data_info = self.get_data_info(idx)
if self.triplet_sampler is not None:
img_info = self.triplet_sampling(data_info['gt_label'],
**self.triplet_sampler)
data_info = copy.deepcopy(img_info) # triplet -> list
else:
data_info = copy.deepcopy(data_info) # no triplet -> dict
return self.pipeline(data_info)
def triplet_sampling(self,
pos_pid,
num_ids: int = 8,
ins_per_id: int = 4) -> Dict:
"""Triplet sampler for hard mining triplet loss. First, for one
pos_pid, random sample ins_per_id images with same person id.
Then, random sample num_ids - 1 images for each negative id.
Finally, random sample ins_per_id images for each negative id.
Args:
pos_pid (ndarray): The person id of the anchor.
num_ids (int): The number of person ids.
ins_per_id (int): The number of images for each person.
Returns:
Dict: Annotation information of num_ids X ins_per_id images.
"""
assert len(self.pids) >= num_ids, \
'The number of person ids in the training set must ' \
'be greater than the number of person ids in the sample.'
pos_idxs = self.index_dic[int(
pos_pid)] # all positive idxs for pos_pid
idxs_list = []
# select positive samplers
idxs_list.extend(pos_idxs[np.random.choice(
pos_idxs.shape[0], ins_per_id, replace=True)])
# select negative ids
neg_pids = np.random.choice(
[i for i, _ in enumerate(self.pids) if i != pos_pid],
num_ids - 1,
replace=False)
# select negative samplers for each negative id
for neg_pid in neg_pids:
neg_idxs = self.index_dic[neg_pid]
idxs_list.extend(neg_idxs[np.random.choice(
neg_idxs.shape[0], ins_per_id, replace=True)])
# return the final triplet batch
triplet_img_infos = []
for idx in idxs_list:
triplet_img_infos.append(copy.deepcopy(self.get_data_info(idx)))
# Collect data_list scatters (list of dict -> dict of list)
out = dict()
for key in triplet_img_infos[0].keys():
out[key] = [_info[key] for _info in triplet_img_infos]
return out
# Copyright (c) OpenMMLab. All rights reserved.
from .batch_sampler import (AspectRatioBatchSampler,
MultiDataAspectRatioBatchSampler,
TrackAspectRatioBatchSampler)
from .class_aware_sampler import ClassAwareSampler
from .multi_data_sampler import MultiDataSampler
from .multi_source_sampler import GroupMultiSourceSampler, MultiSourceSampler
from .track_img_sampler import TrackImgSampler
__all__ = [
'ClassAwareSampler', 'AspectRatioBatchSampler', 'MultiSourceSampler',
'GroupMultiSourceSampler', 'TrackImgSampler',
'TrackAspectRatioBatchSampler', 'MultiDataSampler',
'MultiDataAspectRatioBatchSampler'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
from torch.utils.data import BatchSampler, Sampler
from mmdet.datasets.samplers.track_img_sampler import TrackImgSampler
from mmdet.registry import DATA_SAMPLERS
# TODO: maybe replace with a data_loader wrapper
@DATA_SAMPLERS.register_module()
class AspectRatioBatchSampler(BatchSampler):
"""A sampler wrapper for grouping images with similar aspect ratio (< 1 or.
>= 1) into a same batch.
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``.
"""
def __init__(self,
sampler: Sampler,
batch_size: int,
drop_last: bool = False) -> None:
if not isinstance(sampler, Sampler):
raise TypeError('sampler should be an instance of ``Sampler``, '
f'but got {sampler}')
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError('batch_size should be a positive integer value, '
f'but got batch_size={batch_size}')
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
# two groups for w < h and w >= h
self._aspect_ratio_buckets = [[] for _ in range(2)]
def __iter__(self) -> Sequence[int]:
for idx in self.sampler:
data_info = self.sampler.dataset.get_data_info(idx)
width, height = data_info['width'], data_info['height']
bucket_id = 0 if width < height else 1
bucket = self._aspect_ratio_buckets[bucket_id]
bucket.append(idx)
# yield a batch of indices in the same aspect ratio group
if len(bucket) == self.batch_size:
yield bucket[:]
del bucket[:]
# yield the rest data and reset the bucket
left_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[
1]
self._aspect_ratio_buckets = [[] for _ in range(2)]
while len(left_data) > 0:
if len(left_data) <= self.batch_size:
if not self.drop_last:
yield left_data[:]
left_data = []
else:
yield left_data[:self.batch_size]
left_data = left_data[self.batch_size:]
def __len__(self) -> int:
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
@DATA_SAMPLERS.register_module()
class TrackAspectRatioBatchSampler(AspectRatioBatchSampler):
"""A sampler wrapper for grouping images with similar aspect ratio (< 1 or.
>= 1) into a same batch.
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``.
"""
def __iter__(self) -> Sequence[int]:
for idx in self.sampler:
# hard code to solve TrackImgSampler
if isinstance(self.sampler, TrackImgSampler):
video_idx, _ = idx
else:
video_idx = idx
# video_idx
data_info = self.sampler.dataset.get_data_info(video_idx)
# data_info {video_id, images, video_length}
img_data_info = data_info['images'][0]
width, height = img_data_info['width'], img_data_info['height']
bucket_id = 0 if width < height else 1
bucket = self._aspect_ratio_buckets[bucket_id]
bucket.append(idx)
# yield a batch of indices in the same aspect ratio group
if len(bucket) == self.batch_size:
yield bucket[:]
del bucket[:]
# yield the rest data and reset the bucket
left_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[
1]
self._aspect_ratio_buckets = [[] for _ in range(2)]
while len(left_data) > 0:
if len(left_data) <= self.batch_size:
if not self.drop_last:
yield left_data[:]
left_data = []
else:
yield left_data[:self.batch_size]
left_data = left_data[self.batch_size:]
@DATA_SAMPLERS.register_module()
class MultiDataAspectRatioBatchSampler(BatchSampler):
"""A sampler wrapper for grouping images with similar aspect ratio (< 1 or.
>= 1) into a same batch for multi-source datasets.
Args:
sampler (Sampler): Base sampler.
batch_size (Sequence(int)): Size of mini-batch for multi-source
datasets.
num_datasets(int): Number of multi-source datasets.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``.
"""
def __init__(self,
sampler: Sampler,
batch_size: Sequence[int],
num_datasets: int,
drop_last: bool = True) -> None:
if not isinstance(sampler, Sampler):
raise TypeError('sampler should be an instance of ``Sampler``, '
f'but got {sampler}')
self.sampler = sampler
self.batch_size = batch_size
self.num_datasets = num_datasets
self.drop_last = drop_last
# two groups for w < h and w >= h for each dataset --> 2 * num_datasets
self._buckets = [[] for _ in range(2 * self.num_datasets)]
def __iter__(self) -> Sequence[int]:
for idx in self.sampler:
data_info = self.sampler.dataset.get_data_info(idx)
width, height = data_info['width'], data_info['height']
dataset_source_idx = self.sampler.dataset.get_dataset_source(idx)
aspect_ratio_bucket_id = 0 if width < height else 1
bucket_id = dataset_source_idx * 2 + aspect_ratio_bucket_id
bucket = self._buckets[bucket_id]
bucket.append(idx)
# yield a batch of indices in the same aspect ratio group
if len(bucket) == self.batch_size[dataset_source_idx]:
yield bucket[:]
del bucket[:]
# yield the rest data and reset the bucket
for i in range(self.num_datasets):
left_data = self._buckets[i * 2 + 0] + self._buckets[i * 2 + 1]
while len(left_data) > 0:
if len(left_data) <= self.batch_size[i]:
if not self.drop_last:
yield left_data[:]
left_data = []
else:
yield left_data[:self.batch_size[i]]
left_data = left_data[self.batch_size[i]:]
self._buckets = [[] for _ in range(2 * self.num_datasets)]
def __len__(self) -> int:
sizes = [0 for _ in range(self.num_datasets)]
for idx in self.sampler:
dataset_source_idx = self.sampler.dataset.get_dataset_source(idx)
sizes[dataset_source_idx] += 1
if self.drop_last:
lens = 0
for i in range(self.num_datasets):
lens += sizes[i] // self.batch_size[i]
return lens
else:
lens = 0
for i in range(self.num_datasets):
lens += (sizes[i] + self.batch_size[i] -
1) // self.batch_size[i]
return lens
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Dict, Iterator, Optional, Union
import numpy as np
import torch
from mmengine.dataset import BaseDataset
from mmengine.dist import get_dist_info, sync_random_seed
from torch.utils.data import Sampler
from mmdet.registry import DATA_SAMPLERS
@DATA_SAMPLERS.register_module()
class ClassAwareSampler(Sampler):
r"""Sampler that restricts data loading to the label of the dataset.
A class-aware sampling strategy to effectively tackle the
non-uniform class distribution. The length of the training data is
consistent with source data. Simple improvements based on `Relay
Backpropagation for Effective Learning of Deep Convolutional
Neural Networks <https://arxiv.org/abs/1512.05830>`_
The implementation logic is referred to
https://github.com/Sense-X/TSD/blob/master/mmdet/datasets/samplers/distributed_classaware_sampler.py
Args:
dataset: Dataset used for sampling.
seed (int, optional): random seed used to shuffle the sampler.
This number should be identical across all
processes in the distributed group. Defaults to None.
num_sample_class (int): The number of samples taken from each
per-label list. Defaults to 1.
"""
def __init__(self,
dataset: BaseDataset,
seed: Optional[int] = None,
num_sample_class: int = 1) -> None:
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
self.epoch = 0
# Must be the same across all workers. If None, will use a
# random seed shared among workers
# (require synchronization among all workers)
if seed is None:
seed = sync_random_seed()
self.seed = seed
# The number of samples taken from each per-label list
assert num_sample_class > 0 and isinstance(num_sample_class, int)
self.num_sample_class = num_sample_class
# Get per-label image list from dataset
self.cat_dict = self.get_cat2imgs()
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / world_size))
self.total_size = self.num_samples * self.world_size
# get number of images containing each category
self.num_cat_imgs = [len(x) for x in self.cat_dict.values()]
# filter labels without images
self.valid_cat_inds = [
i for i, length in enumerate(self.num_cat_imgs) if length != 0
]
self.num_classes = len(self.valid_cat_inds)
def get_cat2imgs(self) -> Dict[int, list]:
"""Get a dict with class as key and img_ids as values.
Returns:
dict[int, list]: A dict of per-label image list,
the item of the dict indicates a label index,
corresponds to the image index that contains the label.
"""
classes = self.dataset.metainfo.get('classes', None)
if classes is None:
raise ValueError('dataset metainfo must contain `classes`')
# sort the label index
cat2imgs = {i: [] for i in range(len(classes))}
for i in range(len(self.dataset)):
cat_ids = set(self.dataset.get_cat_ids(i))
for cat in cat_ids:
cat2imgs[cat].append(i)
return cat2imgs
def __iter__(self) -> Iterator[int]:
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch + self.seed)
# initialize label list
label_iter_list = RandomCycleIter(self.valid_cat_inds, generator=g)
# initialize each per-label image list
data_iter_dict = dict()
for i in self.valid_cat_inds:
data_iter_dict[i] = RandomCycleIter(self.cat_dict[i], generator=g)
def gen_cat_img_inds(cls_list, data_dict, num_sample_cls):
"""Traverse the categories and extract `num_sample_cls` image
indexes of the corresponding categories one by one."""
id_indices = []
for _ in range(len(cls_list)):
cls_idx = next(cls_list)
for _ in range(num_sample_cls):
id = next(data_dict[cls_idx])
id_indices.append(id)
return id_indices
# deterministically shuffle based on epoch
num_bins = int(
math.ceil(self.total_size * 1.0 / self.num_classes /
self.num_sample_class))
indices = []
for i in range(num_bins):
indices += gen_cat_img_inds(label_iter_list, data_iter_dict,
self.num_sample_class)
# fix extra samples to make it evenly divisible
if len(indices) >= self.total_size:
indices = indices[:self.total_size]
else:
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
offset = self.num_samples * self.rank
indices = indices[offset:offset + self.num_samples]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self) -> int:
"""The number of samples in this rank."""
return self.num_samples
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all replicas use a different
random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
class RandomCycleIter:
"""Shuffle the list and do it again after the list have traversed.
The implementation logic is referred to
https://github.com/wutong16/DistributionBalancedLoss/blob/master/mllt/datasets/loader/sampler.py
Example:
>>> label_list = [0, 1, 2, 4, 5]
>>> g = torch.Generator()
>>> g.manual_seed(0)
>>> label_iter_list = RandomCycleIter(label_list, generator=g)
>>> index = next(label_iter_list)
Args:
data (list or ndarray): The data that needs to be shuffled.
generator: An torch.Generator object, which is used in setting the seed
for generating random numbers.
""" # noqa: W605
def __init__(self,
data: Union[list, np.ndarray],
generator: torch.Generator = None) -> None:
self.data = data
self.length = len(data)
self.index = torch.randperm(self.length, generator=generator).numpy()
self.i = 0
self.generator = generator
def __iter__(self) -> Iterator:
return self
def __len__(self) -> int:
return len(self.data)
def __next__(self):
if self.i == self.length:
self.index = torch.randperm(
self.length, generator=self.generator).numpy()
self.i = 0
idx = self.data[self.index[self.i]]
self.i += 1
return idx
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Iterator, Optional, Sequence, Sized
import torch
from mmengine.dist import get_dist_info, sync_random_seed
from mmengine.registry import DATA_SAMPLERS
from torch.utils.data import Sampler
@DATA_SAMPLERS.register_module()
class MultiDataSampler(Sampler):
"""The default data sampler for both distributed and non-distributed
environment.
It has several differences from the PyTorch ``DistributedSampler`` as
below:
1. This sampler supports non-distributed environment.
2. The round up behaviors are a little different.
- If ``round_up=True``, this sampler will add extra samples to make the
number of samples is evenly divisible by the world size. And
this behavior is the same as the ``DistributedSampler`` with
``drop_last=False``.
- If ``round_up=False``, this sampler won't remove or add any samples
while the ``DistributedSampler`` with ``drop_last=True`` will remove
tail samples.
Args:
dataset (Sized): The dataset.
dataset_ratio (Sequence(int)) The ratios of different datasets.
seed (int, optional): Random seed used to shuffle the sampler if
:attr:`shuffle=True`. This number should be identical across all
processes in the distributed group. Defaults to None.
round_up (bool): Whether to add extra samples to make the number of
samples evenly divisible by the world size. Defaults to True.
"""
def __init__(self,
dataset: Sized,
dataset_ratio: Sequence[int],
seed: Optional[int] = None,
round_up: bool = True) -> None:
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
self.dataset_ratio = dataset_ratio
if seed is None:
seed = sync_random_seed()
self.seed = seed
self.epoch = 0
self.round_up = round_up
if self.round_up:
self.num_samples = math.ceil(len(self.dataset) / world_size)
self.total_size = self.num_samples * self.world_size
else:
self.num_samples = math.ceil(
(len(self.dataset) - rank) / world_size)
self.total_size = len(self.dataset)
self.sizes = [len(dataset) for dataset in self.dataset.datasets]
dataset_weight = [
torch.ones(s) * max(self.sizes) / s * r / sum(self.dataset_ratio)
for i, (r, s) in enumerate(zip(self.dataset_ratio, self.sizes))
]
self.weights = torch.cat(dataset_weight)
def __iter__(self) -> Iterator[int]:
"""Iterate the indices."""
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.multinomial(
self.weights, len(self.weights), generator=g,
replacement=True).tolist()
# add extra samples to make it evenly divisible
if self.round_up:
indices = (
indices *
int(self.total_size / len(indices) + 1))[:self.total_size]
# subsample
indices = indices[self.rank:self.total_size:self.world_size]
return iter(indices)
def __len__(self) -> int:
"""The number of samples in this rank."""
return self.num_samples
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all replicas use a different
random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from typing import Iterator, List, Optional, Sized, Union
import numpy as np
import torch
from mmengine.dataset import BaseDataset
from mmengine.dist import get_dist_info, sync_random_seed
from torch.utils.data import Sampler
from mmdet.registry import DATA_SAMPLERS
@DATA_SAMPLERS.register_module()
class MultiSourceSampler(Sampler):
r"""Multi-Source Infinite Sampler.
According to the sampling ratio, sample data from different
datasets to form batches.
Args:
dataset (Sized): The dataset.
batch_size (int): Size of mini-batch.
source_ratio (list[int | float]): The sampling ratio of different
source datasets in a mini-batch.
shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
seed (int, optional): Random seed. If None, set a random seed.
Defaults to None.
Examples:
>>> dataset_type = 'ConcatDataset'
>>> sub_dataset_type = 'CocoDataset'
>>> data_root = 'data/coco/'
>>> sup_ann = '../coco_semi_annos/instances_train2017.1@10.json'
>>> unsup_ann = '../coco_semi_annos/' \
>>> 'instances_train2017.1@10-unlabeled.json'
>>> dataset = dict(type=dataset_type,
>>> datasets=[
>>> dict(
>>> type=sub_dataset_type,
>>> data_root=data_root,
>>> ann_file=sup_ann,
>>> data_prefix=dict(img='train2017/'),
>>> filter_cfg=dict(filter_empty_gt=True, min_size=32),
>>> pipeline=sup_pipeline),
>>> dict(
>>> type=sub_dataset_type,
>>> data_root=data_root,
>>> ann_file=unsup_ann,
>>> data_prefix=dict(img='train2017/'),
>>> filter_cfg=dict(filter_empty_gt=True, min_size=32),
>>> pipeline=unsup_pipeline),
>>> ])
>>> train_dataloader = dict(
>>> batch_size=5,
>>> num_workers=5,
>>> persistent_workers=True,
>>> sampler=dict(type='MultiSourceSampler',
>>> batch_size=5, source_ratio=[1, 4]),
>>> batch_sampler=None,
>>> dataset=dataset)
"""
def __init__(self,
dataset: Sized,
batch_size: int,
source_ratio: List[Union[int, float]],
shuffle: bool = True,
seed: Optional[int] = None) -> None:
assert hasattr(dataset, 'cumulative_sizes'),\
f'The dataset must be ConcatDataset, but get {dataset}'
assert isinstance(batch_size, int) and batch_size > 0, \
'batch_size must be a positive integer value, ' \
f'but got batch_size={batch_size}'
assert isinstance(source_ratio, list), \
f'source_ratio must be a list, but got source_ratio={source_ratio}'
assert len(source_ratio) == len(dataset.cumulative_sizes), \
'The length of source_ratio must be equal to ' \
f'the number of datasets, but got source_ratio={source_ratio}'
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
self.cumulative_sizes = [0] + dataset.cumulative_sizes
self.batch_size = batch_size
self.source_ratio = source_ratio
self.num_per_source = [
int(batch_size * sr / sum(source_ratio)) for sr in source_ratio
]
self.num_per_source[0] = batch_size - sum(self.num_per_source[1:])
assert sum(self.num_per_source) == batch_size, \
'The sum of num_per_source must be equal to ' \
f'batch_size, but get {self.num_per_source}'
self.seed = sync_random_seed() if seed is None else seed
self.shuffle = shuffle
self.source2inds = {
source: self._indices_of_rank(len(ds))
for source, ds in enumerate(dataset.datasets)
}
def _infinite_indices(self, sample_size: int) -> Iterator[int]:
"""Infinitely yield a sequence of indices."""
g = torch.Generator()
g.manual_seed(self.seed)
while True:
if self.shuffle:
yield from torch.randperm(sample_size, generator=g).tolist()
else:
yield from torch.arange(sample_size).tolist()
def _indices_of_rank(self, sample_size: int) -> Iterator[int]:
"""Slice the infinite indices by rank."""
yield from itertools.islice(
self._infinite_indices(sample_size), self.rank, None,
self.world_size)
def __iter__(self) -> Iterator[int]:
batch_buffer = []
while True:
for source, num in enumerate(self.num_per_source):
batch_buffer_per_source = []
for idx in self.source2inds[source]:
idx += self.cumulative_sizes[source]
batch_buffer_per_source.append(idx)
if len(batch_buffer_per_source) == num:
batch_buffer += batch_buffer_per_source
break
yield from batch_buffer
batch_buffer = []
def __len__(self) -> int:
return len(self.dataset)
def set_epoch(self, epoch: int) -> None:
"""Not supported in `epoch-based runner."""
pass
@DATA_SAMPLERS.register_module()
class GroupMultiSourceSampler(MultiSourceSampler):
r"""Group Multi-Source Infinite Sampler.
According to the sampling ratio, sample data from different
datasets but the same group to form batches.
Args:
dataset (Sized): The dataset.
batch_size (int): Size of mini-batch.
source_ratio (list[int | float]): The sampling ratio of different
source datasets in a mini-batch.
shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
seed (int, optional): Random seed. If None, set a random seed.
Defaults to None.
"""
def __init__(self,
dataset: BaseDataset,
batch_size: int,
source_ratio: List[Union[int, float]],
shuffle: bool = True,
seed: Optional[int] = None) -> None:
super().__init__(
dataset=dataset,
batch_size=batch_size,
source_ratio=source_ratio,
shuffle=shuffle,
seed=seed)
self._get_source_group_info()
self.group_source2inds = [{
source:
self._indices_of_rank(self.group2size_per_source[source][group])
for source in range(len(dataset.datasets))
} for group in range(len(self.group_ratio))]
def _get_source_group_info(self) -> None:
self.group2size_per_source = [{0: 0, 1: 0}, {0: 0, 1: 0}]
self.group2inds_per_source = [{0: [], 1: []}, {0: [], 1: []}]
for source, dataset in enumerate(self.dataset.datasets):
for idx in range(len(dataset)):
data_info = dataset.get_data_info(idx)
width, height = data_info['width'], data_info['height']
group = 0 if width < height else 1
self.group2size_per_source[source][group] += 1
self.group2inds_per_source[source][group].append(idx)
self.group_sizes = np.zeros(2, dtype=np.int64)
for group2size in self.group2size_per_source:
for group, size in group2size.items():
self.group_sizes[group] += size
self.group_ratio = self.group_sizes / sum(self.group_sizes)
def __iter__(self) -> Iterator[int]:
batch_buffer = []
while True:
group = np.random.choice(
list(range(len(self.group_ratio))), p=self.group_ratio)
for source, num in enumerate(self.num_per_source):
batch_buffer_per_source = []
for idx in self.group_source2inds[group][source]:
idx = self.group2inds_per_source[source][group][
idx] + self.cumulative_sizes[source]
batch_buffer_per_source.append(idx)
if len(batch_buffer_per_source) == num:
batch_buffer += batch_buffer_per_source
break
yield from batch_buffer
batch_buffer = []
# Copyright (c) OpenMMLab. All rights reserved.
import math
import random
from typing import Iterator, Optional, Sized
import numpy as np
from mmengine.dataset import ClassBalancedDataset, ConcatDataset
from mmengine.dist import get_dist_info, sync_random_seed
from torch.utils.data import Sampler
from mmdet.registry import DATA_SAMPLERS
from ..base_video_dataset import BaseVideoDataset
@DATA_SAMPLERS.register_module()
class TrackImgSampler(Sampler):
"""Sampler that providing image-level sampling outputs for video datasets
in tracking tasks. It could be both used in both distributed and
non-distributed environment.
If using the default sampler in pytorch, the subsequent data receiver will
get one video, which is not desired in some cases:
(Take a non-distributed environment as an example)
1. In test mode, we want only one image is fed into the data pipeline. This
is in consideration of memory usage since feeding the whole video commonly
requires a large amount of memory (>=20G on MOTChallenge17 dataset), which
is not available in some machines.
2. In training mode, we may want to make sure all the images in one video
are randomly sampled once in one epoch and this can not be guaranteed in
the default sampler in pytorch.
Args:
dataset (Sized): Dataset used for sampling.
seed (int, optional): random seed used to shuffle the sampler. This
number should be identical across all processes in the distributed
group. Defaults to None.
"""
def __init__(
self,
dataset: Sized,
seed: Optional[int] = None,
) -> None:
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.epoch = 0
if seed is None:
self.seed = sync_random_seed()
else:
self.seed = seed
self.dataset = dataset
self.indices = []
# Hard code here to handle different dataset wrapper
if isinstance(self.dataset, ConcatDataset):
cat_datasets = self.dataset.datasets
assert isinstance(
cat_datasets[0], BaseVideoDataset
), f'expected BaseVideoDataset, but got {type(cat_datasets[0])}'
self.test_mode = cat_datasets[0].test_mode
assert not self.test_mode, "'ConcatDataset' should not exist in "
'test mode'
for dataset in cat_datasets:
num_videos = len(dataset)
for video_ind in range(num_videos):
self.indices.extend([
(video_ind, frame_ind) for frame_ind in range(
dataset.get_len_per_video(video_ind))
])
elif isinstance(self.dataset, ClassBalancedDataset):
ori_dataset = self.dataset.dataset
assert isinstance(
ori_dataset, BaseVideoDataset
), f'expected BaseVideoDataset, but got {type(ori_dataset)}'
self.test_mode = ori_dataset.test_mode
assert not self.test_mode, "'ClassBalancedDataset' should not "
'exist in test mode'
video_indices = self.dataset.repeat_indices
for index in video_indices:
self.indices.extend([(index, frame_ind) for frame_ind in range(
ori_dataset.get_len_per_video(index))])
else:
assert isinstance(
self.dataset, BaseVideoDataset
), 'TrackImgSampler is only supported in BaseVideoDataset or '
'dataset wrapper: ClassBalancedDataset and ConcatDataset, but '
f'got {type(self.dataset)} '
self.test_mode = self.dataset.test_mode
num_videos = len(self.dataset)
if self.test_mode:
# in test mode, the images belong to the same video must be put
# on the same device.
if num_videos < self.world_size:
raise ValueError(f'only {num_videos} videos loaded,'
f'but {self.world_size} gpus were given.')
chunks = np.array_split(
list(range(num_videos)), self.world_size)
for videos_inds in chunks:
indices_chunk = []
for video_ind in videos_inds:
indices_chunk.extend([
(video_ind, frame_ind) for frame_ind in range(
self.dataset.get_len_per_video(video_ind))
])
self.indices.append(indices_chunk)
else:
for video_ind in range(num_videos):
self.indices.extend([
(video_ind, frame_ind) for frame_ind in range(
self.dataset.get_len_per_video(video_ind))
])
if self.test_mode:
self.num_samples = len(self.indices[self.rank])
self.total_size = sum(
[len(index_list) for index_list in self.indices])
else:
self.num_samples = int(
math.ceil(len(self.indices) * 1.0 / self.world_size))
self.total_size = self.num_samples * self.world_size
def __iter__(self) -> Iterator:
if self.test_mode:
# in test mode, the order of frames can not be shuffled.
indices = self.indices[self.rank]
else:
# deterministically shuffle based on epoch
rng = random.Random(self.epoch + self.seed)
indices = rng.sample(self.indices, len(self.indices))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.world_size]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
# Copyright (c) OpenMMLab. All rights reserved.
from .augment_wrappers import AutoAugment, RandAugment
from .colorspace import (AutoContrast, Brightness, Color, ColorTransform,
Contrast, Equalize, Invert, Posterize, Sharpness,
Solarize, SolarizeAdd)
from .formatting import (ImageToTensor, PackDetInputs, PackReIDInputs,
PackTrackInputs, ToTensor, Transpose)
from .frame_sampling import BaseFrameSample, UniformRefFrameSample
from .geometric import (GeomTransform, Rotate, ShearX, ShearY, TranslateX,
TranslateY)
from .instaboost import InstaBoost
from .loading import (FilterAnnotations, InferencerLoader, LoadAnnotations,
LoadEmptyAnnotations, LoadImageFromNDArray,
LoadMultiChannelImageFromFiles, LoadPanopticAnnotations,
LoadProposals, LoadTrackAnnotations)
from .transformers_glip import GTBoxSubOne_GLIP, RandomFlip_GLIP
from .transforms import (Albu, CachedMixUp, CachedMosaic, CopyPaste, CutOut,
Expand, FixScaleResize, FixShapeResize,
MinIoURandomCrop, MixUp, Mosaic, Pad,
PhotoMetricDistortion, RandomAffine,
RandomCenterCropPad, RandomCrop, RandomErasing,
RandomFlip, RandomShift, Resize, ResizeShortestEdge,
SegRescale, YOLOXHSVRandomAug)
from .wrappers import MultiBranch, ProposalBroadcaster, RandomOrder
__all__ = [
'PackDetInputs', 'ToTensor', 'ImageToTensor', 'Transpose',
'LoadImageFromNDArray', 'LoadAnnotations', 'LoadPanopticAnnotations',
'LoadMultiChannelImageFromFiles', 'LoadProposals', 'Resize', 'RandomFlip',
'RandomCrop', 'SegRescale', 'MinIoURandomCrop', 'Expand',
'PhotoMetricDistortion', 'Albu', 'InstaBoost', 'RandomCenterCropPad',
'AutoAugment', 'CutOut', 'ShearX', 'ShearY', 'Rotate', 'Color', 'Equalize',
'Brightness', 'Contrast', 'TranslateX', 'TranslateY', 'RandomShift',
'Mosaic', 'MixUp', 'RandomAffine', 'YOLOXHSVRandomAug', 'CopyPaste',
'FilterAnnotations', 'Pad', 'GeomTransform', 'ColorTransform',
'RandAugment', 'Sharpness', 'Solarize', 'SolarizeAdd', 'Posterize',
'AutoContrast', 'Invert', 'MultiBranch', 'RandomErasing',
'LoadEmptyAnnotations', 'RandomOrder', 'CachedMosaic', 'CachedMixUp',
'FixShapeResize', 'ProposalBroadcaster', 'InferencerLoader',
'LoadTrackAnnotations', 'BaseFrameSample', 'UniformRefFrameSample',
'PackTrackInputs', 'PackReIDInputs', 'FixScaleResize',
'ResizeShortestEdge', 'GTBoxSubOne_GLIP', 'RandomFlip_GLIP'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import numpy as np
from mmcv.transforms import RandomChoice
from mmcv.transforms.utils import cache_randomness
from mmengine.config import ConfigDict
from mmdet.registry import TRANSFORMS
# AutoAugment uses reinforcement learning to search for
# some widely useful data augmentation strategies,
# here we provide AUTOAUG_POLICIES_V0.
# For AUTOAUG_POLICIES_V0, each tuple is an augmentation
# operation of the form (operation, probability, magnitude).
# Each element in policies is a policy that will be applied
# sequentially on the image.
# RandAugment defines a data augmentation search space, RANDAUG_SPACE,
# sampling 1~3 data augmentations each time, and
# setting the magnitude of each data augmentation randomly,
# which will be applied sequentially on the image.
_MAX_LEVEL = 10
AUTOAUG_POLICIES_V0 = [
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
[('Rotate', 1.0, 7), ('TranslateY', 0.8, 9)],
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
]
def policies_v0():
"""Autoaugment policies that was used in AutoAugment Paper."""
policies = list()
for policy_args in AUTOAUG_POLICIES_V0:
policy = list()
for args in policy_args:
policy.append(dict(type=args[0], prob=args[1], level=args[2]))
policies.append(policy)
return policies
RANDAUG_SPACE = [[dict(type='AutoContrast')], [dict(type='Equalize')],
[dict(type='Invert')], [dict(type='Rotate')],
[dict(type='Posterize')], [dict(type='Solarize')],
[dict(type='SolarizeAdd')], [dict(type='Color')],
[dict(type='Contrast')], [dict(type='Brightness')],
[dict(type='Sharpness')], [dict(type='ShearX')],
[dict(type='ShearY')], [dict(type='TranslateX')],
[dict(type='TranslateY')]]
def level_to_mag(level: Optional[int], min_mag: float,
max_mag: float) -> float:
"""Map from level to magnitude."""
if level is None:
return round(np.random.rand() * (max_mag - min_mag) + min_mag, 1)
else:
return round(level / _MAX_LEVEL * (max_mag - min_mag) + min_mag, 1)
@TRANSFORMS.register_module()
class AutoAugment(RandomChoice):
"""Auto augmentation.
This data augmentation is proposed in `AutoAugment: Learning
Augmentation Policies from Data <https://arxiv.org/abs/1805.09501>`_
and in `Learning Data Augmentation Strategies for Object Detection
<https://arxiv.org/pdf/1906.11172>`_.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_bboxes_labels (np.int64) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_ignore_flags (bool) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- img_shape
- gt_bboxes
- gt_bboxes_labels
- gt_masks
- gt_ignore_flags
- gt_seg_map
Added Keys:
- homography_matrix
Args:
policies (List[List[Union[dict, ConfigDict]]]):
The policies of auto augmentation.Each policy in ``policies``
is a specific augmentation policy, and is composed by several
augmentations. When AutoAugment is called, a random policy in
``policies`` will be selected to augment images.
Defaults to policy_v0().
prob (list[float], optional): The probabilities associated
with each policy. The length should be equal to the policy
number and the sum should be 1. If not given, a uniform
distribution will be assumed. Defaults to None.
Examples:
>>> policies = [
>>> [
>>> dict(type='Sharpness', prob=0.0, level=8),
>>> dict(type='ShearX', prob=0.4, level=0,)
>>> ],
>>> [
>>> dict(type='Rotate', prob=0.6, level=10),
>>> dict(type='Color', prob=1.0, level=6)
>>> ]
>>> ]
>>> augmentation = AutoAugment(policies)
>>> img = np.ones(100, 100, 3)
>>> gt_bboxes = np.ones(10, 4)
>>> results = dict(img=img, gt_bboxes=gt_bboxes)
>>> results = augmentation(results)
"""
def __init__(self,
policies: List[List[Union[dict, ConfigDict]]] = policies_v0(),
prob: Optional[List[float]] = None) -> None:
assert isinstance(policies, list) and len(policies) > 0, \
'Policies must be a non-empty list.'
for policy in policies:
assert isinstance(policy, list) and len(policy) > 0, \
'Each policy in policies must be a non-empty list.'
for augment in policy:
assert isinstance(augment, dict) and 'type' in augment, \
'Each specific augmentation must be a dict with key' \
' "type".'
super().__init__(transforms=policies, prob=prob)
self.policies = policies
def __repr__(self) -> str:
return f'{self.__class__.__name__}(policies={self.policies}, ' \
f'prob={self.prob})'
@TRANSFORMS.register_module()
class RandAugment(RandomChoice):
"""Rand augmentation.
This data augmentation is proposed in `RandAugment:
Practical automated data augmentation with a reduced
search space <https://arxiv.org/abs/1909.13719>`_.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_bboxes_labels (np.int64) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_ignore_flags (bool) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- img_shape
- gt_bboxes
- gt_bboxes_labels
- gt_masks
- gt_ignore_flags
- gt_seg_map
Added Keys:
- homography_matrix
Args:
aug_space (List[List[Union[dict, ConfigDict]]]): The augmentation space
of rand augmentation. Each augmentation transform in ``aug_space``
is a specific transform, and is composed by several augmentations.
When RandAugment is called, a random transform in ``aug_space``
will be selected to augment images. Defaults to aug_space.
aug_num (int): Number of augmentation to apply equentially.
Defaults to 2.
prob (list[float], optional): The probabilities associated with
each augmentation. The length should be equal to the
augmentation space and the sum should be 1. If not given,
a uniform distribution will be assumed. Defaults to None.
Examples:
>>> aug_space = [
>>> dict(type='Sharpness'),
>>> dict(type='ShearX'),
>>> dict(type='Color'),
>>> ],
>>> augmentation = RandAugment(aug_space)
>>> img = np.ones(100, 100, 3)
>>> gt_bboxes = np.ones(10, 4)
>>> results = dict(img=img, gt_bboxes=gt_bboxes)
>>> results = augmentation(results)
"""
def __init__(self,
aug_space: List[Union[dict, ConfigDict]] = RANDAUG_SPACE,
aug_num: int = 2,
prob: Optional[List[float]] = None) -> None:
assert isinstance(aug_space, list) and len(aug_space) > 0, \
'Augmentation space must be a non-empty list.'
for aug in aug_space:
assert isinstance(aug, list) and len(aug) == 1, \
'Each augmentation in aug_space must be a list.'
for transform in aug:
assert isinstance(transform, dict) and 'type' in transform, \
'Each specific transform must be a dict with key' \
' "type".'
super().__init__(transforms=aug_space, prob=prob)
self.aug_space = aug_space
self.aug_num = aug_num
@cache_randomness
def random_pipeline_index(self):
indices = np.arange(len(self.transforms))
return np.random.choice(
indices, self.aug_num, p=self.prob, replace=False)
def transform(self, results: dict) -> dict:
"""Transform function to use RandAugment.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with RandAugment.
"""
for idx in self.random_pipeline_index():
results = self.transforms[idx](results)
return results
def __repr__(self) -> str:
return f'{self.__class__.__name__}(' \
f'aug_space={self.aug_space}, '\
f'aug_num={self.aug_num}, ' \
f'prob={self.prob})'
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Optional
import mmcv
import numpy as np
from mmcv.transforms import BaseTransform
from mmcv.transforms.utils import cache_randomness
from mmdet.registry import TRANSFORMS
from .augment_wrappers import _MAX_LEVEL, level_to_mag
@TRANSFORMS.register_module()
class ColorTransform(BaseTransform):
"""Base class for color transformations. All color transformations need to
inherit from this base class. ``ColorTransform`` unifies the class
attributes and class functions of color transformations (Color, Brightness,
Contrast, Sharpness, Solarize, SolarizeAdd, Equalize, AutoContrast, Invert,
and Posterize), and only distort color channels, without impacting the
locations of the instances.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing the geometric
transformation and should be in range [0, 1]. Defaults to 1.0.
level (int, optional): The level should be in range [0, _MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for color transformation.
Defaults to 0.1.
max_mag (float): The maximum magnitude for color transformation.
Defaults to 1.9.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.1,
max_mag: float = 1.9) -> None:
assert 0 <= prob <= 1.0, f'The probability of the transformation ' \
f'should be in range [0,1], got {prob}.'
assert level is None or isinstance(level, int), \
f'The level should be None or type int, got {type(level)}.'
assert level is None or 0 <= level <= _MAX_LEVEL, \
f'The level should be in range [0,{_MAX_LEVEL}], got {level}.'
assert isinstance(min_mag, float), \
f'min_mag should be type float, got {type(min_mag)}.'
assert isinstance(max_mag, float), \
f'max_mag should be type float, got {type(max_mag)}.'
assert min_mag <= max_mag, \
f'min_mag should smaller than max_mag, ' \
f'got min_mag={min_mag} and max_mag={max_mag}'
self.prob = prob
self.level = level
self.min_mag = min_mag
self.max_mag = max_mag
def _transform_img(self, results: dict, mag: float) -> None:
"""Transform the image."""
pass
@cache_randomness
def _random_disable(self):
"""Randomly disable the transform."""
return np.random.rand() > self.prob
@cache_randomness
def _get_mag(self):
"""Get the magnitude of the transform."""
return level_to_mag(self.level, self.min_mag, self.max_mag)
def transform(self, results: dict) -> dict:
"""Transform function for images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Transformed results.
"""
if self._random_disable():
return results
mag = self._get_mag()
self._transform_img(results, mag)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, '
repr_str += f'level={self.level}, '
repr_str += f'min_mag={self.min_mag}, '
repr_str += f'max_mag={self.max_mag})'
return repr_str
@TRANSFORMS.register_module()
class Color(ColorTransform):
"""Adjust the color balance of the image, in a manner similar to the
controls on a colour TV set. A magnitude=0 gives a black & white image,
whereas magnitude=1 gives the original image. The bboxes, masks and
segmentations are not modified.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing Color transformation.
Defaults to 1.0.
level (int, optional): Should be in range [0,_MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for Color transformation.
Defaults to 0.1.
max_mag (float): The maximum magnitude for Color transformation.
Defaults to 1.9.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.1,
max_mag: float = 1.9) -> None:
assert 0. <= min_mag <= 2.0, \
f'min_mag for Color should be in range [0,2], got {min_mag}.'
assert 0. <= max_mag <= 2.0, \
f'max_mag for Color should be in range [0,2], got {max_mag}.'
super().__init__(
prob=prob, level=level, min_mag=min_mag, max_mag=max_mag)
def _transform_img(self, results: dict, mag: float) -> None:
"""Apply Color transformation to image."""
# NOTE defaultly the image should be BGR format
img = results['img']
results['img'] = mmcv.adjust_color(img, mag).astype(img.dtype)
@TRANSFORMS.register_module()
class Brightness(ColorTransform):
"""Adjust the brightness of the image. A magnitude=0 gives a black image,
whereas magnitude=1 gives the original image. The bboxes, masks and
segmentations are not modified.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing Brightness transformation.
Defaults to 1.0.
level (int, optional): Should be in range [0,_MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for Brightness transformation.
Defaults to 0.1.
max_mag (float): The maximum magnitude for Brightness transformation.
Defaults to 1.9.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.1,
max_mag: float = 1.9) -> None:
assert 0. <= min_mag <= 2.0, \
f'min_mag for Brightness should be in range [0,2], got {min_mag}.'
assert 0. <= max_mag <= 2.0, \
f'max_mag for Brightness should be in range [0,2], got {max_mag}.'
super().__init__(
prob=prob, level=level, min_mag=min_mag, max_mag=max_mag)
def _transform_img(self, results: dict, mag: float) -> None:
"""Adjust the brightness of image."""
img = results['img']
results['img'] = mmcv.adjust_brightness(img, mag).astype(img.dtype)
@TRANSFORMS.register_module()
class Contrast(ColorTransform):
"""Control the contrast of the image. A magnitude=0 gives a gray image,
whereas magnitude=1 gives the original imageThe bboxes, masks and
segmentations are not modified.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing Contrast transformation.
Defaults to 1.0.
level (int, optional): Should be in range [0,_MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for Contrast transformation.
Defaults to 0.1.
max_mag (float): The maximum magnitude for Contrast transformation.
Defaults to 1.9.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.1,
max_mag: float = 1.9) -> None:
assert 0. <= min_mag <= 2.0, \
f'min_mag for Contrast should be in range [0,2], got {min_mag}.'
assert 0. <= max_mag <= 2.0, \
f'max_mag for Contrast should be in range [0,2], got {max_mag}.'
super().__init__(
prob=prob, level=level, min_mag=min_mag, max_mag=max_mag)
def _transform_img(self, results: dict, mag: float) -> None:
"""Adjust the image contrast."""
img = results['img']
results['img'] = mmcv.adjust_contrast(img, mag).astype(img.dtype)
@TRANSFORMS.register_module()
class Sharpness(ColorTransform):
"""Adjust images sharpness. A positive magnitude would enhance the
sharpness and a negative magnitude would make the image blurry. A
magnitude=0 gives the origin img.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing Sharpness transformation.
Defaults to 1.0.
level (int, optional): Should be in range [0,_MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for Sharpness transformation.
Defaults to 0.1.
max_mag (float): The maximum magnitude for Sharpness transformation.
Defaults to 1.9.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.1,
max_mag: float = 1.9) -> None:
assert 0. <= min_mag <= 2.0, \
f'min_mag for Sharpness should be in range [0,2], got {min_mag}.'
assert 0. <= max_mag <= 2.0, \
f'max_mag for Sharpness should be in range [0,2], got {max_mag}.'
super().__init__(
prob=prob, level=level, min_mag=min_mag, max_mag=max_mag)
def _transform_img(self, results: dict, mag: float) -> None:
"""Adjust the image sharpness."""
img = results['img']
results['img'] = mmcv.adjust_sharpness(img, mag).astype(img.dtype)
@TRANSFORMS.register_module()
class Solarize(ColorTransform):
"""Solarize images (Invert all pixels above a threshold value of
magnitude.).
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing Solarize transformation.
Defaults to 1.0.
level (int, optional): Should be in range [0,_MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for Solarize transformation.
Defaults to 0.0.
max_mag (float): The maximum magnitude for Solarize transformation.
Defaults to 256.0.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.0,
max_mag: float = 256.0) -> None:
assert 0. <= min_mag <= 256.0, f'min_mag for Solarize should be ' \
f'in range [0, 256], got {min_mag}.'
assert 0. <= max_mag <= 256.0, f'max_mag for Solarize should be ' \
f'in range [0, 256], got {max_mag}.'
super().__init__(
prob=prob, level=level, min_mag=min_mag, max_mag=max_mag)
def _transform_img(self, results: dict, mag: float) -> None:
"""Invert all pixel values above magnitude."""
img = results['img']
results['img'] = mmcv.solarize(img, mag).astype(img.dtype)
@TRANSFORMS.register_module()
class SolarizeAdd(ColorTransform):
"""SolarizeAdd images. For each pixel in the image that is less than 128,
add an additional amount to it decided by the magnitude.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing SolarizeAdd
transformation. Defaults to 1.0.
level (int, optional): Should be in range [0,_MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for SolarizeAdd transformation.
Defaults to 0.0.
max_mag (float): The maximum magnitude for SolarizeAdd transformation.
Defaults to 110.0.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.0,
max_mag: float = 110.0) -> None:
assert 0. <= min_mag <= 110.0, f'min_mag for SolarizeAdd should be ' \
f'in range [0, 110], got {min_mag}.'
assert 0. <= max_mag <= 110.0, f'max_mag for SolarizeAdd should be ' \
f'in range [0, 110], got {max_mag}.'
super().__init__(
prob=prob, level=level, min_mag=min_mag, max_mag=max_mag)
def _transform_img(self, results: dict, mag: float) -> None:
"""SolarizeAdd the image."""
img = results['img']
img_solarized = np.where(img < 128, np.minimum(img + mag, 255), img)
results['img'] = img_solarized.astype(img.dtype)
@TRANSFORMS.register_module()
class Posterize(ColorTransform):
"""Posterize images (reduce the number of bits for each color channel).
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing Posterize
transformation. Defaults to 1.0.
level (int, optional): Should be in range [0,_MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for Posterize transformation.
Defaults to 0.0.
max_mag (float): The maximum magnitude for Posterize transformation.
Defaults to 4.0.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.0,
max_mag: float = 4.0) -> None:
assert 0. <= min_mag <= 8.0, f'min_mag for Posterize should be ' \
f'in range [0, 8], got {min_mag}.'
assert 0. <= max_mag <= 8.0, f'max_mag for Posterize should be ' \
f'in range [0, 8], got {max_mag}.'
super().__init__(
prob=prob, level=level, min_mag=min_mag, max_mag=max_mag)
def _transform_img(self, results: dict, mag: float) -> None:
"""Posterize the image."""
img = results['img']
results['img'] = mmcv.posterize(img, math.ceil(mag)).astype(img.dtype)
@TRANSFORMS.register_module()
class Equalize(ColorTransform):
"""Equalize the image histogram. The bboxes, masks and segmentations are
not modified.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing Equalize transformation.
Defaults to 1.0.
level (int, optional): No use for Equalize transformation.
Defaults to None.
min_mag (float): No use for Equalize transformation. Defaults to 0.1.
max_mag (float): No use for Equalize transformation. Defaults to 1.9.
"""
def _transform_img(self, results: dict, mag: float) -> None:
"""Equalizes the histogram of one image."""
img = results['img']
results['img'] = mmcv.imequalize(img).astype(img.dtype)
@TRANSFORMS.register_module()
class AutoContrast(ColorTransform):
"""Auto adjust image contrast.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing AutoContrast should
be in range [0, 1]. Defaults to 1.0.
level (int, optional): No use for AutoContrast transformation.
Defaults to None.
min_mag (float): No use for AutoContrast transformation.
Defaults to 0.1.
max_mag (float): No use for AutoContrast transformation.
Defaults to 1.9.
"""
def _transform_img(self, results: dict, mag: float) -> None:
"""Auto adjust image contrast."""
img = results['img']
results['img'] = mmcv.auto_contrast(img).astype(img.dtype)
@TRANSFORMS.register_module()
class Invert(ColorTransform):
"""Invert images.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing invert therefore should
be in range [0, 1]. Defaults to 1.0.
level (int, optional): No use for Invert transformation.
Defaults to None.
min_mag (float): No use for Invert transformation. Defaults to 0.1.
max_mag (float): No use for Invert transformation. Defaults to 1.9.
"""
def _transform_img(self, results: dict, mag: float) -> None:
"""Invert the image."""
img = results['img']
results['img'] = mmcv.iminvert(img).astype(img.dtype)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence
import numpy as np
from mmcv.transforms import to_tensor
from mmcv.transforms.base import BaseTransform
from mmengine.structures import InstanceData, PixelData
from mmdet.registry import TRANSFORMS
from mmdet.structures import DetDataSample, ReIDDataSample, TrackDataSample
from mmdet.structures.bbox import BaseBoxes
@TRANSFORMS.register_module()
class PackDetInputs(BaseTransform):
"""Pack the inputs data for the detection / semantic segmentation /
panoptic segmentation.
The ``img_meta`` item is always populated. The contents of the
``img_meta`` dictionary depends on ``meta_keys``. By default this includes:
- ``img_id``: id of the image
- ``img_path``: path to the image file
- ``ori_shape``: original shape of the image as a tuple (h, w)
- ``img_shape``: shape of the image input to the network as a tuple \
(h, w). Note that images may be zero padded on the \
bottom/right if the batch tensor is larger than this shape.
- ``scale_factor``: a float indicating the preprocessing scale
- ``flip``: a boolean indicating if image flip transform was used
- ``flip_direction``: the flipping direction
Args:
meta_keys (Sequence[str], optional): Meta keys to be converted to
``mmcv.DataContainer`` and collected in ``data[img_metas]``.
Default: ``('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction')``
"""
mapping_table = {
'gt_bboxes': 'bboxes',
'gt_bboxes_labels': 'labels',
'gt_masks': 'masks'
}
def __init__(self,
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction')):
self.meta_keys = meta_keys
def transform(self, results: dict) -> dict:
"""Method to pack the input data.
Args:
results (dict): Result dict from the data pipeline.
Returns:
dict:
- 'inputs' (obj:`torch.Tensor`): The forward data of models.
- 'data_sample' (obj:`DetDataSample`): The annotation info of the
sample.
"""
packed_results = dict()
if 'img' in results:
img = results['img']
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
# To improve the computational speed by by 3-5 times, apply:
# If image is not contiguous, use
# `numpy.transpose()` followed by `numpy.ascontiguousarray()`
# If image is already contiguous, use
# `torch.permute()` followed by `torch.contiguous()`
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533
# for more details
if not img.flags.c_contiguous:
img = np.ascontiguousarray(img.transpose(2, 0, 1))
img = to_tensor(img)
else:
img = to_tensor(img).permute(2, 0, 1).contiguous()
packed_results['inputs'] = img
if 'gt_ignore_flags' in results:
valid_idx = np.where(results['gt_ignore_flags'] == 0)[0]
ignore_idx = np.where(results['gt_ignore_flags'] == 1)[0]
data_sample = DetDataSample()
instance_data = InstanceData()
ignore_instance_data = InstanceData()
for key in self.mapping_table.keys():
if key not in results:
continue
if key == 'gt_masks' or isinstance(results[key], BaseBoxes):
if 'gt_ignore_flags' in results:
instance_data[
self.mapping_table[key]] = results[key][valid_idx]
ignore_instance_data[
self.mapping_table[key]] = results[key][ignore_idx]
else:
instance_data[self.mapping_table[key]] = results[key]
else:
if 'gt_ignore_flags' in results:
instance_data[self.mapping_table[key]] = to_tensor(
results[key][valid_idx])
ignore_instance_data[self.mapping_table[key]] = to_tensor(
results[key][ignore_idx])
else:
instance_data[self.mapping_table[key]] = to_tensor(
results[key])
data_sample.gt_instances = instance_data
data_sample.ignored_instances = ignore_instance_data
if 'proposals' in results:
proposals = InstanceData(
bboxes=to_tensor(results['proposals']),
scores=to_tensor(results['proposals_scores']))
data_sample.proposals = proposals
if 'gt_seg_map' in results:
gt_sem_seg_data = dict(
sem_seg=to_tensor(results['gt_seg_map'][None, ...].copy()))
gt_sem_seg_data = PixelData(**gt_sem_seg_data)
if 'ignore_index' in results:
metainfo = dict(ignore_index=results['ignore_index'])
gt_sem_seg_data.set_metainfo(metainfo)
data_sample.gt_sem_seg = gt_sem_seg_data
img_meta = {}
for key in self.meta_keys:
if key in results:
img_meta[key] = results[key]
data_sample.set_metainfo(img_meta)
packed_results['data_samples'] = data_sample
return packed_results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(meta_keys={self.meta_keys})'
return repr_str
@TRANSFORMS.register_module()
class ToTensor:
"""Convert some results to :obj:`torch.Tensor` by given keys.
Args:
keys (Sequence[str]): Keys that need to be converted to Tensor.
"""
def __init__(self, keys):
self.keys = keys
def __call__(self, results):
"""Call function to convert data in results to :obj:`torch.Tensor`.
Args:
results (dict): Result dict contains the data to convert.
Returns:
dict: The result dict contains the data converted
to :obj:`torch.Tensor`.
"""
for key in self.keys:
results[key] = to_tensor(results[key])
return results
def __repr__(self):
return self.__class__.__name__ + f'(keys={self.keys})'
@TRANSFORMS.register_module()
class ImageToTensor:
"""Convert image to :obj:`torch.Tensor` by given keys.
The dimension order of input image is (H, W, C). The pipeline will convert
it to (C, H, W). If only 2 dimension (H, W) is given, the output would be
(1, H, W).
Args:
keys (Sequence[str]): Key of images to be converted to Tensor.
"""
def __init__(self, keys):
self.keys = keys
def __call__(self, results):
"""Call function to convert image in results to :obj:`torch.Tensor` and
transpose the channel order.
Args:
results (dict): Result dict contains the image data to convert.
Returns:
dict: The result dict contains the image converted
to :obj:`torch.Tensor` and permuted to (C, H, W) order.
"""
for key in self.keys:
img = results[key]
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
results[key] = to_tensor(img).permute(2, 0, 1).contiguous()
return results
def __repr__(self):
return self.__class__.__name__ + f'(keys={self.keys})'
@TRANSFORMS.register_module()
class Transpose:
"""Transpose some results by given keys.
Args:
keys (Sequence[str]): Keys of results to be transposed.
order (Sequence[int]): Order of transpose.
"""
def __init__(self, keys, order):
self.keys = keys
self.order = order
def __call__(self, results):
"""Call function to transpose the channel order of data in results.
Args:
results (dict): Result dict contains the data to transpose.
Returns:
dict: The result dict contains the data transposed to \
``self.order``.
"""
for key in self.keys:
results[key] = results[key].transpose(self.order)
return results
def __repr__(self):
return self.__class__.__name__ + \
f'(keys={self.keys}, order={self.order})'
@TRANSFORMS.register_module()
class WrapFieldsToLists:
"""Wrap fields of the data dictionary into lists for evaluation.
This class can be used as a last step of a test or validation
pipeline for single image evaluation or inference.
Example:
>>> test_pipeline = [
>>> dict(type='LoadImageFromFile'),
>>> dict(type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
>>> dict(type='Pad', size_divisor=32),
>>> dict(type='ImageToTensor', keys=['img']),
>>> dict(type='Collect', keys=['img']),
>>> dict(type='WrapFieldsToLists')
>>> ]
"""
def __call__(self, results):
"""Call function to wrap fields into lists.
Args:
results (dict): Result dict contains the data to wrap.
Returns:
dict: The result dict where value of ``self.keys`` are wrapped \
into list.
"""
# Wrap dict fields into lists
for key, val in results.items():
results[key] = [val]
return results
def __repr__(self):
return f'{self.__class__.__name__}()'
@TRANSFORMS.register_module()
class PackTrackInputs(BaseTransform):
"""Pack the inputs data for the multi object tracking and video instance
segmentation. All the information of images are packed to ``inputs``. All
the information except images are packed to ``data_samples``. In order to
get the original annotaiton and meta info, we add `instances` key into meta
keys.
Args:
meta_keys (Sequence[str]): Meta keys to be collected in
``data_sample.metainfo``. Defaults to None.
default_meta_keys (tuple): Default meta keys. Defaults to ('img_id',
'img_path', 'ori_shape', 'img_shape', 'scale_factor',
'flip', 'flip_direction', 'frame_id', 'is_video_data',
'video_id', 'video_length', 'instances').
"""
mapping_table = {
'gt_bboxes': 'bboxes',
'gt_bboxes_labels': 'labels',
'gt_masks': 'masks',
'gt_instances_ids': 'instances_ids'
}
def __init__(self,
meta_keys: Optional[dict] = None,
default_meta_keys: tuple = ('img_id', 'img_path', 'ori_shape',
'img_shape', 'scale_factor',
'flip', 'flip_direction',
'frame_id', 'video_id',
'video_length',
'ori_video_length', 'instances')):
self.meta_keys = default_meta_keys
if meta_keys is not None:
if isinstance(meta_keys, str):
meta_keys = (meta_keys, )
else:
assert isinstance(meta_keys, tuple), \
'meta_keys must be str or tuple'
self.meta_keys += meta_keys
def transform(self, results: dict) -> dict:
"""Method to pack the input data.
Args:
results (dict): Result dict from the data pipeline.
Returns:
dict:
- 'inputs' (dict[Tensor]): The forward data of models.
- 'data_samples' (obj:`TrackDataSample`): The annotation info of
the samples.
"""
packed_results = dict()
packed_results['inputs'] = dict()
# 1. Pack images
if 'img' in results:
imgs = results['img']
imgs = np.stack(imgs, axis=0)
imgs = imgs.transpose(0, 3, 1, 2)
packed_results['inputs'] = to_tensor(imgs)
# 2. Pack InstanceData
if 'gt_ignore_flags' in results:
gt_ignore_flags_list = results['gt_ignore_flags']
valid_idx_list, ignore_idx_list = [], []
for gt_ignore_flags in gt_ignore_flags_list:
valid_idx = np.where(gt_ignore_flags == 0)[0]
ignore_idx = np.where(gt_ignore_flags == 1)[0]
valid_idx_list.append(valid_idx)
ignore_idx_list.append(ignore_idx)
assert 'img_id' in results, "'img_id' must contained in the results "
'for counting the number of images'
num_imgs = len(results['img_id'])
instance_data_list = [InstanceData() for _ in range(num_imgs)]
ignore_instance_data_list = [InstanceData() for _ in range(num_imgs)]
for key in self.mapping_table.keys():
if key not in results:
continue
if key == 'gt_masks':
mapped_key = self.mapping_table[key]
gt_masks_list = results[key]
if 'gt_ignore_flags' in results:
for i, gt_mask in enumerate(gt_masks_list):
valid_idx, ignore_idx = valid_idx_list[
i], ignore_idx_list[i]
instance_data_list[i][mapped_key] = gt_mask[valid_idx]
ignore_instance_data_list[i][mapped_key] = gt_mask[
ignore_idx]
else:
for i, gt_mask in enumerate(gt_masks_list):
instance_data_list[i][mapped_key] = gt_mask
else:
anns_list = results[key]
if 'gt_ignore_flags' in results:
for i, ann in enumerate(anns_list):
valid_idx, ignore_idx = valid_idx_list[
i], ignore_idx_list[i]
instance_data_list[i][
self.mapping_table[key]] = to_tensor(
ann[valid_idx])
ignore_instance_data_list[i][
self.mapping_table[key]] = to_tensor(
ann[ignore_idx])
else:
for i, ann in enumerate(anns_list):
instance_data_list[i][
self.mapping_table[key]] = to_tensor(ann)
det_data_samples_list = []
for i in range(num_imgs):
det_data_sample = DetDataSample()
det_data_sample.gt_instances = instance_data_list[i]
det_data_sample.ignored_instances = ignore_instance_data_list[i]
det_data_samples_list.append(det_data_sample)
# 3. Pack metainfo
for key in self.meta_keys:
if key not in results:
continue
img_metas_list = results[key]
for i, img_meta in enumerate(img_metas_list):
det_data_samples_list[i].set_metainfo({f'{key}': img_meta})
track_data_sample = TrackDataSample()
track_data_sample.video_data_samples = det_data_samples_list
if 'key_frame_flags' in results:
key_frame_flags = np.asarray(results['key_frame_flags'])
key_frames_inds = np.where(key_frame_flags)[0].tolist()
ref_frames_inds = np.where(~key_frame_flags)[0].tolist()
track_data_sample.set_metainfo(
dict(key_frames_inds=key_frames_inds))
track_data_sample.set_metainfo(
dict(ref_frames_inds=ref_frames_inds))
packed_results['data_samples'] = track_data_sample
return packed_results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'meta_keys={self.meta_keys}, '
repr_str += f'default_meta_keys={self.default_meta_keys})'
return repr_str
@TRANSFORMS.register_module()
class PackReIDInputs(BaseTransform):
"""Pack the inputs data for the ReID. The ``meta_info`` item is always
populated. The contents of the ``meta_info`` dictionary depends on
``meta_keys``. By default this includes:
- ``img_path``: path to the image file.
- ``ori_shape``: original shape of the image as a tuple (H, W).
- ``img_shape``: shape of the image input to the network as a tuple
(H, W). Note that images may be zero padded on the bottom/right
if the batch tensor is larger than this shape.
- ``scale``: scale of the image as a tuple (W, H).
- ``scale_factor``: a float indicating the pre-processing scale.
- ``flip``: a boolean indicating if image flip transform was used.
- ``flip_direction``: the flipping direction.
Args:
meta_keys (Sequence[str], optional): The meta keys to saved in the
``metainfo`` of the packed ``data_sample``.
"""
default_meta_keys = ('img_path', 'ori_shape', 'img_shape', 'scale',
'scale_factor')
def __init__(self, meta_keys: Sequence[str] = ()) -> None:
self.meta_keys = self.default_meta_keys
if meta_keys is not None:
if isinstance(meta_keys, str):
meta_keys = (meta_keys, )
else:
assert isinstance(meta_keys, tuple), \
'meta_keys must be str or tuple.'
self.meta_keys += meta_keys
def transform(self, results: dict) -> dict:
"""Method to pack the input data.
Args:
results (dict): Result dict from the data pipeline.
Returns:
dict:
- 'inputs' (dict[Tensor]): The forward data of models.
- 'data_samples' (obj:`ReIDDataSample`): The meta info of the
sample.
"""
packed_results = dict(inputs=dict(), data_samples=None)
assert 'img' in results, 'Missing the key ``img``.'
_type = type(results['img'])
label = results['gt_label']
if _type == list:
img = results['img']
label = np.stack(label, axis=0) # (N,)
assert all([type(v) == _type for v in results.values()]), \
'All items in the results must have the same type.'
else:
img = [results['img']]
img = np.stack(img, axis=3) # (H, W, C, N)
img = img.transpose(3, 2, 0, 1) # (N, C, H, W)
img = np.ascontiguousarray(img)
packed_results['inputs'] = to_tensor(img)
data_sample = ReIDDataSample()
data_sample.set_gt_label(label)
meta_info = dict()
for key in self.meta_keys:
meta_info[key] = results[key]
data_sample.set_metainfo(meta_info)
packed_results['data_samples'] = data_sample
return packed_results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(meta_keys={self.meta_keys})'
return repr_str
# Copyright (c) OpenMMLab. All rights reserved.
import random
from collections import defaultdict
from typing import Dict, List, Optional, Union
from mmcv.transforms import BaseTransform
from mmdet.registry import TRANSFORMS
@TRANSFORMS.register_module()
class BaseFrameSample(BaseTransform):
"""Directly get the key frame, no reference frames.
Args:
collect_video_keys (list[str]): The keys of video info to be
collected.
"""
def __init__(self,
collect_video_keys: List[str] = ['video_id', 'video_length']):
self.collect_video_keys = collect_video_keys
def prepare_data(self, video_infos: dict,
sampled_inds: List[int]) -> Dict[str, List]:
"""Prepare data for the subsequent pipeline.
Args:
video_infos (dict): The whole video information.
sampled_inds (list[int]): The sampled frame indices.
Returns:
dict: The processed data information.
"""
frames_anns = video_infos['images']
final_data_info = defaultdict(list)
# for data in frames_anns:
for index in sampled_inds:
data = frames_anns[index]
# copy the info in video-level into img-level
for key in self.collect_video_keys:
if key == 'video_length':
data['ori_video_length'] = video_infos[key]
data['video_length'] = len(sampled_inds)
else:
data[key] = video_infos[key]
# Collate data_list (list of dict to dict of list)
for key, value in data.items():
final_data_info[key].append(value)
return final_data_info
def transform(self, video_infos: dict) -> Optional[Dict[str, List]]:
"""Transform the video information.
Args:
video_infos (dict): The whole video information.
Returns:
dict: The data information of the key frames.
"""
if 'key_frame_id' in video_infos:
key_frame_id = video_infos['key_frame_id']
assert isinstance(video_infos['key_frame_id'], int)
else:
key_frame_id = random.sample(
list(range(video_infos['video_length'])), 1)[0]
results = self.prepare_data(video_infos, [key_frame_id])
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(collect_video_keys={self.collect_video_keys})'
return repr_str
@TRANSFORMS.register_module()
class UniformRefFrameSample(BaseFrameSample):
"""Uniformly sample reference frames.
Args:
num_ref_imgs (int): Number of reference frames to be sampled.
frame_range (int | list[int]): Range of frames to be sampled around
key frame. If int, the range is [-frame_range, frame_range].
Defaults to 10.
filter_key_img (bool): Whether to filter the key frame when
sampling reference frames. Defaults to True.
collect_video_keys (list[str]): The keys of video info to be
collected.
"""
def __init__(self,
num_ref_imgs: int = 1,
frame_range: Union[int, List[int]] = 10,
filter_key_img: bool = True,
collect_video_keys: List[str] = ['video_id', 'video_length']):
self.num_ref_imgs = num_ref_imgs
self.filter_key_img = filter_key_img
if isinstance(frame_range, int):
assert frame_range >= 0, 'frame_range can not be a negative value.'
frame_range = [-frame_range, frame_range]
elif isinstance(frame_range, list):
assert len(frame_range) == 2, 'The length must be 2.'
assert frame_range[0] <= 0 and frame_range[1] >= 0
for i in frame_range:
assert isinstance(i, int), 'Each element must be int.'
else:
raise TypeError('The type of frame_range must be int or list.')
self.frame_range = frame_range
super().__init__(collect_video_keys=collect_video_keys)
def sampling_frames(self, video_length: int, key_frame_id: int):
"""Sampling frames.
Args:
video_length (int): The length of the video.
key_frame_id (int): The key frame id.
Returns:
list[int]: The sampled frame indices.
"""
if video_length > 1:
left = max(0, key_frame_id + self.frame_range[0])
right = min(key_frame_id + self.frame_range[1], video_length - 1)
frame_ids = list(range(0, video_length))
valid_ids = frame_ids[left:right + 1]
if self.filter_key_img and key_frame_id in valid_ids:
valid_ids.remove(key_frame_id)
assert len(
valid_ids
) > 0, 'After filtering key frame, there are no valid frames'
if len(valid_ids) < self.num_ref_imgs:
valid_ids = valid_ids * self.num_ref_imgs
ref_frame_ids = random.sample(valid_ids, self.num_ref_imgs)
else:
ref_frame_ids = [key_frame_id] * self.num_ref_imgs
sampled_frames_ids = [key_frame_id] + ref_frame_ids
sampled_frames_ids = sorted(sampled_frames_ids)
key_frames_ind = sampled_frames_ids.index(key_frame_id)
key_frame_flags = [False] * len(sampled_frames_ids)
key_frame_flags[key_frames_ind] = True
return sampled_frames_ids, key_frame_flags
def transform(self, video_infos: dict) -> Optional[Dict[str, List]]:
"""Transform the video information.
Args:
video_infos (dict): The whole video information.
Returns:
dict: The data information of the sampled frames.
"""
if 'key_frame_id' in video_infos:
key_frame_id = video_infos['key_frame_id']
assert isinstance(video_infos['key_frame_id'], int)
else:
key_frame_id = random.sample(
list(range(video_infos['video_length'])), 1)[0]
(sampled_frames_ids, key_frame_flags) = self.sampling_frames(
video_infos['video_length'], key_frame_id=key_frame_id)
results = self.prepare_data(video_infos, sampled_frames_ids)
results['key_frame_flags'] = key_frame_flags
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(num_ref_imgs={self.num_ref_imgs}, '
repr_str += f'frame_range={self.frame_range}, '
repr_str += f'filter_key_img={self.filter_key_img}, '
repr_str += f'collect_video_keys={self.collect_video_keys})'
return repr_str
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union
import cv2
import mmcv
import numpy as np
from mmcv.transforms import BaseTransform
from mmcv.transforms.utils import cache_randomness
from mmdet.registry import TRANSFORMS
from mmdet.structures.bbox import autocast_box_type
from .augment_wrappers import _MAX_LEVEL, level_to_mag
@TRANSFORMS.register_module()
class GeomTransform(BaseTransform):
"""Base class for geometric transformations. All geometric transformations
need to inherit from this base class. ``GeomTransform`` unifies the class
attributes and class functions of geometric transformations (ShearX,
ShearY, Rotate, TranslateX, and TranslateY), and records the homography
matrix.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- gt_bboxes
- gt_masks
- gt_seg_map
Added Keys:
- homography_matrix
Args:
prob (float): The probability for performing the geometric
transformation and should be in range [0, 1]. Defaults to 1.0.
level (int, optional): The level should be in range [0, _MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for geometric transformation.
Defaults to 0.0.
max_mag (float): The maximum magnitude for geometric transformation.
Defaults to 1.0.
reversal_prob (float): The probability that reverses the geometric
transformation magnitude. Should be in range [0,1].
Defaults to 0.5.
img_border_value (int | float | tuple): The filled values for
image border. If float, the same fill value will be used for
all the three channels of image. If tuple, it should be 3 elements.
Defaults to 128.
mask_border_value (int): The fill value used for masks. Defaults to 0.
seg_ignore_label (int): The fill value used for segmentation map.
Note this value must equals ``ignore_label`` in ``semantic_head``
of the corresponding config. Defaults to 255.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear" for 'pillow' backend. Defaults
to 'bilinear'.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.0,
max_mag: float = 1.0,
reversal_prob: float = 0.5,
img_border_value: Union[int, float, tuple] = 128,
mask_border_value: int = 0,
seg_ignore_label: int = 255,
interpolation: str = 'bilinear') -> None:
assert 0 <= prob <= 1.0, f'The probability of the transformation ' \
f'should be in range [0,1], got {prob}.'
assert level is None or isinstance(level, int), \
f'The level should be None or type int, got {type(level)}.'
assert level is None or 0 <= level <= _MAX_LEVEL, \
f'The level should be in range [0,{_MAX_LEVEL}], got {level}.'
assert isinstance(min_mag, float), \
f'min_mag should be type float, got {type(min_mag)}.'
assert isinstance(max_mag, float), \
f'max_mag should be type float, got {type(max_mag)}.'
assert min_mag <= max_mag, \
f'min_mag should smaller than max_mag, ' \
f'got min_mag={min_mag} and max_mag={max_mag}'
assert isinstance(reversal_prob, float), \
f'reversal_prob should be type float, got {type(max_mag)}.'
assert 0 <= reversal_prob <= 1.0, \
f'The reversal probability of the transformation magnitude ' \
f'should be type float, got {type(reversal_prob)}.'
if isinstance(img_border_value, (float, int)):
img_border_value = tuple([float(img_border_value)] * 3)
elif isinstance(img_border_value, tuple):
assert len(img_border_value) == 3, \
f'img_border_value as tuple must have 3 elements, ' \
f'got {len(img_border_value)}.'
img_border_value = tuple([float(val) for val in img_border_value])
else:
raise ValueError(
'img_border_value must be float or tuple with 3 elements.')
assert np.all([0 <= val <= 255 for val in img_border_value]), 'all ' \
'elements of img_border_value should between range [0,255].' \
f'got {img_border_value}.'
self.prob = prob
self.level = level
self.min_mag = min_mag
self.max_mag = max_mag
self.reversal_prob = reversal_prob
self.img_border_value = img_border_value
self.mask_border_value = mask_border_value
self.seg_ignore_label = seg_ignore_label
self.interpolation = interpolation
def _transform_img(self, results: dict, mag: float) -> None:
"""Transform the image."""
pass
def _transform_masks(self, results: dict, mag: float) -> None:
"""Transform the masks."""
pass
def _transform_seg(self, results: dict, mag: float) -> None:
"""Transform the segmentation map."""
pass
def _get_homography_matrix(self, results: dict, mag: float) -> np.ndarray:
"""Get the homography matrix for the geometric transformation."""
return np.eye(3, dtype=np.float32)
def _transform_bboxes(self, results: dict, mag: float) -> None:
"""Transform the bboxes."""
results['gt_bboxes'].project_(self.homography_matrix)
results['gt_bboxes'].clip_(results['img_shape'])
def _record_homography_matrix(self, results: dict) -> None:
"""Record the homography matrix for the geometric transformation."""
if results.get('homography_matrix', None) is None:
results['homography_matrix'] = self.homography_matrix
else:
results['homography_matrix'] = self.homography_matrix @ results[
'homography_matrix']
@cache_randomness
def _random_disable(self):
"""Randomly disable the transform."""
return np.random.rand() > self.prob
@cache_randomness
def _get_mag(self):
"""Get the magnitude of the transform."""
mag = level_to_mag(self.level, self.min_mag, self.max_mag)
return -mag if np.random.rand() > self.reversal_prob else mag
@autocast_box_type()
def transform(self, results: dict) -> dict:
"""Transform function for images, bounding boxes, masks and semantic
segmentation map.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Transformed results.
"""
if self._random_disable():
return results
mag = self._get_mag()
self.homography_matrix = self._get_homography_matrix(results, mag)
self._record_homography_matrix(results)
self._transform_img(results, mag)
if results.get('gt_bboxes', None) is not None:
self._transform_bboxes(results, mag)
if results.get('gt_masks', None) is not None:
self._transform_masks(results, mag)
if results.get('gt_seg_map', None) is not None:
self._transform_seg(results, mag)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, '
repr_str += f'level={self.level}, '
repr_str += f'min_mag={self.min_mag}, '
repr_str += f'max_mag={self.max_mag}, '
repr_str += f'reversal_prob={self.reversal_prob}, '
repr_str += f'img_border_value={self.img_border_value}, '
repr_str += f'mask_border_value={self.mask_border_value}, '
repr_str += f'seg_ignore_label={self.seg_ignore_label}, '
repr_str += f'interpolation={self.interpolation})'
return repr_str
@TRANSFORMS.register_module()
class ShearX(GeomTransform):
"""Shear the images, bboxes, masks and segmentation map horizontally.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- gt_bboxes
- gt_masks
- gt_seg_map
Added Keys:
- homography_matrix
Args:
prob (float): The probability for performing Shear and should be in
range [0, 1]. Defaults to 1.0.
level (int, optional): The level should be in range [0, _MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum angle for the horizontal shear.
Defaults to 0.0.
max_mag (float): The maximum angle for the horizontal shear.
Defaults to 30.0.
reversal_prob (float): The probability that reverses the horizontal
shear magnitude. Should be in range [0,1]. Defaults to 0.5.
img_border_value (int | float | tuple): The filled values for
image border. If float, the same fill value will be used for
all the three channels of image. If tuple, it should be 3 elements.
Defaults to 128.
mask_border_value (int): The fill value used for masks. Defaults to 0.
seg_ignore_label (int): The fill value used for segmentation map.
Note this value must equals ``ignore_label`` in ``semantic_head``
of the corresponding config. Defaults to 255.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear" for 'pillow' backend. Defaults
to 'bilinear'.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.0,
max_mag: float = 30.0,
reversal_prob: float = 0.5,
img_border_value: Union[int, float, tuple] = 128,
mask_border_value: int = 0,
seg_ignore_label: int = 255,
interpolation: str = 'bilinear') -> None:
assert 0. <= min_mag <= 90., \
f'min_mag angle for ShearX should be ' \
f'in range [0, 90], got {min_mag}.'
assert 0. <= max_mag <= 90., \
f'max_mag angle for ShearX should be ' \
f'in range [0, 90], got {max_mag}.'
super().__init__(
prob=prob,
level=level,
min_mag=min_mag,
max_mag=max_mag,
reversal_prob=reversal_prob,
img_border_value=img_border_value,
mask_border_value=mask_border_value,
seg_ignore_label=seg_ignore_label,
interpolation=interpolation)
@cache_randomness
def _get_mag(self):
"""Get the magnitude of the transform."""
mag = level_to_mag(self.level, self.min_mag, self.max_mag)
mag = np.tan(mag * np.pi / 180)
return -mag if np.random.rand() > self.reversal_prob else mag
def _get_homography_matrix(self, results: dict, mag: float) -> np.ndarray:
"""Get the homography matrix for ShearX."""
return np.array([[1, mag, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32)
def _transform_img(self, results: dict, mag: float) -> None:
"""Shear the image horizontally."""
results['img'] = mmcv.imshear(
results['img'],
mag,
direction='horizontal',
border_value=self.img_border_value,
interpolation=self.interpolation)
def _transform_masks(self, results: dict, mag: float) -> None:
"""Shear the masks horizontally."""
results['gt_masks'] = results['gt_masks'].shear(
results['img_shape'],
mag,
direction='horizontal',
border_value=self.mask_border_value,
interpolation=self.interpolation)
def _transform_seg(self, results: dict, mag: float) -> None:
"""Shear the segmentation map horizontally."""
results['gt_seg_map'] = mmcv.imshear(
results['gt_seg_map'],
mag,
direction='horizontal',
border_value=self.seg_ignore_label,
interpolation='nearest')
@TRANSFORMS.register_module()
class ShearY(GeomTransform):
"""Shear the images, bboxes, masks and segmentation map vertically.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- gt_bboxes
- gt_masks
- gt_seg_map
Added Keys:
- homography_matrix
Args:
prob (float): The probability for performing ShearY and should be in
range [0, 1]. Defaults to 1.0.
level (int, optional): The level should be in range [0,_MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum angle for the vertical shear.
Defaults to 0.0.
max_mag (float): The maximum angle for the vertical shear.
Defaults to 30.0.
reversal_prob (float): The probability that reverses the vertical
shear magnitude. Should be in range [0,1]. Defaults to 0.5.
img_border_value (int | float | tuple): The filled values for
image border. If float, the same fill value will be used for
all the three channels of image. If tuple, it should be 3 elements.
Defaults to 128.
mask_border_value (int): The fill value used for masks. Defaults to 0.
seg_ignore_label (int): The fill value used for segmentation map.
Note this value must equals ``ignore_label`` in ``semantic_head``
of the corresponding config. Defaults to 255.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear" for 'pillow' backend. Defaults
to 'bilinear'.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.0,
max_mag: float = 30.,
reversal_prob: float = 0.5,
img_border_value: Union[int, float, tuple] = 128,
mask_border_value: int = 0,
seg_ignore_label: int = 255,
interpolation: str = 'bilinear') -> None:
assert 0. <= min_mag <= 90., \
f'min_mag angle for ShearY should be ' \
f'in range [0, 90], got {min_mag}.'
assert 0. <= max_mag <= 90., \
f'max_mag angle for ShearY should be ' \
f'in range [0, 90], got {max_mag}.'
super().__init__(
prob=prob,
level=level,
min_mag=min_mag,
max_mag=max_mag,
reversal_prob=reversal_prob,
img_border_value=img_border_value,
mask_border_value=mask_border_value,
seg_ignore_label=seg_ignore_label,
interpolation=interpolation)
@cache_randomness
def _get_mag(self):
"""Get the magnitude of the transform."""
mag = level_to_mag(self.level, self.min_mag, self.max_mag)
mag = np.tan(mag * np.pi / 180)
return -mag if np.random.rand() > self.reversal_prob else mag
def _get_homography_matrix(self, results: dict, mag: float) -> np.ndarray:
"""Get the homography matrix for ShearY."""
return np.array([[1, 0, 0], [mag, 1, 0], [0, 0, 1]], dtype=np.float32)
def _transform_img(self, results: dict, mag: float) -> None:
"""Shear the image vertically."""
results['img'] = mmcv.imshear(
results['img'],
mag,
direction='vertical',
border_value=self.img_border_value,
interpolation=self.interpolation)
def _transform_masks(self, results: dict, mag: float) -> None:
"""Shear the masks vertically."""
results['gt_masks'] = results['gt_masks'].shear(
results['img_shape'],
mag,
direction='vertical',
border_value=self.mask_border_value,
interpolation=self.interpolation)
def _transform_seg(self, results: dict, mag: float) -> None:
"""Shear the segmentation map vertically."""
results['gt_seg_map'] = mmcv.imshear(
results['gt_seg_map'],
mag,
direction='vertical',
border_value=self.seg_ignore_label,
interpolation='nearest')
@TRANSFORMS.register_module()
class Rotate(GeomTransform):
"""Rotate the images, bboxes, masks and segmentation map.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- gt_bboxes
- gt_masks
- gt_seg_map
Added Keys:
- homography_matrix
Args:
prob (float): The probability for perform transformation and
should be in range 0 to 1. Defaults to 1.0.
level (int, optional): The level should be in range [0, _MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The maximum angle for rotation.
Defaults to 0.0.
max_mag (float): The maximum angle for rotation.
Defaults to 30.0.
reversal_prob (float): The probability that reverses the rotation
magnitude. Should be in range [0,1]. Defaults to 0.5.
img_border_value (int | float | tuple): The filled values for
image border. If float, the same fill value will be used for
all the three channels of image. If tuple, it should be 3 elements.
Defaults to 128.
mask_border_value (int): The fill value used for masks. Defaults to 0.
seg_ignore_label (int): The fill value used for segmentation map.
Note this value must equals ``ignore_label`` in ``semantic_head``
of the corresponding config. Defaults to 255.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear" for 'pillow' backend. Defaults
to 'bilinear'.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.0,
max_mag: float = 30.0,
reversal_prob: float = 0.5,
img_border_value: Union[int, float, tuple] = 128,
mask_border_value: int = 0,
seg_ignore_label: int = 255,
interpolation: str = 'bilinear') -> None:
assert 0. <= min_mag <= 180., \
f'min_mag for Rotate should be in range [0,180], got {min_mag}.'
assert 0. <= max_mag <= 180., \
f'max_mag for Rotate should be in range [0,180], got {max_mag}.'
super().__init__(
prob=prob,
level=level,
min_mag=min_mag,
max_mag=max_mag,
reversal_prob=reversal_prob,
img_border_value=img_border_value,
mask_border_value=mask_border_value,
seg_ignore_label=seg_ignore_label,
interpolation=interpolation)
def _get_homography_matrix(self, results: dict, mag: float) -> np.ndarray:
"""Get the homography matrix for Rotate."""
img_shape = results['img_shape']
center = ((img_shape[1] - 1) * 0.5, (img_shape[0] - 1) * 0.5)
cv2_rotation_matrix = cv2.getRotationMatrix2D(center, -mag, 1.0)
return np.concatenate(
[cv2_rotation_matrix,
np.array([0, 0, 1]).reshape((1, 3))]).astype(np.float32)
def _transform_img(self, results: dict, mag: float) -> None:
"""Rotate the image."""
results['img'] = mmcv.imrotate(
results['img'],
mag,
border_value=self.img_border_value,
interpolation=self.interpolation)
def _transform_masks(self, results: dict, mag: float) -> None:
"""Rotate the masks."""
results['gt_masks'] = results['gt_masks'].rotate(
results['img_shape'],
mag,
border_value=self.mask_border_value,
interpolation=self.interpolation)
def _transform_seg(self, results: dict, mag: float) -> None:
"""Rotate the segmentation map."""
results['gt_seg_map'] = mmcv.imrotate(
results['gt_seg_map'],
mag,
border_value=self.seg_ignore_label,
interpolation='nearest')
@TRANSFORMS.register_module()
class TranslateX(GeomTransform):
"""Translate the images, bboxes, masks and segmentation map horizontally.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- gt_bboxes
- gt_masks
- gt_seg_map
Added Keys:
- homography_matrix
Args:
prob (float): The probability for perform transformation and
should be in range 0 to 1. Defaults to 1.0.
level (int, optional): The level should be in range [0, _MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum pixel's offset ratio for horizontal
translation. Defaults to 0.0.
max_mag (float): The maximum pixel's offset ratio for horizontal
translation. Defaults to 0.1.
reversal_prob (float): The probability that reverses the horizontal
translation magnitude. Should be in range [0,1]. Defaults to 0.5.
img_border_value (int | float | tuple): The filled values for
image border. If float, the same fill value will be used for
all the three channels of image. If tuple, it should be 3 elements.
Defaults to 128.
mask_border_value (int): The fill value used for masks. Defaults to 0.
seg_ignore_label (int): The fill value used for segmentation map.
Note this value must equals ``ignore_label`` in ``semantic_head``
of the corresponding config. Defaults to 255.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear" for 'pillow' backend. Defaults
to 'bilinear'.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.0,
max_mag: float = 0.1,
reversal_prob: float = 0.5,
img_border_value: Union[int, float, tuple] = 128,
mask_border_value: int = 0,
seg_ignore_label: int = 255,
interpolation: str = 'bilinear') -> None:
assert 0. <= min_mag <= 1., \
f'min_mag ratio for TranslateX should be ' \
f'in range [0, 1], got {min_mag}.'
assert 0. <= max_mag <= 1., \
f'max_mag ratio for TranslateX should be ' \
f'in range [0, 1], got {max_mag}.'
super().__init__(
prob=prob,
level=level,
min_mag=min_mag,
max_mag=max_mag,
reversal_prob=reversal_prob,
img_border_value=img_border_value,
mask_border_value=mask_border_value,
seg_ignore_label=seg_ignore_label,
interpolation=interpolation)
def _get_homography_matrix(self, results: dict, mag: float) -> np.ndarray:
"""Get the homography matrix for TranslateX."""
mag = int(results['img_shape'][1] * mag)
return np.array([[1, 0, mag], [0, 1, 0], [0, 0, 1]], dtype=np.float32)
def _transform_img(self, results: dict, mag: float) -> None:
"""Translate the image horizontally."""
mag = int(results['img_shape'][1] * mag)
results['img'] = mmcv.imtranslate(
results['img'],
mag,
direction='horizontal',
border_value=self.img_border_value,
interpolation=self.interpolation)
def _transform_masks(self, results: dict, mag: float) -> None:
"""Translate the masks horizontally."""
mag = int(results['img_shape'][1] * mag)
results['gt_masks'] = results['gt_masks'].translate(
results['img_shape'],
mag,
direction='horizontal',
border_value=self.mask_border_value,
interpolation=self.interpolation)
def _transform_seg(self, results: dict, mag: float) -> None:
"""Translate the segmentation map horizontally."""
mag = int(results['img_shape'][1] * mag)
results['gt_seg_map'] = mmcv.imtranslate(
results['gt_seg_map'],
mag,
direction='horizontal',
border_value=self.seg_ignore_label,
interpolation='nearest')
@TRANSFORMS.register_module()
class TranslateY(GeomTransform):
"""Translate the images, bboxes, masks and segmentation map vertically.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- gt_bboxes
- gt_masks
- gt_seg_map
Added Keys:
- homography_matrix
Args:
prob (float): The probability for perform transformation and
should be in range 0 to 1. Defaults to 1.0.
level (int, optional): The level should be in range [0, _MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum pixel's offset ratio for vertical
translation. Defaults to 0.0.
max_mag (float): The maximum pixel's offset ratio for vertical
translation. Defaults to 0.1.
reversal_prob (float): The probability that reverses the vertical
translation magnitude. Should be in range [0,1]. Defaults to 0.5.
img_border_value (int | float | tuple): The filled values for
image border. If float, the same fill value will be used for
all the three channels of image. If tuple, it should be 3 elements.
Defaults to 128.
mask_border_value (int): The fill value used for masks. Defaults to 0.
seg_ignore_label (int): The fill value used for segmentation map.
Note this value must equals ``ignore_label`` in ``semantic_head``
of the corresponding config. Defaults to 255.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear" for 'pillow' backend. Defaults
to 'bilinear'.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.0,
max_mag: float = 0.1,
reversal_prob: float = 0.5,
img_border_value: Union[int, float, tuple] = 128,
mask_border_value: int = 0,
seg_ignore_label: int = 255,
interpolation: str = 'bilinear') -> None:
assert 0. <= min_mag <= 1., \
f'min_mag ratio for TranslateY should be ' \
f'in range [0,1], got {min_mag}.'
assert 0. <= max_mag <= 1., \
f'max_mag ratio for TranslateY should be ' \
f'in range [0,1], got {max_mag}.'
super().__init__(
prob=prob,
level=level,
min_mag=min_mag,
max_mag=max_mag,
reversal_prob=reversal_prob,
img_border_value=img_border_value,
mask_border_value=mask_border_value,
seg_ignore_label=seg_ignore_label,
interpolation=interpolation)
def _get_homography_matrix(self, results: dict, mag: float) -> np.ndarray:
"""Get the homography matrix for TranslateY."""
mag = int(results['img_shape'][0] * mag)
return np.array([[1, 0, 0], [0, 1, mag], [0, 0, 1]], dtype=np.float32)
def _transform_img(self, results: dict, mag: float) -> None:
"""Translate the image vertically."""
mag = int(results['img_shape'][0] * mag)
results['img'] = mmcv.imtranslate(
results['img'],
mag,
direction='vertical',
border_value=self.img_border_value,
interpolation=self.interpolation)
def _transform_masks(self, results: dict, mag: float) -> None:
"""Translate masks vertically."""
mag = int(results['img_shape'][0] * mag)
results['gt_masks'] = results['gt_masks'].translate(
results['img_shape'],
mag,
direction='vertical',
border_value=self.mask_border_value,
interpolation=self.interpolation)
def _transform_seg(self, results: dict, mag: float) -> None:
"""Translate segmentation map vertically."""
mag = int(results['img_shape'][0] * mag)
results['gt_seg_map'] = mmcv.imtranslate(
results['gt_seg_map'],
mag,
direction='vertical',
border_value=self.seg_ignore_label,
interpolation='nearest')
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import numpy as np
from mmcv.transforms import BaseTransform
from mmdet.registry import TRANSFORMS
@TRANSFORMS.register_module()
class InstaBoost(BaseTransform):
r"""Data augmentation method in `InstaBoost: Boosting Instance
Segmentation Via Probability Map Guided Copy-Pasting
<https://arxiv.org/abs/1908.07801>`_.
Refer to https://github.com/GothicAi/Instaboost for implementation details.
Required Keys:
- img (np.uint8)
- instances
Modified Keys:
- img (np.uint8)
- instances
Args:
action_candidate (tuple): Action candidates. "normal", "horizontal", \
"vertical", "skip" are supported. Defaults to ('normal', \
'horizontal', 'skip').
action_prob (tuple): Corresponding action probabilities. Should be \
the same length as action_candidate. Defaults to (1, 0, 0).
scale (tuple): (min scale, max scale). Defaults to (0.8, 1.2).
dx (int): The maximum x-axis shift will be (instance width) / dx.
Defaults to 15.
dy (int): The maximum y-axis shift will be (instance height) / dy.
Defaults to 15.
theta (tuple): (min rotation degree, max rotation degree). \
Defaults to (-1, 1).
color_prob (float): Probability of images for color augmentation.
Defaults to 0.5.
hflag (bool): Whether to use heatmap guided. Defaults to False.
aug_ratio (float): Probability of applying this transformation. \
Defaults to 0.5.
"""
def __init__(self,
action_candidate: tuple = ('normal', 'horizontal', 'skip'),
action_prob: tuple = (1, 0, 0),
scale: tuple = (0.8, 1.2),
dx: int = 15,
dy: int = 15,
theta: tuple = (-1, 1),
color_prob: float = 0.5,
hflag: bool = False,
aug_ratio: float = 0.5) -> None:
import matplotlib
import matplotlib.pyplot as plt
default_backend = plt.get_backend()
try:
import instaboostfast as instaboost
except ImportError:
raise ImportError(
'Please run "pip install instaboostfast" '
'to install instaboostfast first for instaboost augmentation.')
# instaboost will modify the default backend
# and cause visualization to fail.
matplotlib.use(default_backend)
self.cfg = instaboost.InstaBoostConfig(action_candidate, action_prob,
scale, dx, dy, theta,
color_prob, hflag)
self.aug_ratio = aug_ratio
def _load_anns(self, results: dict) -> Tuple[list, list]:
"""Convert raw anns to instaboost expected input format."""
anns = []
ignore_anns = []
for instance in results['instances']:
label = instance['bbox_label']
bbox = instance['bbox']
mask = instance['mask']
x1, y1, x2, y2 = bbox
# assert (x2 - x1) >= 1 and (y2 - y1) >= 1
bbox = [x1, y1, x2 - x1, y2 - y1]
if instance['ignore_flag'] == 0:
anns.append({
'category_id': label,
'segmentation': mask,
'bbox': bbox
})
else:
# Ignore instances without data augmentation
ignore_anns.append(instance)
return anns, ignore_anns
def _parse_anns(self, results: dict, anns: list, ignore_anns: list,
img: np.ndarray) -> dict:
"""Restore the result of instaboost processing to the original anns
format."""
instances = []
for ann in anns:
x1, y1, w, h = ann['bbox']
# TODO: more essential bug need to be fixed in instaboost
if w <= 0 or h <= 0:
continue
bbox = [x1, y1, x1 + w, y1 + h]
instances.append(
dict(
bbox=bbox,
bbox_label=ann['category_id'],
mask=ann['segmentation'],
ignore_flag=0))
instances.extend(ignore_anns)
results['img'] = img
results['instances'] = instances
return results
def transform(self, results) -> dict:
"""The transform function."""
img = results['img']
ori_type = img.dtype
if 'instances' not in results or len(results['instances']) == 0:
return results
anns, ignore_anns = self._load_anns(results)
if np.random.choice([0, 1], p=[1 - self.aug_ratio, self.aug_ratio]):
try:
import instaboostfast as instaboost
except ImportError:
raise ImportError('Please run "pip install instaboostfast" '
'to install instaboostfast first.')
anns, img = instaboost.get_new_data(
anns, img.astype(np.uint8), self.cfg, background=None)
results = self._parse_anns(results, anns, ignore_anns,
img.astype(ori_type))
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(aug_ratio={self.aug_ratio})'
return repr_str
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple, Union
import mmcv
import numpy as np
import pycocotools.mask as maskUtils
import torch
from mmcv.transforms import BaseTransform
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
from mmcv.transforms import LoadImageFromFile
from mmengine.fileio import get
from mmengine.structures import BaseDataElement
from mmdet.registry import TRANSFORMS
from mmdet.structures.bbox import get_box_type
from mmdet.structures.bbox.box_type import autocast_box_type
from mmdet.structures.mask import BitmapMasks, PolygonMasks
@TRANSFORMS.register_module()
class LoadImageFromNDArray(LoadImageFromFile):
"""Load an image from ``results['img']``.
Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
:obj:`np.ndarray` in ``results['img']``. Can be used when loading image
from webcam.
Required Keys:
- img
Modified Keys:
- img
- img_path
- img_shape
- ori_shape
Args:
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
"""
def transform(self, results: dict) -> dict:
"""Transform function to add image meta information.
Args:
results (dict): Result dict with Webcam read image in
``results['img']``.
Returns:
dict: The dict contains loaded image and meta information.
"""
img = results['img']
if self.to_float32:
img = img.astype(np.float32)
results['img_path'] = None
results['img'] = img
results['img_shape'] = img.shape[:2]
results['ori_shape'] = img.shape[:2]
return results
@TRANSFORMS.register_module()
class LoadMultiChannelImageFromFiles(BaseTransform):
"""Load multi-channel images from a list of separate channel files.
Required Keys:
- img_path
Modified Keys:
- img
- img_shape
- ori_shape
Args:
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
color_type (str): The flag argument for :func:``mmcv.imfrombytes``.
Defaults to 'unchanged'.
imdecode_backend (str): The image decoding backend type. The backend
argument for :func:``mmcv.imfrombytes``.
See :func:``mmcv.imfrombytes`` for details.
Defaults to 'cv2'.
file_client_args (dict): Arguments to instantiate the
corresponding backend in mmdet <= 3.0.0rc6. Defaults to None.
backend_args (dict, optional): Arguments to instantiate the
corresponding backend in mmdet >= 3.0.0rc7. Defaults to None.
"""
def __init__(
self,
to_float32: bool = False,
color_type: str = 'unchanged',
imdecode_backend: str = 'cv2',
file_client_args: dict = None,
backend_args: dict = None,
) -> None:
self.to_float32 = to_float32
self.color_type = color_type
self.imdecode_backend = imdecode_backend
self.backend_args = backend_args
if file_client_args is not None:
raise RuntimeError(
'The `file_client_args` is deprecated, '
'please use `backend_args` instead, please refer to'
'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501
)
def transform(self, results: dict) -> dict:
"""Transform functions to load multiple images and get images meta
information.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded images and meta information.
"""
assert isinstance(results['img_path'], list)
img = []
for name in results['img_path']:
img_bytes = get(name, backend_args=self.backend_args)
img.append(
mmcv.imfrombytes(
img_bytes,
flag=self.color_type,
backend=self.imdecode_backend))
img = np.stack(img, axis=-1)
if self.to_float32:
img = img.astype(np.float32)
results['img'] = img
results['img_shape'] = img.shape[:2]
results['ori_shape'] = img.shape[:2]
return results
def __repr__(self):
repr_str = (f'{self.__class__.__name__}('
f'to_float32={self.to_float32}, '
f"color_type='{self.color_type}', "
f"imdecode_backend='{self.imdecode_backend}', "
f'backend_args={self.backend_args})')
return repr_str
@TRANSFORMS.register_module()
class LoadAnnotations(MMCV_LoadAnnotations):
"""Load and process the ``instances`` and ``seg_map`` annotation provided
by dataset.
The annotation format is as the following:
.. code-block:: python
{
'instances':
[
{
# List of 4 numbers representing the bounding box of the
# instance, in (x1, y1, x2, y2) order.
'bbox': [x1, y1, x2, y2],
# Label of image classification.
'bbox_label': 1,
# Used in instance/panoptic segmentation. The segmentation mask
# of the instance or the information of segments.
# 1. If list[list[float]], it represents a list of polygons,
# one for each connected component of the object. Each
# list[float] is one simple polygon in the format of
# [x1, y1, ..., xn, yn] (n >= 3). The Xs and Ys are absolute
# coordinates in unit of pixels.
# 2. If dict, it represents the per-pixel segmentation mask in
# COCO's compressed RLE format. The dict should have keys
# “size” and “counts”. Can be loaded by pycocotools
'mask': list[list[float]] or dict,
}
]
# Filename of semantic or panoptic segmentation ground truth file.
'seg_map_path': 'a/b/c'
}
After this module, the annotation has been changed to the format below:
.. code-block:: python
{
# In (x1, y1, x2, y2) order, float type. N is the number of bboxes
# in an image
'gt_bboxes': BaseBoxes(N, 4)
# In int type.
'gt_bboxes_labels': np.ndarray(N, )
# In built-in class
'gt_masks': PolygonMasks (H, W) or BitmapMasks (H, W)
# In uint8 type.
'gt_seg_map': np.ndarray (H, W)
# in (x, y, v) order, float type.
}
Required Keys:
- height
- width
- instances
- bbox (optional)
- bbox_label
- mask (optional)
- ignore_flag
- seg_map_path (optional)
Added Keys:
- gt_bboxes (BaseBoxes[torch.float32])
- gt_bboxes_labels (np.int64)
- gt_masks (BitmapMasks | PolygonMasks)
- gt_seg_map (np.uint8)
- gt_ignore_flags (bool)
Args:
with_bbox (bool): Whether to parse and load the bbox annotation.
Defaults to True.
with_label (bool): Whether to parse and load the label annotation.
Defaults to True.
with_mask (bool): Whether to parse and load the mask annotation.
Default: False.
with_seg (bool): Whether to parse and load the semantic segmentation
annotation. Defaults to False.
poly2mask (bool): Whether to convert mask to bitmap. Default: True.
box_type (str): The box type used to wrap the bboxes. If ``box_type``
is None, gt_bboxes will keep being np.ndarray. Defaults to 'hbox'.
reduce_zero_label (bool): Whether reduce all label value
by 1. Usually used for datasets where 0 is background label.
Defaults to False.
ignore_index (int): The label index to be ignored.
Valid only if reduce_zero_label is true. Defaults is 255.
imdecode_backend (str): The image decoding backend type. The backend
argument for :func:``mmcv.imfrombytes``.
See :fun:``mmcv.imfrombytes`` for details.
Defaults to 'cv2'.
backend_args (dict, optional): Arguments to instantiate the
corresponding backend. Defaults to None.
"""
def __init__(
self,
with_mask: bool = False,
poly2mask: bool = True,
box_type: str = 'hbox',
# use for semseg
reduce_zero_label: bool = False,
ignore_index: int = 255,
**kwargs) -> None:
super(LoadAnnotations, self).__init__(**kwargs)
self.with_mask = with_mask
self.poly2mask = poly2mask
self.box_type = box_type
self.reduce_zero_label = reduce_zero_label
self.ignore_index = ignore_index
def _load_bboxes(self, results: dict) -> None:
"""Private function to load bounding box annotations.
Args:
results (dict): Result dict from :obj:``mmengine.BaseDataset``.
Returns:
dict: The dict contains loaded bounding box annotations.
"""
gt_bboxes = []
gt_ignore_flags = []
for instance in results.get('instances', []):
gt_bboxes.append(instance['bbox'])
gt_ignore_flags.append(instance['ignore_flag'])
if self.box_type is None:
results['gt_bboxes'] = np.array(
gt_bboxes, dtype=np.float32).reshape((-1, 4))
else:
_, box_type_cls = get_box_type(self.box_type)
results['gt_bboxes'] = box_type_cls(gt_bboxes, dtype=torch.float32)
results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool)
def _load_labels(self, results: dict) -> None:
"""Private function to load label annotations.
Args:
results (dict): Result dict from :obj:``mmengine.BaseDataset``.
Returns:
dict: The dict contains loaded label annotations.
"""
gt_bboxes_labels = []
for instance in results.get('instances', []):
gt_bboxes_labels.append(instance['bbox_label'])
# TODO: Inconsistent with mmcv, consider how to deal with it later.
results['gt_bboxes_labels'] = np.array(
gt_bboxes_labels, dtype=np.int64)
def _poly2mask(self, mask_ann: Union[list, dict], img_h: int,
img_w: int) -> np.ndarray:
"""Private function to convert masks represented with polygon to
bitmaps.
Args:
mask_ann (list | dict): Polygon mask annotation input.
img_h (int): The height of output mask.
img_w (int): The width of output mask.
Returns:
np.ndarray: The decode bitmap mask of shape (img_h, img_w).
"""
if isinstance(mask_ann, list):
# polygon -- a single object might consist of multiple parts
# we merge all parts into one mask rle code
rles = maskUtils.frPyObjects(mask_ann, img_h, img_w)
rle = maskUtils.merge(rles)
elif isinstance(mask_ann['counts'], list):
# uncompressed RLE
rle = maskUtils.frPyObjects(mask_ann, img_h, img_w)
else:
# rle
rle = mask_ann
mask = maskUtils.decode(rle)
return mask
def _process_masks(self, results: dict) -> list:
"""Process gt_masks and filter invalid polygons.
Args:
results (dict): Result dict from :obj:``mmengine.BaseDataset``.
Returns:
list: Processed gt_masks.
"""
gt_masks = []
gt_ignore_flags = []
for instance in results.get('instances', []):
gt_mask = instance['mask']
# If the annotation of segmentation mask is invalid,
# ignore the whole instance.
if isinstance(gt_mask, list):
gt_mask = [
np.array(polygon) for polygon in gt_mask
if len(polygon) % 2 == 0 and len(polygon) >= 6
]
if len(gt_mask) == 0:
# ignore this instance and set gt_mask to a fake mask
instance['ignore_flag'] = 1
gt_mask = [np.zeros(6)]
elif not self.poly2mask:
# `PolygonMasks` requires a ploygon of format List[np.array],
# other formats are invalid.
instance['ignore_flag'] = 1
gt_mask = [np.zeros(6)]
elif isinstance(gt_mask, dict) and \
not (gt_mask.get('counts') is not None and
gt_mask.get('size') is not None and
isinstance(gt_mask['counts'], (list, str))):
# if gt_mask is a dict, it should include `counts` and `size`,
# so that `BitmapMasks` can uncompressed RLE
instance['ignore_flag'] = 1
gt_mask = [np.zeros(6)]
gt_masks.append(gt_mask)
# re-process gt_ignore_flags
gt_ignore_flags.append(instance['ignore_flag'])
results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool)
return gt_masks
def _load_masks(self, results: dict) -> None:
"""Private function to load mask annotations.
Args:
results (dict): Result dict from :obj:``mmengine.BaseDataset``.
"""
h, w = results['ori_shape']
gt_masks = self._process_masks(results)
if self.poly2mask:
gt_masks = BitmapMasks(
[self._poly2mask(mask, h, w) for mask in gt_masks], h, w)
else:
# fake polygon masks will be ignored in `PackDetInputs`
gt_masks = PolygonMasks([mask for mask in gt_masks], h, w)
results['gt_masks'] = gt_masks
def _load_seg_map(self, results: dict) -> None:
"""Private function to load semantic segmentation annotations.
Args:
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
Returns:
dict: The dict contains loaded semantic segmentation annotations.
"""
if results.get('seg_map_path', None) is None:
return
img_bytes = get(
results['seg_map_path'], backend_args=self.backend_args)
gt_semantic_seg = mmcv.imfrombytes(
img_bytes, flag='unchanged',
backend=self.imdecode_backend).squeeze()
if self.reduce_zero_label:
# avoid using underflow conversion
gt_semantic_seg[gt_semantic_seg == 0] = self.ignore_index
gt_semantic_seg = gt_semantic_seg - 1
gt_semantic_seg[gt_semantic_seg == self.ignore_index -
1] = self.ignore_index
# modify if custom classes
if results.get('label_map', None) is not None:
# Add deep copy to solve bug of repeatedly
# replace `gt_semantic_seg`, which is reported in
# https://github.com/open-mmlab/mmsegmentation/pull/1445/
gt_semantic_seg_copy = gt_semantic_seg.copy()
for old_id, new_id in results['label_map'].items():
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
results['gt_seg_map'] = gt_semantic_seg
results['ignore_index'] = self.ignore_index
def transform(self, results: dict) -> dict:
"""Function to load multiple types annotations.
Args:
results (dict): Result dict from :obj:``mmengine.BaseDataset``.
Returns:
dict: The dict contains loaded bounding box, label and
semantic segmentation.
"""
if self.with_bbox:
self._load_bboxes(results)
if self.with_label:
self._load_labels(results)
if self.with_mask:
self._load_masks(results)
if self.with_seg:
self._load_seg_map(results)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(with_bbox={self.with_bbox}, '
repr_str += f'with_label={self.with_label}, '
repr_str += f'with_mask={self.with_mask}, '
repr_str += f'with_seg={self.with_seg}, '
repr_str += f'poly2mask={self.poly2mask}, '
repr_str += f"imdecode_backend='{self.imdecode_backend}', "
repr_str += f'backend_args={self.backend_args})'
return repr_str
@TRANSFORMS.register_module()
class LoadPanopticAnnotations(LoadAnnotations):
"""Load multiple types of panoptic annotations.
The annotation format is as the following:
.. code-block:: python
{
'instances':
[
{
# List of 4 numbers representing the bounding box of the
# instance, in (x1, y1, x2, y2) order.
'bbox': [x1, y1, x2, y2],
# Label of image classification.
'bbox_label': 1,
},
...
]
'segments_info':
[
{
# id = cls_id + instance_id * INSTANCE_OFFSET
'id': int,
# Contiguous category id defined in dataset.
'category': int
# Thing flag.
'is_thing': bool
},
...
]
# Filename of semantic or panoptic segmentation ground truth file.
'seg_map_path': 'a/b/c'
}
After this module, the annotation has been changed to the format below:
.. code-block:: python
{
# In (x1, y1, x2, y2) order, float type. N is the number of bboxes
# in an image
'gt_bboxes': BaseBoxes(N, 4)
# In int type.
'gt_bboxes_labels': np.ndarray(N, )
# In built-in class
'gt_masks': PolygonMasks (H, W) or BitmapMasks (H, W)
# In uint8 type.
'gt_seg_map': np.ndarray (H, W)
# in (x, y, v) order, float type.
}
Required Keys:
- height
- width
- instances
- bbox
- bbox_label
- ignore_flag
- segments_info
- id
- category
- is_thing
- seg_map_path
Added Keys:
- gt_bboxes (BaseBoxes[torch.float32])
- gt_bboxes_labels (np.int64)
- gt_masks (BitmapMasks | PolygonMasks)
- gt_seg_map (np.uint8)
- gt_ignore_flags (bool)
Args:
with_bbox (bool): Whether to parse and load the bbox annotation.
Defaults to True.
with_label (bool): Whether to parse and load the label annotation.
Defaults to True.
with_mask (bool): Whether to parse and load the mask annotation.
Defaults to True.
with_seg (bool): Whether to parse and load the semantic segmentation
annotation. Defaults to False.
box_type (str): The box mode used to wrap the bboxes.
imdecode_backend (str): The image decoding backend type. The backend
argument for :func:``mmcv.imfrombytes``.
See :fun:``mmcv.imfrombytes`` for details.
Defaults to 'cv2'.
backend_args (dict, optional): Arguments to instantiate the
corresponding backend in mmdet >= 3.0.0rc7. Defaults to None.
"""
def __init__(self,
with_bbox: bool = True,
with_label: bool = True,
with_mask: bool = True,
with_seg: bool = True,
box_type: str = 'hbox',
imdecode_backend: str = 'cv2',
backend_args: dict = None) -> None:
try:
from panopticapi import utils
except ImportError:
raise ImportError(
'panopticapi is not installed, please install it by: '
'pip install git+https://github.com/cocodataset/'
'panopticapi.git.')
self.rgb2id = utils.rgb2id
super(LoadPanopticAnnotations, self).__init__(
with_bbox=with_bbox,
with_label=with_label,
with_mask=with_mask,
with_seg=with_seg,
with_keypoints=False,
box_type=box_type,
imdecode_backend=imdecode_backend,
backend_args=backend_args)
def _load_masks_and_semantic_segs(self, results: dict) -> None:
"""Private function to load mask and semantic segmentation annotations.
In gt_semantic_seg, the foreground label is from ``0`` to
``num_things - 1``, the background label is from ``num_things`` to
``num_things + num_stuff - 1``, 255 means the ignored label (``VOID``).
Args:
results (dict): Result dict from :obj:``mmdet.CustomDataset``.
"""
# seg_map_path is None, when inference on the dataset without gts.
if results.get('seg_map_path', None) is None:
return
img_bytes = get(
results['seg_map_path'], backend_args=self.backend_args)
pan_png = mmcv.imfrombytes(
img_bytes, flag='color', channel_order='rgb').squeeze()
pan_png = self.rgb2id(pan_png)
gt_masks = []
gt_seg = np.zeros_like(pan_png) + 255 # 255 as ignore
for segment_info in results['segments_info']:
mask = (pan_png == segment_info['id'])
gt_seg = np.where(mask, segment_info['category'], gt_seg)
# The legal thing masks
if segment_info.get('is_thing'):
gt_masks.append(mask.astype(np.uint8))
if self.with_mask:
h, w = results['ori_shape']
gt_masks = BitmapMasks(gt_masks, h, w)
results['gt_masks'] = gt_masks
if self.with_seg:
results['gt_seg_map'] = gt_seg
def transform(self, results: dict) -> dict:
"""Function to load multiple types panoptic annotations.
Args:
results (dict): Result dict from :obj:``mmdet.CustomDataset``.
Returns:
dict: The dict contains loaded bounding box, label, mask and
semantic segmentation annotations.
"""
if self.with_bbox:
self._load_bboxes(results)
if self.with_label:
self._load_labels(results)
if self.with_mask or self.with_seg:
# The tasks completed by '_load_masks' and '_load_semantic_segs'
# in LoadAnnotations are merged to one function.
self._load_masks_and_semantic_segs(results)
return results
@TRANSFORMS.register_module()
class LoadProposals(BaseTransform):
"""Load proposal pipeline.
Required Keys:
- proposals
Modified Keys:
- proposals
Args:
num_max_proposals (int, optional): Maximum number of proposals to load.
If not specified, all proposals will be loaded.
"""
def __init__(self, num_max_proposals: Optional[int] = None) -> None:
self.num_max_proposals = num_max_proposals
def transform(self, results: dict) -> dict:
"""Transform function to load proposals from file.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded proposal annotations.
"""
proposals = results['proposals']
# the type of proposals should be `dict` or `InstanceData`
assert isinstance(proposals, dict) \
or isinstance(proposals, BaseDataElement)
bboxes = proposals['bboxes'].astype(np.float32)
assert bboxes.shape[1] == 4, \
f'Proposals should have shapes (n, 4), but found {bboxes.shape}'
if 'scores' in proposals:
scores = proposals['scores'].astype(np.float32)
assert bboxes.shape[0] == scores.shape[0]
else:
scores = np.zeros(bboxes.shape[0], dtype=np.float32)
if self.num_max_proposals is not None:
# proposals should sort by scores during dumping the proposals
bboxes = bboxes[:self.num_max_proposals]
scores = scores[:self.num_max_proposals]
if len(bboxes) == 0:
bboxes = np.zeros((0, 4), dtype=np.float32)
scores = np.zeros(0, dtype=np.float32)
results['proposals'] = bboxes
results['proposals_scores'] = scores
return results
def __repr__(self):
return self.__class__.__name__ + \
f'(num_max_proposals={self.num_max_proposals})'
@TRANSFORMS.register_module()
class FilterAnnotations(BaseTransform):
"""Filter invalid annotations.
Required Keys:
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_bboxes_labels (np.int64) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_ignore_flags (bool) (optional)
Modified Keys:
- gt_bboxes (optional)
- gt_bboxes_labels (optional)
- gt_masks (optional)
- gt_ignore_flags (optional)
Args:
min_gt_bbox_wh (tuple[float]): Minimum width and height of ground truth
boxes. Default: (1., 1.)
min_gt_mask_area (int): Minimum foreground area of ground truth masks.
Default: 1
by_box (bool): Filter instances with bounding boxes not meeting the
min_gt_bbox_wh threshold. Default: True
by_mask (bool): Filter instances with masks not meeting
min_gt_mask_area threshold. Default: False
keep_empty (bool): Whether to return None when it
becomes an empty bbox after filtering. Defaults to True.
"""
def __init__(self,
min_gt_bbox_wh: Tuple[int, int] = (1, 1),
min_gt_mask_area: int = 1,
by_box: bool = True,
by_mask: bool = False,
keep_empty: bool = True) -> None:
# TODO: add more filter options
assert by_box or by_mask
self.min_gt_bbox_wh = min_gt_bbox_wh
self.min_gt_mask_area = min_gt_mask_area
self.by_box = by_box
self.by_mask = by_mask
self.keep_empty = keep_empty
@autocast_box_type()
def transform(self, results: dict) -> Union[dict, None]:
"""Transform function to filter annotations.
Args:
results (dict): Result dict.
Returns:
dict: Updated result dict.
"""
assert 'gt_bboxes' in results
gt_bboxes = results['gt_bboxes']
if gt_bboxes.shape[0] == 0:
return results
tests = []
if self.by_box:
tests.append(
((gt_bboxes.widths > self.min_gt_bbox_wh[0]) &
(gt_bboxes.heights > self.min_gt_bbox_wh[1])).numpy())
if self.by_mask:
assert 'gt_masks' in results
gt_masks = results['gt_masks']
tests.append(gt_masks.areas >= self.min_gt_mask_area)
keep = tests[0]
for t in tests[1:]:
keep = keep & t
if not keep.any():
if self.keep_empty:
return None
keys = ('gt_bboxes', 'gt_bboxes_labels', 'gt_masks', 'gt_ignore_flags')
for key in keys:
if key in results:
results[key] = results[key][keep]
return results
def __repr__(self):
return self.__class__.__name__ + \
f'(min_gt_bbox_wh={self.min_gt_bbox_wh}, ' \
f'keep_empty={self.keep_empty})'
@TRANSFORMS.register_module()
class LoadEmptyAnnotations(BaseTransform):
"""Load Empty Annotations for unlabeled images.
Added Keys:
- gt_bboxes (np.float32)
- gt_bboxes_labels (np.int64)
- gt_masks (BitmapMasks | PolygonMasks)
- gt_seg_map (np.uint8)
- gt_ignore_flags (bool)
Args:
with_bbox (bool): Whether to load the pseudo bbox annotation.
Defaults to True.
with_label (bool): Whether to load the pseudo label annotation.
Defaults to True.
with_mask (bool): Whether to load the pseudo mask annotation.
Default: False.
with_seg (bool): Whether to load the pseudo semantic segmentation
annotation. Defaults to False.
seg_ignore_label (int): The fill value used for segmentation map.
Note this value must equals ``ignore_label`` in ``semantic_head``
of the corresponding config. Defaults to 255.
"""
def __init__(self,
with_bbox: bool = True,
with_label: bool = True,
with_mask: bool = False,
with_seg: bool = False,
seg_ignore_label: int = 255) -> None:
self.with_bbox = with_bbox
self.with_label = with_label
self.with_mask = with_mask
self.with_seg = with_seg
self.seg_ignore_label = seg_ignore_label
def transform(self, results: dict) -> dict:
"""Transform function to load empty annotations.
Args:
results (dict): Result dict.
Returns:
dict: Updated result dict.
"""
if self.with_bbox:
results['gt_bboxes'] = np.zeros((0, 4), dtype=np.float32)
results['gt_ignore_flags'] = np.zeros((0, ), dtype=bool)
if self.with_label:
results['gt_bboxes_labels'] = np.zeros((0, ), dtype=np.int64)
if self.with_mask:
# TODO: support PolygonMasks
h, w = results['img_shape']
gt_masks = np.zeros((0, h, w), dtype=np.uint8)
results['gt_masks'] = BitmapMasks(gt_masks, h, w)
if self.with_seg:
h, w = results['img_shape']
results['gt_seg_map'] = self.seg_ignore_label * np.ones(
(h, w), dtype=np.uint8)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(with_bbox={self.with_bbox}, '
repr_str += f'with_label={self.with_label}, '
repr_str += f'with_mask={self.with_mask}, '
repr_str += f'with_seg={self.with_seg}, '
repr_str += f'seg_ignore_label={self.seg_ignore_label})'
return repr_str
@TRANSFORMS.register_module()
class InferencerLoader(BaseTransform):
"""Load an image from ``results['img']``.
Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
:obj:`np.ndarray` in ``results['img']``. Can be used when loading image
from webcam.
Required Keys:
- img
Modified Keys:
- img
- img_path
- img_shape
- ori_shape
Args:
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
"""
def __init__(self, **kwargs) -> None:
super().__init__()
self.from_file = TRANSFORMS.build(
dict(type='LoadImageFromFile', **kwargs))
self.from_ndarray = TRANSFORMS.build(
dict(type='mmdet.LoadImageFromNDArray', **kwargs))
def transform(self, results: Union[str, np.ndarray, dict]) -> dict:
"""Transform function to add image meta information.
Args:
results (str, np.ndarray or dict): The result.
Returns:
dict: The dict contains loaded image and meta information.
"""
if isinstance(results, str):
inputs = dict(img_path=results)
elif isinstance(results, np.ndarray):
inputs = dict(img=results)
elif isinstance(results, dict):
inputs = results
else:
raise NotImplementedError
if 'img' in inputs:
return self.from_ndarray(inputs)
return self.from_file(inputs)
@TRANSFORMS.register_module()
class LoadTrackAnnotations(LoadAnnotations):
"""Load and process the ``instances`` and ``seg_map`` annotation provided
by dataset. It must load ``instances_ids`` which is only used in the
tracking tasks. The annotation format is as the following:
.. code-block:: python
{
'instances':
[
{
# List of 4 numbers representing the bounding box of the
# instance, in (x1, y1, x2, y2) order.
'bbox': [x1, y1, x2, y2],
# Label of image classification.
'bbox_label': 1,
# Used in tracking.
# Id of instances.
'instance_id': 100,
# Used in instance/panoptic segmentation. The segmentation mask
# of the instance or the information of segments.
# 1. If list[list[float]], it represents a list of polygons,
# one for each connected component of the object. Each
# list[float] is one simple polygon in the format of
# [x1, y1, ..., xn, yn] (n >= 3). The Xs and Ys are absolute
# coordinates in unit of pixels.
# 2. If dict, it represents the per-pixel segmentation mask in
# COCO's compressed RLE format. The dict should have keys
# “size” and “counts”. Can be loaded by pycocotools
'mask': list[list[float]] or dict,
}
]
# Filename of semantic or panoptic segmentation ground truth file.
'seg_map_path': 'a/b/c'
}
After this module, the annotation has been changed to the format below:
.. code-block:: python
{
# In (x1, y1, x2, y2) order, float type. N is the number of bboxes
# in an image
'gt_bboxes': np.ndarray(N, 4)
# In int type.
'gt_bboxes_labels': np.ndarray(N, )
# In built-in class
'gt_masks': PolygonMasks (H, W) or BitmapMasks (H, W)
# In uint8 type.
'gt_seg_map': np.ndarray (H, W)
# in (x, y, v) order, float type.
}
Required Keys:
- height (optional)
- width (optional)
- instances
- bbox (optional)
- bbox_label
- instance_id (optional)
- mask (optional)
- ignore_flag (optional)
- seg_map_path (optional)
Added Keys:
- gt_bboxes (np.float32)
- gt_bboxes_labels (np.int32)
- gt_instances_ids (np.int32)
- gt_masks (BitmapMasks | PolygonMasks)
- gt_seg_map (np.uint8)
- gt_ignore_flags (np.bool)
"""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
def _load_bboxes(self, results: dict) -> None:
"""Private function to load bounding box annotations.
Args:
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
Returns:
dict: The dict contains loaded bounding box annotations.
"""
gt_bboxes = []
gt_ignore_flags = []
# TODO: use bbox_type
for instance in results['instances']:
# The datasets which are only format in evaluation don't have
# groundtruth boxes.
if 'bbox' in instance:
gt_bboxes.append(instance['bbox'])
if 'ignore_flag' in instance:
gt_ignore_flags.append(instance['ignore_flag'])
# TODO: check this case
if len(gt_bboxes) != len(gt_ignore_flags):
# There may be no ``gt_ignore_flags`` in some cases, we treat them
# as all False in order to keep the length of ``gt_bboxes`` and
# ``gt_ignore_flags`` the same
gt_ignore_flags = [False] * len(gt_bboxes)
results['gt_bboxes'] = np.array(
gt_bboxes, dtype=np.float32).reshape(-1, 4)
results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool)
def _load_instances_ids(self, results: dict) -> None:
"""Private function to load instances id annotations.
Args:
results (dict): Result dict from :obj :obj:``mmcv.BaseDataset``.
Returns:
dict: The dict containing instances id annotations.
"""
gt_instances_ids = []
for instance in results['instances']:
gt_instances_ids.append(instance['instance_id'])
results['gt_instances_ids'] = np.array(
gt_instances_ids, dtype=np.int32)
def transform(self, results: dict) -> dict:
"""Function to load multiple types annotations.
Args:
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
Returns:
dict: The dict contains loaded bounding box, label, instances id
and semantic segmentation and keypoints annotations.
"""
results = super().transform(results)
self._load_instances_ids(results)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(with_bbox={self.with_bbox}, '
repr_str += f'with_label={self.with_label}, '
repr_str += f'with_mask={self.with_mask}, '
repr_str += f'with_seg={self.with_seg}, '
repr_str += f'poly2mask={self.poly2mask}, '
repr_str += f"imdecode_backend='{self.imdecode_backend}', "
repr_str += f'file_client_args={self.file_client_args})'
return repr_str
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
from mmcv.transforms import BaseTransform
from mmdet.registry import TRANSFORMS
from mmdet.structures.bbox import HorizontalBoxes, autocast_box_type
from .transforms import RandomFlip
@TRANSFORMS.register_module()
class GTBoxSubOne_GLIP(BaseTransform):
"""Subtract 1 from the x2 and y2 coordinates of the gt_bboxes."""
def transform(self, results: dict) -> dict:
if 'gt_bboxes' in results:
gt_bboxes = results['gt_bboxes']
if isinstance(gt_bboxes, np.ndarray):
gt_bboxes[:, 2:] -= 1
results['gt_bboxes'] = gt_bboxes
elif isinstance(gt_bboxes, HorizontalBoxes):
gt_bboxes = results['gt_bboxes'].tensor
gt_bboxes[:, 2:] -= 1
results['gt_bboxes'] = HorizontalBoxes(gt_bboxes)
else:
raise NotImplementedError
return results
@TRANSFORMS.register_module()
class RandomFlip_GLIP(RandomFlip):
"""Flip the image & bboxes & masks & segs horizontally or vertically.
When using horizontal flipping, the corresponding bbox x-coordinate needs
to be additionally subtracted by one.
"""
@autocast_box_type()
def _flip(self, results: dict) -> None:
"""Flip images, bounding boxes, and semantic segmentation map."""
# flip image
results['img'] = mmcv.imflip(
results['img'], direction=results['flip_direction'])
img_shape = results['img'].shape[:2]
# flip bboxes
if results.get('gt_bboxes', None) is not None:
results['gt_bboxes'].flip_(img_shape, results['flip_direction'])
# Only change this line
if results['flip_direction'] == 'horizontal':
results['gt_bboxes'].translate_([-1, 0])
# TODO: check it
# flip masks
if results.get('gt_masks', None) is not None:
results['gt_masks'] = results['gt_masks'].flip(
results['flip_direction'])
# flip segs
if results.get('gt_seg_map', None) is not None:
results['gt_seg_map'] = mmcv.imflip(
results['gt_seg_map'], direction=results['flip_direction'])
# record homography matrix for flip
self._record_homography_matrix(results)
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import inspect
import math
import warnings
from typing import List, Optional, Sequence, Tuple, Union
import cv2
import mmcv
import numpy as np
from mmcv.image import imresize
from mmcv.image.geometric import _scale_size
from mmcv.transforms import BaseTransform
from mmcv.transforms import Pad as MMCV_Pad
from mmcv.transforms import RandomFlip as MMCV_RandomFlip
from mmcv.transforms import Resize as MMCV_Resize
from mmcv.transforms.utils import avoid_cache_randomness, cache_randomness
from mmengine.dataset import BaseDataset
from mmengine.utils import is_str
from numpy import random
from mmdet.registry import TRANSFORMS
from mmdet.structures.bbox import HorizontalBoxes, autocast_box_type
from mmdet.structures.mask import BitmapMasks, PolygonMasks
from mmdet.utils import log_img_scale
try:
from imagecorruptions import corrupt
except ImportError:
corrupt = None
try:
import albumentations
from albumentations import Compose
except ImportError:
albumentations = None
Compose = None
Number = Union[int, float]
def _fixed_scale_size(
size: Tuple[int, int],
scale: Union[float, int, tuple],
) -> Tuple[int, int]:
"""Rescale a size by a ratio.
Args:
size (tuple[int]): (w, h).
scale (float | tuple(float)): Scaling factor.
Returns:
tuple[int]: scaled size.
"""
if isinstance(scale, (float, int)):
scale = (scale, scale)
w, h = size
# don't need o.5 offset
return int(w * float(scale[0])), int(h * float(scale[1]))
def rescale_size(old_size: tuple,
scale: Union[float, int, tuple],
return_scale: bool = False) -> tuple:
"""Calculate the new size to be rescaled to.
Args:
old_size (tuple[int]): The old size (w, h) of image.
scale (float | tuple[int]): The scaling factor or maximum size.
If it is a float number, then the image will be rescaled by this
factor, else if it is a tuple of 2 integers, then the image will
be rescaled as large as possible within the scale.
return_scale (bool): Whether to return the scaling factor besides the
rescaled image size.
Returns:
tuple[int]: The new rescaled image size.
"""
w, h = old_size
if isinstance(scale, (float, int)):
if scale <= 0:
raise ValueError(f'Invalid scale {scale}, must be positive.')
scale_factor = scale
elif isinstance(scale, tuple):
max_long_edge = max(scale)
max_short_edge = min(scale)
scale_factor = min(max_long_edge / max(h, w),
max_short_edge / min(h, w))
else:
raise TypeError(
f'Scale must be a number or tuple of int, but got {type(scale)}')
# only change this
new_size = _fixed_scale_size((w, h), scale_factor)
if return_scale:
return new_size, scale_factor
else:
return new_size
def imrescale(
img: np.ndarray,
scale: Union[float, Tuple[int, int]],
return_scale: bool = False,
interpolation: str = 'bilinear',
backend: Optional[str] = None
) -> Union[np.ndarray, Tuple[np.ndarray, float]]:
"""Resize image while keeping the aspect ratio.
Args:
img (ndarray): The input image.
scale (float | tuple[int]): The scaling factor or maximum size.
If it is a float number, then the image will be rescaled by this
factor, else if it is a tuple of 2 integers, then the image will
be rescaled as large as possible within the scale.
return_scale (bool): Whether to return the scaling factor besides the
rescaled image.
interpolation (str): Same as :func:`resize`.
backend (str | None): Same as :func:`resize`.
Returns:
ndarray: The rescaled image.
"""
h, w = img.shape[:2]
new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
rescaled_img = imresize(
img, new_size, interpolation=interpolation, backend=backend)
if return_scale:
return rescaled_img, scale_factor
else:
return rescaled_img
@TRANSFORMS.register_module()
class Resize(MMCV_Resize):
"""Resize images & bbox & seg.
This transform resizes the input image according to ``scale`` or
``scale_factor``. Bboxes, masks, and seg map are then resized
with the same scale factor.
if ``scale`` and ``scale_factor`` are both set, it will use ``scale`` to
resize.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- img_shape
- gt_bboxes
- gt_masks
- gt_seg_map
Added Keys:
- scale
- scale_factor
- keep_ratio
- homography_matrix
Args:
scale (int or tuple): Images scales for resizing. Defaults to None
scale_factor (float or tuple[float]): Scale factors for resizing.
Defaults to None.
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
image. Defaults to False.
clip_object_border (bool): Whether to clip the objects
outside the border of the image. In some dataset like MOT17, the gt
bboxes are allowed to cross the border of images. Therefore, we
don't need to clip the gt bboxes in these cases. Defaults to True.
backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
These two backends generates slightly different results. Defaults
to 'cv2'.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear" for 'pillow' backend. Defaults
to 'bilinear'.
"""
def _resize_masks(self, results: dict) -> None:
"""Resize masks with ``results['scale']``"""
if results.get('gt_masks', None) is not None:
if self.keep_ratio:
results['gt_masks'] = results['gt_masks'].rescale(
results['scale'])
else:
results['gt_masks'] = results['gt_masks'].resize(
results['img_shape'])
def _resize_bboxes(self, results: dict) -> None:
"""Resize bounding boxes with ``results['scale_factor']``."""
if results.get('gt_bboxes', None) is not None:
results['gt_bboxes'].rescale_(results['scale_factor'])
if self.clip_object_border:
results['gt_bboxes'].clip_(results['img_shape'])
def _record_homography_matrix(self, results: dict) -> None:
"""Record the homography matrix for the Resize."""
w_scale, h_scale = results['scale_factor']
homography_matrix = np.array(
[[w_scale, 0, 0], [0, h_scale, 0], [0, 0, 1]], dtype=np.float32)
if results.get('homography_matrix', None) is None:
results['homography_matrix'] = homography_matrix
else:
results['homography_matrix'] = homography_matrix @ results[
'homography_matrix']
@autocast_box_type()
def transform(self, results: dict) -> dict:
"""Transform function to resize images, bounding boxes and semantic
segmentation map.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img', 'gt_bboxes', 'gt_seg_map',
'scale', 'scale_factor', 'height', 'width', and 'keep_ratio' keys
are updated in result dict.
"""
if self.scale:
results['scale'] = self.scale
else:
img_shape = results['img'].shape[:2]
results['scale'] = _scale_size(img_shape[::-1], self.scale_factor)
self._resize_img(results)
self._resize_bboxes(results)
self._resize_masks(results)
self._resize_seg(results)
self._record_homography_matrix(results)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(scale={self.scale}, '
repr_str += f'scale_factor={self.scale_factor}, '
repr_str += f'keep_ratio={self.keep_ratio}, '
repr_str += f'clip_object_border={self.clip_object_border}), '
repr_str += f'backend={self.backend}), '
repr_str += f'interpolation={self.interpolation})'
return repr_str
@TRANSFORMS.register_module()
class FixScaleResize(Resize):
"""Compared to Resize, FixScaleResize fixes the scaling issue when
`keep_ratio=true`."""
def _resize_img(self, results):
"""Resize images with ``results['scale']``."""
if results.get('img', None) is not None:
if self.keep_ratio:
img, scale_factor = imrescale(
results['img'],
results['scale'],
interpolation=self.interpolation,
return_scale=True,
backend=self.backend)
new_h, new_w = img.shape[:2]
h, w = results['img'].shape[:2]
w_scale = new_w / w
h_scale = new_h / h
else:
img, w_scale, h_scale = mmcv.imresize(
results['img'],
results['scale'],
interpolation=self.interpolation,
return_scale=True,
backend=self.backend)
results['img'] = img
results['img_shape'] = img.shape[:2]
results['scale_factor'] = (w_scale, h_scale)
results['keep_ratio'] = self.keep_ratio
@TRANSFORMS.register_module()
class ResizeShortestEdge(BaseTransform):
"""Resize the image and mask while keeping the aspect ratio unchanged.
Modified from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/transforms/augmentation_impl.py#L130 # noqa:E501
This transform attempts to scale the shorter edge to the given
`scale`, as long as the longer edge does not exceed `max_size`.
If `max_size` is reached, then downscale so that the longer
edge does not exceed `max_size`.
Required Keys:
- img
- gt_seg_map (optional)
Modified Keys:
- img
- img_shape
- gt_seg_map (optional))
Added Keys:
- scale
- scale_factor
- keep_ratio
Args:
scale (Union[int, Tuple[int, int]]): The target short edge length.
If it's tuple, will select the min value as the short edge length.
max_size (int): The maximum allowed longest edge length.
"""
def __init__(self,
scale: Union[int, Tuple[int, int]],
max_size: Optional[int] = None,
resize_type: str = 'Resize',
**resize_kwargs) -> None:
super().__init__()
self.scale = scale
self.max_size = max_size
self.resize_cfg = dict(type=resize_type, **resize_kwargs)
self.resize = TRANSFORMS.build({'scale': 0, **self.resize_cfg})
def _get_output_shape(
self, img: np.ndarray,
short_edge_length: Union[int, Tuple[int, int]]) -> Tuple[int, int]:
"""Compute the target image shape with the given `short_edge_length`.
Args:
img (np.ndarray): The input image.
short_edge_length (Union[int, Tuple[int, int]]): The target short
edge length. If it's tuple, will select the min value as the
short edge length.
"""
h, w = img.shape[:2]
if isinstance(short_edge_length, int):
size = short_edge_length * 1.0
elif isinstance(short_edge_length, tuple):
size = min(short_edge_length) * 1.0
scale = size / min(h, w)
if h < w:
new_h, new_w = size, scale * w
else:
new_h, new_w = scale * h, size
if self.max_size and max(new_h, new_w) > self.max_size:
scale = self.max_size * 1.0 / max(new_h, new_w)
new_h *= scale
new_w *= scale
new_h = int(new_h + 0.5)
new_w = int(new_w + 0.5)
return new_w, new_h
def transform(self, results: dict) -> dict:
self.resize.scale = self._get_output_shape(results['img'], self.scale)
return self.resize(results)
@TRANSFORMS.register_module()
class FixShapeResize(Resize):
"""Resize images & bbox & seg to the specified size.
This transform resizes the input image according to ``width`` and
``height``. Bboxes, masks, and seg map are then resized
with the same parameters.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- img_shape
- gt_bboxes
- gt_masks
- gt_seg_map
Added Keys:
- scale
- scale_factor
- keep_ratio
- homography_matrix
Args:
width (int): width for resizing.
height (int): height for resizing.
Defaults to None.
pad_val (Number | dict[str, Number], optional): Padding value for if
the pad_mode is "constant". If it is a single number, the value
to pad the image is the number and to pad the semantic
segmentation map is 255. If it is a dict, it should have the
following keys:
- img: The value to pad the image.
- seg: The value to pad the semantic segmentation map.
Defaults to dict(img=0, seg=255).
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
image. Defaults to False.
clip_object_border (bool): Whether to clip the objects
outside the border of the image. In some dataset like MOT17, the gt
bboxes are allowed to cross the border of images. Therefore, we
don't need to clip the gt bboxes in these cases. Defaults to True.
backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
These two backends generates slightly different results. Defaults
to 'cv2'.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear" for 'pillow' backend. Defaults
to 'bilinear'.
"""
def __init__(self,
width: int,
height: int,
pad_val: Union[Number, dict] = dict(img=0, seg=255),
keep_ratio: bool = False,
clip_object_border: bool = True,
backend: str = 'cv2',
interpolation: str = 'bilinear') -> None:
assert width is not None and height is not None, (
'`width` and'
'`height` can not be `None`')
self.width = width
self.height = height
self.scale = (width, height)
self.backend = backend
self.interpolation = interpolation
self.keep_ratio = keep_ratio
self.clip_object_border = clip_object_border
if keep_ratio is True:
# padding to the fixed size when keep_ratio=True
self.pad_transform = Pad(size=self.scale, pad_val=pad_val)
@autocast_box_type()
def transform(self, results: dict) -> dict:
"""Transform function to resize images, bounding boxes and semantic
segmentation map.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img', 'gt_bboxes', 'gt_seg_map',
'scale', 'scale_factor', 'height', 'width', and 'keep_ratio' keys
are updated in result dict.
"""
img = results['img']
h, w = img.shape[:2]
if self.keep_ratio:
scale_factor = min(self.width / w, self.height / h)
results['scale_factor'] = (scale_factor, scale_factor)
real_w, real_h = int(w * float(scale_factor) +
0.5), int(h * float(scale_factor) + 0.5)
img, scale_factor = mmcv.imrescale(
results['img'], (real_w, real_h),
interpolation=self.interpolation,
return_scale=True,
backend=self.backend)
# the w_scale and h_scale has minor difference
# a real fix should be done in the mmcv.imrescale in the future
results['img'] = img
results['img_shape'] = img.shape[:2]
results['keep_ratio'] = self.keep_ratio
results['scale'] = (real_w, real_h)
else:
results['scale'] = (self.width, self.height)
results['scale_factor'] = (self.width / w, self.height / h)
super()._resize_img(results)
self._resize_bboxes(results)
self._resize_masks(results)
self._resize_seg(results)
self._record_homography_matrix(results)
if self.keep_ratio:
self.pad_transform(results)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(width={self.width}, height={self.height}, '
repr_str += f'keep_ratio={self.keep_ratio}, '
repr_str += f'clip_object_border={self.clip_object_border}), '
repr_str += f'backend={self.backend}), '
repr_str += f'interpolation={self.interpolation})'
return repr_str
@TRANSFORMS.register_module()
class RandomFlip(MMCV_RandomFlip):
"""Flip the image & bbox & mask & segmentation map. Added or Updated keys:
flip, flip_direction, img, gt_bboxes, and gt_seg_map. There are 3 flip
modes:
- ``prob`` is float, ``direction`` is string: the image will be
``direction``ly flipped with probability of ``prob`` .
E.g., ``prob=0.5``, ``direction='horizontal'``,
then image will be horizontally flipped with probability of 0.5.
- ``prob`` is float, ``direction`` is list of string: the image will
be ``direction[i]``ly flipped with probability of
``prob/len(direction)``.
E.g., ``prob=0.5``, ``direction=['horizontal', 'vertical']``,
then image will be horizontally flipped with probability of 0.25,
vertically with probability of 0.25.
- ``prob`` is list of float, ``direction`` is list of string:
given ``len(prob) == len(direction)``, the image will
be ``direction[i]``ly flipped with probability of ``prob[i]``.
E.g., ``prob=[0.3, 0.5]``, ``direction=['horizontal',
'vertical']``, then image will be horizontally flipped with
probability of 0.3, vertically with probability of 0.5.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- gt_bboxes
- gt_masks
- gt_seg_map
Added Keys:
- flip
- flip_direction
- homography_matrix
Args:
prob (float | list[float], optional): The flipping probability.
Defaults to None.
direction(str | list[str]): The flipping direction. Options
If input is a list, the length must equal ``prob``. Each
element in ``prob`` indicates the flip probability of
corresponding direction. Defaults to 'horizontal'.
"""
def _record_homography_matrix(self, results: dict) -> None:
"""Record the homography matrix for the RandomFlip."""
cur_dir = results['flip_direction']
h, w = results['img'].shape[:2]
if cur_dir == 'horizontal':
homography_matrix = np.array([[-1, 0, w], [0, 1, 0], [0, 0, 1]],
dtype=np.float32)
elif cur_dir == 'vertical':
homography_matrix = np.array([[1, 0, 0], [0, -1, h], [0, 0, 1]],
dtype=np.float32)
elif cur_dir == 'diagonal':
homography_matrix = np.array([[-1, 0, w], [0, -1, h], [0, 0, 1]],
dtype=np.float32)
else:
homography_matrix = np.eye(3, dtype=np.float32)
if results.get('homography_matrix', None) is None:
results['homography_matrix'] = homography_matrix
else:
results['homography_matrix'] = homography_matrix @ results[
'homography_matrix']
@autocast_box_type()
def _flip(self, results: dict) -> None:
"""Flip images, bounding boxes, and semantic segmentation map."""
# flip image
results['img'] = mmcv.imflip(
results['img'], direction=results['flip_direction'])
img_shape = results['img'].shape[:2]
# flip bboxes
if results.get('gt_bboxes', None) is not None:
results['gt_bboxes'].flip_(img_shape, results['flip_direction'])
# flip masks
if results.get('gt_masks', None) is not None:
results['gt_masks'] = results['gt_masks'].flip(
results['flip_direction'])
# flip segs
if results.get('gt_seg_map', None) is not None:
results['gt_seg_map'] = mmcv.imflip(
results['gt_seg_map'], direction=results['flip_direction'])
# record homography matrix for flip
self._record_homography_matrix(results)
@TRANSFORMS.register_module()
class RandomShift(BaseTransform):
"""Shift the image and box given shift pixels and probability.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32])
- gt_bboxes_labels (np.int64)
- gt_ignore_flags (bool) (optional)
Modified Keys:
- img
- gt_bboxes
- gt_bboxes_labels
- gt_ignore_flags (bool) (optional)
Args:
prob (float): Probability of shifts. Defaults to 0.5.
max_shift_px (int): The max pixels for shifting. Defaults to 32.
filter_thr_px (int): The width and height threshold for filtering.
The bbox and the rest of the targets below the width and
height threshold will be filtered. Defaults to 1.
"""
def __init__(self,
prob: float = 0.5,
max_shift_px: int = 32,
filter_thr_px: int = 1) -> None:
assert 0 <= prob <= 1
assert max_shift_px >= 0
self.prob = prob
self.max_shift_px = max_shift_px
self.filter_thr_px = int(filter_thr_px)
@cache_randomness
def _random_prob(self) -> float:
return random.uniform(0, 1)
@autocast_box_type()
def transform(self, results: dict) -> dict:
"""Transform function to random shift images, bounding boxes.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Shift results.
"""
if self._random_prob() < self.prob:
img_shape = results['img'].shape[:2]
random_shift_x = random.randint(-self.max_shift_px,
self.max_shift_px)
random_shift_y = random.randint(-self.max_shift_px,
self.max_shift_px)
new_x = max(0, random_shift_x)
ori_x = max(0, -random_shift_x)
new_y = max(0, random_shift_y)
ori_y = max(0, -random_shift_y)
# TODO: support mask and semantic segmentation maps.
bboxes = results['gt_bboxes'].clone()
bboxes.translate_([random_shift_x, random_shift_y])
# clip border
bboxes.clip_(img_shape)
# remove invalid bboxes
valid_inds = (bboxes.widths > self.filter_thr_px).numpy() & (
bboxes.heights > self.filter_thr_px).numpy()
# If the shift does not contain any gt-bbox area, skip this
# image.
if not valid_inds.any():
return results
bboxes = bboxes[valid_inds]
results['gt_bboxes'] = bboxes
results['gt_bboxes_labels'] = results['gt_bboxes_labels'][
valid_inds]
if results.get('gt_ignore_flags', None) is not None:
results['gt_ignore_flags'] = \
results['gt_ignore_flags'][valid_inds]
# shift img
img = results['img']
new_img = np.zeros_like(img)
img_h, img_w = img.shape[:2]
new_h = img_h - np.abs(random_shift_y)
new_w = img_w - np.abs(random_shift_x)
new_img[new_y:new_y + new_h, new_x:new_x + new_w] \
= img[ori_y:ori_y + new_h, ori_x:ori_x + new_w]
results['img'] = new_img
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, '
repr_str += f'max_shift_px={self.max_shift_px}, '
repr_str += f'filter_thr_px={self.filter_thr_px})'
return repr_str
@TRANSFORMS.register_module()
class Pad(MMCV_Pad):
"""Pad the image & segmentation map.
There are three padding modes: (1) pad to a fixed size and (2) pad to the
minimum size that is divisible by some number. and (3)pad to square. Also,
pad to square and pad to the minimum size can be used as the same time.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- img_shape
- gt_masks
- gt_seg_map
Added Keys:
- pad_shape
- pad_fixed_size
- pad_size_divisor
Args:
size (tuple, optional): Fixed padding size.
Expected padding shape (width, height). Defaults to None.
size_divisor (int, optional): The divisor of padded size. Defaults to
None.
pad_to_square (bool): Whether to pad the image into a square.
Currently only used for YOLOX. Defaults to False.
pad_val (Number | dict[str, Number], optional) - Padding value for if
the pad_mode is "constant". If it is a single number, the value
to pad the image is the number and to pad the semantic
segmentation map is 255. If it is a dict, it should have the
following keys:
- img: The value to pad the image.
- seg: The value to pad the semantic segmentation map.
Defaults to dict(img=0, seg=255).
padding_mode (str): Type of padding. Should be: constant, edge,
reflect or symmetric. Defaults to 'constant'.
- constant: pads with a constant value, this value is specified
with pad_val.
- edge: pads with the last value at the edge of the image.
- reflect: pads with reflection of image without repeating the last
value on the edge. For example, padding [1, 2, 3, 4] with 2
elements on both sides in reflect mode will result in
[3, 2, 1, 2, 3, 4, 3, 2].
- symmetric: pads with reflection of image repeating the last value
on the edge. For example, padding [1, 2, 3, 4] with 2 elements on
both sides in symmetric mode will result in
[2, 1, 1, 2, 3, 4, 4, 3]
"""
def _pad_masks(self, results: dict) -> None:
"""Pad masks according to ``results['pad_shape']``."""
if results.get('gt_masks', None) is not None:
pad_val = self.pad_val.get('masks', 0)
pad_shape = results['pad_shape'][:2]
results['gt_masks'] = results['gt_masks'].pad(
pad_shape, pad_val=pad_val)
def transform(self, results: dict) -> dict:
"""Call function to pad images, masks, semantic segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Updated result dict.
"""
self._pad_img(results)
self._pad_seg(results)
self._pad_masks(results)
return results
@TRANSFORMS.register_module()
class RandomCrop(BaseTransform):
"""Random crop the image & bboxes & masks.
The absolute ``crop_size`` is sampled based on ``crop_type`` and
``image_size``, then the cropped results are generated.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_bboxes_labels (np.int64) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_ignore_flags (bool) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- img_shape
- gt_bboxes (optional)
- gt_bboxes_labels (optional)
- gt_masks (optional)
- gt_ignore_flags (optional)
- gt_seg_map (optional)
- gt_instances_ids (options, only used in MOT/VIS)
Added Keys:
- homography_matrix
Args:
crop_size (tuple): The relative ratio or absolute pixels of
(width, height).
crop_type (str, optional): One of "relative_range", "relative",
"absolute", "absolute_range". "relative" randomly crops
(h * crop_size[0], w * crop_size[1]) part from an input of size
(h, w). "relative_range" uniformly samples relative crop size from
range [crop_size[0], 1] and [crop_size[1], 1] for height and width
respectively. "absolute" crops from an input with absolute size
(crop_size[0], crop_size[1]). "absolute_range" uniformly samples
crop_h in range [crop_size[0], min(h, crop_size[1])] and crop_w
in range [crop_size[0], min(w, crop_size[1])].
Defaults to "absolute".
allow_negative_crop (bool, optional): Whether to allow a crop that does
not contain any bbox area. Defaults to False.
recompute_bbox (bool, optional): Whether to re-compute the boxes based
on cropped instance masks. Defaults to False.
bbox_clip_border (bool, optional): Whether clip the objects outside
the border of the image. Defaults to True.
Note:
- If the image is smaller than the absolute crop size, return the
original image.
- The keys for bboxes, labels and masks must be aligned. That is,
``gt_bboxes`` corresponds to ``gt_labels`` and ``gt_masks``, and
``gt_bboxes_ignore`` corresponds to ``gt_labels_ignore`` and
``gt_masks_ignore``.
- If the crop does not contain any gt-bbox region and
``allow_negative_crop`` is set to False, skip this image.
"""
def __init__(self,
crop_size: tuple,
crop_type: str = 'absolute',
allow_negative_crop: bool = False,
recompute_bbox: bool = False,
bbox_clip_border: bool = True) -> None:
if crop_type not in [
'relative_range', 'relative', 'absolute', 'absolute_range'
]:
raise ValueError(f'Invalid crop_type {crop_type}.')
if crop_type in ['absolute', 'absolute_range']:
assert crop_size[0] > 0 and crop_size[1] > 0
assert isinstance(crop_size[0], int) and isinstance(
crop_size[1], int)
if crop_type == 'absolute_range':
assert crop_size[0] <= crop_size[1]
else:
assert 0 < crop_size[0] <= 1 and 0 < crop_size[1] <= 1
self.crop_size = crop_size
self.crop_type = crop_type
self.allow_negative_crop = allow_negative_crop
self.bbox_clip_border = bbox_clip_border
self.recompute_bbox = recompute_bbox
def _crop_data(self, results: dict, crop_size: Tuple[int, int],
allow_negative_crop: bool) -> Union[dict, None]:
"""Function to randomly crop images, bounding boxes, masks, semantic
segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
crop_size (Tuple[int, int]): Expected absolute size after
cropping, (h, w).
allow_negative_crop (bool): Whether to allow a crop that does not
contain any bbox area.
Returns:
results (Union[dict, None]): Randomly cropped results, 'img_shape'
key in result dict is updated according to crop size. None will
be returned when there is no valid bbox after cropping.
"""
assert crop_size[0] > 0 and crop_size[1] > 0
img = results['img']
margin_h = max(img.shape[0] - crop_size[0], 0)
margin_w = max(img.shape[1] - crop_size[1], 0)
offset_h, offset_w = self._rand_offset((margin_h, margin_w))
crop_y1, crop_y2 = offset_h, offset_h + crop_size[0]
crop_x1, crop_x2 = offset_w, offset_w + crop_size[1]
# Record the homography matrix for the RandomCrop
homography_matrix = np.array(
[[1, 0, -offset_w], [0, 1, -offset_h], [0, 0, 1]],
dtype=np.float32)
if results.get('homography_matrix', None) is None:
results['homography_matrix'] = homography_matrix
else:
results['homography_matrix'] = homography_matrix @ results[
'homography_matrix']
# crop the image
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
img_shape = img.shape
results['img'] = img
results['img_shape'] = img_shape[:2]
# crop bboxes accordingly and clip to the image boundary
if results.get('gt_bboxes', None) is not None:
bboxes = results['gt_bboxes']
bboxes.translate_([-offset_w, -offset_h])
if self.bbox_clip_border:
bboxes.clip_(img_shape[:2])
valid_inds = bboxes.is_inside(img_shape[:2]).numpy()
# If the crop does not contain any gt-bbox area and
# allow_negative_crop is False, skip this image.
if (not valid_inds.any() and not allow_negative_crop):
return None
results['gt_bboxes'] = bboxes[valid_inds]
if results.get('gt_ignore_flags', None) is not None:
results['gt_ignore_flags'] = \
results['gt_ignore_flags'][valid_inds]
if results.get('gt_bboxes_labels', None) is not None:
results['gt_bboxes_labels'] = \
results['gt_bboxes_labels'][valid_inds]
if results.get('gt_masks', None) is not None:
results['gt_masks'] = results['gt_masks'][
valid_inds.nonzero()[0]].crop(
np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]))
if self.recompute_bbox:
results['gt_bboxes'] = results['gt_masks'].get_bboxes(
type(results['gt_bboxes']))
# We should remove the instance ids corresponding to invalid boxes.
if results.get('gt_instances_ids', None) is not None:
results['gt_instances_ids'] = \
results['gt_instances_ids'][valid_inds]
# crop semantic seg
if results.get('gt_seg_map', None) is not None:
results['gt_seg_map'] = results['gt_seg_map'][crop_y1:crop_y2,
crop_x1:crop_x2]
return results
@cache_randomness
def _rand_offset(self, margin: Tuple[int, int]) -> Tuple[int, int]:
"""Randomly generate crop offset.
Args:
margin (Tuple[int, int]): The upper bound for the offset generated
randomly.
Returns:
Tuple[int, int]: The random offset for the crop.
"""
margin_h, margin_w = margin
offset_h = np.random.randint(0, margin_h + 1)
offset_w = np.random.randint(0, margin_w + 1)
return offset_h, offset_w
@cache_randomness
def _get_crop_size(self, image_size: Tuple[int, int]) -> Tuple[int, int]:
"""Randomly generates the absolute crop size based on `crop_type` and
`image_size`.
Args:
image_size (Tuple[int, int]): (h, w).
Returns:
crop_size (Tuple[int, int]): (crop_h, crop_w) in absolute pixels.
"""
h, w = image_size
if self.crop_type == 'absolute':
return min(self.crop_size[1], h), min(self.crop_size[0], w)
elif self.crop_type == 'absolute_range':
crop_h = np.random.randint(
min(h, self.crop_size[0]),
min(h, self.crop_size[1]) + 1)
crop_w = np.random.randint(
min(w, self.crop_size[0]),
min(w, self.crop_size[1]) + 1)
return crop_h, crop_w
elif self.crop_type == 'relative':
crop_w, crop_h = self.crop_size
return int(h * crop_h + 0.5), int(w * crop_w + 0.5)
else:
# 'relative_range'
crop_size = np.asarray(self.crop_size, dtype=np.float32)
crop_h, crop_w = crop_size + np.random.rand(2) * (1 - crop_size)
return int(h * crop_h + 0.5), int(w * crop_w + 0.5)
@autocast_box_type()
def transform(self, results: dict) -> Union[dict, None]:
"""Transform function to randomly crop images, bounding boxes, masks,
semantic segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
results (Union[dict, None]): Randomly cropped results, 'img_shape'
key in result dict is updated according to crop size. None will
be returned when there is no valid bbox after cropping.
"""
image_size = results['img'].shape[:2]
crop_size = self._get_crop_size(image_size)
results = self._crop_data(results, crop_size, self.allow_negative_crop)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(crop_size={self.crop_size}, '
repr_str += f'crop_type={self.crop_type}, '
repr_str += f'allow_negative_crop={self.allow_negative_crop}, '
repr_str += f'recompute_bbox={self.recompute_bbox}, '
repr_str += f'bbox_clip_border={self.bbox_clip_border})'
return repr_str
@TRANSFORMS.register_module()
class SegRescale(BaseTransform):
"""Rescale semantic segmentation maps.
This transform rescale the ``gt_seg_map`` according to ``scale_factor``.
Required Keys:
- gt_seg_map
Modified Keys:
- gt_seg_map
Args:
scale_factor (float): The scale factor of the final output. Defaults
to 1.
backend (str): Image rescale backend, choices are 'cv2' and 'pillow'.
These two backends generates slightly different results. Defaults
to 'cv2'.
"""
def __init__(self, scale_factor: float = 1, backend: str = 'cv2') -> None:
self.scale_factor = scale_factor
self.backend = backend
def transform(self, results: dict) -> dict:
"""Transform function to scale the semantic segmentation map.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with semantic segmentation map scaled.
"""
if self.scale_factor != 1:
results['gt_seg_map'] = mmcv.imrescale(
results['gt_seg_map'],
self.scale_factor,
interpolation='nearest',
backend=self.backend)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(scale_factor={self.scale_factor}, '
repr_str += f'backend={self.backend})'
return repr_str
@TRANSFORMS.register_module()
class PhotoMetricDistortion(BaseTransform):
"""Apply photometric distortion to image sequentially, every transformation
is applied with a probability of 0.5. The position of random contrast is in
second or second to last.
1. random brightness
2. random contrast (mode 0)
3. convert color from BGR to HSV
4. random saturation
5. random hue
6. convert color from HSV to BGR
7. random contrast (mode 1)
8. randomly swap channels
Required Keys:
- img (np.uint8)
Modified Keys:
- img (np.float32)
Args:
brightness_delta (int): delta of brightness.
contrast_range (sequence): range of contrast.
saturation_range (sequence): range of saturation.
hue_delta (int): delta of hue.
"""
def __init__(self,
brightness_delta: int = 32,
contrast_range: Sequence[Number] = (0.5, 1.5),
saturation_range: Sequence[Number] = (0.5, 1.5),
hue_delta: int = 18) -> None:
self.brightness_delta = brightness_delta
self.contrast_lower, self.contrast_upper = contrast_range
self.saturation_lower, self.saturation_upper = saturation_range
self.hue_delta = hue_delta
@cache_randomness
def _random_flags(self) -> Sequence[Number]:
mode = random.randint(2)
brightness_flag = random.randint(2)
contrast_flag = random.randint(2)
saturation_flag = random.randint(2)
hue_flag = random.randint(2)
swap_flag = random.randint(2)
delta_value = random.uniform(-self.brightness_delta,
self.brightness_delta)
alpha_value = random.uniform(self.contrast_lower, self.contrast_upper)
saturation_value = random.uniform(self.saturation_lower,
self.saturation_upper)
hue_value = random.uniform(-self.hue_delta, self.hue_delta)
swap_value = random.permutation(3)
return (mode, brightness_flag, contrast_flag, saturation_flag,
hue_flag, swap_flag, delta_value, alpha_value,
saturation_value, hue_value, swap_value)
def transform(self, results: dict) -> dict:
"""Transform function to perform photometric distortion on images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images distorted.
"""
assert 'img' in results, '`img` is not found in results'
img = results['img']
img = img.astype(np.float32)
(mode, brightness_flag, contrast_flag, saturation_flag, hue_flag,
swap_flag, delta_value, alpha_value, saturation_value, hue_value,
swap_value) = self._random_flags()
# random brightness
if brightness_flag:
img += delta_value
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
if mode == 1:
if contrast_flag:
img *= alpha_value
# convert color from BGR to HSV
img = mmcv.bgr2hsv(img)
# random saturation
if saturation_flag:
img[..., 1] *= saturation_value
# For image(type=float32), after convert bgr to hsv by opencv,
# valid saturation value range is [0, 1]
if saturation_value > 1:
img[..., 1] = img[..., 1].clip(0, 1)
# random hue
if hue_flag:
img[..., 0] += hue_value
img[..., 0][img[..., 0] > 360] -= 360
img[..., 0][img[..., 0] < 0] += 360
# convert color from HSV to BGR
img = mmcv.hsv2bgr(img)
# random contrast
if mode == 0:
if contrast_flag:
img *= alpha_value
# randomly swap channels
if swap_flag:
img = img[..., swap_value]
results['img'] = img
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(brightness_delta={self.brightness_delta}, '
repr_str += 'contrast_range='
repr_str += f'{(self.contrast_lower, self.contrast_upper)}, '
repr_str += 'saturation_range='
repr_str += f'{(self.saturation_lower, self.saturation_upper)}, '
repr_str += f'hue_delta={self.hue_delta})'
return repr_str
@TRANSFORMS.register_module()
class Expand(BaseTransform):
"""Random expand the image & bboxes & masks & segmentation map.
Randomly place the original image on a canvas of ``ratio`` x original image
size filled with mean values. The ratio is in the range of ratio_range.
Required Keys:
- img
- img_shape
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- img_shape
- gt_bboxes
- gt_masks
- gt_seg_map
Args:
mean (sequence): mean value of dataset.
to_rgb (bool): if need to convert the order of mean to align with RGB.
ratio_range (sequence)): range of expand ratio.
seg_ignore_label (int): label of ignore segmentation map.
prob (float): probability of applying this transformation
"""
def __init__(self,
mean: Sequence[Number] = (0, 0, 0),
to_rgb: bool = True,
ratio_range: Sequence[Number] = (1, 4),
seg_ignore_label: int = None,
prob: float = 0.5) -> None:
self.to_rgb = to_rgb
self.ratio_range = ratio_range
if to_rgb:
self.mean = mean[::-1]
else:
self.mean = mean
self.min_ratio, self.max_ratio = ratio_range
self.seg_ignore_label = seg_ignore_label
self.prob = prob
@cache_randomness
def _random_prob(self) -> float:
return random.uniform(0, 1)
@cache_randomness
def _random_ratio(self) -> float:
return random.uniform(self.min_ratio, self.max_ratio)
@cache_randomness
def _random_left_top(self, ratio: float, h: int,
w: int) -> Tuple[int, int]:
left = int(random.uniform(0, w * ratio - w))
top = int(random.uniform(0, h * ratio - h))
return left, top
@autocast_box_type()
def transform(self, results: dict) -> dict:
"""Transform function to expand images, bounding boxes, masks,
segmentation map.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images, bounding boxes, masks, segmentation
map expanded.
"""
if self._random_prob() > self.prob:
return results
assert 'img' in results, '`img` is not found in results'
img = results['img']
h, w, c = img.shape
ratio = self._random_ratio()
# speedup expand when meets large image
if np.all(self.mean == self.mean[0]):
expand_img = np.empty((int(h * ratio), int(w * ratio), c),
img.dtype)
expand_img.fill(self.mean[0])
else:
expand_img = np.full((int(h * ratio), int(w * ratio), c),
self.mean,
dtype=img.dtype)
left, top = self._random_left_top(ratio, h, w)
expand_img[top:top + h, left:left + w] = img
results['img'] = expand_img
results['img_shape'] = expand_img.shape[:2]
# expand bboxes
if results.get('gt_bboxes', None) is not None:
results['gt_bboxes'].translate_([left, top])
# expand masks
if results.get('gt_masks', None) is not None:
results['gt_masks'] = results['gt_masks'].expand(
int(h * ratio), int(w * ratio), top, left)
# expand segmentation map
if results.get('gt_seg_map', None) is not None:
gt_seg = results['gt_seg_map']
expand_gt_seg = np.full((int(h * ratio), int(w * ratio)),
self.seg_ignore_label,
dtype=gt_seg.dtype)
expand_gt_seg[top:top + h, left:left + w] = gt_seg
results['gt_seg_map'] = expand_gt_seg
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(mean={self.mean}, to_rgb={self.to_rgb}, '
repr_str += f'ratio_range={self.ratio_range}, '
repr_str += f'seg_ignore_label={self.seg_ignore_label}, '
repr_str += f'prob={self.prob})'
return repr_str
@TRANSFORMS.register_module()
class MinIoURandomCrop(BaseTransform):
"""Random crop the image & bboxes & masks & segmentation map, the cropped
patches have minimum IoU requirement with original image & bboxes & masks.
& segmentation map, the IoU threshold is randomly selected from min_ious.
Required Keys:
- img
- img_shape
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_bboxes_labels (np.int64) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_ignore_flags (bool) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- img_shape
- gt_bboxes
- gt_bboxes_labels
- gt_masks
- gt_ignore_flags
- gt_seg_map
Args:
min_ious (Sequence[float]): minimum IoU threshold for all intersections
with bounding boxes.
min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w,
where a >= min_crop_size).
bbox_clip_border (bool, optional): Whether clip the objects outside
the border of the image. Defaults to True.
"""
def __init__(self,
min_ious: Sequence[float] = (0.1, 0.3, 0.5, 0.7, 0.9),
min_crop_size: float = 0.3,
bbox_clip_border: bool = True) -> None:
self.min_ious = min_ious
self.sample_mode = (1, *min_ious, 0)
self.min_crop_size = min_crop_size
self.bbox_clip_border = bbox_clip_border
@cache_randomness
def _random_mode(self) -> Number:
return random.choice(self.sample_mode)
@autocast_box_type()
def transform(self, results: dict) -> dict:
"""Transform function to crop images and bounding boxes with minimum
IoU constraint.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images and bounding boxes cropped, \
'img_shape' key is updated.
"""
assert 'img' in results, '`img` is not found in results'
assert 'gt_bboxes' in results, '`gt_bboxes` is not found in results'
img = results['img']
boxes = results['gt_bboxes']
h, w, c = img.shape
while True:
mode = self._random_mode()
self.mode = mode
if mode == 1:
return results
min_iou = self.mode
for i in range(50):
new_w = random.uniform(self.min_crop_size * w, w)
new_h = random.uniform(self.min_crop_size * h, h)
# h / w in [0.5, 2]
if new_h / new_w < 0.5 or new_h / new_w > 2:
continue
left = random.uniform(w - new_w)
top = random.uniform(h - new_h)
patch = np.array(
(int(left), int(top), int(left + new_w), int(top + new_h)))
# Line or point crop is not allowed
if patch[2] == patch[0] or patch[3] == patch[1]:
continue
overlaps = boxes.overlaps(
HorizontalBoxes(patch.reshape(-1, 4).astype(np.float32)),
boxes).numpy().reshape(-1)
if len(overlaps) > 0 and overlaps.min() < min_iou:
continue
# center of boxes should inside the crop img
# only adjust boxes and instance masks when the gt is not empty
if len(overlaps) > 0:
# adjust boxes
def is_center_of_bboxes_in_patch(boxes, patch):
centers = boxes.centers.numpy()
mask = ((centers[:, 0] > patch[0]) *
(centers[:, 1] > patch[1]) *
(centers[:, 0] < patch[2]) *
(centers[:, 1] < patch[3]))
return mask
mask = is_center_of_bboxes_in_patch(boxes, patch)
if not mask.any():
continue
if results.get('gt_bboxes', None) is not None:
boxes = results['gt_bboxes']
mask = is_center_of_bboxes_in_patch(boxes, patch)
boxes = boxes[mask]
boxes.translate_([-patch[0], -patch[1]])
if self.bbox_clip_border:
boxes.clip_(
[patch[3] - patch[1], patch[2] - patch[0]])
results['gt_bboxes'] = boxes
# ignore_flags
if results.get('gt_ignore_flags', None) is not None:
results['gt_ignore_flags'] = \
results['gt_ignore_flags'][mask]
# labels
if results.get('gt_bboxes_labels', None) is not None:
results['gt_bboxes_labels'] = results[
'gt_bboxes_labels'][mask]
# mask fields
if results.get('gt_masks', None) is not None:
results['gt_masks'] = results['gt_masks'][
mask.nonzero()[0]].crop(patch)
# adjust the img no matter whether the gt is empty before crop
img = img[patch[1]:patch[3], patch[0]:patch[2]]
results['img'] = img
results['img_shape'] = img.shape[:2]
# seg fields
if results.get('gt_seg_map', None) is not None:
results['gt_seg_map'] = results['gt_seg_map'][
patch[1]:patch[3], patch[0]:patch[2]]
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(min_ious={self.min_ious}, '
repr_str += f'min_crop_size={self.min_crop_size}, '
repr_str += f'bbox_clip_border={self.bbox_clip_border})'
return repr_str
@TRANSFORMS.register_module()
class Corrupt(BaseTransform):
"""Corruption augmentation.
Corruption transforms implemented based on
`imagecorruptions <https://github.com/bethgelab/imagecorruptions>`_.
Required Keys:
- img (np.uint8)
Modified Keys:
- img (np.uint8)
Args:
corruption (str): Corruption name.
severity (int): The severity of corruption. Defaults to 1.
"""
def __init__(self, corruption: str, severity: int = 1) -> None:
self.corruption = corruption
self.severity = severity
def transform(self, results: dict) -> dict:
"""Call function to corrupt image.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images corrupted.
"""
if corrupt is None:
raise RuntimeError('imagecorruptions is not installed')
results['img'] = corrupt(
results['img'].astype(np.uint8),
corruption_name=self.corruption,
severity=self.severity)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(corruption={self.corruption}, '
repr_str += f'severity={self.severity})'
return repr_str
@TRANSFORMS.register_module()
@avoid_cache_randomness
class Albu(BaseTransform):
"""Albumentation augmentation.
Adds custom transformations from Albumentations library.
Please, visit `https://albumentations.readthedocs.io`
to get more information.
Required Keys:
- img (np.uint8)
- gt_bboxes (HorizontalBoxes[torch.float32]) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
Modified Keys:
- img (np.uint8)
- gt_bboxes (HorizontalBoxes[torch.float32]) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- img_shape (tuple)
An example of ``transforms`` is as followed:
.. code-block::
[
dict(
type='ShiftScaleRotate',
shift_limit=0.0625,
scale_limit=0.0,
rotate_limit=0,
interpolation=1,
p=0.5),
dict(
type='RandomBrightnessContrast',
brightness_limit=[0.1, 0.3],
contrast_limit=[0.1, 0.3],
p=0.2),
dict(type='ChannelShuffle', p=0.1),
dict(
type='OneOf',
transforms=[
dict(type='Blur', blur_limit=3, p=1.0),
dict(type='MedianBlur', blur_limit=3, p=1.0)
],
p=0.1),
]
Args:
transforms (list[dict]): A list of albu transformations
bbox_params (dict, optional): Bbox_params for albumentation `Compose`
keymap (dict, optional): Contains
{'input key':'albumentation-style key'}
skip_img_without_anno (bool): Whether to skip the image if no ann left
after aug. Defaults to False.
"""
def __init__(self,
transforms: List[dict],
bbox_params: Optional[dict] = None,
keymap: Optional[dict] = None,
skip_img_without_anno: bool = False) -> None:
if Compose is None:
raise RuntimeError('albumentations is not installed')
# Args will be modified later, copying it will be safer
transforms = copy.deepcopy(transforms)
if bbox_params is not None:
bbox_params = copy.deepcopy(bbox_params)
if keymap is not None:
keymap = copy.deepcopy(keymap)
self.transforms = transforms
self.filter_lost_elements = False
self.skip_img_without_anno = skip_img_without_anno
# A simple workaround to remove masks without boxes
if (isinstance(bbox_params, dict) and 'label_fields' in bbox_params
and 'filter_lost_elements' in bbox_params):
self.filter_lost_elements = True
self.origin_label_fields = bbox_params['label_fields']
bbox_params['label_fields'] = ['idx_mapper']
del bbox_params['filter_lost_elements']
self.bbox_params = (
self.albu_builder(bbox_params) if bbox_params else None)
self.aug = Compose([self.albu_builder(t) for t in self.transforms],
bbox_params=self.bbox_params)
if not keymap:
self.keymap_to_albu = {
'img': 'image',
'gt_masks': 'masks',
'gt_bboxes': 'bboxes'
}
else:
self.keymap_to_albu = keymap
self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}
def albu_builder(self, cfg: dict) -> albumentations:
"""Import a module from albumentations.
It inherits some of :func:`build_from_cfg` logic.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
Returns:
obj: The constructed object.
"""
assert isinstance(cfg, dict) and 'type' in cfg
args = cfg.copy()
obj_type = args.pop('type')
if is_str(obj_type):
if albumentations is None:
raise RuntimeError('albumentations is not installed')
obj_cls = getattr(albumentations, obj_type)
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
if 'transforms' in args:
args['transforms'] = [
self.albu_builder(transform)
for transform in args['transforms']
]
return obj_cls(**args)
@staticmethod
def mapper(d: dict, keymap: dict) -> dict:
"""Dictionary mapper. Renames keys according to keymap provided.
Args:
d (dict): old dict
keymap (dict): {'old_key':'new_key'}
Returns:
dict: new dict.
"""
updated_dict = {}
for k, v in zip(d.keys(), d.values()):
new_k = keymap.get(k, k)
updated_dict[new_k] = d[k]
return updated_dict
@autocast_box_type()
def transform(self, results: dict) -> Union[dict, None]:
"""Transform function of Albu."""
# TODO: gt_seg_map is not currently supported
# dict to albumentations format
results = self.mapper(results, self.keymap_to_albu)
results, ori_masks = self._preprocess_results(results)
results = self.aug(**results)
results = self._postprocess_results(results, ori_masks)
if results is None:
return None
# back to the original format
results = self.mapper(results, self.keymap_back)
results['img_shape'] = results['img'].shape[:2]
return results
def _preprocess_results(self, results: dict) -> tuple:
"""Pre-processing results to facilitate the use of Albu."""
if 'bboxes' in results:
# to list of boxes
if not isinstance(results['bboxes'], HorizontalBoxes):
raise NotImplementedError(
'Albu only supports horizontal boxes now')
bboxes = results['bboxes'].numpy()
results['bboxes'] = [x for x in bboxes]
# add pseudo-field for filtration
if self.filter_lost_elements:
results['idx_mapper'] = np.arange(len(results['bboxes']))
# TODO: Support mask structure in albu
ori_masks = None
if 'masks' in results:
if isinstance(results['masks'], PolygonMasks):
raise NotImplementedError(
'Albu only supports BitMap masks now')
ori_masks = results['masks']
if albumentations.__version__ < '0.5':
results['masks'] = results['masks'].masks
else:
results['masks'] = [mask for mask in results['masks'].masks]
return results, ori_masks
def _postprocess_results(
self,
results: dict,
ori_masks: Optional[Union[BitmapMasks,
PolygonMasks]] = None) -> dict:
"""Post-processing Albu output."""
# albumentations may return np.array or list on different versions
if 'gt_bboxes_labels' in results and isinstance(
results['gt_bboxes_labels'], list):
results['gt_bboxes_labels'] = np.array(
results['gt_bboxes_labels'], dtype=np.int64)
if 'gt_ignore_flags' in results and isinstance(
results['gt_ignore_flags'], list):
results['gt_ignore_flags'] = np.array(
results['gt_ignore_flags'], dtype=bool)
if 'bboxes' in results:
if isinstance(results['bboxes'], list):
results['bboxes'] = np.array(
results['bboxes'], dtype=np.float32)
results['bboxes'] = results['bboxes'].reshape(-1, 4)
results['bboxes'] = HorizontalBoxes(results['bboxes'])
# filter label_fields
if self.filter_lost_elements:
for label in self.origin_label_fields:
results[label] = np.array(
[results[label][i] for i in results['idx_mapper']])
if 'masks' in results:
assert ori_masks is not None
results['masks'] = np.array(
[results['masks'][i] for i in results['idx_mapper']])
results['masks'] = ori_masks.__class__(
results['masks'], ori_masks.height, ori_masks.width)
if (not len(results['idx_mapper'])
and self.skip_img_without_anno):
return None
elif 'masks' in results:
results['masks'] = ori_masks.__class__(results['masks'],
ori_masks.height,
ori_masks.width)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__ + f'(transforms={self.transforms})'
return repr_str
@TRANSFORMS.register_module()
@avoid_cache_randomness
class RandomCenterCropPad(BaseTransform):
"""Random center crop and random around padding for CornerNet.
This operation generates randomly cropped image from the original image and
pads it simultaneously. Different from :class:`RandomCrop`, the output
shape may not equal to ``crop_size`` strictly. We choose a random value
from ``ratios`` and the output shape could be larger or smaller than
``crop_size``. The padding operation is also different from :class:`Pad`,
here we use around padding instead of right-bottom padding.
The relation between output image (padding image) and original image:
.. code:: text
output image
+----------------------------+
| padded area |
+------|----------------------------|----------+
| | cropped area | |
| | +---------------+ | |
| | | . center | | | original image
| | | range | | |
| | +---------------+ | |
+------|----------------------------|----------+
| padded area |
+----------------------------+
There are 5 main areas in the figure:
- output image: output image of this operation, also called padding
image in following instruction.
- original image: input image of this operation.
- padded area: non-intersect area of output image and original image.
- cropped area: the overlap of output image and original image.
- center range: a smaller area where random center chosen from.
center range is computed by ``border`` and original image's shape
to avoid our random center is too close to original image's border.
Also this operation act differently in train and test mode, the summary
pipeline is listed below.
Train pipeline:
1. Choose a ``random_ratio`` from ``ratios``, the shape of padding image
will be ``random_ratio * crop_size``.
2. Choose a ``random_center`` in center range.
3. Generate padding image with center matches the ``random_center``.
4. Initialize the padding image with pixel value equals to ``mean``.
5. Copy the cropped area to padding image.
6. Refine annotations.
Test pipeline:
1. Compute output shape according to ``test_pad_mode``.
2. Generate padding image with center matches the original image
center.
3. Initialize the padding image with pixel value equals to ``mean``.
4. Copy the ``cropped area`` to padding image.
Required Keys:
- img (np.float32)
- img_shape (tuple)
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_bboxes_labels (np.int64) (optional)
- gt_ignore_flags (bool) (optional)
Modified Keys:
- img (np.float32)
- img_shape (tuple)
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_bboxes_labels (np.int64) (optional)
- gt_ignore_flags (bool) (optional)
Args:
crop_size (tuple, optional): expected size after crop, final size will
computed according to ratio. Requires (width, height)
in train mode, and None in test mode.
ratios (tuple, optional): random select a ratio from tuple and crop
image to (crop_size[0] * ratio) * (crop_size[1] * ratio).
Only available in train mode. Defaults to (0.9, 1.0, 1.1).
border (int, optional): max distance from center select area to image
border. Only available in train mode. Defaults to 128.
mean (sequence, optional): Mean values of 3 channels.
std (sequence, optional): Std values of 3 channels.
to_rgb (bool, optional): Whether to convert the image from BGR to RGB.
test_mode (bool): whether involve random variables in transform.
In train mode, crop_size is fixed, center coords and ratio is
random selected from predefined lists. In test mode, crop_size
is image's original shape, center coords and ratio is fixed.
Defaults to False.
test_pad_mode (tuple, optional): padding method and padding shape
value, only available in test mode. Default is using
'logical_or' with 127 as padding shape value.
- 'logical_or': final_shape = input_shape | padding_shape_value
- 'size_divisor': final_shape = int(
ceil(input_shape / padding_shape_value) * padding_shape_value)
Defaults to ('logical_or', 127).
test_pad_add_pix (int): Extra padding pixel in test mode.
Defaults to 0.
bbox_clip_border (bool): Whether clip the objects outside
the border of the image. Defaults to True.
"""
def __init__(self,
crop_size: Optional[tuple] = None,
ratios: Optional[tuple] = (0.9, 1.0, 1.1),
border: Optional[int] = 128,
mean: Optional[Sequence] = None,
std: Optional[Sequence] = None,
to_rgb: Optional[bool] = None,
test_mode: bool = False,
test_pad_mode: Optional[tuple] = ('logical_or', 127),
test_pad_add_pix: int = 0,
bbox_clip_border: bool = True) -> None:
if test_mode:
assert crop_size is None, 'crop_size must be None in test mode'
assert ratios is None, 'ratios must be None in test mode'
assert border is None, 'border must be None in test mode'
assert isinstance(test_pad_mode, (list, tuple))
assert test_pad_mode[0] in ['logical_or', 'size_divisor']
else:
assert isinstance(crop_size, (list, tuple))
assert crop_size[0] > 0 and crop_size[1] > 0, (
'crop_size must > 0 in train mode')
assert isinstance(ratios, (list, tuple))
assert test_pad_mode is None, (
'test_pad_mode must be None in train mode')
self.crop_size = crop_size
self.ratios = ratios
self.border = border
# We do not set default value to mean, std and to_rgb because these
# hyper-parameters are easy to forget but could affect the performance.
# Please use the same setting as Normalize for performance assurance.
assert mean is not None and std is not None and to_rgb is not None
self.to_rgb = to_rgb
self.input_mean = mean
self.input_std = std
if to_rgb:
self.mean = mean[::-1]
self.std = std[::-1]
else:
self.mean = mean
self.std = std
self.test_mode = test_mode
self.test_pad_mode = test_pad_mode
self.test_pad_add_pix = test_pad_add_pix
self.bbox_clip_border = bbox_clip_border
def _get_border(self, border, size):
"""Get final border for the target size.
This function generates a ``final_border`` according to image's shape.
The area between ``final_border`` and ``size - final_border`` is the
``center range``. We randomly choose center from the ``center range``
to avoid our random center is too close to original image's border.
Also ``center range`` should be larger than 0.
Args:
border (int): The initial border, default is 128.
size (int): The width or height of original image.
Returns:
int: The final border.
"""
k = 2 * border / size
i = pow(2, np.ceil(np.log2(np.ceil(k))) + (k == int(k)))
return border // i
def _filter_boxes(self, patch, boxes):
"""Check whether the center of each box is in the patch.
Args:
patch (list[int]): The cropped area, [left, top, right, bottom].
boxes (numpy array, (N x 4)): Ground truth boxes.
Returns:
mask (numpy array, (N,)): Each box is inside or outside the patch.
"""
center = boxes.centers.numpy()
mask = (center[:, 0] > patch[0]) * (center[:, 1] > patch[1]) * (
center[:, 0] < patch[2]) * (
center[:, 1] < patch[3])
return mask
def _crop_image_and_paste(self, image, center, size):
"""Crop image with a given center and size, then paste the cropped
image to a blank image with two centers align.
This function is equivalent to generating a blank image with ``size``
as its shape. Then cover it on the original image with two centers (
the center of blank image and the random center of original image)
aligned. The overlap area is paste from the original image and the
outside area is filled with ``mean pixel``.
Args:
image (np array, H x W x C): Original image.
center (list[int]): Target crop center coord.
size (list[int]): Target crop size. [target_h, target_w]
Returns:
cropped_img (np array, target_h x target_w x C): Cropped image.
border (np array, 4): The distance of four border of
``cropped_img`` to the original image area, [top, bottom,
left, right]
patch (list[int]): The cropped area, [left, top, right, bottom].
"""
center_y, center_x = center
target_h, target_w = size
img_h, img_w, img_c = image.shape
x0 = max(0, center_x - target_w // 2)
x1 = min(center_x + target_w // 2, img_w)
y0 = max(0, center_y - target_h // 2)
y1 = min(center_y + target_h // 2, img_h)
patch = np.array((int(x0), int(y0), int(x1), int(y1)))
left, right = center_x - x0, x1 - center_x
top, bottom = center_y - y0, y1 - center_y
cropped_center_y, cropped_center_x = target_h // 2, target_w // 2
cropped_img = np.zeros((target_h, target_w, img_c), dtype=image.dtype)
for i in range(img_c):
cropped_img[:, :, i] += self.mean[i]
y_slice = slice(cropped_center_y - top, cropped_center_y + bottom)
x_slice = slice(cropped_center_x - left, cropped_center_x + right)
cropped_img[y_slice, x_slice, :] = image[y0:y1, x0:x1, :]
border = np.array([
cropped_center_y - top, cropped_center_y + bottom,
cropped_center_x - left, cropped_center_x + right
],
dtype=np.float32)
return cropped_img, border, patch
def _train_aug(self, results):
"""Random crop and around padding the original image.
Args:
results (dict): Image infomations in the augment pipeline.
Returns:
results (dict): The updated dict.
"""
img = results['img']
h, w, c = img.shape
gt_bboxes = results['gt_bboxes']
while True:
scale = random.choice(self.ratios)
new_h = int(self.crop_size[1] * scale)
new_w = int(self.crop_size[0] * scale)
h_border = self._get_border(self.border, h)
w_border = self._get_border(self.border, w)
for i in range(50):
center_x = random.randint(low=w_border, high=w - w_border)
center_y = random.randint(low=h_border, high=h - h_border)
cropped_img, border, patch = self._crop_image_and_paste(
img, [center_y, center_x], [new_h, new_w])
if len(gt_bboxes) == 0:
results['img'] = cropped_img
results['img_shape'] = cropped_img.shape[:2]
return results
# if image do not have valid bbox, any crop patch is valid.
mask = self._filter_boxes(patch, gt_bboxes)
if not mask.any():
continue
results['img'] = cropped_img
results['img_shape'] = cropped_img.shape[:2]
x0, y0, x1, y1 = patch
left_w, top_h = center_x - x0, center_y - y0
cropped_center_x, cropped_center_y = new_w // 2, new_h // 2
# crop bboxes accordingly and clip to the image boundary
gt_bboxes = gt_bboxes[mask]
gt_bboxes.translate_([
cropped_center_x - left_w - x0,
cropped_center_y - top_h - y0
])
if self.bbox_clip_border:
gt_bboxes.clip_([new_h, new_w])
keep = gt_bboxes.is_inside([new_h, new_w]).numpy()
gt_bboxes = gt_bboxes[keep]
results['gt_bboxes'] = gt_bboxes
# ignore_flags
if results.get('gt_ignore_flags', None) is not None:
gt_ignore_flags = results['gt_ignore_flags'][mask]
results['gt_ignore_flags'] = \
gt_ignore_flags[keep]
# labels
if results.get('gt_bboxes_labels', None) is not None:
gt_labels = results['gt_bboxes_labels'][mask]
results['gt_bboxes_labels'] = gt_labels[keep]
if 'gt_masks' in results or 'gt_seg_map' in results:
raise NotImplementedError(
'RandomCenterCropPad only supports bbox.')
return results
def _test_aug(self, results):
"""Around padding the original image without cropping.
The padding mode and value are from ``test_pad_mode``.
Args:
results (dict): Image infomations in the augment pipeline.
Returns:
results (dict): The updated dict.
"""
img = results['img']
h, w, c = img.shape
if self.test_pad_mode[0] in ['logical_or']:
# self.test_pad_add_pix is only used for centernet
target_h = (h | self.test_pad_mode[1]) + self.test_pad_add_pix
target_w = (w | self.test_pad_mode[1]) + self.test_pad_add_pix
elif self.test_pad_mode[0] in ['size_divisor']:
divisor = self.test_pad_mode[1]
target_h = int(np.ceil(h / divisor)) * divisor
target_w = int(np.ceil(w / divisor)) * divisor
else:
raise NotImplementedError(
'RandomCenterCropPad only support two testing pad mode:'
'logical-or and size_divisor.')
cropped_img, border, _ = self._crop_image_and_paste(
img, [h // 2, w // 2], [target_h, target_w])
results['img'] = cropped_img
results['img_shape'] = cropped_img.shape[:2]
results['border'] = border
return results
@autocast_box_type()
def transform(self, results: dict) -> dict:
img = results['img']
assert img.dtype == np.float32, (
'RandomCenterCropPad needs the input image of dtype np.float32,'
' please set "to_float32=True" in "LoadImageFromFile" pipeline')
h, w, c = img.shape
assert c == len(self.mean)
if self.test_mode:
return self._test_aug(results)
else:
return self._train_aug(results)
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(crop_size={self.crop_size}, '
repr_str += f'ratios={self.ratios}, '
repr_str += f'border={self.border}, '
repr_str += f'mean={self.input_mean}, '
repr_str += f'std={self.input_std}, '
repr_str += f'to_rgb={self.to_rgb}, '
repr_str += f'test_mode={self.test_mode}, '
repr_str += f'test_pad_mode={self.test_pad_mode}, '
repr_str += f'bbox_clip_border={self.bbox_clip_border})'
return repr_str
@TRANSFORMS.register_module()
class CutOut(BaseTransform):
"""CutOut operation.
Randomly drop some regions of image used in
`Cutout <https://arxiv.org/abs/1708.04552>`_.
Required Keys:
- img
Modified Keys:
- img
Args:
n_holes (int or tuple[int, int]): Number of regions to be dropped.
If it is given as a list, number of holes will be randomly
selected from the closed interval [``n_holes[0]``, ``n_holes[1]``].
cutout_shape (tuple[int, int] or list[tuple[int, int]], optional):
The candidate shape of dropped regions. It can be
``tuple[int, int]`` to use a fixed cutout shape, or
``list[tuple[int, int]]`` to randomly choose shape
from the list. Defaults to None.
cutout_ratio (tuple[float, float] or list[tuple[float, float]],
optional): The candidate ratio of dropped regions. It can be
``tuple[float, float]`` to use a fixed ratio or
``list[tuple[float, float]]`` to randomly choose ratio
from the list. Please note that ``cutout_shape`` and
``cutout_ratio`` cannot be both given at the same time.
Defaults to None.
fill_in (tuple[float, float, float] or tuple[int, int, int]): The value
of pixel to fill in the dropped regions. Defaults to (0, 0, 0).
"""
def __init__(
self,
n_holes: Union[int, Tuple[int, int]],
cutout_shape: Optional[Union[Tuple[int, int],
List[Tuple[int, int]]]] = None,
cutout_ratio: Optional[Union[Tuple[float, float],
List[Tuple[float, float]]]] = None,
fill_in: Union[Tuple[float, float, float], Tuple[int, int,
int]] = (0, 0, 0)
) -> None:
assert (cutout_shape is None) ^ (cutout_ratio is None), \
'Either cutout_shape or cutout_ratio should be specified.'
assert (isinstance(cutout_shape, (list, tuple))
or isinstance(cutout_ratio, (list, tuple)))
if isinstance(n_holes, tuple):
assert len(n_holes) == 2 and 0 <= n_holes[0] < n_holes[1]
else:
n_holes = (n_holes, n_holes)
self.n_holes = n_holes
self.fill_in = fill_in
self.with_ratio = cutout_ratio is not None
self.candidates = cutout_ratio if self.with_ratio else cutout_shape
if not isinstance(self.candidates, list):
self.candidates = [self.candidates]
@autocast_box_type()
def transform(self, results: dict) -> dict:
"""Call function to drop some regions of image."""
h, w, c = results['img'].shape
n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1)
for _ in range(n_holes):
x1 = np.random.randint(0, w)
y1 = np.random.randint(0, h)
index = np.random.randint(0, len(self.candidates))
if not self.with_ratio:
cutout_w, cutout_h = self.candidates[index]
else:
cutout_w = int(self.candidates[index][0] * w)
cutout_h = int(self.candidates[index][1] * h)
x2 = np.clip(x1 + cutout_w, 0, w)
y2 = np.clip(y1 + cutout_h, 0, h)
results['img'][y1:y2, x1:x2, :] = self.fill_in
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(n_holes={self.n_holes}, '
repr_str += (f'cutout_ratio={self.candidates}, ' if self.with_ratio
else f'cutout_shape={self.candidates}, ')
repr_str += f'fill_in={self.fill_in})'
return repr_str
@TRANSFORMS.register_module()
class Mosaic(BaseTransform):
"""Mosaic augmentation.
Given 4 images, mosaic transform combines them into
one output image. The output image is composed of the parts from each sub-
image.
.. code:: text
mosaic transform
center_x
+------------------------------+
| pad | pad |
| +-----------+ |
| | | |
| | image1 |--------+ |
| | | | |
| | | image2 | |
center_y |----+-------------+-----------|
| | cropped | |
|pad | image3 | image4 |
| | | |
+----|-------------+-----------+
| |
+-------------+
The mosaic transform steps are as follows:
1. Choose the mosaic center as the intersections of 4 images
2. Get the left top image according to the index, and randomly
sample another 3 images from the custom dataset.
3. Sub image will be cropped if image is larger than mosaic patch
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_bboxes_labels (np.int64) (optional)
- gt_ignore_flags (bool) (optional)
- mix_results (List[dict])
Modified Keys:
- img
- img_shape
- gt_bboxes (optional)
- gt_bboxes_labels (optional)
- gt_ignore_flags (optional)
Args:
img_scale (Sequence[int]): Image size before mosaic pipeline of single
image. The shape order should be (width, height).
Defaults to (640, 640).
center_ratio_range (Sequence[float]): Center ratio range of mosaic
output. Defaults to (0.5, 1.5).
bbox_clip_border (bool, optional): Whether to clip the objects outside
the border of the image. In some dataset like MOT17, the gt bboxes
are allowed to cross the border of images. Therefore, we don't
need to clip the gt bboxes in these cases. Defaults to True.
pad_val (int): Pad value. Defaults to 114.
prob (float): Probability of applying this transformation.
Defaults to 1.0.
"""
def __init__(self,
img_scale: Tuple[int, int] = (640, 640),
center_ratio_range: Tuple[float, float] = (0.5, 1.5),
bbox_clip_border: bool = True,
pad_val: float = 114.0,
prob: float = 1.0) -> None:
assert isinstance(img_scale, tuple)
assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. ' \
f'got {prob}.'
log_img_scale(img_scale, skip_square=True, shape_order='wh')
self.img_scale = img_scale
self.center_ratio_range = center_ratio_range
self.bbox_clip_border = bbox_clip_border
self.pad_val = pad_val
self.prob = prob
@cache_randomness
def get_indexes(self, dataset: BaseDataset) -> int:
"""Call function to collect indexes.
Args:
dataset (:obj:`MultiImageMixDataset`): The dataset.
Returns:
list: indexes.
"""
indexes = [random.randint(0, len(dataset)) for _ in range(3)]
return indexes
@autocast_box_type()
def transform(self, results: dict) -> dict:
"""Mosaic transform function.
Args:
results (dict): Result dict.
Returns:
dict: Updated result dict.
"""
if random.uniform(0, 1) > self.prob:
return results
assert 'mix_results' in results
mosaic_bboxes = []
mosaic_bboxes_labels = []
mosaic_ignore_flags = []
if len(results['img'].shape) == 3:
mosaic_img = np.full(
(int(self.img_scale[1] * 2), int(self.img_scale[0] * 2), 3),
self.pad_val,
dtype=results['img'].dtype)
else:
mosaic_img = np.full(
(int(self.img_scale[1] * 2), int(self.img_scale[0] * 2)),
self.pad_val,
dtype=results['img'].dtype)
# mosaic center x, y
center_x = int(
random.uniform(*self.center_ratio_range) * self.img_scale[0])
center_y = int(
random.uniform(*self.center_ratio_range) * self.img_scale[1])
center_position = (center_x, center_y)
loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
for i, loc in enumerate(loc_strs):
if loc == 'top_left':
results_patch = copy.deepcopy(results)
else:
results_patch = copy.deepcopy(results['mix_results'][i - 1])
img_i = results_patch['img']
h_i, w_i = img_i.shape[:2]
# keep_ratio resize
scale_ratio_i = min(self.img_scale[1] / h_i,
self.img_scale[0] / w_i)
img_i = mmcv.imresize(
img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)))
# compute the combine parameters
paste_coord, crop_coord = self._mosaic_combine(
loc, center_position, img_i.shape[:2][::-1])
x1_p, y1_p, x2_p, y2_p = paste_coord
x1_c, y1_c, x2_c, y2_c = crop_coord
# crop and paste image
mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c]
# adjust coordinate
gt_bboxes_i = results_patch['gt_bboxes']
gt_bboxes_labels_i = results_patch['gt_bboxes_labels']
gt_ignore_flags_i = results_patch['gt_ignore_flags']
padw = x1_p - x1_c
padh = y1_p - y1_c
gt_bboxes_i.rescale_([scale_ratio_i, scale_ratio_i])
gt_bboxes_i.translate_([padw, padh])
mosaic_bboxes.append(gt_bboxes_i)
mosaic_bboxes_labels.append(gt_bboxes_labels_i)
mosaic_ignore_flags.append(gt_ignore_flags_i)
mosaic_bboxes = mosaic_bboxes[0].cat(mosaic_bboxes, 0)
mosaic_bboxes_labels = np.concatenate(mosaic_bboxes_labels, 0)
mosaic_ignore_flags = np.concatenate(mosaic_ignore_flags, 0)
if self.bbox_clip_border:
mosaic_bboxes.clip_([2 * self.img_scale[1], 2 * self.img_scale[0]])
# remove outside bboxes
inside_inds = mosaic_bboxes.is_inside(
[2 * self.img_scale[1], 2 * self.img_scale[0]]).numpy()
mosaic_bboxes = mosaic_bboxes[inside_inds]
mosaic_bboxes_labels = mosaic_bboxes_labels[inside_inds]
mosaic_ignore_flags = mosaic_ignore_flags[inside_inds]
results['img'] = mosaic_img
results['img_shape'] = mosaic_img.shape[:2]
results['gt_bboxes'] = mosaic_bboxes
results['gt_bboxes_labels'] = mosaic_bboxes_labels
results['gt_ignore_flags'] = mosaic_ignore_flags
return results
def _mosaic_combine(
self, loc: str, center_position_xy: Sequence[float],
img_shape_wh: Sequence[int]) -> Tuple[Tuple[int], Tuple[int]]:
"""Calculate global coordinate of mosaic image and local coordinate of
cropped sub-image.
Args:
loc (str): Index for the sub-image, loc in ('top_left',
'top_right', 'bottom_left', 'bottom_right').
center_position_xy (Sequence[float]): Mixing center for 4 images,
(x, y).
img_shape_wh (Sequence[int]): Width and height of sub-image
Returns:
tuple[tuple[float]]: Corresponding coordinate of pasting and
cropping
- paste_coord (tuple): paste corner coordinate in mosaic image.
- crop_coord (tuple): crop corner coordinate in mosaic image.
"""
assert loc in ('top_left', 'top_right', 'bottom_left', 'bottom_right')
if loc == 'top_left':
# index0 to top left part of image
x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \
max(center_position_xy[1] - img_shape_wh[1], 0), \
center_position_xy[0], \
center_position_xy[1]
crop_coord = img_shape_wh[0] - (x2 - x1), img_shape_wh[1] - (
y2 - y1), img_shape_wh[0], img_shape_wh[1]
elif loc == 'top_right':
# index1 to top right part of image
x1, y1, x2, y2 = center_position_xy[0], \
max(center_position_xy[1] - img_shape_wh[1], 0), \
min(center_position_xy[0] + img_shape_wh[0],
self.img_scale[0] * 2), \
center_position_xy[1]
crop_coord = 0, img_shape_wh[1] - (y2 - y1), min(
img_shape_wh[0], x2 - x1), img_shape_wh[1]
elif loc == 'bottom_left':
# index2 to bottom left part of image
x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \
center_position_xy[1], \
center_position_xy[0], \
min(self.img_scale[1] * 2, center_position_xy[1] +
img_shape_wh[1])
crop_coord = img_shape_wh[0] - (x2 - x1), 0, img_shape_wh[0], min(
y2 - y1, img_shape_wh[1])
else:
# index3 to bottom right part of image
x1, y1, x2, y2 = center_position_xy[0], \
center_position_xy[1], \
min(center_position_xy[0] + img_shape_wh[0],
self.img_scale[0] * 2), \
min(self.img_scale[1] * 2, center_position_xy[1] +
img_shape_wh[1])
crop_coord = 0, 0, min(img_shape_wh[0],
x2 - x1), min(y2 - y1, img_shape_wh[1])
paste_coord = x1, y1, x2, y2
return paste_coord, crop_coord
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(img_scale={self.img_scale}, '
repr_str += f'center_ratio_range={self.center_ratio_range}, '
repr_str += f'pad_val={self.pad_val}, '
repr_str += f'prob={self.prob})'
return repr_str
@TRANSFORMS.register_module()
class MixUp(BaseTransform):
"""MixUp data augmentation.
.. code:: text
mixup transform
+------------------------------+
| mixup image | |
| +--------|--------+ |
| | | | |
|---------------+ | |
| | | |
| | image | |
| | | |
| | | |
| |-----------------+ |
| pad |
+------------------------------+
The mixup transform steps are as follows:
1. Another random image is picked by dataset and embedded in
the top left patch(after padding and resizing)
2. The target of mixup transform is the weighted average of mixup
image and origin image.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_bboxes_labels (np.int64) (optional)
- gt_ignore_flags (bool) (optional)
- mix_results (List[dict])
Modified Keys:
- img
- img_shape
- gt_bboxes (optional)
- gt_bboxes_labels (optional)
- gt_ignore_flags (optional)
Args:
img_scale (Sequence[int]): Image output size after mixup pipeline.
The shape order should be (width, height). Defaults to (640, 640).
ratio_range (Sequence[float]): Scale ratio of mixup image.
Defaults to (0.5, 1.5).
flip_ratio (float): Horizontal flip ratio of mixup image.
Defaults to 0.5.
pad_val (int): Pad value. Defaults to 114.
max_iters (int): The maximum number of iterations. If the number of
iterations is greater than `max_iters`, but gt_bbox is still
empty, then the iteration is terminated. Defaults to 15.
bbox_clip_border (bool, optional): Whether to clip the objects outside
the border of the image. In some dataset like MOT17, the gt bboxes
are allowed to cross the border of images. Therefore, we don't
need to clip the gt bboxes in these cases. Defaults to True.
"""
def __init__(self,
img_scale: Tuple[int, int] = (640, 640),
ratio_range: Tuple[float, float] = (0.5, 1.5),
flip_ratio: float = 0.5,
pad_val: float = 114.0,
max_iters: int = 15,
bbox_clip_border: bool = True) -> None:
assert isinstance(img_scale, tuple)
log_img_scale(img_scale, skip_square=True, shape_order='wh')
self.dynamic_scale = img_scale
self.ratio_range = ratio_range
self.flip_ratio = flip_ratio
self.pad_val = pad_val
self.max_iters = max_iters
self.bbox_clip_border = bbox_clip_border
@cache_randomness
def get_indexes(self, dataset: BaseDataset) -> int:
"""Call function to collect indexes.
Args:
dataset (:obj:`MultiImageMixDataset`): The dataset.
Returns:
list: indexes.
"""
for i in range(self.max_iters):
index = random.randint(0, len(dataset))
gt_bboxes_i = dataset[index]['gt_bboxes']
if len(gt_bboxes_i) != 0:
break
return index
@autocast_box_type()
def transform(self, results: dict) -> dict:
"""MixUp transform function.
Args:
results (dict): Result dict.
Returns:
dict: Updated result dict.
"""
assert 'mix_results' in results
assert len(
results['mix_results']) == 1, 'MixUp only support 2 images now !'
if results['mix_results'][0]['gt_bboxes'].shape[0] == 0:
# empty bbox
return results
retrieve_results = results['mix_results'][0]
retrieve_img = retrieve_results['img']
jit_factor = random.uniform(*self.ratio_range)
is_flip = random.uniform(0, 1) > self.flip_ratio
if len(retrieve_img.shape) == 3:
out_img = np.ones(
(self.dynamic_scale[1], self.dynamic_scale[0], 3),
dtype=retrieve_img.dtype) * self.pad_val
else:
out_img = np.ones(
self.dynamic_scale[::-1],
dtype=retrieve_img.dtype) * self.pad_val
# 1. keep_ratio resize
scale_ratio = min(self.dynamic_scale[1] / retrieve_img.shape[0],
self.dynamic_scale[0] / retrieve_img.shape[1])
retrieve_img = mmcv.imresize(
retrieve_img, (int(retrieve_img.shape[1] * scale_ratio),
int(retrieve_img.shape[0] * scale_ratio)))
# 2. paste
out_img[:retrieve_img.shape[0], :retrieve_img.shape[1]] = retrieve_img
# 3. scale jit
scale_ratio *= jit_factor
out_img = mmcv.imresize(out_img, (int(out_img.shape[1] * jit_factor),
int(out_img.shape[0] * jit_factor)))
# 4. flip
if is_flip:
out_img = out_img[:, ::-1, :]
# 5. random crop
ori_img = results['img']
origin_h, origin_w = out_img.shape[:2]
target_h, target_w = ori_img.shape[:2]
padded_img = np.ones((max(origin_h, target_h), max(
origin_w, target_w), 3)) * self.pad_val
padded_img = padded_img.astype(np.uint8)
padded_img[:origin_h, :origin_w] = out_img
x_offset, y_offset = 0, 0
if padded_img.shape[0] > target_h:
y_offset = random.randint(0, padded_img.shape[0] - target_h)
if padded_img.shape[1] > target_w:
x_offset = random.randint(0, padded_img.shape[1] - target_w)
padded_cropped_img = padded_img[y_offset:y_offset + target_h,
x_offset:x_offset + target_w]
# 6. adjust bbox
retrieve_gt_bboxes = retrieve_results['gt_bboxes']
retrieve_gt_bboxes.rescale_([scale_ratio, scale_ratio])
if self.bbox_clip_border:
retrieve_gt_bboxes.clip_([origin_h, origin_w])
if is_flip:
retrieve_gt_bboxes.flip_([origin_h, origin_w],
direction='horizontal')
# 7. filter
cp_retrieve_gt_bboxes = retrieve_gt_bboxes.clone()
cp_retrieve_gt_bboxes.translate_([-x_offset, -y_offset])
if self.bbox_clip_border:
cp_retrieve_gt_bboxes.clip_([target_h, target_w])
# 8. mix up
ori_img = ori_img.astype(np.float32)
mixup_img = 0.5 * ori_img + 0.5 * padded_cropped_img.astype(np.float32)
retrieve_gt_bboxes_labels = retrieve_results['gt_bboxes_labels']
retrieve_gt_ignore_flags = retrieve_results['gt_ignore_flags']
mixup_gt_bboxes = cp_retrieve_gt_bboxes.cat(
(results['gt_bboxes'], cp_retrieve_gt_bboxes), dim=0)
mixup_gt_bboxes_labels = np.concatenate(
(results['gt_bboxes_labels'], retrieve_gt_bboxes_labels), axis=0)
mixup_gt_ignore_flags = np.concatenate(
(results['gt_ignore_flags'], retrieve_gt_ignore_flags), axis=0)
# remove outside bbox
inside_inds = mixup_gt_bboxes.is_inside([target_h, target_w]).numpy()
mixup_gt_bboxes = mixup_gt_bboxes[inside_inds]
mixup_gt_bboxes_labels = mixup_gt_bboxes_labels[inside_inds]
mixup_gt_ignore_flags = mixup_gt_ignore_flags[inside_inds]
results['img'] = mixup_img.astype(np.uint8)
results['img_shape'] = mixup_img.shape[:2]
results['gt_bboxes'] = mixup_gt_bboxes
results['gt_bboxes_labels'] = mixup_gt_bboxes_labels
results['gt_ignore_flags'] = mixup_gt_ignore_flags
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(dynamic_scale={self.dynamic_scale}, '
repr_str += f'ratio_range={self.ratio_range}, '
repr_str += f'flip_ratio={self.flip_ratio}, '
repr_str += f'pad_val={self.pad_val}, '
repr_str += f'max_iters={self.max_iters}, '
repr_str += f'bbox_clip_border={self.bbox_clip_border})'
return repr_str
@TRANSFORMS.register_module()
class RandomAffine(BaseTransform):
"""Random affine transform data augmentation.
This operation randomly generates affine transform matrix which including
rotation, translation, shear and scaling transforms.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_bboxes_labels (np.int64) (optional)
- gt_ignore_flags (bool) (optional)
Modified Keys:
- img
- img_shape
- gt_bboxes (optional)
- gt_bboxes_labels (optional)
- gt_ignore_flags (optional)
Args:
max_rotate_degree (float): Maximum degrees of rotation transform.
Defaults to 10.
max_translate_ratio (float): Maximum ratio of translation.
Defaults to 0.1.
scaling_ratio_range (tuple[float]): Min and max ratio of
scaling transform. Defaults to (0.5, 1.5).
max_shear_degree (float): Maximum degrees of shear
transform. Defaults to 2.
border (tuple[int]): Distance from width and height sides of input
image to adjust output shape. Only used in mosaic dataset.
Defaults to (0, 0).
border_val (tuple[int]): Border padding values of 3 channels.
Defaults to (114, 114, 114).
bbox_clip_border (bool, optional): Whether to clip the objects outside
the border of the image. In some dataset like MOT17, the gt bboxes
are allowed to cross the border of images. Therefore, we don't
need to clip the gt bboxes in these cases. Defaults to True.
"""
def __init__(self,
max_rotate_degree: float = 10.0,
max_translate_ratio: float = 0.1,
scaling_ratio_range: Tuple[float, float] = (0.5, 1.5),
max_shear_degree: float = 2.0,
border: Tuple[int, int] = (0, 0),
border_val: Tuple[int, int, int] = (114, 114, 114),
bbox_clip_border: bool = True) -> None:
assert 0 <= max_translate_ratio <= 1
assert scaling_ratio_range[0] <= scaling_ratio_range[1]
assert scaling_ratio_range[0] > 0
self.max_rotate_degree = max_rotate_degree
self.max_translate_ratio = max_translate_ratio
self.scaling_ratio_range = scaling_ratio_range
self.max_shear_degree = max_shear_degree
self.border = border
self.border_val = border_val
self.bbox_clip_border = bbox_clip_border
@cache_randomness
def _get_random_homography_matrix(self, height, width):
# Rotation
rotation_degree = random.uniform(-self.max_rotate_degree,
self.max_rotate_degree)
rotation_matrix = self._get_rotation_matrix(rotation_degree)
# Scaling
scaling_ratio = random.uniform(self.scaling_ratio_range[0],
self.scaling_ratio_range[1])
scaling_matrix = self._get_scaling_matrix(scaling_ratio)
# Shear
x_degree = random.uniform(-self.max_shear_degree,
self.max_shear_degree)
y_degree = random.uniform(-self.max_shear_degree,
self.max_shear_degree)
shear_matrix = self._get_shear_matrix(x_degree, y_degree)
# Translation
trans_x = random.uniform(-self.max_translate_ratio,
self.max_translate_ratio) * width
trans_y = random.uniform(-self.max_translate_ratio,
self.max_translate_ratio) * height
translate_matrix = self._get_translation_matrix(trans_x, trans_y)
warp_matrix = (
translate_matrix @ shear_matrix @ rotation_matrix @ scaling_matrix)
return warp_matrix
@autocast_box_type()
def transform(self, results: dict) -> dict:
img = results['img']
height = img.shape[0] + self.border[1] * 2
width = img.shape[1] + self.border[0] * 2
warp_matrix = self._get_random_homography_matrix(height, width)
img = cv2.warpPerspective(
img,
warp_matrix,
dsize=(width, height),
borderValue=self.border_val)
results['img'] = img
results['img_shape'] = img.shape[:2]
bboxes = results['gt_bboxes']
num_bboxes = len(bboxes)
if num_bboxes:
bboxes.project_(warp_matrix)
if self.bbox_clip_border:
bboxes.clip_([height, width])
# remove outside bbox
valid_index = bboxes.is_inside([height, width]).numpy()
results['gt_bboxes'] = bboxes[valid_index]
results['gt_bboxes_labels'] = results['gt_bboxes_labels'][
valid_index]
results['gt_ignore_flags'] = results['gt_ignore_flags'][
valid_index]
if 'gt_masks' in results:
raise NotImplementedError('RandomAffine only supports bbox.')
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(max_rotate_degree={self.max_rotate_degree}, '
repr_str += f'max_translate_ratio={self.max_translate_ratio}, '
repr_str += f'scaling_ratio_range={self.scaling_ratio_range}, '
repr_str += f'max_shear_degree={self.max_shear_degree}, '
repr_str += f'border={self.border}, '
repr_str += f'border_val={self.border_val}, '
repr_str += f'bbox_clip_border={self.bbox_clip_border})'
return repr_str
@staticmethod
def _get_rotation_matrix(rotate_degrees: float) -> np.ndarray:
radian = math.radians(rotate_degrees)
rotation_matrix = np.array(
[[np.cos(radian), -np.sin(radian), 0.],
[np.sin(radian), np.cos(radian), 0.], [0., 0., 1.]],
dtype=np.float32)
return rotation_matrix
@staticmethod
def _get_scaling_matrix(scale_ratio: float) -> np.ndarray:
scaling_matrix = np.array(
[[scale_ratio, 0., 0.], [0., scale_ratio, 0.], [0., 0., 1.]],
dtype=np.float32)
return scaling_matrix
@staticmethod
def _get_shear_matrix(x_shear_degrees: float,
y_shear_degrees: float) -> np.ndarray:
x_radian = math.radians(x_shear_degrees)
y_radian = math.radians(y_shear_degrees)
shear_matrix = np.array([[1, np.tan(x_radian), 0.],
[np.tan(y_radian), 1, 0.], [0., 0., 1.]],
dtype=np.float32)
return shear_matrix
@staticmethod
def _get_translation_matrix(x: float, y: float) -> np.ndarray:
translation_matrix = np.array([[1, 0., x], [0., 1, y], [0., 0., 1.]],
dtype=np.float32)
return translation_matrix
@TRANSFORMS.register_module()
class YOLOXHSVRandomAug(BaseTransform):
"""Apply HSV augmentation to image sequentially. It is referenced from
https://github.com/Megvii-
BaseDetection/YOLOX/blob/main/yolox/data/data_augment.py#L21.
Required Keys:
- img
Modified Keys:
- img
Args:
hue_delta (int): delta of hue. Defaults to 5.
saturation_delta (int): delta of saturation. Defaults to 30.
value_delta (int): delat of value. Defaults to 30.
"""
def __init__(self,
hue_delta: int = 5,
saturation_delta: int = 30,
value_delta: int = 30) -> None:
self.hue_delta = hue_delta
self.saturation_delta = saturation_delta
self.value_delta = value_delta
@cache_randomness
def _get_hsv_gains(self):
hsv_gains = np.random.uniform(-1, 1, 3) * [
self.hue_delta, self.saturation_delta, self.value_delta
]
# random selection of h, s, v
hsv_gains *= np.random.randint(0, 2, 3)
# prevent overflow
hsv_gains = hsv_gains.astype(np.int16)
return hsv_gains
def transform(self, results: dict) -> dict:
img = results['img']
hsv_gains = self._get_hsv_gains()
img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.int16)
img_hsv[..., 0] = (img_hsv[..., 0] + hsv_gains[0]) % 180
img_hsv[..., 1] = np.clip(img_hsv[..., 1] + hsv_gains[1], 0, 255)
img_hsv[..., 2] = np.clip(img_hsv[..., 2] + hsv_gains[2], 0, 255)
cv2.cvtColor(img_hsv.astype(img.dtype), cv2.COLOR_HSV2BGR, dst=img)
results['img'] = img
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(hue_delta={self.hue_delta}, '
repr_str += f'saturation_delta={self.saturation_delta}, '
repr_str += f'value_delta={self.value_delta})'
return repr_str
@TRANSFORMS.register_module()
class CopyPaste(BaseTransform):
"""Simple Copy-Paste is a Strong Data Augmentation Method for Instance
Segmentation The simple copy-paste transform steps are as follows:
1. The destination image is already resized with aspect ratio kept,
cropped and padded.
2. Randomly select a source image, which is also already resized
with aspect ratio kept, cropped and padded in a similar way
as the destination image.
3. Randomly select some objects from the source image.
4. Paste these source objects to the destination image directly,
due to the source and destination image have the same size.
5. Update object masks of the destination image, for some origin objects
may be occluded.
6. Generate bboxes from the updated destination masks and
filter some objects which are totally occluded, and adjust bboxes
which are partly occluded.
7. Append selected source bboxes, masks, and labels.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_bboxes_labels (np.int64) (optional)
- gt_ignore_flags (bool) (optional)
- gt_masks (BitmapMasks) (optional)
Modified Keys:
- img
- gt_bboxes (optional)
- gt_bboxes_labels (optional)
- gt_ignore_flags (optional)
- gt_masks (optional)
Args:
max_num_pasted (int): The maximum number of pasted objects.
Defaults to 100.
bbox_occluded_thr (int): The threshold of occluded bbox.
Defaults to 10.
mask_occluded_thr (int): The threshold of occluded mask.
Defaults to 300.
selected (bool): Whether select objects or not. If select is False,
all objects of the source image will be pasted to the
destination image.
Defaults to True.
paste_by_box (bool): Whether use boxes as masks when masks are not
available.
Defaults to False.
"""
def __init__(
self,
max_num_pasted: int = 100,
bbox_occluded_thr: int = 10,
mask_occluded_thr: int = 300,
selected: bool = True,
paste_by_box: bool = False,
) -> None:
self.max_num_pasted = max_num_pasted
self.bbox_occluded_thr = bbox_occluded_thr
self.mask_occluded_thr = mask_occluded_thr
self.selected = selected
self.paste_by_box = paste_by_box
@cache_randomness
def get_indexes(self, dataset: BaseDataset) -> int:
"""Call function to collect indexes.s.
Args:
dataset (:obj:`MultiImageMixDataset`): The dataset.
Returns:
list: Indexes.
"""
return random.randint(0, len(dataset))
@autocast_box_type()
def transform(self, results: dict) -> dict:
"""Transform function to make a copy-paste of image.
Args:
results (dict): Result dict.
Returns:
dict: Result dict with copy-paste transformed.
"""
assert 'mix_results' in results
num_images = len(results['mix_results'])
assert num_images == 1, \
f'CopyPaste only supports processing 2 images, got {num_images}'
if self.selected:
selected_results = self._select_object(results['mix_results'][0])
else:
selected_results = results['mix_results'][0]
return self._copy_paste(results, selected_results)
@cache_randomness
def _get_selected_inds(self, num_bboxes: int) -> np.ndarray:
max_num_pasted = min(num_bboxes + 1, self.max_num_pasted)
num_pasted = np.random.randint(0, max_num_pasted)
return np.random.choice(num_bboxes, size=num_pasted, replace=False)
def get_gt_masks(self, results: dict) -> BitmapMasks:
"""Get gt_masks originally or generated based on bboxes.
If gt_masks is not contained in results,
it will be generated based on gt_bboxes.
Args:
results (dict): Result dict.
Returns:
BitmapMasks: gt_masks, originally or generated based on bboxes.
"""
if results.get('gt_masks', None) is not None:
if self.paste_by_box:
warnings.warn('gt_masks is already contained in results, '
'so paste_by_box is disabled.')
return results['gt_masks']
else:
if not self.paste_by_box:
raise RuntimeError('results does not contain masks.')
return results['gt_bboxes'].create_masks(results['img'].shape[:2])
def _select_object(self, results: dict) -> dict:
"""Select some objects from the source results."""
bboxes = results['gt_bboxes']
labels = results['gt_bboxes_labels']
masks = self.get_gt_masks(results)
ignore_flags = results['gt_ignore_flags']
selected_inds = self._get_selected_inds(bboxes.shape[0])
selected_bboxes = bboxes[selected_inds]
selected_labels = labels[selected_inds]
selected_masks = masks[selected_inds]
selected_ignore_flags = ignore_flags[selected_inds]
results['gt_bboxes'] = selected_bboxes
results['gt_bboxes_labels'] = selected_labels
results['gt_masks'] = selected_masks
results['gt_ignore_flags'] = selected_ignore_flags
return results
def _copy_paste(self, dst_results: dict, src_results: dict) -> dict:
"""CopyPaste transform function.
Args:
dst_results (dict): Result dict of the destination image.
src_results (dict): Result dict of the source image.
Returns:
dict: Updated result dict.
"""
dst_img = dst_results['img']
dst_bboxes = dst_results['gt_bboxes']
dst_labels = dst_results['gt_bboxes_labels']
dst_masks = self.get_gt_masks(dst_results)
dst_ignore_flags = dst_results['gt_ignore_flags']
src_img = src_results['img']
src_bboxes = src_results['gt_bboxes']
src_labels = src_results['gt_bboxes_labels']
src_masks = src_results['gt_masks']
src_ignore_flags = src_results['gt_ignore_flags']
if len(src_bboxes) == 0:
return dst_results
# update masks and generate bboxes from updated masks
composed_mask = np.where(np.any(src_masks.masks, axis=0), 1, 0)
updated_dst_masks = self._get_updated_masks(dst_masks, composed_mask)
updated_dst_bboxes = updated_dst_masks.get_bboxes(type(dst_bboxes))
assert len(updated_dst_bboxes) == len(updated_dst_masks)
# filter totally occluded objects
l1_distance = (updated_dst_bboxes.tensor - dst_bboxes.tensor).abs()
bboxes_inds = (l1_distance <= self.bbox_occluded_thr).all(
dim=-1).numpy()
masks_inds = updated_dst_masks.masks.sum(
axis=(1, 2)) > self.mask_occluded_thr
valid_inds = bboxes_inds | masks_inds
# Paste source objects to destination image directly
img = dst_img * (1 - composed_mask[..., np.newaxis]
) + src_img * composed_mask[..., np.newaxis]
bboxes = src_bboxes.cat([updated_dst_bboxes[valid_inds], src_bboxes])
labels = np.concatenate([dst_labels[valid_inds], src_labels])
masks = np.concatenate(
[updated_dst_masks.masks[valid_inds], src_masks.masks])
ignore_flags = np.concatenate(
[dst_ignore_flags[valid_inds], src_ignore_flags])
dst_results['img'] = img
dst_results['gt_bboxes'] = bboxes
dst_results['gt_bboxes_labels'] = labels
dst_results['gt_masks'] = BitmapMasks(masks, masks.shape[1],
masks.shape[2])
dst_results['gt_ignore_flags'] = ignore_flags
return dst_results
def _get_updated_masks(self, masks: BitmapMasks,
composed_mask: np.ndarray) -> BitmapMasks:
"""Update masks with composed mask."""
assert masks.masks.shape[-2:] == composed_mask.shape[-2:], \
'Cannot compare two arrays of different size'
masks.masks = np.where(composed_mask, 0, masks.masks)
return masks
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(max_num_pasted={self.max_num_pasted}, '
repr_str += f'bbox_occluded_thr={self.bbox_occluded_thr}, '
repr_str += f'mask_occluded_thr={self.mask_occluded_thr}, '
repr_str += f'selected={self.selected}), '
repr_str += f'paste_by_box={self.paste_by_box})'
return repr_str
@TRANSFORMS.register_module()
class RandomErasing(BaseTransform):
"""RandomErasing operation.
Random Erasing randomly selects a rectangle region
in an image and erases its pixels with random values.
`RandomErasing <https://arxiv.org/abs/1708.04896>`_.
Required Keys:
- img
- gt_bboxes (HorizontalBoxes[torch.float32]) (optional)
- gt_bboxes_labels (np.int64) (optional)
- gt_ignore_flags (bool) (optional)
- gt_masks (BitmapMasks) (optional)
Modified Keys:
- img
- gt_bboxes (optional)
- gt_bboxes_labels (optional)
- gt_ignore_flags (optional)
- gt_masks (optional)
Args:
n_patches (int or tuple[int, int]): Number of regions to be dropped.
If it is given as a tuple, number of patches will be randomly
selected from the closed interval [``n_patches[0]``,
``n_patches[1]``].
ratio (float or tuple[float, float]): The ratio of erased regions.
It can be ``float`` to use a fixed ratio or ``tuple[float, float]``
to randomly choose ratio from the interval.
squared (bool): Whether to erase square region. Defaults to True.
bbox_erased_thr (float): The threshold for the maximum area proportion
of the bbox to be erased. When the proportion of the area where the
bbox is erased is greater than the threshold, the bbox will be
removed. Defaults to 0.9.
img_border_value (int or float or tuple): The filled values for
image border. If float, the same fill value will be used for
all the three channels of image. If tuple, it should be 3 elements.
Defaults to 128.
mask_border_value (int): The fill value used for masks. Defaults to 0.
seg_ignore_label (int): The fill value used for segmentation map.
Note this value must equals ``ignore_label`` in ``semantic_head``
of the corresponding config. Defaults to 255.
"""
def __init__(
self,
n_patches: Union[int, Tuple[int, int]],
ratio: Union[float, Tuple[float, float]],
squared: bool = True,
bbox_erased_thr: float = 0.9,
img_border_value: Union[int, float, tuple] = 128,
mask_border_value: int = 0,
seg_ignore_label: int = 255,
) -> None:
if isinstance(n_patches, tuple):
assert len(n_patches) == 2 and 0 <= n_patches[0] < n_patches[1]
else:
n_patches = (n_patches, n_patches)
if isinstance(ratio, tuple):
assert len(ratio) == 2 and 0 <= ratio[0] < ratio[1] <= 1
else:
ratio = (ratio, ratio)
self.n_patches = n_patches
self.ratio = ratio
self.squared = squared
self.bbox_erased_thr = bbox_erased_thr
self.img_border_value = img_border_value
self.mask_border_value = mask_border_value
self.seg_ignore_label = seg_ignore_label
@cache_randomness
def _get_patches(self, img_shape: Tuple[int, int]) -> List[list]:
"""Get patches for random erasing."""
patches = []
n_patches = np.random.randint(self.n_patches[0], self.n_patches[1] + 1)
for _ in range(n_patches):
if self.squared:
ratio = np.random.random() * (self.ratio[1] -
self.ratio[0]) + self.ratio[0]
ratio = (ratio, ratio)
else:
ratio = (np.random.random() * (self.ratio[1] - self.ratio[0]) +
self.ratio[0], np.random.random() *
(self.ratio[1] - self.ratio[0]) + self.ratio[0])
ph, pw = int(img_shape[0] * ratio[0]), int(img_shape[1] * ratio[1])
px1, py1 = np.random.randint(0,
img_shape[1] - pw), np.random.randint(
0, img_shape[0] - ph)
px2, py2 = px1 + pw, py1 + ph
patches.append([px1, py1, px2, py2])
return np.array(patches)
def _transform_img(self, results: dict, patches: List[list]) -> None:
"""Random erasing the image."""
for patch in patches:
px1, py1, px2, py2 = patch
results['img'][py1:py2, px1:px2, :] = self.img_border_value
def _transform_bboxes(self, results: dict, patches: List[list]) -> None:
"""Random erasing the bboxes."""
bboxes = results['gt_bboxes']
# TODO: unify the logic by using operators in BaseBoxes.
assert isinstance(bboxes, HorizontalBoxes)
bboxes = bboxes.numpy()
left_top = np.maximum(bboxes[:, None, :2], patches[:, :2])
right_bottom = np.minimum(bboxes[:, None, 2:], patches[:, 2:])
wh = np.maximum(right_bottom - left_top, 0)
inter_areas = wh[:, :, 0] * wh[:, :, 1]
bbox_areas = (bboxes[:, 2] - bboxes[:, 0]) * (
bboxes[:, 3] - bboxes[:, 1])
bboxes_erased_ratio = inter_areas.sum(-1) / (bbox_areas + 1e-7)
valid_inds = bboxes_erased_ratio < self.bbox_erased_thr
results['gt_bboxes'] = HorizontalBoxes(bboxes[valid_inds])
results['gt_bboxes_labels'] = results['gt_bboxes_labels'][valid_inds]
results['gt_ignore_flags'] = results['gt_ignore_flags'][valid_inds]
if results.get('gt_masks', None) is not None:
results['gt_masks'] = results['gt_masks'][valid_inds]
def _transform_masks(self, results: dict, patches: List[list]) -> None:
"""Random erasing the masks."""
for patch in patches:
px1, py1, px2, py2 = patch
results['gt_masks'].masks[:, py1:py2,
px1:px2] = self.mask_border_value
def _transform_seg(self, results: dict, patches: List[list]) -> None:
"""Random erasing the segmentation map."""
for patch in patches:
px1, py1, px2, py2 = patch
results['gt_seg_map'][py1:py2, px1:px2] = self.seg_ignore_label
@autocast_box_type()
def transform(self, results: dict) -> dict:
"""Transform function to erase some regions of image."""
patches = self._get_patches(results['img_shape'])
self._transform_img(results, patches)
if results.get('gt_bboxes', None) is not None:
self._transform_bboxes(results, patches)
if results.get('gt_masks', None) is not None:
self._transform_masks(results, patches)
if results.get('gt_seg_map', None) is not None:
self._transform_seg(results, patches)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(n_patches={self.n_patches}, '
repr_str += f'ratio={self.ratio}, '
repr_str += f'squared={self.squared}, '
repr_str += f'bbox_erased_thr={self.bbox_erased_thr}, '
repr_str += f'img_border_value={self.img_border_value}, '
repr_str += f'mask_border_value={self.mask_border_value}, '
repr_str += f'seg_ignore_label={self.seg_ignore_label})'
return repr_str
@TRANSFORMS.register_module()
class CachedMosaic(Mosaic):
"""Cached mosaic augmentation.
Cached mosaic transform will random select images from the cache
and combine them into one output image.
.. code:: text
mosaic transform
center_x
+------------------------------+
| pad | pad |
| +-----------+ |
| | | |
| | image1 |--------+ |
| | | | |
| | | image2 | |
center_y |----+-------------+-----------|
| | cropped | |
|pad | image3 | image4 |
| | | |
+----|-------------+-----------+
| |
+-------------+
The cached mosaic transform steps are as follows:
1. Append the results from the last transform into the cache.
2. Choose the mosaic center as the intersections of 4 images
3. Get the left top image according to the index, and randomly
sample another 3 images from the result cache.
4. Sub image will be cropped if image is larger than mosaic patch
Required Keys:
- img
- gt_bboxes (np.float32) (optional)
- gt_bboxes_labels (np.int64) (optional)
- gt_ignore_flags (bool) (optional)
Modified Keys:
- img
- img_shape
- gt_bboxes (optional)
- gt_bboxes_labels (optional)
- gt_ignore_flags (optional)
Args:
img_scale (Sequence[int]): Image size before mosaic pipeline of single
image. The shape order should be (width, height).
Defaults to (640, 640).
center_ratio_range (Sequence[float]): Center ratio range of mosaic
output. Defaults to (0.5, 1.5).
bbox_clip_border (bool, optional): Whether to clip the objects outside
the border of the image. In some dataset like MOT17, the gt bboxes
are allowed to cross the border of images. Therefore, we don't
need to clip the gt bboxes in these cases. Defaults to True.
pad_val (int): Pad value. Defaults to 114.
prob (float): Probability of applying this transformation.
Defaults to 1.0.
max_cached_images (int): The maximum length of the cache. The larger
the cache, the stronger the randomness of this transform. As a
rule of thumb, providing 10 caches for each image suffices for
randomness. Defaults to 40.
random_pop (bool): Whether to randomly pop a result from the cache
when the cache is full. If set to False, use FIFO popping method.
Defaults to True.
"""
def __init__(self,
*args,
max_cached_images: int = 40,
random_pop: bool = True,
**kwargs) -> None:
super().__init__(*args, **kwargs)
self.results_cache = []
self.random_pop = random_pop
assert max_cached_images >= 4, 'The length of cache must >= 4, ' \
f'but got {max_cached_images}.'
self.max_cached_images = max_cached_images
@cache_randomness
def get_indexes(self, cache: list) -> list:
"""Call function to collect indexes.
Args:
cache (list): The results cache.
Returns:
list: indexes.
"""
indexes = [random.randint(0, len(cache) - 1) for _ in range(3)]
return indexes
@autocast_box_type()
def transform(self, results: dict) -> dict:
"""Mosaic transform function.
Args:
results (dict): Result dict.
Returns:
dict: Updated result dict.
"""
# cache and pop images
self.results_cache.append(copy.deepcopy(results))
if len(self.results_cache) > self.max_cached_images:
if self.random_pop:
index = random.randint(0, len(self.results_cache) - 1)
else:
index = 0
self.results_cache.pop(index)
if len(self.results_cache) <= 4:
return results
if random.uniform(0, 1) > self.prob:
return results
indices = self.get_indexes(self.results_cache)
mix_results = [copy.deepcopy(self.results_cache[i]) for i in indices]
# TODO: refactor mosaic to reuse these code.
mosaic_bboxes = []
mosaic_bboxes_labels = []
mosaic_ignore_flags = []
mosaic_masks = []
with_mask = True if 'gt_masks' in results else False
if len(results['img'].shape) == 3:
mosaic_img = np.full(
(int(self.img_scale[1] * 2), int(self.img_scale[0] * 2), 3),
self.pad_val,
dtype=results['img'].dtype)
else:
mosaic_img = np.full(
(int(self.img_scale[1] * 2), int(self.img_scale[0] * 2)),
self.pad_val,
dtype=results['img'].dtype)
# mosaic center x, y
center_x = int(
random.uniform(*self.center_ratio_range) * self.img_scale[0])
center_y = int(
random.uniform(*self.center_ratio_range) * self.img_scale[1])
center_position = (center_x, center_y)
loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
for i, loc in enumerate(loc_strs):
if loc == 'top_left':
results_patch = copy.deepcopy(results)
else:
results_patch = copy.deepcopy(mix_results[i - 1])
img_i = results_patch['img']
h_i, w_i = img_i.shape[:2]
# keep_ratio resize
scale_ratio_i = min(self.img_scale[1] / h_i,
self.img_scale[0] / w_i)
img_i = mmcv.imresize(
img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)))
# compute the combine parameters
paste_coord, crop_coord = self._mosaic_combine(
loc, center_position, img_i.shape[:2][::-1])
x1_p, y1_p, x2_p, y2_p = paste_coord
x1_c, y1_c, x2_c, y2_c = crop_coord
# crop and paste image
mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c]
# adjust coordinate
gt_bboxes_i = results_patch['gt_bboxes']
gt_bboxes_labels_i = results_patch['gt_bboxes_labels']
gt_ignore_flags_i = results_patch['gt_ignore_flags']
padw = x1_p - x1_c
padh = y1_p - y1_c
gt_bboxes_i.rescale_([scale_ratio_i, scale_ratio_i])
gt_bboxes_i.translate_([padw, padh])
mosaic_bboxes.append(gt_bboxes_i)
mosaic_bboxes_labels.append(gt_bboxes_labels_i)
mosaic_ignore_flags.append(gt_ignore_flags_i)
if with_mask and results_patch.get('gt_masks', None) is not None:
gt_masks_i = results_patch['gt_masks']
gt_masks_i = gt_masks_i.rescale(float(scale_ratio_i))
gt_masks_i = gt_masks_i.translate(
out_shape=(int(self.img_scale[0] * 2),
int(self.img_scale[1] * 2)),
offset=padw,
direction='horizontal')
gt_masks_i = gt_masks_i.translate(
out_shape=(int(self.img_scale[0] * 2),
int(self.img_scale[1] * 2)),
offset=padh,
direction='vertical')
mosaic_masks.append(gt_masks_i)
mosaic_bboxes = mosaic_bboxes[0].cat(mosaic_bboxes, 0)
mosaic_bboxes_labels = np.concatenate(mosaic_bboxes_labels, 0)
mosaic_ignore_flags = np.concatenate(mosaic_ignore_flags, 0)
if self.bbox_clip_border:
mosaic_bboxes.clip_([2 * self.img_scale[1], 2 * self.img_scale[0]])
# remove outside bboxes
inside_inds = mosaic_bboxes.is_inside(
[2 * self.img_scale[1], 2 * self.img_scale[0]]).numpy()
mosaic_bboxes = mosaic_bboxes[inside_inds]
mosaic_bboxes_labels = mosaic_bboxes_labels[inside_inds]
mosaic_ignore_flags = mosaic_ignore_flags[inside_inds]
results['img'] = mosaic_img
results['img_shape'] = mosaic_img.shape[:2]
results['gt_bboxes'] = mosaic_bboxes
results['gt_bboxes_labels'] = mosaic_bboxes_labels
results['gt_ignore_flags'] = mosaic_ignore_flags
if with_mask:
mosaic_masks = mosaic_masks[0].cat(mosaic_masks)
results['gt_masks'] = mosaic_masks[inside_inds]
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(img_scale={self.img_scale}, '
repr_str += f'center_ratio_range={self.center_ratio_range}, '
repr_str += f'pad_val={self.pad_val}, '
repr_str += f'prob={self.prob}, '
repr_str += f'max_cached_images={self.max_cached_images}, '
repr_str += f'random_pop={self.random_pop})'
return repr_str
@TRANSFORMS.register_module()
class CachedMixUp(BaseTransform):
"""Cached mixup data augmentation.
.. code:: text
mixup transform
+------------------------------+
| mixup image | |
| +--------|--------+ |
| | | | |
|---------------+ | |
| | | |
| | image | |
| | | |
| | | |
| |-----------------+ |
| pad |
+------------------------------+
The cached mixup transform steps are as follows:
1. Append the results from the last transform into the cache.
2. Another random image is picked from the cache and embedded in
the top left patch(after padding and resizing)
3. The target of mixup transform is the weighted average of mixup
image and origin image.
Required Keys:
- img
- gt_bboxes (np.float32) (optional)
- gt_bboxes_labels (np.int64) (optional)
- gt_ignore_flags (bool) (optional)
- mix_results (List[dict])
Modified Keys:
- img
- img_shape
- gt_bboxes (optional)
- gt_bboxes_labels (optional)
- gt_ignore_flags (optional)
Args:
img_scale (Sequence[int]): Image output size after mixup pipeline.
The shape order should be (width, height). Defaults to (640, 640).
ratio_range (Sequence[float]): Scale ratio of mixup image.
Defaults to (0.5, 1.5).
flip_ratio (float): Horizontal flip ratio of mixup image.
Defaults to 0.5.
pad_val (int): Pad value. Defaults to 114.
max_iters (int): The maximum number of iterations. If the number of
iterations is greater than `max_iters`, but gt_bbox is still
empty, then the iteration is terminated. Defaults to 15.
bbox_clip_border (bool, optional): Whether to clip the objects outside
the border of the image. In some dataset like MOT17, the gt bboxes
are allowed to cross the border of images. Therefore, we don't
need to clip the gt bboxes in these cases. Defaults to True.
max_cached_images (int): The maximum length of the cache. The larger
the cache, the stronger the randomness of this transform. As a
rule of thumb, providing 10 caches for each image suffices for
randomness. Defaults to 20.
random_pop (bool): Whether to randomly pop a result from the cache
when the cache is full. If set to False, use FIFO popping method.
Defaults to True.
prob (float): Probability of applying this transformation.
Defaults to 1.0.
"""
def __init__(self,
img_scale: Tuple[int, int] = (640, 640),
ratio_range: Tuple[float, float] = (0.5, 1.5),
flip_ratio: float = 0.5,
pad_val: float = 114.0,
max_iters: int = 15,
bbox_clip_border: bool = True,
max_cached_images: int = 20,
random_pop: bool = True,
prob: float = 1.0) -> None:
assert isinstance(img_scale, tuple)
assert max_cached_images >= 2, 'The length of cache must >= 2, ' \
f'but got {max_cached_images}.'
assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. ' \
f'got {prob}.'
self.dynamic_scale = img_scale
self.ratio_range = ratio_range
self.flip_ratio = flip_ratio
self.pad_val = pad_val
self.max_iters = max_iters
self.bbox_clip_border = bbox_clip_border
self.results_cache = []
self.max_cached_images = max_cached_images
self.random_pop = random_pop
self.prob = prob
@cache_randomness
def get_indexes(self, cache: list) -> int:
"""Call function to collect indexes.
Args:
cache (list): The result cache.
Returns:
int: index.
"""
for i in range(self.max_iters):
index = random.randint(0, len(cache) - 1)
gt_bboxes_i = cache[index]['gt_bboxes']
if len(gt_bboxes_i) != 0:
break
return index
@autocast_box_type()
def transform(self, results: dict) -> dict:
"""MixUp transform function.
Args:
results (dict): Result dict.
Returns:
dict: Updated result dict.
"""
# cache and pop images
self.results_cache.append(copy.deepcopy(results))
if len(self.results_cache) > self.max_cached_images:
if self.random_pop:
index = random.randint(0, len(self.results_cache) - 1)
else:
index = 0
self.results_cache.pop(index)
if len(self.results_cache) <= 1:
return results
if random.uniform(0, 1) > self.prob:
return results
index = self.get_indexes(self.results_cache)
retrieve_results = copy.deepcopy(self.results_cache[index])
# TODO: refactor mixup to reuse these code.
if retrieve_results['gt_bboxes'].shape[0] == 0:
# empty bbox
return results
retrieve_img = retrieve_results['img']
with_mask = True if 'gt_masks' in results else False
jit_factor = random.uniform(*self.ratio_range)
is_flip = random.uniform(0, 1) > self.flip_ratio
if len(retrieve_img.shape) == 3:
out_img = np.ones(
(self.dynamic_scale[1], self.dynamic_scale[0], 3),
dtype=retrieve_img.dtype) * self.pad_val
else:
out_img = np.ones(
self.dynamic_scale[::-1],
dtype=retrieve_img.dtype) * self.pad_val
# 1. keep_ratio resize
scale_ratio = min(self.dynamic_scale[1] / retrieve_img.shape[0],
self.dynamic_scale[0] / retrieve_img.shape[1])
retrieve_img = mmcv.imresize(
retrieve_img, (int(retrieve_img.shape[1] * scale_ratio),
int(retrieve_img.shape[0] * scale_ratio)))
# 2. paste
out_img[:retrieve_img.shape[0], :retrieve_img.shape[1]] = retrieve_img
# 3. scale jit
scale_ratio *= jit_factor
out_img = mmcv.imresize(out_img, (int(out_img.shape[1] * jit_factor),
int(out_img.shape[0] * jit_factor)))
# 4. flip
if is_flip:
out_img = out_img[:, ::-1, :]
# 5. random crop
ori_img = results['img']
origin_h, origin_w = out_img.shape[:2]
target_h, target_w = ori_img.shape[:2]
padded_img = np.ones((max(origin_h, target_h), max(
origin_w, target_w), 3)) * self.pad_val
padded_img = padded_img.astype(np.uint8)
padded_img[:origin_h, :origin_w] = out_img
x_offset, y_offset = 0, 0
if padded_img.shape[0] > target_h:
y_offset = random.randint(0, padded_img.shape[0] - target_h)
if padded_img.shape[1] > target_w:
x_offset = random.randint(0, padded_img.shape[1] - target_w)
padded_cropped_img = padded_img[y_offset:y_offset + target_h,
x_offset:x_offset + target_w]
# 6. adjust bbox
retrieve_gt_bboxes = retrieve_results['gt_bboxes']
retrieve_gt_bboxes.rescale_([scale_ratio, scale_ratio])
if with_mask:
retrieve_gt_masks = retrieve_results['gt_masks'].rescale(
scale_ratio)
if self.bbox_clip_border:
retrieve_gt_bboxes.clip_([origin_h, origin_w])
if is_flip:
retrieve_gt_bboxes.flip_([origin_h, origin_w],
direction='horizontal')
if with_mask:
retrieve_gt_masks = retrieve_gt_masks.flip()
# 7. filter
cp_retrieve_gt_bboxes = retrieve_gt_bboxes.clone()
cp_retrieve_gt_bboxes.translate_([-x_offset, -y_offset])
if with_mask:
retrieve_gt_masks = retrieve_gt_masks.translate(
out_shape=(target_h, target_w),
offset=-x_offset,
direction='horizontal')
retrieve_gt_masks = retrieve_gt_masks.translate(
out_shape=(target_h, target_w),
offset=-y_offset,
direction='vertical')
if self.bbox_clip_border:
cp_retrieve_gt_bboxes.clip_([target_h, target_w])
# 8. mix up
ori_img = ori_img.astype(np.float32)
mixup_img = 0.5 * ori_img + 0.5 * padded_cropped_img.astype(np.float32)
retrieve_gt_bboxes_labels = retrieve_results['gt_bboxes_labels']
retrieve_gt_ignore_flags = retrieve_results['gt_ignore_flags']
mixup_gt_bboxes = cp_retrieve_gt_bboxes.cat(
(results['gt_bboxes'], cp_retrieve_gt_bboxes), dim=0)
mixup_gt_bboxes_labels = np.concatenate(
(results['gt_bboxes_labels'], retrieve_gt_bboxes_labels), axis=0)
mixup_gt_ignore_flags = np.concatenate(
(results['gt_ignore_flags'], retrieve_gt_ignore_flags), axis=0)
if with_mask:
mixup_gt_masks = retrieve_gt_masks.cat(
[results['gt_masks'], retrieve_gt_masks])
# remove outside bbox
inside_inds = mixup_gt_bboxes.is_inside([target_h, target_w]).numpy()
mixup_gt_bboxes = mixup_gt_bboxes[inside_inds]
mixup_gt_bboxes_labels = mixup_gt_bboxes_labels[inside_inds]
mixup_gt_ignore_flags = mixup_gt_ignore_flags[inside_inds]
if with_mask:
mixup_gt_masks = mixup_gt_masks[inside_inds]
results['img'] = mixup_img.astype(np.uint8)
results['img_shape'] = mixup_img.shape[:2]
results['gt_bboxes'] = mixup_gt_bboxes
results['gt_bboxes_labels'] = mixup_gt_bboxes_labels
results['gt_ignore_flags'] = mixup_gt_ignore_flags
if with_mask:
results['gt_masks'] = mixup_gt_masks
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(dynamic_scale={self.dynamic_scale}, '
repr_str += f'ratio_range={self.ratio_range}, '
repr_str += f'flip_ratio={self.flip_ratio}, '
repr_str += f'pad_val={self.pad_val}, '
repr_str += f'max_iters={self.max_iters}, '
repr_str += f'bbox_clip_border={self.bbox_clip_border}, '
repr_str += f'max_cached_images={self.max_cached_images}, '
repr_str += f'random_pop={self.random_pop}, '
repr_str += f'prob={self.prob})'
return repr_str
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Callable, Dict, List, Optional, Union
import numpy as np
from mmcv.transforms import BaseTransform, Compose
from mmcv.transforms.utils import cache_random_params, cache_randomness
from mmdet.registry import TRANSFORMS
@TRANSFORMS.register_module()
class MultiBranch(BaseTransform):
r"""Multiple branch pipeline wrapper.
Generate multiple data-augmented versions of the same image.
`MultiBranch` needs to specify the branch names of all
pipelines of the dataset, perform corresponding data augmentation
for the current branch, and return None for other branches,
which ensures the consistency of return format across
different samples.
Args:
branch_field (list): List of branch names.
branch_pipelines (dict): Dict of different pipeline configs
to be composed.
Examples:
>>> branch_field = ['sup', 'unsup_teacher', 'unsup_student']
>>> sup_pipeline = [
>>> dict(type='LoadImageFromFile'),
>>> dict(type='LoadAnnotations', with_bbox=True),
>>> dict(type='Resize', scale=(1333, 800), keep_ratio=True),
>>> dict(type='RandomFlip', prob=0.5),
>>> dict(
>>> type='MultiBranch',
>>> branch_field=branch_field,
>>> sup=dict(type='PackDetInputs'))
>>> ]
>>> weak_pipeline = [
>>> dict(type='LoadImageFromFile'),
>>> dict(type='LoadAnnotations', with_bbox=True),
>>> dict(type='Resize', scale=(1333, 800), keep_ratio=True),
>>> dict(type='RandomFlip', prob=0.0),
>>> dict(
>>> type='MultiBranch',
>>> branch_field=branch_field,
>>> sup=dict(type='PackDetInputs'))
>>> ]
>>> strong_pipeline = [
>>> dict(type='LoadImageFromFile'),
>>> dict(type='LoadAnnotations', with_bbox=True),
>>> dict(type='Resize', scale=(1333, 800), keep_ratio=True),
>>> dict(type='RandomFlip', prob=1.0),
>>> dict(
>>> type='MultiBranch',
>>> branch_field=branch_field,
>>> sup=dict(type='PackDetInputs'))
>>> ]
>>> unsup_pipeline = [
>>> dict(type='LoadImageFromFile'),
>>> dict(type='LoadEmptyAnnotations'),
>>> dict(
>>> type='MultiBranch',
>>> branch_field=branch_field,
>>> unsup_teacher=weak_pipeline,
>>> unsup_student=strong_pipeline)
>>> ]
>>> from mmcv.transforms import Compose
>>> sup_branch = Compose(sup_pipeline)
>>> unsup_branch = Compose(unsup_pipeline)
>>> print(sup_branch)
>>> Compose(
>>> LoadImageFromFile(ignore_empty=False, to_float32=False, color_type='color', imdecode_backend='cv2') # noqa
>>> LoadAnnotations(with_bbox=True, with_label=True, with_mask=False, with_seg=False, poly2mask=True, imdecode_backend='cv2') # noqa
>>> Resize(scale=(1333, 800), scale_factor=None, keep_ratio=True, clip_object_border=True), backend=cv2), interpolation=bilinear) # noqa
>>> RandomFlip(prob=0.5, direction=horizontal)
>>> MultiBranch(branch_pipelines=['sup'])
>>> )
>>> print(unsup_branch)
>>> Compose(
>>> LoadImageFromFile(ignore_empty=False, to_float32=False, color_type='color', imdecode_backend='cv2') # noqa
>>> LoadEmptyAnnotations(with_bbox=True, with_label=True, with_mask=False, with_seg=False, seg_ignore_label=255) # noqa
>>> MultiBranch(branch_pipelines=['unsup_teacher', 'unsup_student'])
>>> )
"""
def __init__(self, branch_field: List[str],
**branch_pipelines: dict) -> None:
self.branch_field = branch_field
self.branch_pipelines = {
branch: Compose(pipeline)
for branch, pipeline in branch_pipelines.items()
}
def transform(self, results: dict) -> dict:
"""Transform function to apply transforms sequentially.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict:
- 'inputs' (Dict[str, obj:`torch.Tensor`]): The forward data of
models from different branches.
- 'data_sample' (Dict[str,obj:`DetDataSample`]): The annotation
info of the sample from different branches.
"""
multi_results = {}
for branch in self.branch_field:
multi_results[branch] = {'inputs': None, 'data_samples': None}
for branch, pipeline in self.branch_pipelines.items():
branch_results = pipeline(copy.deepcopy(results))
# If one branch pipeline returns None,
# it will sample another data from dataset.
if branch_results is None:
return None
multi_results[branch] = branch_results
format_results = {}
for branch, results in multi_results.items():
for key in results.keys():
if format_results.get(key, None) is None:
format_results[key] = {branch: results[key]}
else:
format_results[key][branch] = results[key]
return format_results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(branch_pipelines={list(self.branch_pipelines.keys())})'
return repr_str
@TRANSFORMS.register_module()
class RandomOrder(Compose):
"""Shuffle the transform Sequence."""
@cache_randomness
def _random_permutation(self):
return np.random.permutation(len(self.transforms))
def transform(self, results: Dict) -> Optional[Dict]:
"""Transform function to apply transforms in random order.
Args:
results (dict): A result dict contains the results to transform.
Returns:
dict or None: Transformed results.
"""
inds = self._random_permutation()
for idx in inds:
t = self.transforms[idx]
results = t(results)
if results is None:
return None
return results
def __repr__(self):
"""Compute the string representation."""
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += f'{t.__class__.__name__}, '
format_string += ')'
return format_string
@TRANSFORMS.register_module()
class ProposalBroadcaster(BaseTransform):
"""A transform wrapper to apply the wrapped transforms to process both
`gt_bboxes` and `proposals` without adding any codes. It will do the
following steps:
1. Scatter the broadcasting targets to a list of inputs of the wrapped
transforms. The type of the list should be list[dict, dict], which
the first is the original inputs, the second is the processing
results that `gt_bboxes` being rewritten by the `proposals`.
2. Apply ``self.transforms``, with same random parameters, which is
sharing with a context manager. The type of the outputs is a
list[dict, dict].
3. Gather the outputs, update the `proposals` in the first item of
the outputs with the `gt_bboxes` in the second .
Args:
transforms (list, optional): Sequence of transform
object or config dict to be wrapped. Defaults to [].
Note: The `TransformBroadcaster` in MMCV can achieve the same operation as
`ProposalBroadcaster`, but need to set more complex parameters.
Examples:
>>> pipeline = [
>>> dict(type='LoadImageFromFile'),
>>> dict(type='LoadProposals', num_max_proposals=2000),
>>> dict(type='LoadAnnotations', with_bbox=True),
>>> dict(
>>> type='ProposalBroadcaster',
>>> transforms=[
>>> dict(type='Resize', scale=(1333, 800),
>>> keep_ratio=True),
>>> dict(type='RandomFlip', prob=0.5),
>>> ]),
>>> dict(type='PackDetInputs')]
"""
def __init__(self, transforms: List[Union[dict, Callable]] = []) -> None:
self.transforms = Compose(transforms)
def transform(self, results: dict) -> dict:
"""Apply wrapped transform functions to process both `gt_bboxes` and
`proposals`.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Updated result dict.
"""
assert results.get('proposals', None) is not None, \
'`proposals` should be in the results, please delete ' \
'`ProposalBroadcaster` in your configs, or check whether ' \
'you have load proposals successfully.'
inputs = self._process_input(results)
outputs = self._apply_transforms(inputs)
outputs = self._process_output(outputs)
return outputs
def _process_input(self, data: dict) -> list:
"""Scatter the broadcasting targets to a list of inputs of the wrapped
transforms.
Args:
data (dict): The original input data.
Returns:
list[dict]: A list of input data.
"""
cp_data = copy.deepcopy(data)
cp_data['gt_bboxes'] = cp_data['proposals']
scatters = [data, cp_data]
return scatters
def _apply_transforms(self, inputs: list) -> list:
"""Apply ``self.transforms``.
Args:
inputs (list[dict, dict]): list of input data.
Returns:
list[dict]: The output of the wrapped pipeline.
"""
assert len(inputs) == 2
ctx = cache_random_params
with ctx(self.transforms):
output_scatters = [self.transforms(_input) for _input in inputs]
return output_scatters
def _process_output(self, output_scatters: list) -> dict:
"""Gathering and renaming data items.
Args:
output_scatters (list[dict, dict]): The output of the wrapped
pipeline.
Returns:
dict: Updated result dict.
"""
assert isinstance(output_scatters, list) and \
isinstance(output_scatters[0], dict) and \
len(output_scatters) == 2
outputs = output_scatters[0]
outputs['proposals'] = output_scatters[1]['gt_bboxes']
return outputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment