Commit 26b83c4a authored by dengjb's avatar dengjb
Browse files

update codes

parent 2f6baaee
Pipeline #1045 failed with stages
in 0 seconds
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp
from typing import List, Optional
from mmengine.fileio import get_local_path
from mmdet.registry import DATASETS
from .base_det_dataset import BaseDetDataset
@DATASETS.register_module()
class ODVGDataset(BaseDetDataset):
"""object detection and visual grounding dataset."""
def __init__(self,
*args,
data_root: str = '',
label_map_file: Optional[str] = None,
need_text: bool = True,
**kwargs) -> None:
self.dataset_mode = 'VG'
self.need_text = need_text
if label_map_file:
label_map_file = osp.join(data_root, label_map_file)
with open(label_map_file, 'r') as file:
self.label_map = json.load(file)
self.dataset_mode = 'OD'
super().__init__(*args, data_root=data_root, **kwargs)
assert self.return_classes is True
def load_data_list(self) -> List[dict]:
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
data_list = [json.loads(line) for line in f]
out_data_list = []
for data in data_list:
data_info = {}
img_path = osp.join(self.data_prefix['img'], data['filename'])
data_info['img_path'] = img_path
data_info['height'] = data['height']
data_info['width'] = data['width']
if self.dataset_mode == 'OD':
if self.need_text:
data_info['text'] = self.label_map
anno = data.get('detection', {})
instances = [obj for obj in anno.get('instances', [])]
bboxes = [obj['bbox'] for obj in instances]
bbox_labels = [str(obj['label']) for obj in instances]
instances = []
for bbox, label in zip(bboxes, bbox_labels):
instance = {}
x1, y1, x2, y2 = bbox
inter_w = max(0, min(x2, data['width']) - max(x1, 0))
inter_h = max(0, min(y2, data['height']) - max(y1, 0))
if inter_w * inter_h == 0:
continue
if (x2 - x1) < 1 or (y2 - y1) < 1:
continue
instance['ignore_flag'] = 0
instance['bbox'] = bbox
instance['bbox_label'] = int(label)
instances.append(instance)
data_info['instances'] = instances
data_info['dataset_mode'] = self.dataset_mode
out_data_list.append(data_info)
else:
anno = data['grounding']
data_info['text'] = anno['caption']
regions = anno['regions']
instances = []
phrases = {}
for i, region in enumerate(regions):
bbox = region['bbox']
phrase = region['phrase']
tokens_positive = region['tokens_positive']
if not isinstance(bbox[0], list):
bbox = [bbox]
for box in bbox:
instance = {}
x1, y1, x2, y2 = box
inter_w = max(0, min(x2, data['width']) - max(x1, 0))
inter_h = max(0, min(y2, data['height']) - max(y1, 0))
if inter_w * inter_h == 0:
continue
if (x2 - x1) < 1 or (y2 - y1) < 1:
continue
instance['ignore_flag'] = 0
instance['bbox'] = box
instance['bbox_label'] = i
phrases[i] = {
'phrase': phrase,
'tokens_positive': tokens_positive
}
instances.append(instance)
data_info['instances'] = instances
data_info['phrases'] = phrases
data_info['dataset_mode'] = self.dataset_mode
out_data_list.append(data_info)
del data_list
return out_data_list
# Copyright (c) OpenMMLab. All rights reserved.
import csv
import os.path as osp
from collections import defaultdict
from typing import Dict, List, Optional
import numpy as np
from mmengine.fileio import get_local_path, load
from mmengine.utils import is_abs
from mmdet.registry import DATASETS
from .base_det_dataset import BaseDetDataset
@DATASETS.register_module()
class OpenImagesDataset(BaseDetDataset):
"""Open Images dataset for detection.
Args:
ann_file (str): Annotation file path.
label_file (str): File path of the label description file that
maps the classes names in MID format to their short
descriptions.
meta_file (str): File path to get image metas.
hierarchy_file (str): The file path of the class hierarchy.
image_level_ann_file (str): Human-verified image level annotation,
which is used in evaluation.
backend_args (dict, optional): Arguments to instantiate the
corresponding backend. Defaults to None.
"""
METAINFO: dict = dict(dataset_type='oid_v6')
def __init__(self,
label_file: str,
meta_file: str,
hierarchy_file: str,
image_level_ann_file: Optional[str] = None,
**kwargs) -> None:
self.label_file = label_file
self.meta_file = meta_file
self.hierarchy_file = hierarchy_file
self.image_level_ann_file = image_level_ann_file
super().__init__(**kwargs)
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
"""
classes_names, label_id_mapping = self._parse_label_file(
self.label_file)
self._metainfo['classes'] = classes_names
self.label_id_mapping = label_id_mapping
if self.image_level_ann_file is not None:
img_level_anns = self._parse_img_level_ann(
self.image_level_ann_file)
else:
img_level_anns = None
# OpenImagesMetric can get the relation matrix from the dataset meta
relation_matrix = self._get_relation_matrix(self.hierarchy_file)
self._metainfo['RELATION_MATRIX'] = relation_matrix
data_list = []
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
last_img_id = None
instances = []
for i, line in enumerate(reader):
if i == 0:
continue
img_id = line[0]
if last_img_id is None:
last_img_id = img_id
label_id = line[2]
assert label_id in self.label_id_mapping
label = int(self.label_id_mapping[label_id])
bbox = [
float(line[4]), # xmin
float(line[6]), # ymin
float(line[5]), # xmax
float(line[7]) # ymax
]
is_occluded = True if int(line[8]) == 1 else False
is_truncated = True if int(line[9]) == 1 else False
is_group_of = True if int(line[10]) == 1 else False
is_depiction = True if int(line[11]) == 1 else False
is_inside = True if int(line[12]) == 1 else False
instance = dict(
bbox=bbox,
bbox_label=label,
ignore_flag=0,
is_occluded=is_occluded,
is_truncated=is_truncated,
is_group_of=is_group_of,
is_depiction=is_depiction,
is_inside=is_inside)
last_img_path = osp.join(self.data_prefix['img'],
f'{last_img_id}.jpg')
if img_id != last_img_id:
# switch to a new image, record previous image's data.
data_info = dict(
img_path=last_img_path,
img_id=last_img_id,
instances=instances,
)
data_list.append(data_info)
instances = []
instances.append(instance)
last_img_id = img_id
data_list.append(
dict(
img_path=last_img_path,
img_id=last_img_id,
instances=instances,
))
# add image metas to data list
img_metas = load(
self.meta_file, file_format='pkl', backend_args=self.backend_args)
assert len(img_metas) == len(data_list)
for i, meta in enumerate(img_metas):
img_id = data_list[i]['img_id']
assert f'{img_id}.jpg' == osp.split(meta['filename'])[-1]
h, w = meta['ori_shape'][:2]
data_list[i]['height'] = h
data_list[i]['width'] = w
# denormalize bboxes
for j in range(len(data_list[i]['instances'])):
data_list[i]['instances'][j]['bbox'][0] *= w
data_list[i]['instances'][j]['bbox'][2] *= w
data_list[i]['instances'][j]['bbox'][1] *= h
data_list[i]['instances'][j]['bbox'][3] *= h
# add image-level annotation
if img_level_anns is not None:
img_labels = []
confidences = []
img_ann_list = img_level_anns.get(img_id, [])
for ann in img_ann_list:
img_labels.append(int(ann['image_level_label']))
confidences.append(float(ann['confidence']))
data_list[i]['image_level_labels'] = np.array(
img_labels, dtype=np.int64)
data_list[i]['confidences'] = np.array(
confidences, dtype=np.float32)
return data_list
def _parse_label_file(self, label_file: str) -> tuple:
"""Get classes name and index mapping from cls-label-description file.
Args:
label_file (str): File path of the label description file that
maps the classes names in MID format to their short
descriptions.
Returns:
tuple: Class name of OpenImages.
"""
index_list = []
classes_names = []
with get_local_path(
label_file, backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
for line in reader:
# self.cat2label[line[0]] = line[1]
classes_names.append(line[1])
index_list.append(line[0])
index_mapping = {index: i for i, index in enumerate(index_list)}
return classes_names, index_mapping
def _parse_img_level_ann(self,
img_level_ann_file: str) -> Dict[str, List[dict]]:
"""Parse image level annotations from csv style ann_file.
Args:
img_level_ann_file (str): CSV style image level annotation
file path.
Returns:
Dict[str, List[dict]]: Annotations where item of the defaultdict
indicates an image, each of which has (n) dicts.
Keys of dicts are:
- `image_level_label` (int): Label id.
- `confidence` (float): Labels that are human-verified to be
present in an image have confidence = 1 (positive labels).
Labels that are human-verified to be absent from an image
have confidence = 0 (negative labels). Machine-generated
labels have fractional confidences, generally >= 0.5.
The higher the confidence, the smaller the chance for
the label to be a false positive.
"""
item_lists = defaultdict(list)
with get_local_path(
img_level_ann_file,
backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
for i, line in enumerate(reader):
if i == 0:
continue
img_id = line[0]
item_lists[img_id].append(
dict(
image_level_label=int(
self.label_id_mapping[line[2]]),
confidence=float(line[3])))
return item_lists
def _get_relation_matrix(self, hierarchy_file: str) -> np.ndarray:
"""Get the matrix of class hierarchy from the hierarchy file. Hierarchy
for 600 classes can be found at https://storage.googleapis.com/openimag
es/2018_04/bbox_labels_600_hierarchy_visualizer/circle.html.
Args:
hierarchy_file (str): File path to the hierarchy for classes.
Returns:
np.ndarray: The matrix of the corresponding relationship between
the parent class and the child class, of shape
(class_num, class_num).
""" # noqa
hierarchy = load(
hierarchy_file, file_format='json', backend_args=self.backend_args)
class_num = len(self._metainfo['classes'])
relation_matrix = np.eye(class_num, class_num)
relation_matrix = self._convert_hierarchy_tree(hierarchy,
relation_matrix)
return relation_matrix
def _convert_hierarchy_tree(self,
hierarchy_map: dict,
relation_matrix: np.ndarray,
parents: list = [],
get_all_parents: bool = True) -> np.ndarray:
"""Get matrix of the corresponding relationship between the parent
class and the child class.
Args:
hierarchy_map (dict): Including label name and corresponding
subcategory. Keys of dicts are:
- `LabeName` (str): Name of the label.
- `Subcategory` (dict | list): Corresponding subcategory(ies).
relation_matrix (ndarray): The matrix of the corresponding
relationship between the parent class and the child class,
of shape (class_num, class_num).
parents (list): Corresponding parent class.
get_all_parents (bool): Whether get all parent names.
Default: True
Returns:
ndarray: The matrix of the corresponding relationship between
the parent class and the child class, of shape
(class_num, class_num).
"""
if 'Subcategory' in hierarchy_map:
for node in hierarchy_map['Subcategory']:
if 'LabelName' in node:
children_name = node['LabelName']
children_index = self.label_id_mapping[children_name]
children = [children_index]
else:
continue
if len(parents) > 0:
for parent_index in parents:
if get_all_parents:
children.append(parent_index)
relation_matrix[children_index, parent_index] = 1
relation_matrix = self._convert_hierarchy_tree(
node, relation_matrix, parents=children)
return relation_matrix
def _join_prefix(self):
"""Join ``self.data_root`` with annotation path."""
super()._join_prefix()
if not is_abs(self.label_file) and self.label_file:
self.label_file = osp.join(self.data_root, self.label_file)
if not is_abs(self.meta_file) and self.meta_file:
self.meta_file = osp.join(self.data_root, self.meta_file)
if not is_abs(self.hierarchy_file) and self.hierarchy_file:
self.hierarchy_file = osp.join(self.data_root, self.hierarchy_file)
if self.image_level_ann_file and not is_abs(self.image_level_ann_file):
self.image_level_ann_file = osp.join(self.data_root,
self.image_level_ann_file)
@DATASETS.register_module()
class OpenImagesChallengeDataset(OpenImagesDataset):
"""Open Images Challenge dataset for detection.
Args:
ann_file (str): Open Images Challenge box annotation in txt format.
"""
METAINFO: dict = dict(dataset_type='oid_challenge')
def __init__(self, ann_file: str, **kwargs) -> None:
if not ann_file.endswith('txt'):
raise TypeError('The annotation file of Open Images Challenge '
'should be a txt file.')
super().__init__(ann_file=ann_file, **kwargs)
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
"""
classes_names, label_id_mapping = self._parse_label_file(
self.label_file)
self._metainfo['classes'] = classes_names
self.label_id_mapping = label_id_mapping
if self.image_level_ann_file is not None:
img_level_anns = self._parse_img_level_ann(
self.image_level_ann_file)
else:
img_level_anns = None
# OpenImagesMetric can get the relation matrix from the dataset meta
relation_matrix = self._get_relation_matrix(self.hierarchy_file)
self._metainfo['RELATION_MATRIX'] = relation_matrix
data_list = []
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
lines = f.readlines()
i = 0
while i < len(lines):
instances = []
filename = lines[i].rstrip()
i += 2
img_gt_size = int(lines[i])
i += 1
for j in range(img_gt_size):
sp = lines[i + j].split()
instances.append(
dict(
bbox=[
float(sp[1]),
float(sp[2]),
float(sp[3]),
float(sp[4])
],
bbox_label=int(sp[0]) - 1, # labels begin from 1
ignore_flag=0,
is_group_ofs=True if int(sp[5]) == 1 else False))
i += img_gt_size
data_list.append(
dict(
img_path=osp.join(self.data_prefix['img'], filename),
instances=instances,
))
# add image metas to data list
img_metas = load(
self.meta_file, file_format='pkl', backend_args=self.backend_args)
assert len(img_metas) == len(data_list)
for i, meta in enumerate(img_metas):
img_id = osp.split(data_list[i]['img_path'])[-1][:-4]
assert img_id == osp.split(meta['filename'])[-1][:-4]
h, w = meta['ori_shape'][:2]
data_list[i]['height'] = h
data_list[i]['width'] = w
data_list[i]['img_id'] = img_id
# denormalize bboxes
for j in range(len(data_list[i]['instances'])):
data_list[i]['instances'][j]['bbox'][0] *= w
data_list[i]['instances'][j]['bbox'][2] *= w
data_list[i]['instances'][j]['bbox'][1] *= h
data_list[i]['instances'][j]['bbox'][3] *= h
# add image-level annotation
if img_level_anns is not None:
img_labels = []
confidences = []
img_ann_list = img_level_anns.get(img_id, [])
for ann in img_ann_list:
img_labels.append(int(ann['image_level_label']))
confidences.append(float(ann['confidence']))
data_list[i]['image_level_labels'] = np.array(
img_labels, dtype=np.int64)
data_list[i]['confidences'] = np.array(
confidences, dtype=np.float32)
return data_list
def _parse_label_file(self, label_file: str) -> tuple:
"""Get classes name and index mapping from cls-label-description file.
Args:
label_file (str): File path of the label description file that
maps the classes names in MID format to their short
descriptions.
Returns:
tuple: Class name of OpenImages.
"""
label_list = []
id_list = []
index_mapping = {}
with get_local_path(
label_file, backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
for line in reader:
label_name = line[0]
label_id = int(line[2])
label_list.append(line[1])
id_list.append(label_id)
index_mapping[label_name] = label_id - 1
indexes = np.argsort(id_list)
classes_names = []
for index in indexes:
classes_names.append(label_list[index])
return classes_names, index_mapping
def _parse_img_level_ann(self, image_level_ann_file):
"""Parse image level annotations from csv style ann_file.
Args:
image_level_ann_file (str): CSV style image level annotation
file path.
Returns:
defaultdict[list[dict]]: Annotations where item of the defaultdict
indicates an image, each of which has (n) dicts.
Keys of dicts are:
- `image_level_label` (int): of shape 1.
- `confidence` (float): of shape 1.
"""
item_lists = defaultdict(list)
with get_local_path(
image_level_ann_file,
backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
i = -1
for line in reader:
i += 1
if i == 0:
continue
else:
img_id = line[0]
label_id = line[1]
assert label_id in self.label_id_mapping
image_level_label = int(
self.label_id_mapping[label_id])
confidence = float(line[2])
item_lists[img_id].append(
dict(
image_level_label=image_level_label,
confidence=confidence))
return item_lists
def _get_relation_matrix(self, hierarchy_file: str) -> np.ndarray:
"""Get the matrix of class hierarchy from the hierarchy file.
Args:
hierarchy_file (str): File path to the hierarchy for classes.
Returns:
np.ndarray: The matrix of the corresponding
relationship between the parent class and the child class,
of shape (class_num, class_num).
"""
with get_local_path(
hierarchy_file, backend_args=self.backend_args) as local_path:
class_label_tree = np.load(local_path, allow_pickle=True)
return class_label_tree[1:, 1:]
# Copyright (c) OpenMMLab. All rights reserved.
import collections
import os.path as osp
import random
from typing import Dict, List
import mmengine
from mmengine.dataset import BaseDataset
from mmdet.registry import DATASETS
@DATASETS.register_module()
class RefCocoDataset(BaseDataset):
"""RefCOCO dataset.
The `Refcoco` and `Refcoco+` dataset is based on
`ReferItGame: Referring to Objects in Photographs of Natural Scenes
<http://tamaraberg.com/papers/referit.pdf>`_.
The `Refcocog` dataset is based on
`Generation and Comprehension of Unambiguous Object Descriptions
<https://arxiv.org/abs/1511.02283>`_.
Args:
ann_file (str): Annotation file path.
data_root (str): The root directory for ``data_prefix`` and
``ann_file``. Defaults to ''.
data_prefix (str): Prefix for training data.
split_file (str): Split file path.
split (str): Split name. Defaults to 'train'.
text_mode (str): Text mode. Defaults to 'random'.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def __init__(self,
data_root: str,
ann_file: str,
split_file: str,
data_prefix: Dict,
split: str = 'train',
text_mode: str = 'random',
**kwargs):
self.split_file = split_file
self.split = split
assert text_mode in ['original', 'random', 'concat', 'select_first']
self.text_mode = text_mode
super().__init__(
data_root=data_root,
data_prefix=data_prefix,
ann_file=ann_file,
**kwargs,
)
def _join_prefix(self):
if not mmengine.is_abs(self.split_file) and self.split_file:
self.split_file = osp.join(self.data_root, self.split_file)
return super()._join_prefix()
def _init_refs(self):
"""Initialize the refs for RefCOCO."""
anns, imgs = {}, {}
for ann in self.instances['annotations']:
anns[ann['id']] = ann
for img in self.instances['images']:
imgs[img['id']] = img
refs, ref_to_ann = {}, {}
for ref in self.splits:
# ids
ref_id = ref['ref_id']
ann_id = ref['ann_id']
# add mapping related to ref
refs[ref_id] = ref
ref_to_ann[ref_id] = anns[ann_id]
self.refs = refs
self.ref_to_ann = ref_to_ann
def load_data_list(self) -> List[dict]:
"""Load data list."""
self.splits = mmengine.load(self.split_file, file_format='pkl')
self.instances = mmengine.load(self.ann_file, file_format='json')
self._init_refs()
img_prefix = self.data_prefix['img_path']
ref_ids = [
ref['ref_id'] for ref in self.splits if ref['split'] == self.split
]
full_anno = []
for ref_id in ref_ids:
ref = self.refs[ref_id]
ann = self.ref_to_ann[ref_id]
ann.update(ref)
full_anno.append(ann)
image_id_list = []
final_anno = {}
for anno in full_anno:
image_id_list.append(anno['image_id'])
final_anno[anno['ann_id']] = anno
annotations = [value for key, value in final_anno.items()]
coco_train_id = []
image_annot = {}
for i in range(len(self.instances['images'])):
coco_train_id.append(self.instances['images'][i]['id'])
image_annot[self.instances['images'][i]
['id']] = self.instances['images'][i]
images = []
for image_id in list(set(image_id_list)):
images += [image_annot[image_id]]
data_list = []
grounding_dict = collections.defaultdict(list)
for anno in annotations:
image_id = int(anno['image_id'])
grounding_dict[image_id].append(anno)
join_path = mmengine.fileio.get_file_backend(img_prefix).join_path
for image in images:
img_id = image['id']
instances = []
sentences = []
for grounding_anno in grounding_dict[img_id]:
texts = [x['raw'].lower() for x in grounding_anno['sentences']]
# random select one text
if self.text_mode == 'random':
idx = random.randint(0, len(texts) - 1)
text = [texts[idx]]
# concat all texts
elif self.text_mode == 'concat':
text = [''.join(texts)]
# select the first text
elif self.text_mode == 'select_first':
text = [texts[0]]
# use all texts
elif self.text_mode == 'original':
text = texts
else:
raise ValueError(f'Invalid text mode "{self.text_mode}".')
ins = [{
'mask': grounding_anno['segmentation'],
'ignore_flag': 0
}] * len(text)
instances.extend(ins)
sentences.extend(text)
data_info = {
'img_path': join_path(img_prefix, image['file_name']),
'img_id': img_id,
'instances': instances,
'text': sentences
}
data_list.append(data_info)
if len(data_list) == 0:
raise ValueError(f'No sample in split "{self.split}".')
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from collections import defaultdict
from typing import Any, Dict, List
import numpy as np
from mmengine.dataset import BaseDataset
from mmengine.utils import check_file_exist
from mmdet.registry import DATASETS
@DATASETS.register_module()
class ReIDDataset(BaseDataset):
"""Dataset for ReID.
Args:
triplet_sampler (dict, optional): The sampler for hard mining
triplet loss. Defaults to None.
keys: num_ids (int): The number of person ids.
ins_per_id (int): The number of image for each person.
"""
def __init__(self, triplet_sampler: dict = None, *args, **kwargs):
self.triplet_sampler = triplet_sampler
super().__init__(*args, **kwargs)
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ''self.ann_file''.
Returns:
list[dict]: A list of annotation.
"""
assert isinstance(self.ann_file, str)
check_file_exist(self.ann_file)
data_list = []
with open(self.ann_file) as f:
samples = [x.strip().split(' ') for x in f.readlines()]
for filename, gt_label in samples:
info = dict(img_prefix=self.data_prefix)
if self.data_prefix['img_path'] is not None:
info['img_path'] = osp.join(self.data_prefix['img_path'],
filename)
else:
info['img_path'] = filename
info['gt_label'] = np.array(gt_label, dtype=np.int64)
data_list.append(info)
self._parse_ann_info(data_list)
return data_list
def _parse_ann_info(self, data_list: List[dict]):
"""Parse person id annotations."""
index_tmp_dic = defaultdict(list) # pid->[idx1,...,idxN]
self.index_dic = dict() # pid->array([idx1,...,idxN])
for idx, info in enumerate(data_list):
pid = info['gt_label']
index_tmp_dic[int(pid)].append(idx)
for pid, idxs in index_tmp_dic.items():
self.index_dic[pid] = np.asarray(idxs, dtype=np.int64)
self.pids = np.asarray(list(self.index_dic.keys()), dtype=np.int64)
def prepare_data(self, idx: int) -> Any:
"""Get data processed by ''self.pipeline''.
Args:
idx (int): The index of ''data_info''
Returns:
Any: Depends on ''self.pipeline''
"""
data_info = self.get_data_info(idx)
if self.triplet_sampler is not None:
img_info = self.triplet_sampling(data_info['gt_label'],
**self.triplet_sampler)
data_info = copy.deepcopy(img_info) # triplet -> list
else:
data_info = copy.deepcopy(data_info) # no triplet -> dict
return self.pipeline(data_info)
def triplet_sampling(self,
pos_pid,
num_ids: int = 8,
ins_per_id: int = 4) -> Dict:
"""Triplet sampler for hard mining triplet loss. First, for one
pos_pid, random sample ins_per_id images with same person id.
Then, random sample num_ids - 1 images for each negative id.
Finally, random sample ins_per_id images for each negative id.
Args:
pos_pid (ndarray): The person id of the anchor.
num_ids (int): The number of person ids.
ins_per_id (int): The number of images for each person.
Returns:
Dict: Annotation information of num_ids X ins_per_id images.
"""
assert len(self.pids) >= num_ids, \
'The number of person ids in the training set must ' \
'be greater than the number of person ids in the sample.'
pos_idxs = self.index_dic[int(
pos_pid)] # all positive idxs for pos_pid
idxs_list = []
# select positive samplers
idxs_list.extend(pos_idxs[np.random.choice(
pos_idxs.shape[0], ins_per_id, replace=True)])
# select negative ids
neg_pids = np.random.choice(
[i for i, _ in enumerate(self.pids) if i != pos_pid],
num_ids - 1,
replace=False)
# select negative samplers for each negative id
for neg_pid in neg_pids:
neg_idxs = self.index_dic[neg_pid]
idxs_list.extend(neg_idxs[np.random.choice(
neg_idxs.shape[0], ins_per_id, replace=True)])
# return the final triplet batch
triplet_img_infos = []
for idx in idxs_list:
triplet_img_infos.append(copy.deepcopy(self.get_data_info(idx)))
# Collect data_list scatters (list of dict -> dict of list)
out = dict()
for key in triplet_img_infos[0].keys():
out[key] = [_info[key] for _info in triplet_img_infos]
return out
# Copyright (c) OpenMMLab. All rights reserved.
from .batch_sampler import (AspectRatioBatchSampler,
MultiDataAspectRatioBatchSampler,
TrackAspectRatioBatchSampler)
from .class_aware_sampler import ClassAwareSampler
from .custom_sample_size_sampler import CustomSampleSizeSampler
from .multi_data_sampler import MultiDataSampler
from .multi_source_sampler import GroupMultiSourceSampler, MultiSourceSampler
from .track_img_sampler import TrackImgSampler
__all__ = [
'ClassAwareSampler', 'AspectRatioBatchSampler', 'MultiSourceSampler',
'GroupMultiSourceSampler', 'TrackImgSampler',
'TrackAspectRatioBatchSampler', 'MultiDataSampler',
'MultiDataAspectRatioBatchSampler', 'CustomSampleSizeSampler'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
from torch.utils.data import BatchSampler, Sampler
from mmdet.datasets.samplers.track_img_sampler import TrackImgSampler
from mmdet.registry import DATA_SAMPLERS
# TODO: maybe replace with a data_loader wrapper
@DATA_SAMPLERS.register_module()
class AspectRatioBatchSampler(BatchSampler):
"""A sampler wrapper for grouping images with similar aspect ratio (< 1 or.
>= 1) into a same batch.
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``.
"""
def __init__(self,
sampler: Sampler,
batch_size: int,
drop_last: bool = False) -> None:
if not isinstance(sampler, Sampler):
raise TypeError('sampler should be an instance of ``Sampler``, '
f'but got {sampler}')
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError('batch_size should be a positive integer value, '
f'but got batch_size={batch_size}')
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
# two groups for w < h and w >= h
self._aspect_ratio_buckets = [[] for _ in range(2)]
def __iter__(self) -> Sequence[int]:
for idx in self.sampler:
data_info = self.sampler.dataset.get_data_info(idx)
width, height = data_info['width'], data_info['height']
bucket_id = 0 if width < height else 1
bucket = self._aspect_ratio_buckets[bucket_id]
bucket.append(idx)
# yield a batch of indices in the same aspect ratio group
if len(bucket) == self.batch_size:
yield bucket[:]
del bucket[:]
# yield the rest data and reset the bucket
left_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[
1]
self._aspect_ratio_buckets = [[] for _ in range(2)]
while len(left_data) > 0:
if len(left_data) <= self.batch_size:
if not self.drop_last:
yield left_data[:]
left_data = []
else:
yield left_data[:self.batch_size]
left_data = left_data[self.batch_size:]
def __len__(self) -> int:
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
@DATA_SAMPLERS.register_module()
class TrackAspectRatioBatchSampler(AspectRatioBatchSampler):
"""A sampler wrapper for grouping images with similar aspect ratio (< 1 or.
>= 1) into a same batch.
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``.
"""
def __iter__(self) -> Sequence[int]:
for idx in self.sampler:
# hard code to solve TrackImgSampler
if isinstance(self.sampler, TrackImgSampler):
video_idx, _ = idx
else:
video_idx = idx
# video_idx
data_info = self.sampler.dataset.get_data_info(video_idx)
# data_info {video_id, images, video_length}
img_data_info = data_info['images'][0]
width, height = img_data_info['width'], img_data_info['height']
bucket_id = 0 if width < height else 1
bucket = self._aspect_ratio_buckets[bucket_id]
bucket.append(idx)
# yield a batch of indices in the same aspect ratio group
if len(bucket) == self.batch_size:
yield bucket[:]
del bucket[:]
# yield the rest data and reset the bucket
left_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[
1]
self._aspect_ratio_buckets = [[] for _ in range(2)]
while len(left_data) > 0:
if len(left_data) <= self.batch_size:
if not self.drop_last:
yield left_data[:]
left_data = []
else:
yield left_data[:self.batch_size]
left_data = left_data[self.batch_size:]
@DATA_SAMPLERS.register_module()
class MultiDataAspectRatioBatchSampler(BatchSampler):
"""A sampler wrapper for grouping images with similar aspect ratio (< 1 or.
>= 1) into a same batch for multi-source datasets.
Args:
sampler (Sampler): Base sampler.
batch_size (Sequence(int)): Size of mini-batch for multi-source
datasets.
num_datasets(int): Number of multi-source datasets.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``.
"""
def __init__(self,
sampler: Sampler,
batch_size: Sequence[int],
num_datasets: int,
drop_last: bool = True) -> None:
if not isinstance(sampler, Sampler):
raise TypeError('sampler should be an instance of ``Sampler``, '
f'but got {sampler}')
self.sampler = sampler
self.batch_size = batch_size
self.num_datasets = num_datasets
self.drop_last = drop_last
# two groups for w < h and w >= h for each dataset --> 2 * num_datasets
self._buckets = [[] for _ in range(2 * self.num_datasets)]
def __iter__(self) -> Sequence[int]:
for idx in self.sampler:
data_info = self.sampler.dataset.get_data_info(idx)
width, height = data_info['width'], data_info['height']
dataset_source_idx = self.sampler.dataset.get_dataset_source(idx)
aspect_ratio_bucket_id = 0 if width < height else 1
bucket_id = dataset_source_idx * 2 + aspect_ratio_bucket_id
bucket = self._buckets[bucket_id]
bucket.append(idx)
# yield a batch of indices in the same aspect ratio group
if len(bucket) == self.batch_size[dataset_source_idx]:
yield bucket[:]
del bucket[:]
# yield the rest data and reset the bucket
for i in range(self.num_datasets):
left_data = self._buckets[i * 2 + 0] + self._buckets[i * 2 + 1]
while len(left_data) > 0:
if len(left_data) <= self.batch_size[i]:
if not self.drop_last:
yield left_data[:]
left_data = []
else:
yield left_data[:self.batch_size[i]]
left_data = left_data[self.batch_size[i]:]
self._buckets = [[] for _ in range(2 * self.num_datasets)]
def __len__(self) -> int:
sizes = [0 for _ in range(self.num_datasets)]
for idx in self.sampler:
dataset_source_idx = self.sampler.dataset.get_dataset_source(idx)
sizes[dataset_source_idx] += 1
if self.drop_last:
lens = 0
for i in range(self.num_datasets):
lens += sizes[i] // self.batch_size[i]
return lens
else:
lens = 0
for i in range(self.num_datasets):
lens += (sizes[i] + self.batch_size[i] -
1) // self.batch_size[i]
return lens
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Dict, Iterator, Optional, Union
import numpy as np
import torch
from mmengine.dataset import BaseDataset
from mmengine.dist import get_dist_info, sync_random_seed
from torch.utils.data import Sampler
from mmdet.registry import DATA_SAMPLERS
@DATA_SAMPLERS.register_module()
class ClassAwareSampler(Sampler):
r"""Sampler that restricts data loading to the label of the dataset.
A class-aware sampling strategy to effectively tackle the
non-uniform class distribution. The length of the training data is
consistent with source data. Simple improvements based on `Relay
Backpropagation for Effective Learning of Deep Convolutional
Neural Networks <https://arxiv.org/abs/1512.05830>`_
The implementation logic is referred to
https://github.com/Sense-X/TSD/blob/master/mmdet/datasets/samplers/distributed_classaware_sampler.py
Args:
dataset: Dataset used for sampling.
seed (int, optional): random seed used to shuffle the sampler.
This number should be identical across all
processes in the distributed group. Defaults to None.
num_sample_class (int): The number of samples taken from each
per-label list. Defaults to 1.
"""
def __init__(self,
dataset: BaseDataset,
seed: Optional[int] = None,
num_sample_class: int = 1) -> None:
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
self.epoch = 0
# Must be the same across all workers. If None, will use a
# random seed shared among workers
# (require synchronization among all workers)
if seed is None:
seed = sync_random_seed()
self.seed = seed
# The number of samples taken from each per-label list
assert num_sample_class > 0 and isinstance(num_sample_class, int)
self.num_sample_class = num_sample_class
# Get per-label image list from dataset
self.cat_dict = self.get_cat2imgs()
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / world_size))
self.total_size = self.num_samples * self.world_size
# get number of images containing each category
self.num_cat_imgs = [len(x) for x in self.cat_dict.values()]
# filter labels without images
self.valid_cat_inds = [
i for i, length in enumerate(self.num_cat_imgs) if length != 0
]
self.num_classes = len(self.valid_cat_inds)
def get_cat2imgs(self) -> Dict[int, list]:
"""Get a dict with class as key and img_ids as values.
Returns:
dict[int, list]: A dict of per-label image list,
the item of the dict indicates a label index,
corresponds to the image index that contains the label.
"""
classes = self.dataset.metainfo.get('classes', None)
if classes is None:
raise ValueError('dataset metainfo must contain `classes`')
# sort the label index
cat2imgs = {i: [] for i in range(len(classes))}
for i in range(len(self.dataset)):
cat_ids = set(self.dataset.get_cat_ids(i))
for cat in cat_ids:
cat2imgs[cat].append(i)
return cat2imgs
def __iter__(self) -> Iterator[int]:
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch + self.seed)
# initialize label list
label_iter_list = RandomCycleIter(self.valid_cat_inds, generator=g)
# initialize each per-label image list
data_iter_dict = dict()
for i in self.valid_cat_inds:
data_iter_dict[i] = RandomCycleIter(self.cat_dict[i], generator=g)
def gen_cat_img_inds(cls_list, data_dict, num_sample_cls):
"""Traverse the categories and extract `num_sample_cls` image
indexes of the corresponding categories one by one."""
id_indices = []
for _ in range(len(cls_list)):
cls_idx = next(cls_list)
for _ in range(num_sample_cls):
id = next(data_dict[cls_idx])
id_indices.append(id)
return id_indices
# deterministically shuffle based on epoch
num_bins = int(
math.ceil(self.total_size * 1.0 / self.num_classes /
self.num_sample_class))
indices = []
for i in range(num_bins):
indices += gen_cat_img_inds(label_iter_list, data_iter_dict,
self.num_sample_class)
# fix extra samples to make it evenly divisible
if len(indices) >= self.total_size:
indices = indices[:self.total_size]
else:
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
offset = self.num_samples * self.rank
indices = indices[offset:offset + self.num_samples]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self) -> int:
"""The number of samples in this rank."""
return self.num_samples
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all replicas use a different
random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
class RandomCycleIter:
"""Shuffle the list and do it again after the list have traversed.
The implementation logic is referred to
https://github.com/wutong16/DistributionBalancedLoss/blob/master/mllt/datasets/loader/sampler.py
Example:
>>> label_list = [0, 1, 2, 4, 5]
>>> g = torch.Generator()
>>> g.manual_seed(0)
>>> label_iter_list = RandomCycleIter(label_list, generator=g)
>>> index = next(label_iter_list)
Args:
data (list or ndarray): The data that needs to be shuffled.
generator: An torch.Generator object, which is used in setting the seed
for generating random numbers.
""" # noqa: W605
def __init__(self,
data: Union[list, np.ndarray],
generator: torch.Generator = None) -> None:
self.data = data
self.length = len(data)
self.index = torch.randperm(self.length, generator=generator).numpy()
self.i = 0
self.generator = generator
def __iter__(self) -> Iterator:
return self
def __len__(self) -> int:
return len(self.data)
def __next__(self):
if self.i == self.length:
self.index = torch.randperm(
self.length, generator=self.generator).numpy()
self.i = 0
idx = self.data[self.index[self.i]]
self.i += 1
return idx
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Iterator, Optional, Sequence, Sized
import torch
from mmengine.dist import get_dist_info, sync_random_seed
from torch.utils.data import Sampler
from mmdet.registry import DATA_SAMPLERS
from .class_aware_sampler import RandomCycleIter
@DATA_SAMPLERS.register_module()
class CustomSampleSizeSampler(Sampler):
def __init__(self,
dataset: Sized,
dataset_size: Sequence[int],
ratio_mode: bool = False,
seed: Optional[int] = None,
round_up: bool = True) -> None:
assert len(dataset.datasets) == len(dataset_size)
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
if seed is None:
seed = sync_random_seed()
self.seed = seed
self.epoch = 0
self.round_up = round_up
total_size = 0
total_size_fake = 0
self.dataset_index = []
self.dataset_cycle_iter = []
new_dataset_size = []
for dataset, size in zip(dataset.datasets, dataset_size):
self.dataset_index.append(
list(range(total_size_fake,
len(dataset) + total_size_fake)))
total_size_fake += len(dataset)
if size == -1:
total_size += len(dataset)
self.dataset_cycle_iter.append(None)
new_dataset_size.append(-1)
else:
if ratio_mode:
size = int(size * len(dataset))
assert size <= len(
dataset
), f'dataset size {size} is larger than ' \
f'dataset length {len(dataset)}'
total_size += size
new_dataset_size.append(size)
g = torch.Generator()
g.manual_seed(self.seed)
self.dataset_cycle_iter.append(
RandomCycleIter(self.dataset_index[-1], generator=g))
self.dataset_size = new_dataset_size
if self.round_up:
self.num_samples = math.ceil(total_size / world_size)
self.total_size = self.num_samples * self.world_size
else:
self.num_samples = math.ceil((total_size - rank) / world_size)
self.total_size = total_size
def __iter__(self) -> Iterator[int]:
"""Iterate the indices."""
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
out_index = []
for data_size, data_index, cycle_iter in zip(self.dataset_size,
self.dataset_index,
self.dataset_cycle_iter):
if data_size == -1:
out_index += data_index
else:
index = [next(cycle_iter) for _ in range(data_size)]
out_index += index
index = torch.randperm(len(out_index), generator=g).numpy().tolist()
indices = [out_index[i] for i in index]
if self.round_up:
indices = (
indices *
int(self.total_size / len(indices) + 1))[:self.total_size]
indices = indices[self.rank:self.total_size:self.world_size]
return iter(indices)
def __len__(self) -> int:
"""The number of samples in this rank."""
return self.num_samples
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all replicas use a different
random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Iterator, Optional, Sequence, Sized
import torch
from mmengine.dist import get_dist_info, sync_random_seed
from mmengine.registry import DATA_SAMPLERS
from torch.utils.data import Sampler
@DATA_SAMPLERS.register_module()
class MultiDataSampler(Sampler):
"""The default data sampler for both distributed and non-distributed
environment.
It has several differences from the PyTorch ``DistributedSampler`` as
below:
1. This sampler supports non-distributed environment.
2. The round up behaviors are a little different.
- If ``round_up=True``, this sampler will add extra samples to make the
number of samples is evenly divisible by the world size. And
this behavior is the same as the ``DistributedSampler`` with
``drop_last=False``.
- If ``round_up=False``, this sampler won't remove or add any samples
while the ``DistributedSampler`` with ``drop_last=True`` will remove
tail samples.
Args:
dataset (Sized): The dataset.
dataset_ratio (Sequence(int)) The ratios of different datasets.
seed (int, optional): Random seed used to shuffle the sampler if
:attr:`shuffle=True`. This number should be identical across all
processes in the distributed group. Defaults to None.
round_up (bool): Whether to add extra samples to make the number of
samples evenly divisible by the world size. Defaults to True.
"""
def __init__(self,
dataset: Sized,
dataset_ratio: Sequence[int],
seed: Optional[int] = None,
round_up: bool = True) -> None:
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
self.dataset_ratio = dataset_ratio
if seed is None:
seed = sync_random_seed()
self.seed = seed
self.epoch = 0
self.round_up = round_up
if self.round_up:
self.num_samples = math.ceil(len(self.dataset) / world_size)
self.total_size = self.num_samples * self.world_size
else:
self.num_samples = math.ceil(
(len(self.dataset) - rank) / world_size)
self.total_size = len(self.dataset)
self.sizes = [len(dataset) for dataset in self.dataset.datasets]
dataset_weight = [
torch.ones(s) * max(self.sizes) / s * r / sum(self.dataset_ratio)
for i, (r, s) in enumerate(zip(self.dataset_ratio, self.sizes))
]
self.weights = torch.cat(dataset_weight)
def __iter__(self) -> Iterator[int]:
"""Iterate the indices."""
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.multinomial(
self.weights, len(self.weights), generator=g,
replacement=True).tolist()
# add extra samples to make it evenly divisible
if self.round_up:
indices = (
indices *
int(self.total_size / len(indices) + 1))[:self.total_size]
# subsample
indices = indices[self.rank:self.total_size:self.world_size]
return iter(indices)
def __len__(self) -> int:
"""The number of samples in this rank."""
return self.num_samples
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all replicas use a different
random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from typing import Iterator, List, Optional, Sized, Union
import numpy as np
import torch
from mmengine.dataset import BaseDataset
from mmengine.dist import get_dist_info, sync_random_seed
from torch.utils.data import Sampler
from mmdet.registry import DATA_SAMPLERS
@DATA_SAMPLERS.register_module()
class MultiSourceSampler(Sampler):
r"""Multi-Source Infinite Sampler.
According to the sampling ratio, sample data from different
datasets to form batches.
Args:
dataset (Sized): The dataset.
batch_size (int): Size of mini-batch.
source_ratio (list[int | float]): The sampling ratio of different
source datasets in a mini-batch.
shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
seed (int, optional): Random seed. If None, set a random seed.
Defaults to None.
Examples:
>>> dataset_type = 'ConcatDataset'
>>> sub_dataset_type = 'CocoDataset'
>>> data_root = 'data/coco/'
>>> sup_ann = '../coco_semi_annos/instances_train2017.1@10.json'
>>> unsup_ann = '../coco_semi_annos/' \
>>> 'instances_train2017.1@10-unlabeled.json'
>>> dataset = dict(type=dataset_type,
>>> datasets=[
>>> dict(
>>> type=sub_dataset_type,
>>> data_root=data_root,
>>> ann_file=sup_ann,
>>> data_prefix=dict(img='train2017/'),
>>> filter_cfg=dict(filter_empty_gt=True, min_size=32),
>>> pipeline=sup_pipeline),
>>> dict(
>>> type=sub_dataset_type,
>>> data_root=data_root,
>>> ann_file=unsup_ann,
>>> data_prefix=dict(img='train2017/'),
>>> filter_cfg=dict(filter_empty_gt=True, min_size=32),
>>> pipeline=unsup_pipeline),
>>> ])
>>> train_dataloader = dict(
>>> batch_size=5,
>>> num_workers=5,
>>> persistent_workers=True,
>>> sampler=dict(type='MultiSourceSampler',
>>> batch_size=5, source_ratio=[1, 4]),
>>> batch_sampler=None,
>>> dataset=dataset)
"""
def __init__(self,
dataset: Sized,
batch_size: int,
source_ratio: List[Union[int, float]],
shuffle: bool = True,
seed: Optional[int] = None) -> None:
assert hasattr(dataset, 'cumulative_sizes'),\
f'The dataset must be ConcatDataset, but get {dataset}'
assert isinstance(batch_size, int) and batch_size > 0, \
'batch_size must be a positive integer value, ' \
f'but got batch_size={batch_size}'
assert isinstance(source_ratio, list), \
f'source_ratio must be a list, but got source_ratio={source_ratio}'
assert len(source_ratio) == len(dataset.cumulative_sizes), \
'The length of source_ratio must be equal to ' \
f'the number of datasets, but got source_ratio={source_ratio}'
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
self.cumulative_sizes = [0] + dataset.cumulative_sizes
self.batch_size = batch_size
self.source_ratio = source_ratio
self.num_per_source = [
int(batch_size * sr / sum(source_ratio)) for sr in source_ratio
]
self.num_per_source[0] = batch_size - sum(self.num_per_source[1:])
assert sum(self.num_per_source) == batch_size, \
'The sum of num_per_source must be equal to ' \
f'batch_size, but get {self.num_per_source}'
self.seed = sync_random_seed() if seed is None else seed
self.shuffle = shuffle
self.source2inds = {
source: self._indices_of_rank(len(ds))
for source, ds in enumerate(dataset.datasets)
}
def _infinite_indices(self, sample_size: int) -> Iterator[int]:
"""Infinitely yield a sequence of indices."""
g = torch.Generator()
g.manual_seed(self.seed)
while True:
if self.shuffle:
yield from torch.randperm(sample_size, generator=g).tolist()
else:
yield from torch.arange(sample_size).tolist()
def _indices_of_rank(self, sample_size: int) -> Iterator[int]:
"""Slice the infinite indices by rank."""
yield from itertools.islice(
self._infinite_indices(sample_size), self.rank, None,
self.world_size)
def __iter__(self) -> Iterator[int]:
batch_buffer = []
while True:
for source, num in enumerate(self.num_per_source):
batch_buffer_per_source = []
for idx in self.source2inds[source]:
idx += self.cumulative_sizes[source]
batch_buffer_per_source.append(idx)
if len(batch_buffer_per_source) == num:
batch_buffer += batch_buffer_per_source
break
yield from batch_buffer
batch_buffer = []
def __len__(self) -> int:
return len(self.dataset)
def set_epoch(self, epoch: int) -> None:
"""Not supported in `epoch-based runner."""
pass
@DATA_SAMPLERS.register_module()
class GroupMultiSourceSampler(MultiSourceSampler):
r"""Group Multi-Source Infinite Sampler.
According to the sampling ratio, sample data from different
datasets but the same group to form batches.
Args:
dataset (Sized): The dataset.
batch_size (int): Size of mini-batch.
source_ratio (list[int | float]): The sampling ratio of different
source datasets in a mini-batch.
shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
seed (int, optional): Random seed. If None, set a random seed.
Defaults to None.
"""
def __init__(self,
dataset: BaseDataset,
batch_size: int,
source_ratio: List[Union[int, float]],
shuffle: bool = True,
seed: Optional[int] = None) -> None:
super().__init__(
dataset=dataset,
batch_size=batch_size,
source_ratio=source_ratio,
shuffle=shuffle,
seed=seed)
self._get_source_group_info()
self.group_source2inds = [{
source:
self._indices_of_rank(self.group2size_per_source[source][group])
for source in range(len(dataset.datasets))
} for group in range(len(self.group_ratio))]
def _get_source_group_info(self) -> None:
self.group2size_per_source = [{0: 0, 1: 0}, {0: 0, 1: 0}]
self.group2inds_per_source = [{0: [], 1: []}, {0: [], 1: []}]
for source, dataset in enumerate(self.dataset.datasets):
for idx in range(len(dataset)):
data_info = dataset.get_data_info(idx)
width, height = data_info['width'], data_info['height']
group = 0 if width < height else 1
self.group2size_per_source[source][group] += 1
self.group2inds_per_source[source][group].append(idx)
self.group_sizes = np.zeros(2, dtype=np.int64)
for group2size in self.group2size_per_source:
for group, size in group2size.items():
self.group_sizes[group] += size
self.group_ratio = self.group_sizes / sum(self.group_sizes)
def __iter__(self) -> Iterator[int]:
batch_buffer = []
while True:
group = np.random.choice(
list(range(len(self.group_ratio))), p=self.group_ratio)
for source, num in enumerate(self.num_per_source):
batch_buffer_per_source = []
for idx in self.group_source2inds[group][source]:
idx = self.group2inds_per_source[source][group][
idx] + self.cumulative_sizes[source]
batch_buffer_per_source.append(idx)
if len(batch_buffer_per_source) == num:
batch_buffer += batch_buffer_per_source
break
yield from batch_buffer
batch_buffer = []
# Copyright (c) OpenMMLab. All rights reserved.
import math
import random
from typing import Iterator, Optional, Sized
import numpy as np
from mmengine.dataset import ClassBalancedDataset, ConcatDataset
from mmengine.dist import get_dist_info, sync_random_seed
from torch.utils.data import Sampler
from mmdet.registry import DATA_SAMPLERS
from ..base_video_dataset import BaseVideoDataset
@DATA_SAMPLERS.register_module()
class TrackImgSampler(Sampler):
"""Sampler that providing image-level sampling outputs for video datasets
in tracking tasks. It could be both used in both distributed and
non-distributed environment.
If using the default sampler in pytorch, the subsequent data receiver will
get one video, which is not desired in some cases:
(Take a non-distributed environment as an example)
1. In test mode, we want only one image is fed into the data pipeline. This
is in consideration of memory usage since feeding the whole video commonly
requires a large amount of memory (>=20G on MOTChallenge17 dataset), which
is not available in some machines.
2. In training mode, we may want to make sure all the images in one video
are randomly sampled once in one epoch and this can not be guaranteed in
the default sampler in pytorch.
Args:
dataset (Sized): Dataset used for sampling.
seed (int, optional): random seed used to shuffle the sampler. This
number should be identical across all processes in the distributed
group. Defaults to None.
"""
def __init__(
self,
dataset: Sized,
seed: Optional[int] = None,
) -> None:
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.epoch = 0
if seed is None:
self.seed = sync_random_seed()
else:
self.seed = seed
self.dataset = dataset
self.indices = []
# Hard code here to handle different dataset wrapper
if isinstance(self.dataset, ConcatDataset):
cat_datasets = self.dataset.datasets
assert isinstance(
cat_datasets[0], BaseVideoDataset
), f'expected BaseVideoDataset, but got {type(cat_datasets[0])}'
self.test_mode = cat_datasets[0].test_mode
assert not self.test_mode, "'ConcatDataset' should not exist in "
'test mode'
for dataset in cat_datasets:
num_videos = len(dataset)
for video_ind in range(num_videos):
self.indices.extend([
(video_ind, frame_ind) for frame_ind in range(
dataset.get_len_per_video(video_ind))
])
elif isinstance(self.dataset, ClassBalancedDataset):
ori_dataset = self.dataset.dataset
assert isinstance(
ori_dataset, BaseVideoDataset
), f'expected BaseVideoDataset, but got {type(ori_dataset)}'
self.test_mode = ori_dataset.test_mode
assert not self.test_mode, "'ClassBalancedDataset' should not "
'exist in test mode'
video_indices = self.dataset.repeat_indices
for index in video_indices:
self.indices.extend([(index, frame_ind) for frame_ind in range(
ori_dataset.get_len_per_video(index))])
else:
assert isinstance(
self.dataset, BaseVideoDataset
), 'TrackImgSampler is only supported in BaseVideoDataset or '
'dataset wrapper: ClassBalancedDataset and ConcatDataset, but '
f'got {type(self.dataset)} '
self.test_mode = self.dataset.test_mode
num_videos = len(self.dataset)
if self.test_mode:
# in test mode, the images belong to the same video must be put
# on the same device.
if num_videos < self.world_size:
raise ValueError(f'only {num_videos} videos loaded,'
f'but {self.world_size} gpus were given.')
chunks = np.array_split(
list(range(num_videos)), self.world_size)
for videos_inds in chunks:
indices_chunk = []
for video_ind in videos_inds:
indices_chunk.extend([
(video_ind, frame_ind) for frame_ind in range(
self.dataset.get_len_per_video(video_ind))
])
self.indices.append(indices_chunk)
else:
for video_ind in range(num_videos):
self.indices.extend([
(video_ind, frame_ind) for frame_ind in range(
self.dataset.get_len_per_video(video_ind))
])
if self.test_mode:
self.num_samples = len(self.indices[self.rank])
self.total_size = sum(
[len(index_list) for index_list in self.indices])
else:
self.num_samples = int(
math.ceil(len(self.indices) * 1.0 / self.world_size))
self.total_size = self.num_samples * self.world_size
def __iter__(self) -> Iterator:
if self.test_mode:
# in test mode, the order of frames can not be shuffled.
indices = self.indices[self.rank]
else:
# deterministically shuffle based on epoch
rng = random.Random(self.epoch + self.seed)
indices = rng.sample(self.indices, len(self.indices))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.world_size]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
# Copyright (c) OpenMMLab. All rights reserved.
from .augment_wrappers import AutoAugment, RandAugment
from .colorspace import (AutoContrast, Brightness, Color, ColorTransform,
Contrast, Equalize, Invert, Posterize, Sharpness,
Solarize, SolarizeAdd)
from .formatting import (ImageToTensor, PackDetInputs, PackReIDInputs,
PackTrackInputs, ToTensor, Transpose)
from .frame_sampling import BaseFrameSample, UniformRefFrameSample
from .geometric import (GeomTransform, Rotate, ShearX, ShearY, TranslateX,
TranslateY)
from .instaboost import InstaBoost
from .loading import (FilterAnnotations, InferencerLoader, LoadAnnotations,
LoadEmptyAnnotations, LoadImageFromNDArray,
LoadMultiChannelImageFromFiles, LoadPanopticAnnotations,
LoadProposals, LoadTrackAnnotations)
from .text_transformers import LoadTextAnnotations, RandomSamplingNegPos
from .transformers_glip import GTBoxSubOne_GLIP, RandomFlip_GLIP
from .transforms import (Albu, CachedMixUp, CachedMosaic, CopyPaste, CutOut,
Expand, FixScaleResize, FixShapeResize,
MinIoURandomCrop, MixUp, Mosaic, Pad,
PhotoMetricDistortion, RandomAffine,
RandomCenterCropPad, RandomCrop, RandomErasing,
RandomFlip, RandomShift, Resize, ResizeShortestEdge,
SegRescale, YOLOXHSVRandomAug)
from .wrappers import MultiBranch, ProposalBroadcaster, RandomOrder
__all__ = [
'PackDetInputs', 'ToTensor', 'ImageToTensor', 'Transpose',
'LoadImageFromNDArray', 'LoadAnnotations', 'LoadPanopticAnnotations',
'LoadMultiChannelImageFromFiles', 'LoadProposals', 'Resize', 'RandomFlip',
'RandomCrop', 'SegRescale', 'MinIoURandomCrop', 'Expand',
'PhotoMetricDistortion', 'Albu', 'InstaBoost', 'RandomCenterCropPad',
'AutoAugment', 'CutOut', 'ShearX', 'ShearY', 'Rotate', 'Color', 'Equalize',
'Brightness', 'Contrast', 'TranslateX', 'TranslateY', 'RandomShift',
'Mosaic', 'MixUp', 'RandomAffine', 'YOLOXHSVRandomAug', 'CopyPaste',
'FilterAnnotations', 'Pad', 'GeomTransform', 'ColorTransform',
'RandAugment', 'Sharpness', 'Solarize', 'SolarizeAdd', 'Posterize',
'AutoContrast', 'Invert', 'MultiBranch', 'RandomErasing',
'LoadEmptyAnnotations', 'RandomOrder', 'CachedMosaic', 'CachedMixUp',
'FixShapeResize', 'ProposalBroadcaster', 'InferencerLoader',
'LoadTrackAnnotations', 'BaseFrameSample', 'UniformRefFrameSample',
'PackTrackInputs', 'PackReIDInputs', 'FixScaleResize',
'ResizeShortestEdge', 'GTBoxSubOne_GLIP', 'RandomFlip_GLIP',
'RandomSamplingNegPos', 'LoadTextAnnotations'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import numpy as np
from mmcv.transforms import RandomChoice
from mmcv.transforms.utils import cache_randomness
from mmengine.config import ConfigDict
from mmdet.registry import TRANSFORMS
# AutoAugment uses reinforcement learning to search for
# some widely useful data augmentation strategies,
# here we provide AUTOAUG_POLICIES_V0.
# For AUTOAUG_POLICIES_V0, each tuple is an augmentation
# operation of the form (operation, probability, magnitude).
# Each element in policies is a policy that will be applied
# sequentially on the image.
# RandAugment defines a data augmentation search space, RANDAUG_SPACE,
# sampling 1~3 data augmentations each time, and
# setting the magnitude of each data augmentation randomly,
# which will be applied sequentially on the image.
_MAX_LEVEL = 10
AUTOAUG_POLICIES_V0 = [
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
[('Rotate', 1.0, 7), ('TranslateY', 0.8, 9)],
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
]
def policies_v0():
"""Autoaugment policies that was used in AutoAugment Paper."""
policies = list()
for policy_args in AUTOAUG_POLICIES_V0:
policy = list()
for args in policy_args:
policy.append(dict(type=args[0], prob=args[1], level=args[2]))
policies.append(policy)
return policies
RANDAUG_SPACE = [[dict(type='AutoContrast')], [dict(type='Equalize')],
[dict(type='Invert')], [dict(type='Rotate')],
[dict(type='Posterize')], [dict(type='Solarize')],
[dict(type='SolarizeAdd')], [dict(type='Color')],
[dict(type='Contrast')], [dict(type='Brightness')],
[dict(type='Sharpness')], [dict(type='ShearX')],
[dict(type='ShearY')], [dict(type='TranslateX')],
[dict(type='TranslateY')]]
def level_to_mag(level: Optional[int], min_mag: float,
max_mag: float) -> float:
"""Map from level to magnitude."""
if level is None:
return round(np.random.rand() * (max_mag - min_mag) + min_mag, 1)
else:
return round(level / _MAX_LEVEL * (max_mag - min_mag) + min_mag, 1)
@TRANSFORMS.register_module()
class AutoAugment(RandomChoice):
"""Auto augmentation.
This data augmentation is proposed in `AutoAugment: Learning
Augmentation Policies from Data <https://arxiv.org/abs/1805.09501>`_
and in `Learning Data Augmentation Strategies for Object Detection
<https://arxiv.org/pdf/1906.11172>`_.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_bboxes_labels (np.int64) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_ignore_flags (bool) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- img_shape
- gt_bboxes
- gt_bboxes_labels
- gt_masks
- gt_ignore_flags
- gt_seg_map
Added Keys:
- homography_matrix
Args:
policies (List[List[Union[dict, ConfigDict]]]):
The policies of auto augmentation.Each policy in ``policies``
is a specific augmentation policy, and is composed by several
augmentations. When AutoAugment is called, a random policy in
``policies`` will be selected to augment images.
Defaults to policy_v0().
prob (list[float], optional): The probabilities associated
with each policy. The length should be equal to the policy
number and the sum should be 1. If not given, a uniform
distribution will be assumed. Defaults to None.
Examples:
>>> policies = [
>>> [
>>> dict(type='Sharpness', prob=0.0, level=8),
>>> dict(type='ShearX', prob=0.4, level=0,)
>>> ],
>>> [
>>> dict(type='Rotate', prob=0.6, level=10),
>>> dict(type='Color', prob=1.0, level=6)
>>> ]
>>> ]
>>> augmentation = AutoAugment(policies)
>>> img = np.ones(100, 100, 3)
>>> gt_bboxes = np.ones(10, 4)
>>> results = dict(img=img, gt_bboxes=gt_bboxes)
>>> results = augmentation(results)
"""
def __init__(self,
policies: List[List[Union[dict, ConfigDict]]] = policies_v0(),
prob: Optional[List[float]] = None) -> None:
assert isinstance(policies, list) and len(policies) > 0, \
'Policies must be a non-empty list.'
for policy in policies:
assert isinstance(policy, list) and len(policy) > 0, \
'Each policy in policies must be a non-empty list.'
for augment in policy:
assert isinstance(augment, dict) and 'type' in augment, \
'Each specific augmentation must be a dict with key' \
' "type".'
super().__init__(transforms=policies, prob=prob)
self.policies = policies
def __repr__(self) -> str:
return f'{self.__class__.__name__}(policies={self.policies}, ' \
f'prob={self.prob})'
@TRANSFORMS.register_module()
class RandAugment(RandomChoice):
"""Rand augmentation.
This data augmentation is proposed in `RandAugment:
Practical automated data augmentation with a reduced
search space <https://arxiv.org/abs/1909.13719>`_.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_bboxes_labels (np.int64) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_ignore_flags (bool) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- img_shape
- gt_bboxes
- gt_bboxes_labels
- gt_masks
- gt_ignore_flags
- gt_seg_map
Added Keys:
- homography_matrix
Args:
aug_space (List[List[Union[dict, ConfigDict]]]): The augmentation space
of rand augmentation. Each augmentation transform in ``aug_space``
is a specific transform, and is composed by several augmentations.
When RandAugment is called, a random transform in ``aug_space``
will be selected to augment images. Defaults to aug_space.
aug_num (int): Number of augmentation to apply equentially.
Defaults to 2.
prob (list[float], optional): The probabilities associated with
each augmentation. The length should be equal to the
augmentation space and the sum should be 1. If not given,
a uniform distribution will be assumed. Defaults to None.
Examples:
>>> aug_space = [
>>> dict(type='Sharpness'),
>>> dict(type='ShearX'),
>>> dict(type='Color'),
>>> ],
>>> augmentation = RandAugment(aug_space)
>>> img = np.ones(100, 100, 3)
>>> gt_bboxes = np.ones(10, 4)
>>> results = dict(img=img, gt_bboxes=gt_bboxes)
>>> results = augmentation(results)
"""
def __init__(self,
aug_space: List[Union[dict, ConfigDict]] = RANDAUG_SPACE,
aug_num: int = 2,
prob: Optional[List[float]] = None) -> None:
assert isinstance(aug_space, list) and len(aug_space) > 0, \
'Augmentation space must be a non-empty list.'
for aug in aug_space:
assert isinstance(aug, list) and len(aug) == 1, \
'Each augmentation in aug_space must be a list.'
for transform in aug:
assert isinstance(transform, dict) and 'type' in transform, \
'Each specific transform must be a dict with key' \
' "type".'
super().__init__(transforms=aug_space, prob=prob)
self.aug_space = aug_space
self.aug_num = aug_num
@cache_randomness
def random_pipeline_index(self):
indices = np.arange(len(self.transforms))
return np.random.choice(
indices, self.aug_num, p=self.prob, replace=False)
def transform(self, results: dict) -> dict:
"""Transform function to use RandAugment.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with RandAugment.
"""
for idx in self.random_pipeline_index():
results = self.transforms[idx](results)
return results
def __repr__(self) -> str:
return f'{self.__class__.__name__}(' \
f'aug_space={self.aug_space}, '\
f'aug_num={self.aug_num}, ' \
f'prob={self.prob})'
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Optional
import mmcv
import numpy as np
from mmcv.transforms import BaseTransform
from mmcv.transforms.utils import cache_randomness
from mmdet.registry import TRANSFORMS
from .augment_wrappers import _MAX_LEVEL, level_to_mag
@TRANSFORMS.register_module()
class ColorTransform(BaseTransform):
"""Base class for color transformations. All color transformations need to
inherit from this base class. ``ColorTransform`` unifies the class
attributes and class functions of color transformations (Color, Brightness,
Contrast, Sharpness, Solarize, SolarizeAdd, Equalize, AutoContrast, Invert,
and Posterize), and only distort color channels, without impacting the
locations of the instances.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing the geometric
transformation and should be in range [0, 1]. Defaults to 1.0.
level (int, optional): The level should be in range [0, _MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for color transformation.
Defaults to 0.1.
max_mag (float): The maximum magnitude for color transformation.
Defaults to 1.9.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.1,
max_mag: float = 1.9) -> None:
assert 0 <= prob <= 1.0, f'The probability of the transformation ' \
f'should be in range [0,1], got {prob}.'
assert level is None or isinstance(level, int), \
f'The level should be None or type int, got {type(level)}.'
assert level is None or 0 <= level <= _MAX_LEVEL, \
f'The level should be in range [0,{_MAX_LEVEL}], got {level}.'
assert isinstance(min_mag, float), \
f'min_mag should be type float, got {type(min_mag)}.'
assert isinstance(max_mag, float), \
f'max_mag should be type float, got {type(max_mag)}.'
assert min_mag <= max_mag, \
f'min_mag should smaller than max_mag, ' \
f'got min_mag={min_mag} and max_mag={max_mag}'
self.prob = prob
self.level = level
self.min_mag = min_mag
self.max_mag = max_mag
def _transform_img(self, results: dict, mag: float) -> None:
"""Transform the image."""
pass
@cache_randomness
def _random_disable(self):
"""Randomly disable the transform."""
return np.random.rand() > self.prob
@cache_randomness
def _get_mag(self):
"""Get the magnitude of the transform."""
return level_to_mag(self.level, self.min_mag, self.max_mag)
def transform(self, results: dict) -> dict:
"""Transform function for images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Transformed results.
"""
if self._random_disable():
return results
mag = self._get_mag()
self._transform_img(results, mag)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, '
repr_str += f'level={self.level}, '
repr_str += f'min_mag={self.min_mag}, '
repr_str += f'max_mag={self.max_mag})'
return repr_str
@TRANSFORMS.register_module()
class Color(ColorTransform):
"""Adjust the color balance of the image, in a manner similar to the
controls on a colour TV set. A magnitude=0 gives a black & white image,
whereas magnitude=1 gives the original image. The bboxes, masks and
segmentations are not modified.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing Color transformation.
Defaults to 1.0.
level (int, optional): Should be in range [0,_MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for Color transformation.
Defaults to 0.1.
max_mag (float): The maximum magnitude for Color transformation.
Defaults to 1.9.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.1,
max_mag: float = 1.9) -> None:
assert 0. <= min_mag <= 2.0, \
f'min_mag for Color should be in range [0,2], got {min_mag}.'
assert 0. <= max_mag <= 2.0, \
f'max_mag for Color should be in range [0,2], got {max_mag}.'
super().__init__(
prob=prob, level=level, min_mag=min_mag, max_mag=max_mag)
def _transform_img(self, results: dict, mag: float) -> None:
"""Apply Color transformation to image."""
# NOTE defaultly the image should be BGR format
img = results['img']
results['img'] = mmcv.adjust_color(img, mag).astype(img.dtype)
@TRANSFORMS.register_module()
class Brightness(ColorTransform):
"""Adjust the brightness of the image. A magnitude=0 gives a black image,
whereas magnitude=1 gives the original image. The bboxes, masks and
segmentations are not modified.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing Brightness transformation.
Defaults to 1.0.
level (int, optional): Should be in range [0,_MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for Brightness transformation.
Defaults to 0.1.
max_mag (float): The maximum magnitude for Brightness transformation.
Defaults to 1.9.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.1,
max_mag: float = 1.9) -> None:
assert 0. <= min_mag <= 2.0, \
f'min_mag for Brightness should be in range [0,2], got {min_mag}.'
assert 0. <= max_mag <= 2.0, \
f'max_mag for Brightness should be in range [0,2], got {max_mag}.'
super().__init__(
prob=prob, level=level, min_mag=min_mag, max_mag=max_mag)
def _transform_img(self, results: dict, mag: float) -> None:
"""Adjust the brightness of image."""
img = results['img']
results['img'] = mmcv.adjust_brightness(img, mag).astype(img.dtype)
@TRANSFORMS.register_module()
class Contrast(ColorTransform):
"""Control the contrast of the image. A magnitude=0 gives a gray image,
whereas magnitude=1 gives the original imageThe bboxes, masks and
segmentations are not modified.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing Contrast transformation.
Defaults to 1.0.
level (int, optional): Should be in range [0,_MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for Contrast transformation.
Defaults to 0.1.
max_mag (float): The maximum magnitude for Contrast transformation.
Defaults to 1.9.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.1,
max_mag: float = 1.9) -> None:
assert 0. <= min_mag <= 2.0, \
f'min_mag for Contrast should be in range [0,2], got {min_mag}.'
assert 0. <= max_mag <= 2.0, \
f'max_mag for Contrast should be in range [0,2], got {max_mag}.'
super().__init__(
prob=prob, level=level, min_mag=min_mag, max_mag=max_mag)
def _transform_img(self, results: dict, mag: float) -> None:
"""Adjust the image contrast."""
img = results['img']
results['img'] = mmcv.adjust_contrast(img, mag).astype(img.dtype)
@TRANSFORMS.register_module()
class Sharpness(ColorTransform):
"""Adjust images sharpness. A positive magnitude would enhance the
sharpness and a negative magnitude would make the image blurry. A
magnitude=0 gives the origin img.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing Sharpness transformation.
Defaults to 1.0.
level (int, optional): Should be in range [0,_MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for Sharpness transformation.
Defaults to 0.1.
max_mag (float): The maximum magnitude for Sharpness transformation.
Defaults to 1.9.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.1,
max_mag: float = 1.9) -> None:
assert 0. <= min_mag <= 2.0, \
f'min_mag for Sharpness should be in range [0,2], got {min_mag}.'
assert 0. <= max_mag <= 2.0, \
f'max_mag for Sharpness should be in range [0,2], got {max_mag}.'
super().__init__(
prob=prob, level=level, min_mag=min_mag, max_mag=max_mag)
def _transform_img(self, results: dict, mag: float) -> None:
"""Adjust the image sharpness."""
img = results['img']
results['img'] = mmcv.adjust_sharpness(img, mag).astype(img.dtype)
@TRANSFORMS.register_module()
class Solarize(ColorTransform):
"""Solarize images (Invert all pixels above a threshold value of
magnitude.).
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing Solarize transformation.
Defaults to 1.0.
level (int, optional): Should be in range [0,_MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for Solarize transformation.
Defaults to 0.0.
max_mag (float): The maximum magnitude for Solarize transformation.
Defaults to 256.0.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.0,
max_mag: float = 256.0) -> None:
assert 0. <= min_mag <= 256.0, f'min_mag for Solarize should be ' \
f'in range [0, 256], got {min_mag}.'
assert 0. <= max_mag <= 256.0, f'max_mag for Solarize should be ' \
f'in range [0, 256], got {max_mag}.'
super().__init__(
prob=prob, level=level, min_mag=min_mag, max_mag=max_mag)
def _transform_img(self, results: dict, mag: float) -> None:
"""Invert all pixel values above magnitude."""
img = results['img']
results['img'] = mmcv.solarize(img, mag).astype(img.dtype)
@TRANSFORMS.register_module()
class SolarizeAdd(ColorTransform):
"""SolarizeAdd images. For each pixel in the image that is less than 128,
add an additional amount to it decided by the magnitude.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing SolarizeAdd
transformation. Defaults to 1.0.
level (int, optional): Should be in range [0,_MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for SolarizeAdd transformation.
Defaults to 0.0.
max_mag (float): The maximum magnitude for SolarizeAdd transformation.
Defaults to 110.0.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.0,
max_mag: float = 110.0) -> None:
assert 0. <= min_mag <= 110.0, f'min_mag for SolarizeAdd should be ' \
f'in range [0, 110], got {min_mag}.'
assert 0. <= max_mag <= 110.0, f'max_mag for SolarizeAdd should be ' \
f'in range [0, 110], got {max_mag}.'
super().__init__(
prob=prob, level=level, min_mag=min_mag, max_mag=max_mag)
def _transform_img(self, results: dict, mag: float) -> None:
"""SolarizeAdd the image."""
img = results['img']
img_solarized = np.where(img < 128, np.minimum(img + mag, 255), img)
results['img'] = img_solarized.astype(img.dtype)
@TRANSFORMS.register_module()
class Posterize(ColorTransform):
"""Posterize images (reduce the number of bits for each color channel).
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing Posterize
transformation. Defaults to 1.0.
level (int, optional): Should be in range [0,_MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for Posterize transformation.
Defaults to 0.0.
max_mag (float): The maximum magnitude for Posterize transformation.
Defaults to 4.0.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.0,
max_mag: float = 4.0) -> None:
assert 0. <= min_mag <= 8.0, f'min_mag for Posterize should be ' \
f'in range [0, 8], got {min_mag}.'
assert 0. <= max_mag <= 8.0, f'max_mag for Posterize should be ' \
f'in range [0, 8], got {max_mag}.'
super().__init__(
prob=prob, level=level, min_mag=min_mag, max_mag=max_mag)
def _transform_img(self, results: dict, mag: float) -> None:
"""Posterize the image."""
img = results['img']
results['img'] = mmcv.posterize(img, math.ceil(mag)).astype(img.dtype)
@TRANSFORMS.register_module()
class Equalize(ColorTransform):
"""Equalize the image histogram. The bboxes, masks and segmentations are
not modified.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing Equalize transformation.
Defaults to 1.0.
level (int, optional): No use for Equalize transformation.
Defaults to None.
min_mag (float): No use for Equalize transformation. Defaults to 0.1.
max_mag (float): No use for Equalize transformation. Defaults to 1.9.
"""
def _transform_img(self, results: dict, mag: float) -> None:
"""Equalizes the histogram of one image."""
img = results['img']
results['img'] = mmcv.imequalize(img).astype(img.dtype)
@TRANSFORMS.register_module()
class AutoContrast(ColorTransform):
"""Auto adjust image contrast.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing AutoContrast should
be in range [0, 1]. Defaults to 1.0.
level (int, optional): No use for AutoContrast transformation.
Defaults to None.
min_mag (float): No use for AutoContrast transformation.
Defaults to 0.1.
max_mag (float): No use for AutoContrast transformation.
Defaults to 1.9.
"""
def _transform_img(self, results: dict, mag: float) -> None:
"""Auto adjust image contrast."""
img = results['img']
results['img'] = mmcv.auto_contrast(img).astype(img.dtype)
@TRANSFORMS.register_module()
class Invert(ColorTransform):
"""Invert images.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing invert therefore should
be in range [0, 1]. Defaults to 1.0.
level (int, optional): No use for Invert transformation.
Defaults to None.
min_mag (float): No use for Invert transformation. Defaults to 0.1.
max_mag (float): No use for Invert transformation. Defaults to 1.9.
"""
def _transform_img(self, results: dict, mag: float) -> None:
"""Invert the image."""
img = results['img']
results['img'] = mmcv.iminvert(img).astype(img.dtype)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence
import numpy as np
from mmcv.transforms import to_tensor
from mmcv.transforms.base import BaseTransform
from mmengine.structures import InstanceData, PixelData
from mmdet.registry import TRANSFORMS
from mmdet.structures import DetDataSample, ReIDDataSample, TrackDataSample
from mmdet.structures.bbox import BaseBoxes
@TRANSFORMS.register_module()
class PackDetInputs(BaseTransform):
"""Pack the inputs data for the detection / semantic segmentation /
panoptic segmentation.
The ``img_meta`` item is always populated. The contents of the
``img_meta`` dictionary depends on ``meta_keys``. By default this includes:
- ``img_id``: id of the image
- ``img_path``: path to the image file
- ``ori_shape``: original shape of the image as a tuple (h, w)
- ``img_shape``: shape of the image input to the network as a tuple \
(h, w). Note that images may be zero padded on the \
bottom/right if the batch tensor is larger than this shape.
- ``scale_factor``: a float indicating the preprocessing scale
- ``flip``: a boolean indicating if image flip transform was used
- ``flip_direction``: the flipping direction
Args:
meta_keys (Sequence[str], optional): Meta keys to be converted to
``mmcv.DataContainer`` and collected in ``data[img_metas]``.
Default: ``('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction')``
"""
mapping_table = {
'gt_bboxes': 'bboxes',
'gt_bboxes_labels': 'labels',
'gt_masks': 'masks'
}
def __init__(self,
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction')):
self.meta_keys = meta_keys
def transform(self, results: dict) -> dict:
"""Method to pack the input data.
Args:
results (dict): Result dict from the data pipeline.
Returns:
dict:
- 'inputs' (obj:`torch.Tensor`): The forward data of models.
- 'data_sample' (obj:`DetDataSample`): The annotation info of the
sample.
"""
packed_results = dict()
if 'img' in results:
img = results['img']
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
# To improve the computational speed by by 3-5 times, apply:
# If image is not contiguous, use
# `numpy.transpose()` followed by `numpy.ascontiguousarray()`
# If image is already contiguous, use
# `torch.permute()` followed by `torch.contiguous()`
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533
# for more details
if not img.flags.c_contiguous:
img = np.ascontiguousarray(img.transpose(2, 0, 1))
img = to_tensor(img)
else:
img = to_tensor(img).permute(2, 0, 1).contiguous()
packed_results['inputs'] = img
if 'gt_ignore_flags' in results:
valid_idx = np.where(results['gt_ignore_flags'] == 0)[0]
ignore_idx = np.where(results['gt_ignore_flags'] == 1)[0]
data_sample = DetDataSample()
instance_data = InstanceData()
ignore_instance_data = InstanceData()
for key in self.mapping_table.keys():
if key not in results:
continue
if key == 'gt_masks' or isinstance(results[key], BaseBoxes):
if 'gt_ignore_flags' in results:
instance_data[
self.mapping_table[key]] = results[key][valid_idx]
ignore_instance_data[
self.mapping_table[key]] = results[key][ignore_idx]
else:
instance_data[self.mapping_table[key]] = results[key]
else:
if 'gt_ignore_flags' in results:
instance_data[self.mapping_table[key]] = to_tensor(
results[key][valid_idx])
ignore_instance_data[self.mapping_table[key]] = to_tensor(
results[key][ignore_idx])
else:
instance_data[self.mapping_table[key]] = to_tensor(
results[key])
data_sample.gt_instances = instance_data
data_sample.ignored_instances = ignore_instance_data
if 'proposals' in results:
proposals = InstanceData(
bboxes=to_tensor(results['proposals']),
scores=to_tensor(results['proposals_scores']))
data_sample.proposals = proposals
if 'gt_seg_map' in results:
gt_sem_seg_data = dict(
sem_seg=to_tensor(results['gt_seg_map'][None, ...].copy()))
gt_sem_seg_data = PixelData(**gt_sem_seg_data)
if 'ignore_index' in results:
metainfo = dict(ignore_index=results['ignore_index'])
gt_sem_seg_data.set_metainfo(metainfo)
data_sample.gt_sem_seg = gt_sem_seg_data
img_meta = {}
for key in self.meta_keys:
if key in results:
img_meta[key] = results[key]
data_sample.set_metainfo(img_meta)
packed_results['data_samples'] = data_sample
return packed_results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(meta_keys={self.meta_keys})'
return repr_str
@TRANSFORMS.register_module()
class ToTensor:
"""Convert some results to :obj:`torch.Tensor` by given keys.
Args:
keys (Sequence[str]): Keys that need to be converted to Tensor.
"""
def __init__(self, keys):
self.keys = keys
def __call__(self, results):
"""Call function to convert data in results to :obj:`torch.Tensor`.
Args:
results (dict): Result dict contains the data to convert.
Returns:
dict: The result dict contains the data converted
to :obj:`torch.Tensor`.
"""
for key in self.keys:
results[key] = to_tensor(results[key])
return results
def __repr__(self):
return self.__class__.__name__ + f'(keys={self.keys})'
@TRANSFORMS.register_module()
class ImageToTensor:
"""Convert image to :obj:`torch.Tensor` by given keys.
The dimension order of input image is (H, W, C). The pipeline will convert
it to (C, H, W). If only 2 dimension (H, W) is given, the output would be
(1, H, W).
Args:
keys (Sequence[str]): Key of images to be converted to Tensor.
"""
def __init__(self, keys):
self.keys = keys
def __call__(self, results):
"""Call function to convert image in results to :obj:`torch.Tensor` and
transpose the channel order.
Args:
results (dict): Result dict contains the image data to convert.
Returns:
dict: The result dict contains the image converted
to :obj:`torch.Tensor` and permuted to (C, H, W) order.
"""
for key in self.keys:
img = results[key]
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
results[key] = to_tensor(img).permute(2, 0, 1).contiguous()
return results
def __repr__(self):
return self.__class__.__name__ + f'(keys={self.keys})'
@TRANSFORMS.register_module()
class Transpose:
"""Transpose some results by given keys.
Args:
keys (Sequence[str]): Keys of results to be transposed.
order (Sequence[int]): Order of transpose.
"""
def __init__(self, keys, order):
self.keys = keys
self.order = order
def __call__(self, results):
"""Call function to transpose the channel order of data in results.
Args:
results (dict): Result dict contains the data to transpose.
Returns:
dict: The result dict contains the data transposed to \
``self.order``.
"""
for key in self.keys:
results[key] = results[key].transpose(self.order)
return results
def __repr__(self):
return self.__class__.__name__ + \
f'(keys={self.keys}, order={self.order})'
@TRANSFORMS.register_module()
class WrapFieldsToLists:
"""Wrap fields of the data dictionary into lists for evaluation.
This class can be used as a last step of a test or validation
pipeline for single image evaluation or inference.
Example:
>>> test_pipeline = [
>>> dict(type='LoadImageFromFile'),
>>> dict(type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
>>> dict(type='Pad', size_divisor=32),
>>> dict(type='ImageToTensor', keys=['img']),
>>> dict(type='Collect', keys=['img']),
>>> dict(type='WrapFieldsToLists')
>>> ]
"""
def __call__(self, results):
"""Call function to wrap fields into lists.
Args:
results (dict): Result dict contains the data to wrap.
Returns:
dict: The result dict where value of ``self.keys`` are wrapped \
into list.
"""
# Wrap dict fields into lists
for key, val in results.items():
results[key] = [val]
return results
def __repr__(self):
return f'{self.__class__.__name__}()'
@TRANSFORMS.register_module()
class PackTrackInputs(BaseTransform):
"""Pack the inputs data for the multi object tracking and video instance
segmentation. All the information of images are packed to ``inputs``. All
the information except images are packed to ``data_samples``. In order to
get the original annotaiton and meta info, we add `instances` key into meta
keys.
Args:
meta_keys (Sequence[str]): Meta keys to be collected in
``data_sample.metainfo``. Defaults to None.
default_meta_keys (tuple): Default meta keys. Defaults to ('img_id',
'img_path', 'ori_shape', 'img_shape', 'scale_factor',
'flip', 'flip_direction', 'frame_id', 'is_video_data',
'video_id', 'video_length', 'instances').
"""
mapping_table = {
'gt_bboxes': 'bboxes',
'gt_bboxes_labels': 'labels',
'gt_masks': 'masks',
'gt_instances_ids': 'instances_ids'
}
def __init__(self,
meta_keys: Optional[dict] = None,
default_meta_keys: tuple = ('img_id', 'img_path', 'ori_shape',
'img_shape', 'scale_factor',
'flip', 'flip_direction',
'frame_id', 'video_id',
'video_length',
'ori_video_length', 'instances')):
self.meta_keys = default_meta_keys
if meta_keys is not None:
if isinstance(meta_keys, str):
meta_keys = (meta_keys, )
else:
assert isinstance(meta_keys, tuple), \
'meta_keys must be str or tuple'
self.meta_keys += meta_keys
def transform(self, results: dict) -> dict:
"""Method to pack the input data.
Args:
results (dict): Result dict from the data pipeline.
Returns:
dict:
- 'inputs' (dict[Tensor]): The forward data of models.
- 'data_samples' (obj:`TrackDataSample`): The annotation info of
the samples.
"""
packed_results = dict()
packed_results['inputs'] = dict()
# 1. Pack images
if 'img' in results:
imgs = results['img']
imgs = np.stack(imgs, axis=0)
imgs = imgs.transpose(0, 3, 1, 2)
packed_results['inputs'] = to_tensor(imgs)
# 2. Pack InstanceData
if 'gt_ignore_flags' in results:
gt_ignore_flags_list = results['gt_ignore_flags']
valid_idx_list, ignore_idx_list = [], []
for gt_ignore_flags in gt_ignore_flags_list:
valid_idx = np.where(gt_ignore_flags == 0)[0]
ignore_idx = np.where(gt_ignore_flags == 1)[0]
valid_idx_list.append(valid_idx)
ignore_idx_list.append(ignore_idx)
assert 'img_id' in results, "'img_id' must contained in the results "
'for counting the number of images'
num_imgs = len(results['img_id'])
instance_data_list = [InstanceData() for _ in range(num_imgs)]
ignore_instance_data_list = [InstanceData() for _ in range(num_imgs)]
for key in self.mapping_table.keys():
if key not in results:
continue
if key == 'gt_masks':
mapped_key = self.mapping_table[key]
gt_masks_list = results[key]
if 'gt_ignore_flags' in results:
for i, gt_mask in enumerate(gt_masks_list):
valid_idx, ignore_idx = valid_idx_list[
i], ignore_idx_list[i]
instance_data_list[i][mapped_key] = gt_mask[valid_idx]
ignore_instance_data_list[i][mapped_key] = gt_mask[
ignore_idx]
else:
for i, gt_mask in enumerate(gt_masks_list):
instance_data_list[i][mapped_key] = gt_mask
else:
anns_list = results[key]
if 'gt_ignore_flags' in results:
for i, ann in enumerate(anns_list):
valid_idx, ignore_idx = valid_idx_list[
i], ignore_idx_list[i]
instance_data_list[i][
self.mapping_table[key]] = to_tensor(
ann[valid_idx])
ignore_instance_data_list[i][
self.mapping_table[key]] = to_tensor(
ann[ignore_idx])
else:
for i, ann in enumerate(anns_list):
instance_data_list[i][
self.mapping_table[key]] = to_tensor(ann)
det_data_samples_list = []
for i in range(num_imgs):
det_data_sample = DetDataSample()
det_data_sample.gt_instances = instance_data_list[i]
det_data_sample.ignored_instances = ignore_instance_data_list[i]
det_data_samples_list.append(det_data_sample)
# 3. Pack metainfo
for key in self.meta_keys:
if key not in results:
continue
img_metas_list = results[key]
for i, img_meta in enumerate(img_metas_list):
det_data_samples_list[i].set_metainfo({f'{key}': img_meta})
track_data_sample = TrackDataSample()
track_data_sample.video_data_samples = det_data_samples_list
if 'key_frame_flags' in results:
key_frame_flags = np.asarray(results['key_frame_flags'])
key_frames_inds = np.where(key_frame_flags)[0].tolist()
ref_frames_inds = np.where(~key_frame_flags)[0].tolist()
track_data_sample.set_metainfo(
dict(key_frames_inds=key_frames_inds))
track_data_sample.set_metainfo(
dict(ref_frames_inds=ref_frames_inds))
packed_results['data_samples'] = track_data_sample
return packed_results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'meta_keys={self.meta_keys}, '
repr_str += f'default_meta_keys={self.default_meta_keys})'
return repr_str
@TRANSFORMS.register_module()
class PackReIDInputs(BaseTransform):
"""Pack the inputs data for the ReID. The ``meta_info`` item is always
populated. The contents of the ``meta_info`` dictionary depends on
``meta_keys``. By default this includes:
- ``img_path``: path to the image file.
- ``ori_shape``: original shape of the image as a tuple (H, W).
- ``img_shape``: shape of the image input to the network as a tuple
(H, W). Note that images may be zero padded on the bottom/right
if the batch tensor is larger than this shape.
- ``scale``: scale of the image as a tuple (W, H).
- ``scale_factor``: a float indicating the pre-processing scale.
- ``flip``: a boolean indicating if image flip transform was used.
- ``flip_direction``: the flipping direction.
Args:
meta_keys (Sequence[str], optional): The meta keys to saved in the
``metainfo`` of the packed ``data_sample``.
"""
default_meta_keys = ('img_path', 'ori_shape', 'img_shape', 'scale',
'scale_factor')
def __init__(self, meta_keys: Sequence[str] = ()) -> None:
self.meta_keys = self.default_meta_keys
if meta_keys is not None:
if isinstance(meta_keys, str):
meta_keys = (meta_keys, )
else:
assert isinstance(meta_keys, tuple), \
'meta_keys must be str or tuple.'
self.meta_keys += meta_keys
def transform(self, results: dict) -> dict:
"""Method to pack the input data.
Args:
results (dict): Result dict from the data pipeline.
Returns:
dict:
- 'inputs' (dict[Tensor]): The forward data of models.
- 'data_samples' (obj:`ReIDDataSample`): The meta info of the
sample.
"""
packed_results = dict(inputs=dict(), data_samples=None)
assert 'img' in results, 'Missing the key ``img``.'
_type = type(results['img'])
label = results['gt_label']
if _type == list:
img = results['img']
label = np.stack(label, axis=0) # (N,)
assert all([type(v) == _type for v in results.values()]), \
'All items in the results must have the same type.'
else:
img = [results['img']]
img = np.stack(img, axis=3) # (H, W, C, N)
img = img.transpose(3, 2, 0, 1) # (N, C, H, W)
img = np.ascontiguousarray(img)
packed_results['inputs'] = to_tensor(img)
data_sample = ReIDDataSample()
data_sample.set_gt_label(label)
meta_info = dict()
for key in self.meta_keys:
meta_info[key] = results[key]
data_sample.set_metainfo(meta_info)
packed_results['data_samples'] = data_sample
return packed_results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(meta_keys={self.meta_keys})'
return repr_str
# Copyright (c) OpenMMLab. All rights reserved.
import random
from collections import defaultdict
from typing import Dict, List, Optional, Union
from mmcv.transforms import BaseTransform
from mmdet.registry import TRANSFORMS
@TRANSFORMS.register_module()
class BaseFrameSample(BaseTransform):
"""Directly get the key frame, no reference frames.
Args:
collect_video_keys (list[str]): The keys of video info to be
collected.
"""
def __init__(self,
collect_video_keys: List[str] = ['video_id', 'video_length']):
self.collect_video_keys = collect_video_keys
def prepare_data(self, video_infos: dict,
sampled_inds: List[int]) -> Dict[str, List]:
"""Prepare data for the subsequent pipeline.
Args:
video_infos (dict): The whole video information.
sampled_inds (list[int]): The sampled frame indices.
Returns:
dict: The processed data information.
"""
frames_anns = video_infos['images']
final_data_info = defaultdict(list)
# for data in frames_anns:
for index in sampled_inds:
data = frames_anns[index]
# copy the info in video-level into img-level
for key in self.collect_video_keys:
if key == 'video_length':
data['ori_video_length'] = video_infos[key]
data['video_length'] = len(sampled_inds)
else:
data[key] = video_infos[key]
# Collate data_list (list of dict to dict of list)
for key, value in data.items():
final_data_info[key].append(value)
return final_data_info
def transform(self, video_infos: dict) -> Optional[Dict[str, List]]:
"""Transform the video information.
Args:
video_infos (dict): The whole video information.
Returns:
dict: The data information of the key frames.
"""
if 'key_frame_id' in video_infos:
key_frame_id = video_infos['key_frame_id']
assert isinstance(video_infos['key_frame_id'], int)
else:
key_frame_id = random.sample(
list(range(video_infos['video_length'])), 1)[0]
results = self.prepare_data(video_infos, [key_frame_id])
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(collect_video_keys={self.collect_video_keys})'
return repr_str
@TRANSFORMS.register_module()
class UniformRefFrameSample(BaseFrameSample):
"""Uniformly sample reference frames.
Args:
num_ref_imgs (int): Number of reference frames to be sampled.
frame_range (int | list[int]): Range of frames to be sampled around
key frame. If int, the range is [-frame_range, frame_range].
Defaults to 10.
filter_key_img (bool): Whether to filter the key frame when
sampling reference frames. Defaults to True.
collect_video_keys (list[str]): The keys of video info to be
collected.
"""
def __init__(self,
num_ref_imgs: int = 1,
frame_range: Union[int, List[int]] = 10,
filter_key_img: bool = True,
collect_video_keys: List[str] = ['video_id', 'video_length']):
self.num_ref_imgs = num_ref_imgs
self.filter_key_img = filter_key_img
if isinstance(frame_range, int):
assert frame_range >= 0, 'frame_range can not be a negative value.'
frame_range = [-frame_range, frame_range]
elif isinstance(frame_range, list):
assert len(frame_range) == 2, 'The length must be 2.'
assert frame_range[0] <= 0 and frame_range[1] >= 0
for i in frame_range:
assert isinstance(i, int), 'Each element must be int.'
else:
raise TypeError('The type of frame_range must be int or list.')
self.frame_range = frame_range
super().__init__(collect_video_keys=collect_video_keys)
def sampling_frames(self, video_length: int, key_frame_id: int):
"""Sampling frames.
Args:
video_length (int): The length of the video.
key_frame_id (int): The key frame id.
Returns:
list[int]: The sampled frame indices.
"""
if video_length > 1:
left = max(0, key_frame_id + self.frame_range[0])
right = min(key_frame_id + self.frame_range[1], video_length - 1)
frame_ids = list(range(0, video_length))
valid_ids = frame_ids[left:right + 1]
if self.filter_key_img and key_frame_id in valid_ids:
valid_ids.remove(key_frame_id)
assert len(
valid_ids
) > 0, 'After filtering key frame, there are no valid frames'
if len(valid_ids) < self.num_ref_imgs:
valid_ids = valid_ids * self.num_ref_imgs
ref_frame_ids = random.sample(valid_ids, self.num_ref_imgs)
else:
ref_frame_ids = [key_frame_id] * self.num_ref_imgs
sampled_frames_ids = [key_frame_id] + ref_frame_ids
sampled_frames_ids = sorted(sampled_frames_ids)
key_frames_ind = sampled_frames_ids.index(key_frame_id)
key_frame_flags = [False] * len(sampled_frames_ids)
key_frame_flags[key_frames_ind] = True
return sampled_frames_ids, key_frame_flags
def transform(self, video_infos: dict) -> Optional[Dict[str, List]]:
"""Transform the video information.
Args:
video_infos (dict): The whole video information.
Returns:
dict: The data information of the sampled frames.
"""
if 'key_frame_id' in video_infos:
key_frame_id = video_infos['key_frame_id']
assert isinstance(video_infos['key_frame_id'], int)
else:
key_frame_id = random.sample(
list(range(video_infos['video_length'])), 1)[0]
(sampled_frames_ids, key_frame_flags) = self.sampling_frames(
video_infos['video_length'], key_frame_id=key_frame_id)
results = self.prepare_data(video_infos, sampled_frames_ids)
results['key_frame_flags'] = key_frame_flags
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(num_ref_imgs={self.num_ref_imgs}, '
repr_str += f'frame_range={self.frame_range}, '
repr_str += f'filter_key_img={self.filter_key_img}, '
repr_str += f'collect_video_keys={self.collect_video_keys})'
return repr_str
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union
import cv2
import mmcv
import numpy as np
from mmcv.transforms import BaseTransform
from mmcv.transforms.utils import cache_randomness
from mmdet.registry import TRANSFORMS
from mmdet.structures.bbox import autocast_box_type
from .augment_wrappers import _MAX_LEVEL, level_to_mag
@TRANSFORMS.register_module()
class GeomTransform(BaseTransform):
"""Base class for geometric transformations. All geometric transformations
need to inherit from this base class. ``GeomTransform`` unifies the class
attributes and class functions of geometric transformations (ShearX,
ShearY, Rotate, TranslateX, and TranslateY), and records the homography
matrix.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- gt_bboxes
- gt_masks
- gt_seg_map
Added Keys:
- homography_matrix
Args:
prob (float): The probability for performing the geometric
transformation and should be in range [0, 1]. Defaults to 1.0.
level (int, optional): The level should be in range [0, _MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for geometric transformation.
Defaults to 0.0.
max_mag (float): The maximum magnitude for geometric transformation.
Defaults to 1.0.
reversal_prob (float): The probability that reverses the geometric
transformation magnitude. Should be in range [0,1].
Defaults to 0.5.
img_border_value (int | float | tuple): The filled values for
image border. If float, the same fill value will be used for
all the three channels of image. If tuple, it should be 3 elements.
Defaults to 128.
mask_border_value (int): The fill value used for masks. Defaults to 0.
seg_ignore_label (int): The fill value used for segmentation map.
Note this value must equals ``ignore_label`` in ``semantic_head``
of the corresponding config. Defaults to 255.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear" for 'pillow' backend. Defaults
to 'bilinear'.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.0,
max_mag: float = 1.0,
reversal_prob: float = 0.5,
img_border_value: Union[int, float, tuple] = 128,
mask_border_value: int = 0,
seg_ignore_label: int = 255,
interpolation: str = 'bilinear') -> None:
assert 0 <= prob <= 1.0, f'The probability of the transformation ' \
f'should be in range [0,1], got {prob}.'
assert level is None or isinstance(level, int), \
f'The level should be None or type int, got {type(level)}.'
assert level is None or 0 <= level <= _MAX_LEVEL, \
f'The level should be in range [0,{_MAX_LEVEL}], got {level}.'
assert isinstance(min_mag, float), \
f'min_mag should be type float, got {type(min_mag)}.'
assert isinstance(max_mag, float), \
f'max_mag should be type float, got {type(max_mag)}.'
assert min_mag <= max_mag, \
f'min_mag should smaller than max_mag, ' \
f'got min_mag={min_mag} and max_mag={max_mag}'
assert isinstance(reversal_prob, float), \
f'reversal_prob should be type float, got {type(max_mag)}.'
assert 0 <= reversal_prob <= 1.0, \
f'The reversal probability of the transformation magnitude ' \
f'should be type float, got {type(reversal_prob)}.'
if isinstance(img_border_value, (float, int)):
img_border_value = tuple([float(img_border_value)] * 3)
elif isinstance(img_border_value, tuple):
assert len(img_border_value) == 3, \
f'img_border_value as tuple must have 3 elements, ' \
f'got {len(img_border_value)}.'
img_border_value = tuple([float(val) for val in img_border_value])
else:
raise ValueError(
'img_border_value must be float or tuple with 3 elements.')
assert np.all([0 <= val <= 255 for val in img_border_value]), 'all ' \
'elements of img_border_value should between range [0,255].' \
f'got {img_border_value}.'
self.prob = prob
self.level = level
self.min_mag = min_mag
self.max_mag = max_mag
self.reversal_prob = reversal_prob
self.img_border_value = img_border_value
self.mask_border_value = mask_border_value
self.seg_ignore_label = seg_ignore_label
self.interpolation = interpolation
def _transform_img(self, results: dict, mag: float) -> None:
"""Transform the image."""
pass
def _transform_masks(self, results: dict, mag: float) -> None:
"""Transform the masks."""
pass
def _transform_seg(self, results: dict, mag: float) -> None:
"""Transform the segmentation map."""
pass
def _get_homography_matrix(self, results: dict, mag: float) -> np.ndarray:
"""Get the homography matrix for the geometric transformation."""
return np.eye(3, dtype=np.float32)
def _transform_bboxes(self, results: dict, mag: float) -> None:
"""Transform the bboxes."""
results['gt_bboxes'].project_(self.homography_matrix)
results['gt_bboxes'].clip_(results['img_shape'])
def _record_homography_matrix(self, results: dict) -> None:
"""Record the homography matrix for the geometric transformation."""
if results.get('homography_matrix', None) is None:
results['homography_matrix'] = self.homography_matrix
else:
results['homography_matrix'] = self.homography_matrix @ results[
'homography_matrix']
@cache_randomness
def _random_disable(self):
"""Randomly disable the transform."""
return np.random.rand() > self.prob
@cache_randomness
def _get_mag(self):
"""Get the magnitude of the transform."""
mag = level_to_mag(self.level, self.min_mag, self.max_mag)
return -mag if np.random.rand() > self.reversal_prob else mag
@autocast_box_type()
def transform(self, results: dict) -> dict:
"""Transform function for images, bounding boxes, masks and semantic
segmentation map.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Transformed results.
"""
if self._random_disable():
return results
mag = self._get_mag()
self.homography_matrix = self._get_homography_matrix(results, mag)
self._record_homography_matrix(results)
self._transform_img(results, mag)
if results.get('gt_bboxes', None) is not None:
self._transform_bboxes(results, mag)
if results.get('gt_masks', None) is not None:
self._transform_masks(results, mag)
if results.get('gt_seg_map', None) is not None:
self._transform_seg(results, mag)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, '
repr_str += f'level={self.level}, '
repr_str += f'min_mag={self.min_mag}, '
repr_str += f'max_mag={self.max_mag}, '
repr_str += f'reversal_prob={self.reversal_prob}, '
repr_str += f'img_border_value={self.img_border_value}, '
repr_str += f'mask_border_value={self.mask_border_value}, '
repr_str += f'seg_ignore_label={self.seg_ignore_label}, '
repr_str += f'interpolation={self.interpolation})'
return repr_str
@TRANSFORMS.register_module()
class ShearX(GeomTransform):
"""Shear the images, bboxes, masks and segmentation map horizontally.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- gt_bboxes
- gt_masks
- gt_seg_map
Added Keys:
- homography_matrix
Args:
prob (float): The probability for performing Shear and should be in
range [0, 1]. Defaults to 1.0.
level (int, optional): The level should be in range [0, _MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum angle for the horizontal shear.
Defaults to 0.0.
max_mag (float): The maximum angle for the horizontal shear.
Defaults to 30.0.
reversal_prob (float): The probability that reverses the horizontal
shear magnitude. Should be in range [0,1]. Defaults to 0.5.
img_border_value (int | float | tuple): The filled values for
image border. If float, the same fill value will be used for
all the three channels of image. If tuple, it should be 3 elements.
Defaults to 128.
mask_border_value (int): The fill value used for masks. Defaults to 0.
seg_ignore_label (int): The fill value used for segmentation map.
Note this value must equals ``ignore_label`` in ``semantic_head``
of the corresponding config. Defaults to 255.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear" for 'pillow' backend. Defaults
to 'bilinear'.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.0,
max_mag: float = 30.0,
reversal_prob: float = 0.5,
img_border_value: Union[int, float, tuple] = 128,
mask_border_value: int = 0,
seg_ignore_label: int = 255,
interpolation: str = 'bilinear') -> None:
assert 0. <= min_mag <= 90., \
f'min_mag angle for ShearX should be ' \
f'in range [0, 90], got {min_mag}.'
assert 0. <= max_mag <= 90., \
f'max_mag angle for ShearX should be ' \
f'in range [0, 90], got {max_mag}.'
super().__init__(
prob=prob,
level=level,
min_mag=min_mag,
max_mag=max_mag,
reversal_prob=reversal_prob,
img_border_value=img_border_value,
mask_border_value=mask_border_value,
seg_ignore_label=seg_ignore_label,
interpolation=interpolation)
@cache_randomness
def _get_mag(self):
"""Get the magnitude of the transform."""
mag = level_to_mag(self.level, self.min_mag, self.max_mag)
mag = np.tan(mag * np.pi / 180)
return -mag if np.random.rand() > self.reversal_prob else mag
def _get_homography_matrix(self, results: dict, mag: float) -> np.ndarray:
"""Get the homography matrix for ShearX."""
return np.array([[1, mag, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32)
def _transform_img(self, results: dict, mag: float) -> None:
"""Shear the image horizontally."""
results['img'] = mmcv.imshear(
results['img'],
mag,
direction='horizontal',
border_value=self.img_border_value,
interpolation=self.interpolation)
def _transform_masks(self, results: dict, mag: float) -> None:
"""Shear the masks horizontally."""
results['gt_masks'] = results['gt_masks'].shear(
results['img_shape'],
mag,
direction='horizontal',
border_value=self.mask_border_value,
interpolation=self.interpolation)
def _transform_seg(self, results: dict, mag: float) -> None:
"""Shear the segmentation map horizontally."""
results['gt_seg_map'] = mmcv.imshear(
results['gt_seg_map'],
mag,
direction='horizontal',
border_value=self.seg_ignore_label,
interpolation='nearest')
@TRANSFORMS.register_module()
class ShearY(GeomTransform):
"""Shear the images, bboxes, masks and segmentation map vertically.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- gt_bboxes
- gt_masks
- gt_seg_map
Added Keys:
- homography_matrix
Args:
prob (float): The probability for performing ShearY and should be in
range [0, 1]. Defaults to 1.0.
level (int, optional): The level should be in range [0,_MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum angle for the vertical shear.
Defaults to 0.0.
max_mag (float): The maximum angle for the vertical shear.
Defaults to 30.0.
reversal_prob (float): The probability that reverses the vertical
shear magnitude. Should be in range [0,1]. Defaults to 0.5.
img_border_value (int | float | tuple): The filled values for
image border. If float, the same fill value will be used for
all the three channels of image. If tuple, it should be 3 elements.
Defaults to 128.
mask_border_value (int): The fill value used for masks. Defaults to 0.
seg_ignore_label (int): The fill value used for segmentation map.
Note this value must equals ``ignore_label`` in ``semantic_head``
of the corresponding config. Defaults to 255.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear" for 'pillow' backend. Defaults
to 'bilinear'.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.0,
max_mag: float = 30.,
reversal_prob: float = 0.5,
img_border_value: Union[int, float, tuple] = 128,
mask_border_value: int = 0,
seg_ignore_label: int = 255,
interpolation: str = 'bilinear') -> None:
assert 0. <= min_mag <= 90., \
f'min_mag angle for ShearY should be ' \
f'in range [0, 90], got {min_mag}.'
assert 0. <= max_mag <= 90., \
f'max_mag angle for ShearY should be ' \
f'in range [0, 90], got {max_mag}.'
super().__init__(
prob=prob,
level=level,
min_mag=min_mag,
max_mag=max_mag,
reversal_prob=reversal_prob,
img_border_value=img_border_value,
mask_border_value=mask_border_value,
seg_ignore_label=seg_ignore_label,
interpolation=interpolation)
@cache_randomness
def _get_mag(self):
"""Get the magnitude of the transform."""
mag = level_to_mag(self.level, self.min_mag, self.max_mag)
mag = np.tan(mag * np.pi / 180)
return -mag if np.random.rand() > self.reversal_prob else mag
def _get_homography_matrix(self, results: dict, mag: float) -> np.ndarray:
"""Get the homography matrix for ShearY."""
return np.array([[1, 0, 0], [mag, 1, 0], [0, 0, 1]], dtype=np.float32)
def _transform_img(self, results: dict, mag: float) -> None:
"""Shear the image vertically."""
results['img'] = mmcv.imshear(
results['img'],
mag,
direction='vertical',
border_value=self.img_border_value,
interpolation=self.interpolation)
def _transform_masks(self, results: dict, mag: float) -> None:
"""Shear the masks vertically."""
results['gt_masks'] = results['gt_masks'].shear(
results['img_shape'],
mag,
direction='vertical',
border_value=self.mask_border_value,
interpolation=self.interpolation)
def _transform_seg(self, results: dict, mag: float) -> None:
"""Shear the segmentation map vertically."""
results['gt_seg_map'] = mmcv.imshear(
results['gt_seg_map'],
mag,
direction='vertical',
border_value=self.seg_ignore_label,
interpolation='nearest')
@TRANSFORMS.register_module()
class Rotate(GeomTransform):
"""Rotate the images, bboxes, masks and segmentation map.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- gt_bboxes
- gt_masks
- gt_seg_map
Added Keys:
- homography_matrix
Args:
prob (float): The probability for perform transformation and
should be in range 0 to 1. Defaults to 1.0.
level (int, optional): The level should be in range [0, _MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The maximum angle for rotation.
Defaults to 0.0.
max_mag (float): The maximum angle for rotation.
Defaults to 30.0.
reversal_prob (float): The probability that reverses the rotation
magnitude. Should be in range [0,1]. Defaults to 0.5.
img_border_value (int | float | tuple): The filled values for
image border. If float, the same fill value will be used for
all the three channels of image. If tuple, it should be 3 elements.
Defaults to 128.
mask_border_value (int): The fill value used for masks. Defaults to 0.
seg_ignore_label (int): The fill value used for segmentation map.
Note this value must equals ``ignore_label`` in ``semantic_head``
of the corresponding config. Defaults to 255.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear" for 'pillow' backend. Defaults
to 'bilinear'.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.0,
max_mag: float = 30.0,
reversal_prob: float = 0.5,
img_border_value: Union[int, float, tuple] = 128,
mask_border_value: int = 0,
seg_ignore_label: int = 255,
interpolation: str = 'bilinear') -> None:
assert 0. <= min_mag <= 180., \
f'min_mag for Rotate should be in range [0,180], got {min_mag}.'
assert 0. <= max_mag <= 180., \
f'max_mag for Rotate should be in range [0,180], got {max_mag}.'
super().__init__(
prob=prob,
level=level,
min_mag=min_mag,
max_mag=max_mag,
reversal_prob=reversal_prob,
img_border_value=img_border_value,
mask_border_value=mask_border_value,
seg_ignore_label=seg_ignore_label,
interpolation=interpolation)
def _get_homography_matrix(self, results: dict, mag: float) -> np.ndarray:
"""Get the homography matrix for Rotate."""
img_shape = results['img_shape']
center = ((img_shape[1] - 1) * 0.5, (img_shape[0] - 1) * 0.5)
cv2_rotation_matrix = cv2.getRotationMatrix2D(center, -mag, 1.0)
return np.concatenate(
[cv2_rotation_matrix,
np.array([0, 0, 1]).reshape((1, 3))]).astype(np.float32)
def _transform_img(self, results: dict, mag: float) -> None:
"""Rotate the image."""
results['img'] = mmcv.imrotate(
results['img'],
mag,
border_value=self.img_border_value,
interpolation=self.interpolation)
def _transform_masks(self, results: dict, mag: float) -> None:
"""Rotate the masks."""
results['gt_masks'] = results['gt_masks'].rotate(
results['img_shape'],
mag,
border_value=self.mask_border_value,
interpolation=self.interpolation)
def _transform_seg(self, results: dict, mag: float) -> None:
"""Rotate the segmentation map."""
results['gt_seg_map'] = mmcv.imrotate(
results['gt_seg_map'],
mag,
border_value=self.seg_ignore_label,
interpolation='nearest')
@TRANSFORMS.register_module()
class TranslateX(GeomTransform):
"""Translate the images, bboxes, masks and segmentation map horizontally.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- gt_bboxes
- gt_masks
- gt_seg_map
Added Keys:
- homography_matrix
Args:
prob (float): The probability for perform transformation and
should be in range 0 to 1. Defaults to 1.0.
level (int, optional): The level should be in range [0, _MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum pixel's offset ratio for horizontal
translation. Defaults to 0.0.
max_mag (float): The maximum pixel's offset ratio for horizontal
translation. Defaults to 0.1.
reversal_prob (float): The probability that reverses the horizontal
translation magnitude. Should be in range [0,1]. Defaults to 0.5.
img_border_value (int | float | tuple): The filled values for
image border. If float, the same fill value will be used for
all the three channels of image. If tuple, it should be 3 elements.
Defaults to 128.
mask_border_value (int): The fill value used for masks. Defaults to 0.
seg_ignore_label (int): The fill value used for segmentation map.
Note this value must equals ``ignore_label`` in ``semantic_head``
of the corresponding config. Defaults to 255.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear" for 'pillow' backend. Defaults
to 'bilinear'.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.0,
max_mag: float = 0.1,
reversal_prob: float = 0.5,
img_border_value: Union[int, float, tuple] = 128,
mask_border_value: int = 0,
seg_ignore_label: int = 255,
interpolation: str = 'bilinear') -> None:
assert 0. <= min_mag <= 1., \
f'min_mag ratio for TranslateX should be ' \
f'in range [0, 1], got {min_mag}.'
assert 0. <= max_mag <= 1., \
f'max_mag ratio for TranslateX should be ' \
f'in range [0, 1], got {max_mag}.'
super().__init__(
prob=prob,
level=level,
min_mag=min_mag,
max_mag=max_mag,
reversal_prob=reversal_prob,
img_border_value=img_border_value,
mask_border_value=mask_border_value,
seg_ignore_label=seg_ignore_label,
interpolation=interpolation)
def _get_homography_matrix(self, results: dict, mag: float) -> np.ndarray:
"""Get the homography matrix for TranslateX."""
mag = int(results['img_shape'][1] * mag)
return np.array([[1, 0, mag], [0, 1, 0], [0, 0, 1]], dtype=np.float32)
def _transform_img(self, results: dict, mag: float) -> None:
"""Translate the image horizontally."""
mag = int(results['img_shape'][1] * mag)
results['img'] = mmcv.imtranslate(
results['img'],
mag,
direction='horizontal',
border_value=self.img_border_value,
interpolation=self.interpolation)
def _transform_masks(self, results: dict, mag: float) -> None:
"""Translate the masks horizontally."""
mag = int(results['img_shape'][1] * mag)
results['gt_masks'] = results['gt_masks'].translate(
results['img_shape'],
mag,
direction='horizontal',
border_value=self.mask_border_value,
interpolation=self.interpolation)
def _transform_seg(self, results: dict, mag: float) -> None:
"""Translate the segmentation map horizontally."""
mag = int(results['img_shape'][1] * mag)
results['gt_seg_map'] = mmcv.imtranslate(
results['gt_seg_map'],
mag,
direction='horizontal',
border_value=self.seg_ignore_label,
interpolation='nearest')
@TRANSFORMS.register_module()
class TranslateY(GeomTransform):
"""Translate the images, bboxes, masks and segmentation map vertically.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- gt_bboxes
- gt_masks
- gt_seg_map
Added Keys:
- homography_matrix
Args:
prob (float): The probability for perform transformation and
should be in range 0 to 1. Defaults to 1.0.
level (int, optional): The level should be in range [0, _MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum pixel's offset ratio for vertical
translation. Defaults to 0.0.
max_mag (float): The maximum pixel's offset ratio for vertical
translation. Defaults to 0.1.
reversal_prob (float): The probability that reverses the vertical
translation magnitude. Should be in range [0,1]. Defaults to 0.5.
img_border_value (int | float | tuple): The filled values for
image border. If float, the same fill value will be used for
all the three channels of image. If tuple, it should be 3 elements.
Defaults to 128.
mask_border_value (int): The fill value used for masks. Defaults to 0.
seg_ignore_label (int): The fill value used for segmentation map.
Note this value must equals ``ignore_label`` in ``semantic_head``
of the corresponding config. Defaults to 255.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear" for 'pillow' backend. Defaults
to 'bilinear'.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.0,
max_mag: float = 0.1,
reversal_prob: float = 0.5,
img_border_value: Union[int, float, tuple] = 128,
mask_border_value: int = 0,
seg_ignore_label: int = 255,
interpolation: str = 'bilinear') -> None:
assert 0. <= min_mag <= 1., \
f'min_mag ratio for TranslateY should be ' \
f'in range [0,1], got {min_mag}.'
assert 0. <= max_mag <= 1., \
f'max_mag ratio for TranslateY should be ' \
f'in range [0,1], got {max_mag}.'
super().__init__(
prob=prob,
level=level,
min_mag=min_mag,
max_mag=max_mag,
reversal_prob=reversal_prob,
img_border_value=img_border_value,
mask_border_value=mask_border_value,
seg_ignore_label=seg_ignore_label,
interpolation=interpolation)
def _get_homography_matrix(self, results: dict, mag: float) -> np.ndarray:
"""Get the homography matrix for TranslateY."""
mag = int(results['img_shape'][0] * mag)
return np.array([[1, 0, 0], [0, 1, mag], [0, 0, 1]], dtype=np.float32)
def _transform_img(self, results: dict, mag: float) -> None:
"""Translate the image vertically."""
mag = int(results['img_shape'][0] * mag)
results['img'] = mmcv.imtranslate(
results['img'],
mag,
direction='vertical',
border_value=self.img_border_value,
interpolation=self.interpolation)
def _transform_masks(self, results: dict, mag: float) -> None:
"""Translate masks vertically."""
mag = int(results['img_shape'][0] * mag)
results['gt_masks'] = results['gt_masks'].translate(
results['img_shape'],
mag,
direction='vertical',
border_value=self.mask_border_value,
interpolation=self.interpolation)
def _transform_seg(self, results: dict, mag: float) -> None:
"""Translate segmentation map vertically."""
mag = int(results['img_shape'][0] * mag)
results['gt_seg_map'] = mmcv.imtranslate(
results['gt_seg_map'],
mag,
direction='vertical',
border_value=self.seg_ignore_label,
interpolation='nearest')
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import numpy as np
from mmcv.transforms import BaseTransform
from mmdet.registry import TRANSFORMS
@TRANSFORMS.register_module()
class InstaBoost(BaseTransform):
r"""Data augmentation method in `InstaBoost: Boosting Instance
Segmentation Via Probability Map Guided Copy-Pasting
<https://arxiv.org/abs/1908.07801>`_.
Refer to https://github.com/GothicAi/Instaboost for implementation details.
Required Keys:
- img (np.uint8)
- instances
Modified Keys:
- img (np.uint8)
- instances
Args:
action_candidate (tuple): Action candidates. "normal", "horizontal", \
"vertical", "skip" are supported. Defaults to ('normal', \
'horizontal', 'skip').
action_prob (tuple): Corresponding action probabilities. Should be \
the same length as action_candidate. Defaults to (1, 0, 0).
scale (tuple): (min scale, max scale). Defaults to (0.8, 1.2).
dx (int): The maximum x-axis shift will be (instance width) / dx.
Defaults to 15.
dy (int): The maximum y-axis shift will be (instance height) / dy.
Defaults to 15.
theta (tuple): (min rotation degree, max rotation degree). \
Defaults to (-1, 1).
color_prob (float): Probability of images for color augmentation.
Defaults to 0.5.
hflag (bool): Whether to use heatmap guided. Defaults to False.
aug_ratio (float): Probability of applying this transformation. \
Defaults to 0.5.
"""
def __init__(self,
action_candidate: tuple = ('normal', 'horizontal', 'skip'),
action_prob: tuple = (1, 0, 0),
scale: tuple = (0.8, 1.2),
dx: int = 15,
dy: int = 15,
theta: tuple = (-1, 1),
color_prob: float = 0.5,
hflag: bool = False,
aug_ratio: float = 0.5) -> None:
import matplotlib
import matplotlib.pyplot as plt
default_backend = plt.get_backend()
try:
import instaboostfast as instaboost
except ImportError:
raise ImportError(
'Please run "pip install instaboostfast" '
'to install instaboostfast first for instaboost augmentation.')
# instaboost will modify the default backend
# and cause visualization to fail.
matplotlib.use(default_backend)
self.cfg = instaboost.InstaBoostConfig(action_candidate, action_prob,
scale, dx, dy, theta,
color_prob, hflag)
self.aug_ratio = aug_ratio
def _load_anns(self, results: dict) -> Tuple[list, list]:
"""Convert raw anns to instaboost expected input format."""
anns = []
ignore_anns = []
for instance in results['instances']:
label = instance['bbox_label']
bbox = instance['bbox']
mask = instance['mask']
x1, y1, x2, y2 = bbox
# assert (x2 - x1) >= 1 and (y2 - y1) >= 1
bbox = [x1, y1, x2 - x1, y2 - y1]
if instance['ignore_flag'] == 0:
anns.append({
'category_id': label,
'segmentation': mask,
'bbox': bbox
})
else:
# Ignore instances without data augmentation
ignore_anns.append(instance)
return anns, ignore_anns
def _parse_anns(self, results: dict, anns: list, ignore_anns: list,
img: np.ndarray) -> dict:
"""Restore the result of instaboost processing to the original anns
format."""
instances = []
for ann in anns:
x1, y1, w, h = ann['bbox']
# TODO: more essential bug need to be fixed in instaboost
if w <= 0 or h <= 0:
continue
bbox = [x1, y1, x1 + w, y1 + h]
instances.append(
dict(
bbox=bbox,
bbox_label=ann['category_id'],
mask=ann['segmentation'],
ignore_flag=0))
instances.extend(ignore_anns)
results['img'] = img
results['instances'] = instances
return results
def transform(self, results) -> dict:
"""The transform function."""
img = results['img']
ori_type = img.dtype
if 'instances' not in results or len(results['instances']) == 0:
return results
anns, ignore_anns = self._load_anns(results)
if np.random.choice([0, 1], p=[1 - self.aug_ratio, self.aug_ratio]):
try:
import instaboostfast as instaboost
except ImportError:
raise ImportError('Please run "pip install instaboostfast" '
'to install instaboostfast first.')
anns, img = instaboost.get_new_data(
anns, img.astype(np.uint8), self.cfg, background=None)
results = self._parse_anns(results, anns, ignore_anns,
img.astype(ori_type))
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(aug_ratio={self.aug_ratio})'
return repr_str
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple, Union
import mmcv
import numpy as np
import pycocotools.mask as maskUtils
import torch
from mmcv.transforms import BaseTransform
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
from mmcv.transforms import LoadImageFromFile
from mmengine.fileio import get
from mmengine.structures import BaseDataElement
from mmdet.registry import TRANSFORMS
from mmdet.structures.bbox import get_box_type
from mmdet.structures.bbox.box_type import autocast_box_type
from mmdet.structures.mask import BitmapMasks, PolygonMasks
@TRANSFORMS.register_module()
class LoadImageFromNDArray(LoadImageFromFile):
"""Load an image from ``results['img']``.
Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
:obj:`np.ndarray` in ``results['img']``. Can be used when loading image
from webcam.
Required Keys:
- img
Modified Keys:
- img
- img_path
- img_shape
- ori_shape
Args:
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
"""
def transform(self, results: dict) -> dict:
"""Transform function to add image meta information.
Args:
results (dict): Result dict with Webcam read image in
``results['img']``.
Returns:
dict: The dict contains loaded image and meta information.
"""
img = results['img']
if self.to_float32:
img = img.astype(np.float32)
results['img_path'] = None
results['img'] = img
results['img_shape'] = img.shape[:2]
results['ori_shape'] = img.shape[:2]
return results
@TRANSFORMS.register_module()
class LoadMultiChannelImageFromFiles(BaseTransform):
"""Load multi-channel images from a list of separate channel files.
Required Keys:
- img_path
Modified Keys:
- img
- img_shape
- ori_shape
Args:
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
color_type (str): The flag argument for :func:``mmcv.imfrombytes``.
Defaults to 'unchanged'.
imdecode_backend (str): The image decoding backend type. The backend
argument for :func:``mmcv.imfrombytes``.
See :func:``mmcv.imfrombytes`` for details.
Defaults to 'cv2'.
file_client_args (dict): Arguments to instantiate the
corresponding backend in mmdet <= 3.0.0rc6. Defaults to None.
backend_args (dict, optional): Arguments to instantiate the
corresponding backend in mmdet >= 3.0.0rc7. Defaults to None.
"""
def __init__(
self,
to_float32: bool = False,
color_type: str = 'unchanged',
imdecode_backend: str = 'cv2',
file_client_args: dict = None,
backend_args: dict = None,
) -> None:
self.to_float32 = to_float32
self.color_type = color_type
self.imdecode_backend = imdecode_backend
self.backend_args = backend_args
if file_client_args is not None:
raise RuntimeError(
'The `file_client_args` is deprecated, '
'please use `backend_args` instead, please refer to'
'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501
)
def transform(self, results: dict) -> dict:
"""Transform functions to load multiple images and get images meta
information.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded images and meta information.
"""
assert isinstance(results['img_path'], list)
img = []
for name in results['img_path']:
img_bytes = get(name, backend_args=self.backend_args)
img.append(
mmcv.imfrombytes(
img_bytes,
flag=self.color_type,
backend=self.imdecode_backend))
img = np.stack(img, axis=-1)
if self.to_float32:
img = img.astype(np.float32)
results['img'] = img
results['img_shape'] = img.shape[:2]
results['ori_shape'] = img.shape[:2]
return results
def __repr__(self):
repr_str = (f'{self.__class__.__name__}('
f'to_float32={self.to_float32}, '
f"color_type='{self.color_type}', "
f"imdecode_backend='{self.imdecode_backend}', "
f'backend_args={self.backend_args})')
return repr_str
@TRANSFORMS.register_module()
class LoadAnnotations(MMCV_LoadAnnotations):
"""Load and process the ``instances`` and ``seg_map`` annotation provided
by dataset.
The annotation format is as the following:
.. code-block:: python
{
'instances':
[
{
# List of 4 numbers representing the bounding box of the
# instance, in (x1, y1, x2, y2) order.
'bbox': [x1, y1, x2, y2],
# Label of image classification.
'bbox_label': 1,
# Used in instance/panoptic segmentation. The segmentation mask
# of the instance or the information of segments.
# 1. If list[list[float]], it represents a list of polygons,
# one for each connected component of the object. Each
# list[float] is one simple polygon in the format of
# [x1, y1, ..., xn, yn] (n >= 3). The Xs and Ys are absolute
# coordinates in unit of pixels.
# 2. If dict, it represents the per-pixel segmentation mask in
# COCO's compressed RLE format. The dict should have keys
# “size” and “counts”. Can be loaded by pycocotools
'mask': list[list[float]] or dict,
}
]
# Filename of semantic or panoptic segmentation ground truth file.
'seg_map_path': 'a/b/c'
}
After this module, the annotation has been changed to the format below:
.. code-block:: python
{
# In (x1, y1, x2, y2) order, float type. N is the number of bboxes
# in an image
'gt_bboxes': BaseBoxes(N, 4)
# In int type.
'gt_bboxes_labels': np.ndarray(N, )
# In built-in class
'gt_masks': PolygonMasks (H, W) or BitmapMasks (H, W)
# In uint8 type.
'gt_seg_map': np.ndarray (H, W)
# in (x, y, v) order, float type.
}
Required Keys:
- height
- width
- instances
- bbox (optional)
- bbox_label
- mask (optional)
- ignore_flag
- seg_map_path (optional)
Added Keys:
- gt_bboxes (BaseBoxes[torch.float32])
- gt_bboxes_labels (np.int64)
- gt_masks (BitmapMasks | PolygonMasks)
- gt_seg_map (np.uint8)
- gt_ignore_flags (bool)
Args:
with_bbox (bool): Whether to parse and load the bbox annotation.
Defaults to True.
with_label (bool): Whether to parse and load the label annotation.
Defaults to True.
with_mask (bool): Whether to parse and load the mask annotation.
Default: False.
with_seg (bool): Whether to parse and load the semantic segmentation
annotation. Defaults to False.
poly2mask (bool): Whether to convert mask to bitmap. Default: True.
box_type (str): The box type used to wrap the bboxes. If ``box_type``
is None, gt_bboxes will keep being np.ndarray. Defaults to 'hbox'.
reduce_zero_label (bool): Whether reduce all label value
by 1. Usually used for datasets where 0 is background label.
Defaults to False.
ignore_index (int): The label index to be ignored.
Valid only if reduce_zero_label is true. Defaults is 255.
imdecode_backend (str): The image decoding backend type. The backend
argument for :func:``mmcv.imfrombytes``.
See :fun:``mmcv.imfrombytes`` for details.
Defaults to 'cv2'.
backend_args (dict, optional): Arguments to instantiate the
corresponding backend. Defaults to None.
"""
def __init__(
self,
with_mask: bool = False,
poly2mask: bool = True,
box_type: str = 'hbox',
# use for semseg
reduce_zero_label: bool = False,
ignore_index: int = 255,
**kwargs) -> None:
super(LoadAnnotations, self).__init__(**kwargs)
self.with_mask = with_mask
self.poly2mask = poly2mask
self.box_type = box_type
self.reduce_zero_label = reduce_zero_label
self.ignore_index = ignore_index
def _load_bboxes(self, results: dict) -> None:
"""Private function to load bounding box annotations.
Args:
results (dict): Result dict from :obj:``mmengine.BaseDataset``.
Returns:
dict: The dict contains loaded bounding box annotations.
"""
gt_bboxes = []
gt_ignore_flags = []
for instance in results.get('instances', []):
gt_bboxes.append(instance['bbox'])
gt_ignore_flags.append(instance['ignore_flag'])
if self.box_type is None:
results['gt_bboxes'] = np.array(
gt_bboxes, dtype=np.float32).reshape((-1, 4))
else:
_, box_type_cls = get_box_type(self.box_type)
results['gt_bboxes'] = box_type_cls(gt_bboxes, dtype=torch.float32)
results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool)
def _load_labels(self, results: dict) -> None:
"""Private function to load label annotations.
Args:
results (dict): Result dict from :obj:``mmengine.BaseDataset``.
Returns:
dict: The dict contains loaded label annotations.
"""
gt_bboxes_labels = []
for instance in results.get('instances', []):
gt_bboxes_labels.append(instance['bbox_label'])
# TODO: Inconsistent with mmcv, consider how to deal with it later.
results['gt_bboxes_labels'] = np.array(
gt_bboxes_labels, dtype=np.int64)
def _poly2mask(self, mask_ann: Union[list, dict], img_h: int,
img_w: int) -> np.ndarray:
"""Private function to convert masks represented with polygon to
bitmaps.
Args:
mask_ann (list | dict): Polygon mask annotation input.
img_h (int): The height of output mask.
img_w (int): The width of output mask.
Returns:
np.ndarray: The decode bitmap mask of shape (img_h, img_w).
"""
if isinstance(mask_ann, list):
# polygon -- a single object might consist of multiple parts
# we merge all parts into one mask rle code
rles = maskUtils.frPyObjects(mask_ann, img_h, img_w)
rle = maskUtils.merge(rles)
elif isinstance(mask_ann['counts'], list):
# uncompressed RLE
rle = maskUtils.frPyObjects(mask_ann, img_h, img_w)
else:
# rle
rle = mask_ann
mask = maskUtils.decode(rle)
return mask
def _process_masks(self, results: dict) -> list:
"""Process gt_masks and filter invalid polygons.
Args:
results (dict): Result dict from :obj:``mmengine.BaseDataset``.
Returns:
list: Processed gt_masks.
"""
gt_masks = []
gt_ignore_flags = []
for instance in results.get('instances', []):
gt_mask = instance['mask']
# If the annotation of segmentation mask is invalid,
# ignore the whole instance.
if isinstance(gt_mask, list):
gt_mask = [
np.array(polygon) for polygon in gt_mask
if len(polygon) % 2 == 0 and len(polygon) >= 6
]
if len(gt_mask) == 0:
# ignore this instance and set gt_mask to a fake mask
instance['ignore_flag'] = 1
gt_mask = [np.zeros(6)]
elif not self.poly2mask:
# `PolygonMasks` requires a ploygon of format List[np.array],
# other formats are invalid.
instance['ignore_flag'] = 1
gt_mask = [np.zeros(6)]
elif isinstance(gt_mask, dict) and \
not (gt_mask.get('counts') is not None and
gt_mask.get('size') is not None and
isinstance(gt_mask['counts'], (list, str))):
# if gt_mask is a dict, it should include `counts` and `size`,
# so that `BitmapMasks` can uncompressed RLE
instance['ignore_flag'] = 1
gt_mask = [np.zeros(6)]
gt_masks.append(gt_mask)
# re-process gt_ignore_flags
gt_ignore_flags.append(instance['ignore_flag'])
results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool)
return gt_masks
def _load_masks(self, results: dict) -> None:
"""Private function to load mask annotations.
Args:
results (dict): Result dict from :obj:``mmengine.BaseDataset``.
"""
h, w = results['ori_shape']
gt_masks = self._process_masks(results)
if self.poly2mask:
gt_masks = BitmapMasks(
[self._poly2mask(mask, h, w) for mask in gt_masks], h, w)
else:
# fake polygon masks will be ignored in `PackDetInputs`
gt_masks = PolygonMasks([mask for mask in gt_masks], h, w)
results['gt_masks'] = gt_masks
def _load_seg_map(self, results: dict) -> None:
"""Private function to load semantic segmentation annotations.
Args:
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
Returns:
dict: The dict contains loaded semantic segmentation annotations.
"""
if results.get('seg_map_path', None) is None:
return
img_bytes = get(
results['seg_map_path'], backend_args=self.backend_args)
gt_semantic_seg = mmcv.imfrombytes(
img_bytes, flag='unchanged',
backend=self.imdecode_backend).squeeze()
if self.reduce_zero_label:
# avoid using underflow conversion
gt_semantic_seg[gt_semantic_seg == 0] = self.ignore_index
gt_semantic_seg = gt_semantic_seg - 1
gt_semantic_seg[gt_semantic_seg == self.ignore_index -
1] = self.ignore_index
# modify if custom classes
if results.get('label_map', None) is not None:
# Add deep copy to solve bug of repeatedly
# replace `gt_semantic_seg`, which is reported in
# https://github.com/open-mmlab/mmsegmentation/pull/1445/
gt_semantic_seg_copy = gt_semantic_seg.copy()
for old_id, new_id in results['label_map'].items():
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
results['gt_seg_map'] = gt_semantic_seg
results['ignore_index'] = self.ignore_index
def transform(self, results: dict) -> dict:
"""Function to load multiple types annotations.
Args:
results (dict): Result dict from :obj:``mmengine.BaseDataset``.
Returns:
dict: The dict contains loaded bounding box, label and
semantic segmentation.
"""
if self.with_bbox:
self._load_bboxes(results)
if self.with_label:
self._load_labels(results)
if self.with_mask:
self._load_masks(results)
if self.with_seg:
self._load_seg_map(results)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(with_bbox={self.with_bbox}, '
repr_str += f'with_label={self.with_label}, '
repr_str += f'with_mask={self.with_mask}, '
repr_str += f'with_seg={self.with_seg}, '
repr_str += f'poly2mask={self.poly2mask}, '
repr_str += f"imdecode_backend='{self.imdecode_backend}', "
repr_str += f'backend_args={self.backend_args})'
return repr_str
@TRANSFORMS.register_module()
class LoadPanopticAnnotations(LoadAnnotations):
"""Load multiple types of panoptic annotations.
The annotation format is as the following:
.. code-block:: python
{
'instances':
[
{
# List of 4 numbers representing the bounding box of the
# instance, in (x1, y1, x2, y2) order.
'bbox': [x1, y1, x2, y2],
# Label of image classification.
'bbox_label': 1,
},
...
]
'segments_info':
[
{
# id = cls_id + instance_id * INSTANCE_OFFSET
'id': int,
# Contiguous category id defined in dataset.
'category': int
# Thing flag.
'is_thing': bool
},
...
]
# Filename of semantic or panoptic segmentation ground truth file.
'seg_map_path': 'a/b/c'
}
After this module, the annotation has been changed to the format below:
.. code-block:: python
{
# In (x1, y1, x2, y2) order, float type. N is the number of bboxes
# in an image
'gt_bboxes': BaseBoxes(N, 4)
# In int type.
'gt_bboxes_labels': np.ndarray(N, )
# In built-in class
'gt_masks': PolygonMasks (H, W) or BitmapMasks (H, W)
# In uint8 type.
'gt_seg_map': np.ndarray (H, W)
# in (x, y, v) order, float type.
}
Required Keys:
- height
- width
- instances
- bbox
- bbox_label
- ignore_flag
- segments_info
- id
- category
- is_thing
- seg_map_path
Added Keys:
- gt_bboxes (BaseBoxes[torch.float32])
- gt_bboxes_labels (np.int64)
- gt_masks (BitmapMasks | PolygonMasks)
- gt_seg_map (np.uint8)
- gt_ignore_flags (bool)
Args:
with_bbox (bool): Whether to parse and load the bbox annotation.
Defaults to True.
with_label (bool): Whether to parse and load the label annotation.
Defaults to True.
with_mask (bool): Whether to parse and load the mask annotation.
Defaults to True.
with_seg (bool): Whether to parse and load the semantic segmentation
annotation. Defaults to False.
box_type (str): The box mode used to wrap the bboxes.
imdecode_backend (str): The image decoding backend type. The backend
argument for :func:``mmcv.imfrombytes``.
See :fun:``mmcv.imfrombytes`` for details.
Defaults to 'cv2'.
backend_args (dict, optional): Arguments to instantiate the
corresponding backend in mmdet >= 3.0.0rc7. Defaults to None.
"""
def __init__(self,
with_bbox: bool = True,
with_label: bool = True,
with_mask: bool = True,
with_seg: bool = True,
box_type: str = 'hbox',
imdecode_backend: str = 'cv2',
backend_args: dict = None) -> None:
try:
from panopticapi import utils
except ImportError:
raise ImportError(
'panopticapi is not installed, please install it by: '
'pip install git+https://github.com/cocodataset/'
'panopticapi.git.')
self.rgb2id = utils.rgb2id
super(LoadPanopticAnnotations, self).__init__(
with_bbox=with_bbox,
with_label=with_label,
with_mask=with_mask,
with_seg=with_seg,
with_keypoints=False,
box_type=box_type,
imdecode_backend=imdecode_backend,
backend_args=backend_args)
def _load_masks_and_semantic_segs(self, results: dict) -> None:
"""Private function to load mask and semantic segmentation annotations.
In gt_semantic_seg, the foreground label is from ``0`` to
``num_things - 1``, the background label is from ``num_things`` to
``num_things + num_stuff - 1``, 255 means the ignored label (``VOID``).
Args:
results (dict): Result dict from :obj:``mmdet.CustomDataset``.
"""
# seg_map_path is None, when inference on the dataset without gts.
if results.get('seg_map_path', None) is None:
return
img_bytes = get(
results['seg_map_path'], backend_args=self.backend_args)
pan_png = mmcv.imfrombytes(
img_bytes, flag='color', channel_order='rgb').squeeze()
pan_png = self.rgb2id(pan_png)
gt_masks = []
gt_seg = np.zeros_like(pan_png) + 255 # 255 as ignore
for segment_info in results['segments_info']:
mask = (pan_png == segment_info['id'])
gt_seg = np.where(mask, segment_info['category'], gt_seg)
# The legal thing masks
if segment_info.get('is_thing'):
gt_masks.append(mask.astype(np.uint8))
if self.with_mask:
h, w = results['ori_shape']
gt_masks = BitmapMasks(gt_masks, h, w)
results['gt_masks'] = gt_masks
if self.with_seg:
results['gt_seg_map'] = gt_seg
def transform(self, results: dict) -> dict:
"""Function to load multiple types panoptic annotations.
Args:
results (dict): Result dict from :obj:``mmdet.CustomDataset``.
Returns:
dict: The dict contains loaded bounding box, label, mask and
semantic segmentation annotations.
"""
if self.with_bbox:
self._load_bboxes(results)
if self.with_label:
self._load_labels(results)
if self.with_mask or self.with_seg:
# The tasks completed by '_load_masks' and '_load_semantic_segs'
# in LoadAnnotations are merged to one function.
self._load_masks_and_semantic_segs(results)
return results
@TRANSFORMS.register_module()
class LoadProposals(BaseTransform):
"""Load proposal pipeline.
Required Keys:
- proposals
Modified Keys:
- proposals
Args:
num_max_proposals (int, optional): Maximum number of proposals to load.
If not specified, all proposals will be loaded.
"""
def __init__(self, num_max_proposals: Optional[int] = None) -> None:
self.num_max_proposals = num_max_proposals
def transform(self, results: dict) -> dict:
"""Transform function to load proposals from file.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded proposal annotations.
"""
proposals = results['proposals']
# the type of proposals should be `dict` or `InstanceData`
assert isinstance(proposals, dict) \
or isinstance(proposals, BaseDataElement)
bboxes = proposals['bboxes'].astype(np.float32)
assert bboxes.shape[1] == 4, \
f'Proposals should have shapes (n, 4), but found {bboxes.shape}'
if 'scores' in proposals:
scores = proposals['scores'].astype(np.float32)
assert bboxes.shape[0] == scores.shape[0]
else:
scores = np.zeros(bboxes.shape[0], dtype=np.float32)
if self.num_max_proposals is not None:
# proposals should sort by scores during dumping the proposals
bboxes = bboxes[:self.num_max_proposals]
scores = scores[:self.num_max_proposals]
if len(bboxes) == 0:
bboxes = np.zeros((0, 4), dtype=np.float32)
scores = np.zeros(0, dtype=np.float32)
results['proposals'] = bboxes
results['proposals_scores'] = scores
return results
def __repr__(self):
return self.__class__.__name__ + \
f'(num_max_proposals={self.num_max_proposals})'
@TRANSFORMS.register_module()
class FilterAnnotations(BaseTransform):
"""Filter invalid annotations.
Required Keys:
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_bboxes_labels (np.int64) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_ignore_flags (bool) (optional)
Modified Keys:
- gt_bboxes (optional)
- gt_bboxes_labels (optional)
- gt_masks (optional)
- gt_ignore_flags (optional)
Args:
min_gt_bbox_wh (tuple[float]): Minimum width and height of ground truth
boxes. Default: (1., 1.)
min_gt_mask_area (int): Minimum foreground area of ground truth masks.
Default: 1
by_box (bool): Filter instances with bounding boxes not meeting the
min_gt_bbox_wh threshold. Default: True
by_mask (bool): Filter instances with masks not meeting
min_gt_mask_area threshold. Default: False
keep_empty (bool): Whether to return None when it
becomes an empty bbox after filtering. Defaults to True.
"""
def __init__(self,
min_gt_bbox_wh: Tuple[int, int] = (1, 1),
min_gt_mask_area: int = 1,
by_box: bool = True,
by_mask: bool = False,
keep_empty: bool = True) -> None:
# TODO: add more filter options
assert by_box or by_mask
self.min_gt_bbox_wh = min_gt_bbox_wh
self.min_gt_mask_area = min_gt_mask_area
self.by_box = by_box
self.by_mask = by_mask
self.keep_empty = keep_empty
@autocast_box_type()
def transform(self, results: dict) -> Union[dict, None]:
"""Transform function to filter annotations.
Args:
results (dict): Result dict.
Returns:
dict: Updated result dict.
"""
assert 'gt_bboxes' in results
gt_bboxes = results['gt_bboxes']
if gt_bboxes.shape[0] == 0:
return results
tests = []
if self.by_box:
tests.append(
((gt_bboxes.widths > self.min_gt_bbox_wh[0]) &
(gt_bboxes.heights > self.min_gt_bbox_wh[1])).numpy())
if self.by_mask:
assert 'gt_masks' in results
gt_masks = results['gt_masks']
tests.append(gt_masks.areas >= self.min_gt_mask_area)
keep = tests[0]
for t in tests[1:]:
keep = keep & t
if not keep.any():
if self.keep_empty:
return None
keys = ('gt_bboxes', 'gt_bboxes_labels', 'gt_masks', 'gt_ignore_flags')
for key in keys:
if key in results:
results[key] = results[key][keep]
return results
def __repr__(self):
return self.__class__.__name__ + \
f'(min_gt_bbox_wh={self.min_gt_bbox_wh}, ' \
f'keep_empty={self.keep_empty})'
@TRANSFORMS.register_module()
class LoadEmptyAnnotations(BaseTransform):
"""Load Empty Annotations for unlabeled images.
Added Keys:
- gt_bboxes (np.float32)
- gt_bboxes_labels (np.int64)
- gt_masks (BitmapMasks | PolygonMasks)
- gt_seg_map (np.uint8)
- gt_ignore_flags (bool)
Args:
with_bbox (bool): Whether to load the pseudo bbox annotation.
Defaults to True.
with_label (bool): Whether to load the pseudo label annotation.
Defaults to True.
with_mask (bool): Whether to load the pseudo mask annotation.
Default: False.
with_seg (bool): Whether to load the pseudo semantic segmentation
annotation. Defaults to False.
seg_ignore_label (int): The fill value used for segmentation map.
Note this value must equals ``ignore_label`` in ``semantic_head``
of the corresponding config. Defaults to 255.
"""
def __init__(self,
with_bbox: bool = True,
with_label: bool = True,
with_mask: bool = False,
with_seg: bool = False,
seg_ignore_label: int = 255) -> None:
self.with_bbox = with_bbox
self.with_label = with_label
self.with_mask = with_mask
self.with_seg = with_seg
self.seg_ignore_label = seg_ignore_label
def transform(self, results: dict) -> dict:
"""Transform function to load empty annotations.
Args:
results (dict): Result dict.
Returns:
dict: Updated result dict.
"""
if self.with_bbox:
results['gt_bboxes'] = np.zeros((0, 4), dtype=np.float32)
results['gt_ignore_flags'] = np.zeros((0, ), dtype=bool)
if self.with_label:
results['gt_bboxes_labels'] = np.zeros((0, ), dtype=np.int64)
if self.with_mask:
# TODO: support PolygonMasks
h, w = results['img_shape']
gt_masks = np.zeros((0, h, w), dtype=np.uint8)
results['gt_masks'] = BitmapMasks(gt_masks, h, w)
if self.with_seg:
h, w = results['img_shape']
results['gt_seg_map'] = self.seg_ignore_label * np.ones(
(h, w), dtype=np.uint8)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(with_bbox={self.with_bbox}, '
repr_str += f'with_label={self.with_label}, '
repr_str += f'with_mask={self.with_mask}, '
repr_str += f'with_seg={self.with_seg}, '
repr_str += f'seg_ignore_label={self.seg_ignore_label})'
return repr_str
@TRANSFORMS.register_module()
class InferencerLoader(BaseTransform):
"""Load an image from ``results['img']``.
Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
:obj:`np.ndarray` in ``results['img']``. Can be used when loading image
from webcam.
Required Keys:
- img
Modified Keys:
- img
- img_path
- img_shape
- ori_shape
Args:
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
"""
def __init__(self, **kwargs) -> None:
super().__init__()
self.from_file = TRANSFORMS.build(
dict(type='LoadImageFromFile', **kwargs))
self.from_ndarray = TRANSFORMS.build(
dict(type='mmdet.LoadImageFromNDArray', **kwargs))
def transform(self, results: Union[str, np.ndarray, dict]) -> dict:
"""Transform function to add image meta information.
Args:
results (str, np.ndarray or dict): The result.
Returns:
dict: The dict contains loaded image and meta information.
"""
if isinstance(results, str):
inputs = dict(img_path=results)
elif isinstance(results, np.ndarray):
inputs = dict(img=results)
elif isinstance(results, dict):
inputs = results
else:
raise NotImplementedError
if 'img' in inputs:
return self.from_ndarray(inputs)
return self.from_file(inputs)
@TRANSFORMS.register_module()
class LoadTrackAnnotations(LoadAnnotations):
"""Load and process the ``instances`` and ``seg_map`` annotation provided
by dataset. It must load ``instances_ids`` which is only used in the
tracking tasks. The annotation format is as the following:
.. code-block:: python
{
'instances':
[
{
# List of 4 numbers representing the bounding box of the
# instance, in (x1, y1, x2, y2) order.
'bbox': [x1, y1, x2, y2],
# Label of image classification.
'bbox_label': 1,
# Used in tracking.
# Id of instances.
'instance_id': 100,
# Used in instance/panoptic segmentation. The segmentation mask
# of the instance or the information of segments.
# 1. If list[list[float]], it represents a list of polygons,
# one for each connected component of the object. Each
# list[float] is one simple polygon in the format of
# [x1, y1, ..., xn, yn] (n >= 3). The Xs and Ys are absolute
# coordinates in unit of pixels.
# 2. If dict, it represents the per-pixel segmentation mask in
# COCO's compressed RLE format. The dict should have keys
# “size” and “counts”. Can be loaded by pycocotools
'mask': list[list[float]] or dict,
}
]
# Filename of semantic or panoptic segmentation ground truth file.
'seg_map_path': 'a/b/c'
}
After this module, the annotation has been changed to the format below:
.. code-block:: python
{
# In (x1, y1, x2, y2) order, float type. N is the number of bboxes
# in an image
'gt_bboxes': np.ndarray(N, 4)
# In int type.
'gt_bboxes_labels': np.ndarray(N, )
# In built-in class
'gt_masks': PolygonMasks (H, W) or BitmapMasks (H, W)
# In uint8 type.
'gt_seg_map': np.ndarray (H, W)
# in (x, y, v) order, float type.
}
Required Keys:
- height (optional)
- width (optional)
- instances
- bbox (optional)
- bbox_label
- instance_id (optional)
- mask (optional)
- ignore_flag (optional)
- seg_map_path (optional)
Added Keys:
- gt_bboxes (np.float32)
- gt_bboxes_labels (np.int32)
- gt_instances_ids (np.int32)
- gt_masks (BitmapMasks | PolygonMasks)
- gt_seg_map (np.uint8)
- gt_ignore_flags (np.bool)
"""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
def _load_bboxes(self, results: dict) -> None:
"""Private function to load bounding box annotations.
Args:
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
Returns:
dict: The dict contains loaded bounding box annotations.
"""
gt_bboxes = []
gt_ignore_flags = []
# TODO: use bbox_type
for instance in results['instances']:
# The datasets which are only format in evaluation don't have
# groundtruth boxes.
if 'bbox' in instance:
gt_bboxes.append(instance['bbox'])
if 'ignore_flag' in instance:
gt_ignore_flags.append(instance['ignore_flag'])
# TODO: check this case
if len(gt_bboxes) != len(gt_ignore_flags):
# There may be no ``gt_ignore_flags`` in some cases, we treat them
# as all False in order to keep the length of ``gt_bboxes`` and
# ``gt_ignore_flags`` the same
gt_ignore_flags = [False] * len(gt_bboxes)
results['gt_bboxes'] = np.array(
gt_bboxes, dtype=np.float32).reshape(-1, 4)
results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool)
def _load_instances_ids(self, results: dict) -> None:
"""Private function to load instances id annotations.
Args:
results (dict): Result dict from :obj :obj:``mmcv.BaseDataset``.
Returns:
dict: The dict containing instances id annotations.
"""
gt_instances_ids = []
for instance in results['instances']:
gt_instances_ids.append(instance['instance_id'])
results['gt_instances_ids'] = np.array(
gt_instances_ids, dtype=np.int32)
def transform(self, results: dict) -> dict:
"""Function to load multiple types annotations.
Args:
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
Returns:
dict: The dict contains loaded bounding box, label, instances id
and semantic segmentation and keypoints annotations.
"""
results = super().transform(results)
self._load_instances_ids(results)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(with_bbox={self.with_bbox}, '
repr_str += f'with_label={self.with_label}, '
repr_str += f'with_mask={self.with_mask}, '
repr_str += f'with_seg={self.with_seg}, '
repr_str += f'poly2mask={self.poly2mask}, '
repr_str += f"imdecode_backend='{self.imdecode_backend}', "
repr_str += f'file_client_args={self.file_client_args})'
return repr_str
# Copyright (c) OpenMMLab. All rights reserved.
import json
from mmcv.transforms import BaseTransform
from mmdet.registry import TRANSFORMS
from mmdet.structures.bbox import BaseBoxes
try:
from transformers import AutoTokenizer
from transformers import BertModel as HFBertModel
except ImportError:
AutoTokenizer = None
HFBertModel = None
import random
import re
import numpy as np
def clean_name(name):
name = re.sub(r'\(.*\)', '', name)
name = re.sub(r'_', ' ', name)
name = re.sub(r' ', ' ', name)
name = name.lower()
return name
def check_for_positive_overflow(gt_bboxes, gt_labels, text, tokenizer,
max_tokens):
# Check if we have too many positive labels
# generate a caption by appending the positive labels
positive_label_list = np.unique(gt_labels).tolist()
# random shuffule so we can sample different annotations
# at different epochs
random.shuffle(positive_label_list)
kept_lables = []
length = 0
for index, label in enumerate(positive_label_list):
label_text = clean_name(text[str(label)]) + '. '
tokenized = tokenizer.tokenize(label_text)
length += len(tokenized)
if length > max_tokens:
break
else:
kept_lables.append(label)
keep_box_index = []
keep_gt_labels = []
for i in range(len(gt_labels)):
if gt_labels[i] in kept_lables:
keep_box_index.append(i)
keep_gt_labels.append(gt_labels[i])
return gt_bboxes[keep_box_index], np.array(
keep_gt_labels, dtype=np.long), length
def generate_senetence_given_labels(positive_label_list, negative_label_list,
text):
label_to_positions = {}
label_list = negative_label_list + positive_label_list
random.shuffle(label_list)
pheso_caption = ''
label_remap_dict = {}
for index, label in enumerate(label_list):
start_index = len(pheso_caption)
pheso_caption += clean_name(text[str(label)])
end_index = len(pheso_caption)
if label in positive_label_list:
label_to_positions[index] = [[start_index, end_index]]
label_remap_dict[int(label)] = index
# if index != len(label_list) - 1:
# pheso_caption += '. '
pheso_caption += '. '
return label_to_positions, pheso_caption, label_remap_dict
@TRANSFORMS.register_module()
class RandomSamplingNegPos(BaseTransform):
def __init__(self,
tokenizer_name,
num_sample_negative=85,
max_tokens=256,
full_sampling_prob=0.5,
label_map_file=None):
if AutoTokenizer is None:
raise RuntimeError(
'transformers is not installed, please install it by: '
'pip install transformers.')
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.num_sample_negative = num_sample_negative
self.full_sampling_prob = full_sampling_prob
self.max_tokens = max_tokens
self.label_map = None
if label_map_file:
with open(label_map_file, 'r') as file:
self.label_map = json.load(file)
def transform(self, results: dict) -> dict:
if 'phrases' in results:
return self.vg_aug(results)
else:
return self.od_aug(results)
def vg_aug(self, results):
gt_bboxes = results['gt_bboxes']
if isinstance(gt_bboxes, BaseBoxes):
gt_bboxes = gt_bboxes.tensor
gt_labels = results['gt_bboxes_labels']
text = results['text'].lower().strip()
if not text.endswith('.'):
text = text + '. '
phrases = results['phrases']
# TODO: add neg
positive_label_list = np.unique(gt_labels).tolist()
label_to_positions = {}
for label in positive_label_list:
label_to_positions[label] = phrases[label]['tokens_positive']
results['gt_bboxes'] = gt_bboxes
results['gt_bboxes_labels'] = gt_labels
results['text'] = text
results['tokens_positive'] = label_to_positions
return results
def od_aug(self, results):
gt_bboxes = results['gt_bboxes']
if isinstance(gt_bboxes, BaseBoxes):
gt_bboxes = gt_bboxes.tensor
gt_labels = results['gt_bboxes_labels']
if 'text' not in results:
assert self.label_map is not None
text = self.label_map
else:
text = results['text']
original_box_num = len(gt_labels)
# If the category name is in the format of 'a/b' (in object365),
# we randomly select one of them.
for key, value in text.items():
if '/' in value:
text[key] = random.choice(value.split('/')).strip()
gt_bboxes, gt_labels, positive_caption_length = \
check_for_positive_overflow(gt_bboxes, gt_labels,
text, self.tokenizer, self.max_tokens)
if len(gt_bboxes) < original_box_num:
print('WARNING: removed {} boxes due to positive caption overflow'.
format(original_box_num - len(gt_bboxes)))
valid_negative_indexes = list(text.keys())
positive_label_list = np.unique(gt_labels).tolist()
full_negative = self.num_sample_negative
if full_negative > len(valid_negative_indexes):
full_negative = len(valid_negative_indexes)
outer_prob = random.random()
if outer_prob < self.full_sampling_prob:
# c. probability_full: add both all positive and all negatives
num_negatives = full_negative
else:
if random.random() < 1.0:
num_negatives = np.random.choice(max(1, full_negative)) + 1
else:
num_negatives = full_negative
# Keep some negatives
negative_label_list = set()
if num_negatives != -1:
if num_negatives > len(valid_negative_indexes):
num_negatives = len(valid_negative_indexes)
for i in np.random.choice(
valid_negative_indexes, size=num_negatives, replace=False):
if int(i) not in positive_label_list:
negative_label_list.add(i)
random.shuffle(positive_label_list)
negative_label_list = list(negative_label_list)
random.shuffle(negative_label_list)
negative_max_length = self.max_tokens - positive_caption_length
screened_negative_label_list = []
for negative_label in negative_label_list:
label_text = clean_name(text[str(negative_label)]) + '. '
tokenized = self.tokenizer.tokenize(label_text)
negative_max_length -= len(tokenized)
if negative_max_length > 0:
screened_negative_label_list.append(negative_label)
else:
break
negative_label_list = screened_negative_label_list
label_to_positions, pheso_caption, label_remap_dict = \
generate_senetence_given_labels(positive_label_list,
negative_label_list, text)
# label remap
if len(gt_labels) > 0:
gt_labels = np.vectorize(lambda x: label_remap_dict[x])(gt_labels)
results['gt_bboxes'] = gt_bboxes
results['gt_bboxes_labels'] = gt_labels
results['text'] = pheso_caption
results['tokens_positive'] = label_to_positions
return results
@TRANSFORMS.register_module()
class LoadTextAnnotations(BaseTransform):
def transform(self, results: dict) -> dict:
if 'phrases' in results:
tokens_positive = [
phrase['tokens_positive']
for phrase in results['phrases'].values()
]
results['tokens_positive'] = tokens_positive
else:
text = results['text']
results['text'] = list(text.values())
return results
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment