Commit 20e33356 authored by luopl's avatar luopl
Browse files

init

parents
Pipeline #1587 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
import collections
import copy
from typing import List, Sequence, Union
from mmengine.dataset import BaseDataset
from mmengine.dataset import ConcatDataset as MMENGINE_ConcatDataset
from mmengine.dataset import force_full_init
from mmdet.registry import DATASETS, TRANSFORMS
@DATASETS.register_module()
class MultiImageMixDataset:
"""A wrapper of multiple images mixed dataset.
Suitable for training on multiple images mixed data augmentation like
mosaic and mixup. For the augmentation pipeline of mixed image data,
the `get_indexes` method needs to be provided to obtain the image
indexes, and you can set `skip_flags` to change the pipeline running
process. At the same time, we provide the `dynamic_scale` parameter
to dynamically change the output image size.
Args:
dataset (:obj:`CustomDataset`): The dataset to be mixed.
pipeline (Sequence[dict]): Sequence of transform object or
config dict to be composed.
dynamic_scale (tuple[int], optional): The image scale can be changed
dynamically. Default to None. It is deprecated.
skip_type_keys (list[str], optional): Sequence of type string to
be skip pipeline. Default to None.
max_refetch (int): The maximum number of retry iterations for getting
valid results from the pipeline. If the number of iterations is
greater than `max_refetch`, but results is still None, then the
iteration is terminated and raise the error. Default: 15.
"""
def __init__(self,
dataset: Union[BaseDataset, dict],
pipeline: Sequence[str],
skip_type_keys: Union[Sequence[str], None] = None,
max_refetch: int = 15,
lazy_init: bool = False) -> None:
assert isinstance(pipeline, collections.abc.Sequence)
if skip_type_keys is not None:
assert all([
isinstance(skip_type_key, str)
for skip_type_key in skip_type_keys
])
self._skip_type_keys = skip_type_keys
self.pipeline = []
self.pipeline_types = []
for transform in pipeline:
if isinstance(transform, dict):
self.pipeline_types.append(transform['type'])
transform = TRANSFORMS.build(transform)
self.pipeline.append(transform)
else:
raise TypeError('pipeline must be a dict')
self.dataset: BaseDataset
if isinstance(dataset, dict):
self.dataset = DATASETS.build(dataset)
elif isinstance(dataset, BaseDataset):
self.dataset = dataset
else:
raise TypeError(
'elements in datasets sequence should be config or '
f'`BaseDataset` instance, but got {type(dataset)}')
self._metainfo = self.dataset.metainfo
if hasattr(self.dataset, 'flag'):
self.flag = self.dataset.flag
self.num_samples = len(self.dataset)
self.max_refetch = max_refetch
self._fully_initialized = False
if not lazy_init:
self.full_init()
@property
def metainfo(self) -> dict:
"""Get the meta information of the multi-image-mixed dataset.
Returns:
dict: The meta information of multi-image-mixed dataset.
"""
return copy.deepcopy(self._metainfo)
def full_init(self):
"""Loop to ``full_init`` each dataset."""
if self._fully_initialized:
return
self.dataset.full_init()
self._ori_len = len(self.dataset)
self._fully_initialized = True
@force_full_init
def get_data_info(self, idx: int) -> dict:
"""Get annotation by index.
Args:
idx (int): Global index of ``ConcatDataset``.
Returns:
dict: The idx-th annotation of the datasets.
"""
return self.dataset.get_data_info(idx)
@force_full_init
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
results = copy.deepcopy(self.dataset[idx])
for (transform, transform_type) in zip(self.pipeline,
self.pipeline_types):
if self._skip_type_keys is not None and \
transform_type in self._skip_type_keys:
continue
if hasattr(transform, 'get_indexes'):
for i in range(self.max_refetch):
# Make sure the results passed the loading pipeline
# of the original dataset is not None.
indexes = transform.get_indexes(self.dataset)
if not isinstance(indexes, collections.abc.Sequence):
indexes = [indexes]
mix_results = [
copy.deepcopy(self.dataset[index]) for index in indexes
]
if None not in mix_results:
results['mix_results'] = mix_results
break
else:
raise RuntimeError(
'The loading pipeline of the original dataset'
' always return None. Please check the correctness '
'of the dataset and its pipeline.')
for i in range(self.max_refetch):
# To confirm the results passed the training pipeline
# of the wrapper is not None.
updated_results = transform(copy.deepcopy(results))
if updated_results is not None:
results = updated_results
break
else:
raise RuntimeError(
'The training pipeline of the dataset wrapper'
' always return None.Please check the correctness '
'of the dataset and its pipeline.')
if 'mix_results' in results:
results.pop('mix_results')
return results
def update_skip_type_keys(self, skip_type_keys):
"""Update skip_type_keys. It is called by an external hook.
Args:
skip_type_keys (list[str], optional): Sequence of type
string to be skip pipeline.
"""
assert all([
isinstance(skip_type_key, str) for skip_type_key in skip_type_keys
])
self._skip_type_keys = skip_type_keys
@DATASETS.register_module()
class ConcatDataset(MMENGINE_ConcatDataset):
"""A wrapper of concatenated dataset.
Same as ``torch.utils.data.dataset.ConcatDataset``, support
lazy_init and get_dataset_source.
Note:
``ConcatDataset`` should not inherit from ``BaseDataset`` since
``get_subset`` and ``get_subset_`` could produce ambiguous meaning
sub-dataset which conflicts with original dataset. If you want to use
a sub-dataset of ``ConcatDataset``, you should set ``indices``
arguments for wrapped dataset which inherit from ``BaseDataset``.
Args:
datasets (Sequence[BaseDataset] or Sequence[dict]): A list of datasets
which will be concatenated.
lazy_init (bool, optional): Whether to load annotation during
instantiation. Defaults to False.
ignore_keys (List[str] or str): Ignore the keys that can be
unequal in `dataset.metainfo`. Defaults to None.
`New in version 0.3.0.`
"""
def __init__(self,
datasets: Sequence[Union[BaseDataset, dict]],
lazy_init: bool = False,
ignore_keys: Union[str, List[str], None] = None):
self.datasets: List[BaseDataset] = []
for i, dataset in enumerate(datasets):
if isinstance(dataset, dict):
self.datasets.append(DATASETS.build(dataset))
elif isinstance(dataset, BaseDataset):
self.datasets.append(dataset)
else:
raise TypeError(
'elements in datasets sequence should be config or '
f'`BaseDataset` instance, but got {type(dataset)}')
if ignore_keys is None:
self.ignore_keys = []
elif isinstance(ignore_keys, str):
self.ignore_keys = [ignore_keys]
elif isinstance(ignore_keys, list):
self.ignore_keys = ignore_keys
else:
raise TypeError('ignore_keys should be a list or str, '
f'but got {type(ignore_keys)}')
meta_keys: set = set()
for dataset in self.datasets:
meta_keys |= dataset.metainfo.keys()
# if the metainfo of multiple datasets are the same, use metainfo
# of the first dataset, else the metainfo is a list with metainfo
# of all the datasets
is_all_same = True
self._metainfo_first = self.datasets[0].metainfo
for i, dataset in enumerate(self.datasets, 1):
for key in meta_keys:
if key in self.ignore_keys:
continue
if key not in dataset.metainfo:
is_all_same = False
break
if self._metainfo_first[key] != dataset.metainfo[key]:
is_all_same = False
break
if is_all_same:
self._metainfo = self.datasets[0].metainfo
else:
self._metainfo = [dataset.metainfo for dataset in self.datasets]
self._fully_initialized = False
if not lazy_init:
self.full_init()
if is_all_same:
self._metainfo.update(
dict(cumulative_sizes=self.cumulative_sizes))
else:
for i, dataset in enumerate(self.datasets):
self._metainfo[i].update(
dict(cumulative_sizes=self.cumulative_sizes))
def get_dataset_source(self, idx: int) -> int:
dataset_idx, _ = self._get_ori_dataset_idx(idx)
return dataset_idx
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.registry import DATASETS
from .coco import CocoDataset
@DATASETS.register_module()
class DeepFashionDataset(CocoDataset):
"""Dataset for DeepFashion."""
METAINFO = {
'classes': ('top', 'skirt', 'leggings', 'dress', 'outer', 'pants',
'bag', 'neckwear', 'headwear', 'eyeglass', 'belt',
'footwear', 'hair', 'skin', 'face'),
# palette is a list of color tuples, which is used for visualization.
'palette': [(0, 192, 64), (0, 64, 96), (128, 192, 192), (0, 64, 64),
(0, 192, 224), (0, 192, 192), (128, 192, 64), (0, 192, 96),
(128, 32, 192), (0, 0, 224), (0, 0, 64), (0, 160, 192),
(128, 0, 96), (128, 0, 192), (0, 32, 192)]
}
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List, Optional
import numpy as np
from mmdet.registry import DATASETS
from .base_det_dataset import BaseDetDataset
try:
from d_cube import D3
except ImportError:
D3 = None
from .api_wrappers import COCO
@DATASETS.register_module()
class DODDataset(BaseDetDataset):
def __init__(self,
*args,
data_root: Optional[str] = '',
data_prefix: dict = dict(img_path=''),
**kwargs) -> None:
if D3 is None:
raise ImportError(
'Please install d3 by `pip install ddd-dataset`.')
pkl_anno_path = osp.join(data_root, data_prefix['anno'])
self.img_root = osp.join(data_root, data_prefix['img'])
self.d3 = D3(self.img_root, pkl_anno_path)
sent_infos = self.d3.load_sents()
classes = tuple([sent_info['raw_sent'] for sent_info in sent_infos])
super().__init__(
*args,
data_root=data_root,
data_prefix=data_prefix,
metainfo={'classes': classes},
**kwargs)
def load_data_list(self) -> List[dict]:
coco = COCO(self.ann_file)
data_list = []
img_ids = self.d3.get_img_ids()
for img_id in img_ids:
data_info = {}
img_info = self.d3.load_imgs(img_id)[0]
file_name = img_info['file_name']
img_path = osp.join(self.img_root, file_name)
data_info['img_path'] = img_path
data_info['img_id'] = img_id
data_info['height'] = img_info['height']
data_info['width'] = img_info['width']
group_ids = self.d3.get_group_ids(img_ids=[img_id])
sent_ids = self.d3.get_sent_ids(group_ids=group_ids)
sent_list = self.d3.load_sents(sent_ids=sent_ids)
text_list = [sent['raw_sent'] for sent in sent_list]
ann_ids = coco.get_ann_ids(img_ids=[img_id])
anno = coco.load_anns(ann_ids)
data_info['text'] = text_list
data_info['sent_ids'] = np.array([s for s in sent_ids])
data_info['custom_entities'] = True
instances = []
for i, ann in enumerate(anno):
instance = {}
x1, y1, w, h = ann['bbox']
bbox = [x1, y1, x1 + w, y1 + h]
instance['ignore_flag'] = 0
instance['bbox'] = bbox
instance['bbox_label'] = ann['category_id'] - 1
instances.append(instance)
data_info['instances'] = instances
data_list.append(data_info)
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import List
from mmdet.registry import DATASETS
from .base_det_dataset import BaseDetDataset
try:
from dsdl.dataset import DSDLDataset
except ImportError:
DSDLDataset = None
@DATASETS.register_module()
class DSDLDetDataset(BaseDetDataset):
"""Dataset for dsdl detection.
Args:
with_bbox(bool): Load bbox or not, defaults to be True.
with_polygon(bool): Load polygon or not, defaults to be False.
with_mask(bool): Load seg map mask or not, defaults to be False.
with_imagelevel_label(bool): Load image level label or not,
defaults to be False.
with_hierarchy(bool): Load hierarchy information or not,
defaults to be False.
specific_key_path(dict): Path of specific key which can not
be loaded by it's field name.
pre_transform(dict): pre-transform functions before loading.
"""
METAINFO = {}
def __init__(self,
with_bbox: bool = True,
with_polygon: bool = False,
with_mask: bool = False,
with_imagelevel_label: bool = False,
with_hierarchy: bool = False,
specific_key_path: dict = {},
pre_transform: dict = {},
**kwargs) -> None:
if DSDLDataset is None:
raise RuntimeError(
'Package dsdl is not installed. Please run "pip install dsdl".'
)
self.with_hierarchy = with_hierarchy
self.specific_key_path = specific_key_path
loc_config = dict(type='LocalFileReader', working_dir='')
if kwargs.get('data_root'):
kwargs['ann_file'] = os.path.join(kwargs['data_root'],
kwargs['ann_file'])
self.required_fields = ['Image', 'ImageShape', 'Label', 'ignore_flag']
if with_bbox:
self.required_fields.append('Bbox')
if with_polygon:
self.required_fields.append('Polygon')
if with_mask:
self.required_fields.append('LabelMap')
if with_imagelevel_label:
self.required_fields.append('image_level_labels')
assert 'image_level_labels' in specific_key_path.keys(
), '`image_level_labels` not specified in `specific_key_path` !'
self.extra_keys = [
key for key in self.specific_key_path.keys()
if key not in self.required_fields
]
self.dsdldataset = DSDLDataset(
dsdl_yaml=kwargs['ann_file'],
location_config=loc_config,
required_fields=self.required_fields,
specific_key_path=specific_key_path,
transform=pre_transform,
)
BaseDetDataset.__init__(self, **kwargs)
def load_data_list(self) -> List[dict]:
"""Load data info from an dsdl yaml file named as ``self.ann_file``
Returns:
List[dict]: A list of data info.
"""
if self.with_hierarchy:
# get classes_names and relation_matrix
classes_names, relation_matrix = \
self.dsdldataset.class_dom.get_hierarchy_info()
self._metainfo['classes'] = tuple(classes_names)
self._metainfo['RELATION_MATRIX'] = relation_matrix
else:
self._metainfo['classes'] = tuple(self.dsdldataset.class_names)
data_list = []
for i, data in enumerate(self.dsdldataset):
# basic image info, including image id, path and size.
datainfo = dict(
img_id=i,
img_path=os.path.join(self.data_prefix['img_path'],
data['Image'][0].location),
width=data['ImageShape'][0].width,
height=data['ImageShape'][0].height,
)
# get image label info
if 'image_level_labels' in data.keys():
if self.with_hierarchy:
# get leaf node name when using hierarchy classes
datainfo['image_level_labels'] = [
self._metainfo['classes'].index(i.leaf_node_name)
for i in data['image_level_labels']
]
else:
datainfo['image_level_labels'] = [
self._metainfo['classes'].index(i.name)
for i in data['image_level_labels']
]
# get semantic segmentation info
if 'LabelMap' in data.keys():
datainfo['seg_map_path'] = data['LabelMap']
# load instance info
instances = []
if 'Bbox' in data.keys():
for idx in range(len(data['Bbox'])):
bbox = data['Bbox'][idx]
if self.with_hierarchy:
# get leaf node name when using hierarchy classes
label = data['Label'][idx].leaf_node_name
label_index = self._metainfo['classes'].index(label)
else:
label = data['Label'][idx].name
label_index = self._metainfo['classes'].index(label)
instance = {}
instance['bbox'] = bbox.xyxy
instance['bbox_label'] = label_index
if 'ignore_flag' in data.keys():
# get ignore flag
instance['ignore_flag'] = data['ignore_flag'][idx]
else:
instance['ignore_flag'] = 0
if 'Polygon' in data.keys():
# get polygon info
polygon = data['Polygon'][idx]
instance['mask'] = polygon.openmmlabformat
for key in self.extra_keys:
# load extra instance info
instance[key] = data[key][idx]
instances.append(instance)
datainfo['instances'] = instances
# append a standard sample in data list
if len(datainfo['instances']) > 0:
data_list.append(datainfo)
return data_list
def filter_data(self) -> List[dict]:
"""Filter annotations according to filter_cfg.
Returns:
List[dict]: Filtered results.
"""
if self.test_mode:
return self.data_list
filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) \
if self.filter_cfg is not None else False
min_size = self.filter_cfg.get('min_size', 0) \
if self.filter_cfg is not None else 0
valid_data_list = []
for i, data_info in enumerate(self.data_list):
width = data_info['width']
height = data_info['height']
if filter_empty_gt and len(data_info['instances']) == 0:
continue
if min(width, height) >= min_size:
valid_data_list.append(data_info)
return valid_data_list
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List
from pycocotools.coco import COCO
from mmdet.registry import DATASETS
from .base_det_dataset import BaseDetDataset
def convert_phrase_ids(phrase_ids: list) -> list:
unique_elements = sorted(set(phrase_ids))
element_to_new_label = {
element: label
for label, element in enumerate(unique_elements)
}
phrase_ids = [element_to_new_label[element] for element in phrase_ids]
return phrase_ids
@DATASETS.register_module()
class Flickr30kDataset(BaseDetDataset):
"""Flickr30K Dataset."""
def load_data_list(self) -> List[dict]:
self.coco = COCO(self.ann_file)
self.ids = sorted(list(self.coco.imgs.keys()))
data_list = []
for img_id in self.ids:
if isinstance(img_id, str):
ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None)
else:
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
coco_img = self.coco.loadImgs(img_id)[0]
caption = coco_img['caption']
file_name = coco_img['file_name']
img_path = osp.join(self.data_prefix['img'], file_name)
width = coco_img['width']
height = coco_img['height']
tokens_positive = coco_img['tokens_positive_eval']
phrases = [caption[i[0][0]:i[0][1]] for i in tokens_positive]
phrase_ids = []
instances = []
annos = self.coco.loadAnns(ann_ids)
for anno in annos:
instance = {
'bbox': [
anno['bbox'][0], anno['bbox'][1],
anno['bbox'][0] + anno['bbox'][2],
anno['bbox'][1] + anno['bbox'][3]
],
'bbox_label':
anno['category_id'],
'ignore_flag':
anno['iscrowd']
}
phrase_ids.append(anno['phrase_ids'])
instances.append(instance)
phrase_ids = convert_phrase_ids(phrase_ids)
data_list.append(
dict(
img_path=img_path,
img_id=img_id,
height=height,
width=width,
instances=instances,
text=caption,
phrase_ids=phrase_ids,
tokens_positive=tokens_positive,
phrases=phrases,
))
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.registry import DATASETS
from .coco import CocoDataset
@DATASETS.register_module()
class iSAIDDataset(CocoDataset):
"""Dataset for iSAID instance segmentation.
iSAID: A Large-scale Dataset for Instance Segmentation
in Aerial Images.
For more detail, please refer to "projects/iSAID/README.md"
"""
METAINFO = dict(
classes=('background', 'ship', 'store_tank', 'baseball_diamond',
'tennis_court', 'basketball_court', 'Ground_Track_Field',
'Bridge', 'Large_Vehicle', 'Small_Vehicle', 'Helicopter',
'Swimming_pool', 'Roundabout', 'Soccer_ball_field', 'plane',
'Harbor'),
palette=[(0, 0, 0), (0, 0, 63), (0, 63, 63), (0, 63, 0), (0, 63, 127),
(0, 63, 191), (0, 63, 255), (0, 127, 63), (0, 127, 127),
(0, 0, 127), (0, 0, 191), (0, 0, 255), (0, 191, 127),
(0, 127, 191), (0, 127, 255), (0, 100, 155)])
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings
from typing import List
from mmengine.fileio import get_local_path
from mmdet.registry import DATASETS
from .coco import CocoDataset
@DATASETS.register_module()
class LVISV05Dataset(CocoDataset):
"""LVIS v0.5 dataset for detection."""
METAINFO = {
'classes':
('acorn', 'aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock',
'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet',
'antenna', 'apple', 'apple_juice', 'applesauce', 'apricot', 'apron',
'aquarium', 'armband', 'armchair', 'armoire', 'armor', 'artichoke',
'trash_can', 'ashtray', 'asparagus', 'atomizer', 'avocado', 'award',
'awning', 'ax', 'baby_buggy', 'basketball_backboard', 'backpack',
'handbag', 'suitcase', 'bagel', 'bagpipe', 'baguet', 'bait', 'ball',
'ballet_skirt', 'balloon', 'bamboo', 'banana', 'Band_Aid', 'bandage',
'bandanna', 'banjo', 'banner', 'barbell', 'barge', 'barrel',
'barrette', 'barrow', 'baseball_base', 'baseball', 'baseball_bat',
'baseball_cap', 'baseball_glove', 'basket', 'basketball_hoop',
'basketball', 'bass_horn', 'bat_(animal)', 'bath_mat', 'bath_towel',
'bathrobe', 'bathtub', 'batter_(food)', 'battery', 'beachball',
'bead', 'beaker', 'bean_curd', 'beanbag', 'beanie', 'bear', 'bed',
'bedspread', 'cow', 'beef_(food)', 'beeper', 'beer_bottle',
'beer_can', 'beetle', 'bell', 'bell_pepper', 'belt', 'belt_buckle',
'bench', 'beret', 'bib', 'Bible', 'bicycle', 'visor', 'binder',
'binoculars', 'bird', 'birdfeeder', 'birdbath', 'birdcage',
'birdhouse', 'birthday_cake', 'birthday_card', 'biscuit_(bread)',
'pirate_flag', 'black_sheep', 'blackboard', 'blanket', 'blazer',
'blender', 'blimp', 'blinker', 'blueberry', 'boar', 'gameboard',
'boat', 'bobbin', 'bobby_pin', 'boiled_egg', 'bolo_tie', 'deadbolt',
'bolt', 'bonnet', 'book', 'book_bag', 'bookcase', 'booklet',
'bookmark', 'boom_microphone', 'boot', 'bottle', 'bottle_opener',
'bouquet', 'bow_(weapon)', 'bow_(decorative_ribbons)', 'bow-tie',
'bowl', 'pipe_bowl', 'bowler_hat', 'bowling_ball', 'bowling_pin',
'boxing_glove', 'suspenders', 'bracelet', 'brass_plaque', 'brassiere',
'bread-bin', 'breechcloth', 'bridal_gown', 'briefcase',
'bristle_brush', 'broccoli', 'broach', 'broom', 'brownie',
'brussels_sprouts', 'bubble_gum', 'bucket', 'horse_buggy', 'bull',
'bulldog', 'bulldozer', 'bullet_train', 'bulletin_board',
'bulletproof_vest', 'bullhorn', 'corned_beef', 'bun', 'bunk_bed',
'buoy', 'burrito', 'bus_(vehicle)', 'business_card', 'butcher_knife',
'butter', 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car',
'cabinet', 'locker', 'cake', 'calculator', 'calendar', 'calf',
'camcorder', 'camel', 'camera', 'camera_lens', 'camper_(vehicle)',
'can', 'can_opener', 'candelabrum', 'candle', 'candle_holder',
'candy_bar', 'candy_cane', 'walking_cane', 'canister', 'cannon',
'canoe', 'cantaloup', 'canteen', 'cap_(headwear)', 'bottle_cap',
'cape', 'cappuccino', 'car_(automobile)', 'railcar_(part_of_a_train)',
'elevator_car', 'car_battery', 'identity_card', 'card', 'cardigan',
'cargo_ship', 'carnation', 'horse_carriage', 'carrot', 'tote_bag',
'cart', 'carton', 'cash_register', 'casserole', 'cassette', 'cast',
'cat', 'cauliflower', 'caviar', 'cayenne_(spice)', 'CD_player',
'celery', 'cellular_telephone', 'chain_mail', 'chair',
'chaise_longue', 'champagne', 'chandelier', 'chap', 'checkbook',
'checkerboard', 'cherry', 'chessboard',
'chest_of_drawers_(furniture)', 'chicken_(animal)', 'chicken_wire',
'chickpea', 'Chihuahua', 'chili_(vegetable)', 'chime', 'chinaware',
'crisp_(potato_chip)', 'poker_chip', 'chocolate_bar',
'chocolate_cake', 'chocolate_milk', 'chocolate_mousse', 'choker',
'chopping_board', 'chopstick', 'Christmas_tree', 'slide', 'cider',
'cigar_box', 'cigarette', 'cigarette_case', 'cistern', 'clarinet',
'clasp', 'cleansing_agent', 'clementine', 'clip', 'clipboard',
'clock', 'clock_tower', 'clothes_hamper', 'clothespin', 'clutch_bag',
'coaster', 'coat', 'coat_hanger', 'coatrack', 'cock', 'coconut',
'coffee_filter', 'coffee_maker', 'coffee_table', 'coffeepot', 'coil',
'coin', 'colander', 'coleslaw', 'coloring_material',
'combination_lock', 'pacifier', 'comic_book', 'computer_keyboard',
'concrete_mixer', 'cone', 'control', 'convertible_(automobile)',
'sofa_bed', 'cookie', 'cookie_jar', 'cooking_utensil',
'cooler_(for_food)', 'cork_(bottle_plug)', 'corkboard', 'corkscrew',
'edible_corn', 'cornbread', 'cornet', 'cornice', 'cornmeal', 'corset',
'romaine_lettuce', 'costume', 'cougar', 'coverall', 'cowbell',
'cowboy_hat', 'crab_(animal)', 'cracker', 'crape', 'crate', 'crayon',
'cream_pitcher', 'credit_card', 'crescent_roll', 'crib', 'crock_pot',
'crossbar', 'crouton', 'crow', 'crown', 'crucifix', 'cruise_ship',
'police_cruiser', 'crumb', 'crutch', 'cub_(animal)', 'cube',
'cucumber', 'cufflink', 'cup', 'trophy_cup', 'cupcake', 'hair_curler',
'curling_iron', 'curtain', 'cushion', 'custard', 'cutting_tool',
'cylinder', 'cymbal', 'dachshund', 'dagger', 'dartboard',
'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk',
'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table',
'tux', 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher',
'dishwasher_detergent', 'diskette', 'dispenser', 'Dixie_cup', 'dog',
'dog_collar', 'doll', 'dollar', 'dolphin', 'domestic_ass', 'eye_mask',
'doorbell', 'doorknob', 'doormat', 'doughnut', 'dove', 'dragonfly',
'drawer', 'underdrawers', 'dress', 'dress_hat', 'dress_suit',
'dresser', 'drill', 'drinking_fountain', 'drone', 'dropper',
'drum_(musical_instrument)', 'drumstick', 'duck', 'duckling',
'duct_tape', 'duffel_bag', 'dumbbell', 'dumpster', 'dustpan',
'Dutch_oven', 'eagle', 'earphone', 'earplug', 'earring', 'easel',
'eclair', 'eel', 'egg', 'egg_roll', 'egg_yolk', 'eggbeater',
'eggplant', 'electric_chair', 'refrigerator', 'elephant', 'elk',
'envelope', 'eraser', 'escargot', 'eyepatch', 'falcon', 'fan',
'faucet', 'fedora', 'ferret', 'Ferris_wheel', 'ferry', 'fig_(fruit)',
'fighter_jet', 'figurine', 'file_cabinet', 'file_(tool)',
'fire_alarm', 'fire_engine', 'fire_extinguisher', 'fire_hose',
'fireplace', 'fireplug', 'fish', 'fish_(food)', 'fishbowl',
'fishing_boat', 'fishing_rod', 'flag', 'flagpole', 'flamingo',
'flannel', 'flash', 'flashlight', 'fleece', 'flip-flop_(sandal)',
'flipper_(footwear)', 'flower_arrangement', 'flute_glass', 'foal',
'folding_chair', 'food_processor', 'football_(American)',
'football_helmet', 'footstool', 'fork', 'forklift', 'freight_car',
'French_toast', 'freshener', 'frisbee', 'frog', 'fruit_juice',
'fruit_salad', 'frying_pan', 'fudge', 'funnel', 'futon', 'gag',
'garbage', 'garbage_truck', 'garden_hose', 'gargle', 'gargoyle',
'garlic', 'gasmask', 'gazelle', 'gelatin', 'gemstone', 'giant_panda',
'gift_wrap', 'ginger', 'giraffe', 'cincture',
'glass_(drink_container)', 'globe', 'glove', 'goat', 'goggles',
'goldfish', 'golf_club', 'golfcart', 'gondola_(boat)', 'goose',
'gorilla', 'gourd', 'surgical_gown', 'grape', 'grasshopper', 'grater',
'gravestone', 'gravy_boat', 'green_bean', 'green_onion', 'griddle',
'grillroom', 'grinder_(tool)', 'grits', 'grizzly', 'grocery_bag',
'guacamole', 'guitar', 'gull', 'gun', 'hair_spray', 'hairbrush',
'hairnet', 'hairpin', 'ham', 'hamburger', 'hammer', 'hammock',
'hamper', 'hamster', 'hair_dryer', 'hand_glass', 'hand_towel',
'handcart', 'handcuff', 'handkerchief', 'handle', 'handsaw',
'hardback_book', 'harmonium', 'hat', 'hatbox', 'hatch', 'veil',
'headband', 'headboard', 'headlight', 'headscarf', 'headset',
'headstall_(for_horses)', 'hearing_aid', 'heart', 'heater',
'helicopter', 'helmet', 'heron', 'highchair', 'hinge', 'hippopotamus',
'hockey_stick', 'hog', 'home_plate_(baseball)', 'honey', 'fume_hood',
'hook', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce',
'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear',
'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate',
'ice_tea', 'igniter', 'incense', 'inhaler', 'iPod',
'iron_(for_clothing)', 'ironing_board', 'jacket', 'jam', 'jean',
'jeep', 'jelly_bean', 'jersey', 'jet_plane', 'jewelry', 'joystick',
'jumpsuit', 'kayak', 'keg', 'kennel', 'kettle', 'key', 'keycard',
'kilt', 'kimono', 'kitchen_sink', 'kitchen_table', 'kite', 'kitten',
'kiwi_fruit', 'knee_pad', 'knife', 'knight_(chess_piece)',
'knitting_needle', 'knob', 'knocker_(on_a_door)', 'koala', 'lab_coat',
'ladder', 'ladle', 'ladybug', 'lamb_(animal)', 'lamb-chop', 'lamp',
'lamppost', 'lampshade', 'lantern', 'lanyard', 'laptop_computer',
'lasagna', 'latch', 'lawn_mower', 'leather', 'legging_(clothing)',
'Lego', 'lemon', 'lemonade', 'lettuce', 'license_plate', 'life_buoy',
'life_jacket', 'lightbulb', 'lightning_rod', 'lime', 'limousine',
'linen_paper', 'lion', 'lip_balm', 'lipstick', 'liquor', 'lizard',
'Loafer_(type_of_shoe)', 'log', 'lollipop', 'lotion',
'speaker_(stereo_equipment)', 'loveseat', 'machine_gun', 'magazine',
'magnet', 'mail_slot', 'mailbox_(at_home)', 'mallet', 'mammoth',
'mandarin_orange', 'manger', 'manhole', 'map', 'marker', 'martini',
'mascot', 'mashed_potato', 'masher', 'mask', 'mast',
'mat_(gym_equipment)', 'matchbox', 'mattress', 'measuring_cup',
'measuring_stick', 'meatball', 'medicine', 'melon', 'microphone',
'microscope', 'microwave_oven', 'milestone', 'milk', 'minivan',
'mint_candy', 'mirror', 'mitten', 'mixer_(kitchen_tool)', 'money',
'monitor_(computer_equipment) computer_monitor', 'monkey', 'motor',
'motor_scooter', 'motor_vehicle', 'motorboat', 'motorcycle',
'mound_(baseball)', 'mouse_(animal_rodent)',
'mouse_(computer_equipment)', 'mousepad', 'muffin', 'mug', 'mushroom',
'music_stool', 'musical_instrument', 'nailfile', 'nameplate',
'napkin', 'neckerchief', 'necklace', 'necktie', 'needle', 'nest',
'newsstand', 'nightshirt', 'nosebag_(for_animals)',
'noseband_(for_animals)', 'notebook', 'notepad', 'nut', 'nutcracker',
'oar', 'octopus_(food)', 'octopus_(animal)', 'oil_lamp', 'olive_oil',
'omelet', 'onion', 'orange_(fruit)', 'orange_juice', 'oregano',
'ostrich', 'ottoman', 'overalls_(clothing)', 'owl', 'packet',
'inkpad', 'pad', 'paddle', 'padlock', 'paintbox', 'paintbrush',
'painting', 'pajamas', 'palette', 'pan_(for_cooking)',
'pan_(metal_container)', 'pancake', 'pantyhose', 'papaya',
'paperclip', 'paper_plate', 'paper_towel', 'paperback_book',
'paperweight', 'parachute', 'parakeet', 'parasail_(sports)',
'parchment', 'parka', 'parking_meter', 'parrot',
'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport',
'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter',
'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'pegboard',
'pelican', 'pen', 'pencil', 'pencil_box', 'pencil_sharpener',
'pendulum', 'penguin', 'pennant', 'penny_(coin)', 'pepper',
'pepper_mill', 'perfume', 'persimmon', 'baby', 'pet', 'petfood',
'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano',
'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow',
'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball',
'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)',
'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat',
'plate', 'platter', 'playing_card', 'playpen', 'pliers',
'plow_(farm_equipment)', 'pocket_watch', 'pocketknife',
'poker_(fire_stirring_tool)', 'pole', 'police_van', 'polo_shirt',
'poncho', 'pony', 'pool_table', 'pop_(soda)', 'portrait',
'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot',
'potato', 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn',
'printer', 'projectile_(weapon)', 'projector', 'propeller', 'prune',
'pudding', 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin', 'puncher',
'puppet', 'puppy', 'quesadilla', 'quiche', 'quilt', 'rabbit',
'race_car', 'racket', 'radar', 'radiator', 'radio_receiver', 'radish',
'raft', 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry', 'rat',
'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt',
'recliner', 'record_player', 'red_cabbage', 'reflector',
'remote_control', 'rhinoceros', 'rib_(food)', 'rifle', 'ring',
'river_boat', 'road_map', 'robe', 'rocking_chair', 'roller_skate',
'Rollerblade', 'rolling_pin', 'root_beer',
'router_(computer_equipment)', 'rubber_band', 'runner_(carpet)',
'plastic_bag', 'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag',
'safety_pin', 'sail', 'salad', 'salad_plate', 'salami',
'salmon_(fish)', 'salmon_(food)', 'salsa', 'saltshaker',
'sandal_(type_of_shoe)', 'sandwich', 'satchel', 'saucepan', 'saucer',
'sausage', 'sawhorse', 'saxophone', 'scale_(measuring_instrument)',
'scarecrow', 'scarf', 'school_bus', 'scissors', 'scoreboard',
'scrambled_eggs', 'scraper', 'scratcher', 'screwdriver',
'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane',
'seashell', 'seedling', 'serving_dish', 'sewing_machine', 'shaker',
'shampoo', 'shark', 'sharpener', 'Sharpie', 'shaver_(electric)',
'shaving_cream', 'shawl', 'shears', 'sheep', 'shepherd_dog',
'sherbert', 'shield', 'shirt', 'shoe', 'shopping_bag',
'shopping_cart', 'short_pants', 'shot_glass', 'shoulder_bag',
'shovel', 'shower_head', 'shower_curtain', 'shredder_(for_paper)',
'sieve', 'signboard', 'silo', 'sink', 'skateboard', 'skewer', 'ski',
'ski_boot', 'ski_parka', 'ski_pole', 'skirt', 'sled', 'sleeping_bag',
'sling_(bandage)', 'slipper_(footwear)', 'smoothie', 'snake',
'snowboard', 'snowman', 'snowmobile', 'soap', 'soccer_ball', 'sock',
'soda_fountain', 'carbonated_water', 'sofa', 'softball',
'solar_array', 'sombrero', 'soup', 'soup_bowl', 'soupspoon',
'sour_cream', 'soya_milk', 'space_shuttle', 'sparkler_(fireworks)',
'spatula', 'spear', 'spectacles', 'spice_rack', 'spider', 'sponge',
'spoon', 'sportswear', 'spotlight', 'squirrel',
'stapler_(stapling_machine)', 'starfish', 'statue_(sculpture)',
'steak_(food)', 'steak_knife', 'steamer_(kitchen_appliance)',
'steering_wheel', 'stencil', 'stepladder', 'step_stool',
'stereo_(sound_system)', 'stew', 'stirrer', 'stirrup',
'stockings_(leg_wear)', 'stool', 'stop_sign', 'brake_light', 'stove',
'strainer', 'strap', 'straw_(for_drinking)', 'strawberry',
'street_sign', 'streetlight', 'string_cheese', 'stylus', 'subwoofer',
'sugar_bowl', 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower',
'sunglasses', 'sunhat', 'sunscreen', 'surfboard', 'sushi', 'mop',
'sweat_pants', 'sweatband', 'sweater', 'sweatshirt', 'sweet_potato',
'swimsuit', 'sword', 'syringe', 'Tabasco_sauce', 'table-tennis_table',
'table', 'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag',
'taillight', 'tambourine', 'army_tank', 'tank_(storage_vessel)',
'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure',
'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup',
'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth',
'telephone_pole', 'telephoto_lens', 'television_camera',
'television_set', 'tennis_ball', 'tennis_racket', 'tequila',
'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread',
'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer',
'tinfoil', 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster',
'toaster_oven', 'toilet', 'toilet_tissue', 'tomato', 'tongs',
'toolbox', 'toothbrush', 'toothpaste', 'toothpick', 'cover',
'tortilla', 'tow_truck', 'towel', 'towel_rack', 'toy',
'tractor_(farm_equipment)', 'traffic_light', 'dirt_bike',
'trailer_truck', 'train_(railroad_vehicle)', 'trampoline', 'tray',
'tree_house', 'trench_coat', 'triangle_(musical_instrument)',
'tricycle', 'tripod', 'trousers', 'truck', 'truffle_(chocolate)',
'trunk', 'vat', 'turban', 'turkey_(bird)', 'turkey_(food)', 'turnip',
'turtle', 'turtleneck_(clothing)', 'typewriter', 'umbrella',
'underwear', 'unicycle', 'urinal', 'urn', 'vacuum_cleaner', 'valve',
'vase', 'vending_machine', 'vent', 'videotape', 'vinegar', 'violin',
'vodka', 'volleyball', 'vulture', 'waffle', 'waffle_iron', 'wagon',
'wagon_wheel', 'walking_stick', 'wall_clock', 'wall_socket', 'wallet',
'walrus', 'wardrobe', 'wasabi', 'automatic_washer', 'watch',
'water_bottle', 'water_cooler', 'water_faucet', 'water_filter',
'water_heater', 'water_jug', 'water_gun', 'water_scooter',
'water_ski', 'water_tower', 'watering_can', 'watermelon',
'weathervane', 'webcam', 'wedding_cake', 'wedding_ring', 'wet_suit',
'wheel', 'wheelchair', 'whipped_cream', 'whiskey', 'whistle', 'wick',
'wig', 'wind_chime', 'windmill', 'window_box_(for_plants)',
'windshield_wiper', 'windsock', 'wine_bottle', 'wine_bucket',
'wineglass', 'wing_chair', 'blinder_(for_horses)', 'wok', 'wolf',
'wooden_spoon', 'wreath', 'wrench', 'wristband', 'wristlet', 'yacht',
'yak', 'yogurt', 'yoke_(animal_equipment)', 'zebra', 'zucchini'),
'palette':
None
}
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
""" # noqa: E501
try:
import lvis
if getattr(lvis, '__version__', '0') >= '10.5.3':
warnings.warn(
'mmlvis is deprecated, please install official lvis-api by "pip install git+https://github.com/lvis-dataset/lvis-api.git"', # noqa: E501
UserWarning)
from lvis import LVIS
except ImportError:
raise ImportError(
'Package lvis is not installed. Please run "pip install git+https://github.com/lvis-dataset/lvis-api.git".' # noqa: E501
)
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
self.lvis = LVIS(local_path)
self.cat_ids = self.lvis.get_cat_ids()
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.cat_img_map = copy.deepcopy(self.lvis.cat_img_map)
img_ids = self.lvis.get_img_ids()
data_list = []
total_ann_ids = []
for img_id in img_ids:
raw_img_info = self.lvis.load_imgs([img_id])[0]
raw_img_info['img_id'] = img_id
if raw_img_info['file_name'].startswith('COCO'):
# Convert form the COCO 2014 file naming convention of
# COCO_[train/val/test]2014_000000000000.jpg to the 2017
# naming convention of 000000000000.jpg
# (LVIS v1 will fix this naming issue)
raw_img_info['file_name'] = raw_img_info['file_name'][-16:]
ann_ids = self.lvis.get_ann_ids(img_ids=[img_id])
raw_ann_info = self.lvis.load_anns(ann_ids)
total_ann_ids.extend(ann_ids)
parsed_data_info = self.parse_data_info({
'raw_ann_info':
raw_ann_info,
'raw_img_info':
raw_img_info
})
data_list.append(parsed_data_info)
if self.ANN_ID_UNIQUE:
assert len(set(total_ann_ids)) == len(
total_ann_ids
), f"Annotation ids in '{self.ann_file}' are not unique!"
del self.lvis
return data_list
LVISDataset = LVISV05Dataset
DATASETS.register_module(name='LVISDataset', module=LVISDataset)
@DATASETS.register_module()
class LVISV1Dataset(LVISDataset):
"""LVIS v1 dataset for detection."""
METAINFO = {
'classes':
('aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock',
'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet',
'antenna', 'apple', 'applesauce', 'apricot', 'apron', 'aquarium',
'arctic_(type_of_shoe)', 'armband', 'armchair', 'armoire', 'armor',
'artichoke', 'trash_can', 'ashtray', 'asparagus', 'atomizer',
'avocado', 'award', 'awning', 'ax', 'baboon', 'baby_buggy',
'basketball_backboard', 'backpack', 'handbag', 'suitcase', 'bagel',
'bagpipe', 'baguet', 'bait', 'ball', 'ballet_skirt', 'balloon',
'bamboo', 'banana', 'Band_Aid', 'bandage', 'bandanna', 'banjo',
'banner', 'barbell', 'barge', 'barrel', 'barrette', 'barrow',
'baseball_base', 'baseball', 'baseball_bat', 'baseball_cap',
'baseball_glove', 'basket', 'basketball', 'bass_horn', 'bat_(animal)',
'bath_mat', 'bath_towel', 'bathrobe', 'bathtub', 'batter_(food)',
'battery', 'beachball', 'bead', 'bean_curd', 'beanbag', 'beanie',
'bear', 'bed', 'bedpan', 'bedspread', 'cow', 'beef_(food)', 'beeper',
'beer_bottle', 'beer_can', 'beetle', 'bell', 'bell_pepper', 'belt',
'belt_buckle', 'bench', 'beret', 'bib', 'Bible', 'bicycle', 'visor',
'billboard', 'binder', 'binoculars', 'bird', 'birdfeeder', 'birdbath',
'birdcage', 'birdhouse', 'birthday_cake', 'birthday_card',
'pirate_flag', 'black_sheep', 'blackberry', 'blackboard', 'blanket',
'blazer', 'blender', 'blimp', 'blinker', 'blouse', 'blueberry',
'gameboard', 'boat', 'bob', 'bobbin', 'bobby_pin', 'boiled_egg',
'bolo_tie', 'deadbolt', 'bolt', 'bonnet', 'book', 'bookcase',
'booklet', 'bookmark', 'boom_microphone', 'boot', 'bottle',
'bottle_opener', 'bouquet', 'bow_(weapon)',
'bow_(decorative_ribbons)', 'bow-tie', 'bowl', 'pipe_bowl',
'bowler_hat', 'bowling_ball', 'box', 'boxing_glove', 'suspenders',
'bracelet', 'brass_plaque', 'brassiere', 'bread-bin', 'bread',
'breechcloth', 'bridal_gown', 'briefcase', 'broccoli', 'broach',
'broom', 'brownie', 'brussels_sprouts', 'bubble_gum', 'bucket',
'horse_buggy', 'bull', 'bulldog', 'bulldozer', 'bullet_train',
'bulletin_board', 'bulletproof_vest', 'bullhorn', 'bun', 'bunk_bed',
'buoy', 'burrito', 'bus_(vehicle)', 'business_card', 'butter',
'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car', 'cabinet',
'locker', 'cake', 'calculator', 'calendar', 'calf', 'camcorder',
'camel', 'camera', 'camera_lens', 'camper_(vehicle)', 'can',
'can_opener', 'candle', 'candle_holder', 'candy_bar', 'candy_cane',
'walking_cane', 'canister', 'canoe', 'cantaloup', 'canteen',
'cap_(headwear)', 'bottle_cap', 'cape', 'cappuccino',
'car_(automobile)', 'railcar_(part_of_a_train)', 'elevator_car',
'car_battery', 'identity_card', 'card', 'cardigan', 'cargo_ship',
'carnation', 'horse_carriage', 'carrot', 'tote_bag', 'cart', 'carton',
'cash_register', 'casserole', 'cassette', 'cast', 'cat',
'cauliflower', 'cayenne_(spice)', 'CD_player', 'celery',
'cellular_telephone', 'chain_mail', 'chair', 'chaise_longue',
'chalice', 'chandelier', 'chap', 'checkbook', 'checkerboard',
'cherry', 'chessboard', 'chicken_(animal)', 'chickpea',
'chili_(vegetable)', 'chime', 'chinaware', 'crisp_(potato_chip)',
'poker_chip', 'chocolate_bar', 'chocolate_cake', 'chocolate_milk',
'chocolate_mousse', 'choker', 'chopping_board', 'chopstick',
'Christmas_tree', 'slide', 'cider', 'cigar_box', 'cigarette',
'cigarette_case', 'cistern', 'clarinet', 'clasp', 'cleansing_agent',
'cleat_(for_securing_rope)', 'clementine', 'clip', 'clipboard',
'clippers_(for_plants)', 'cloak', 'clock', 'clock_tower',
'clothes_hamper', 'clothespin', 'clutch_bag', 'coaster', 'coat',
'coat_hanger', 'coatrack', 'cock', 'cockroach', 'cocoa_(beverage)',
'coconut', 'coffee_maker', 'coffee_table', 'coffeepot', 'coil',
'coin', 'colander', 'coleslaw', 'coloring_material',
'combination_lock', 'pacifier', 'comic_book', 'compass',
'computer_keyboard', 'condiment', 'cone', 'control',
'convertible_(automobile)', 'sofa_bed', 'cooker', 'cookie',
'cooking_utensil', 'cooler_(for_food)', 'cork_(bottle_plug)',
'corkboard', 'corkscrew', 'edible_corn', 'cornbread', 'cornet',
'cornice', 'cornmeal', 'corset', 'costume', 'cougar', 'coverall',
'cowbell', 'cowboy_hat', 'crab_(animal)', 'crabmeat', 'cracker',
'crape', 'crate', 'crayon', 'cream_pitcher', 'crescent_roll', 'crib',
'crock_pot', 'crossbar', 'crouton', 'crow', 'crowbar', 'crown',
'crucifix', 'cruise_ship', 'police_cruiser', 'crumb', 'crutch',
'cub_(animal)', 'cube', 'cucumber', 'cufflink', 'cup', 'trophy_cup',
'cupboard', 'cupcake', 'hair_curler', 'curling_iron', 'curtain',
'cushion', 'cylinder', 'cymbal', 'dagger', 'dalmatian', 'dartboard',
'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk',
'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table',
'tux', 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher',
'dishwasher_detergent', 'dispenser', 'diving_board', 'Dixie_cup',
'dog', 'dog_collar', 'doll', 'dollar', 'dollhouse', 'dolphin',
'domestic_ass', 'doorknob', 'doormat', 'doughnut', 'dove',
'dragonfly', 'drawer', 'underdrawers', 'dress', 'dress_hat',
'dress_suit', 'dresser', 'drill', 'drone', 'dropper',
'drum_(musical_instrument)', 'drumstick', 'duck', 'duckling',
'duct_tape', 'duffel_bag', 'dumbbell', 'dumpster', 'dustpan', 'eagle',
'earphone', 'earplug', 'earring', 'easel', 'eclair', 'eel', 'egg',
'egg_roll', 'egg_yolk', 'eggbeater', 'eggplant', 'electric_chair',
'refrigerator', 'elephant', 'elk', 'envelope', 'eraser', 'escargot',
'eyepatch', 'falcon', 'fan', 'faucet', 'fedora', 'ferret',
'Ferris_wheel', 'ferry', 'fig_(fruit)', 'fighter_jet', 'figurine',
'file_cabinet', 'file_(tool)', 'fire_alarm', 'fire_engine',
'fire_extinguisher', 'fire_hose', 'fireplace', 'fireplug',
'first-aid_kit', 'fish', 'fish_(food)', 'fishbowl', 'fishing_rod',
'flag', 'flagpole', 'flamingo', 'flannel', 'flap', 'flash',
'flashlight', 'fleece', 'flip-flop_(sandal)', 'flipper_(footwear)',
'flower_arrangement', 'flute_glass', 'foal', 'folding_chair',
'food_processor', 'football_(American)', 'football_helmet',
'footstool', 'fork', 'forklift', 'freight_car', 'French_toast',
'freshener', 'frisbee', 'frog', 'fruit_juice', 'frying_pan', 'fudge',
'funnel', 'futon', 'gag', 'garbage', 'garbage_truck', 'garden_hose',
'gargle', 'gargoyle', 'garlic', 'gasmask', 'gazelle', 'gelatin',
'gemstone', 'generator', 'giant_panda', 'gift_wrap', 'ginger',
'giraffe', 'cincture', 'glass_(drink_container)', 'globe', 'glove',
'goat', 'goggles', 'goldfish', 'golf_club', 'golfcart',
'gondola_(boat)', 'goose', 'gorilla', 'gourd', 'grape', 'grater',
'gravestone', 'gravy_boat', 'green_bean', 'green_onion', 'griddle',
'grill', 'grits', 'grizzly', 'grocery_bag', 'guitar', 'gull', 'gun',
'hairbrush', 'hairnet', 'hairpin', 'halter_top', 'ham', 'hamburger',
'hammer', 'hammock', 'hamper', 'hamster', 'hair_dryer', 'hand_glass',
'hand_towel', 'handcart', 'handcuff', 'handkerchief', 'handle',
'handsaw', 'hardback_book', 'harmonium', 'hat', 'hatbox', 'veil',
'headband', 'headboard', 'headlight', 'headscarf', 'headset',
'headstall_(for_horses)', 'heart', 'heater', 'helicopter', 'helmet',
'heron', 'highchair', 'hinge', 'hippopotamus', 'hockey_stick', 'hog',
'home_plate_(baseball)', 'honey', 'fume_hood', 'hook', 'hookah',
'hornet', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce',
'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear',
'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate',
'igniter', 'inhaler', 'iPod', 'iron_(for_clothing)', 'ironing_board',
'jacket', 'jam', 'jar', 'jean', 'jeep', 'jelly_bean', 'jersey',
'jet_plane', 'jewel', 'jewelry', 'joystick', 'jumpsuit', 'kayak',
'keg', 'kennel', 'kettle', 'key', 'keycard', 'kilt', 'kimono',
'kitchen_sink', 'kitchen_table', 'kite', 'kitten', 'kiwi_fruit',
'knee_pad', 'knife', 'knitting_needle', 'knob', 'knocker_(on_a_door)',
'koala', 'lab_coat', 'ladder', 'ladle', 'ladybug', 'lamb_(animal)',
'lamb-chop', 'lamp', 'lamppost', 'lampshade', 'lantern', 'lanyard',
'laptop_computer', 'lasagna', 'latch', 'lawn_mower', 'leather',
'legging_(clothing)', 'Lego', 'legume', 'lemon', 'lemonade',
'lettuce', 'license_plate', 'life_buoy', 'life_jacket', 'lightbulb',
'lightning_rod', 'lime', 'limousine', 'lion', 'lip_balm', 'liquor',
'lizard', 'log', 'lollipop', 'speaker_(stereo_equipment)', 'loveseat',
'machine_gun', 'magazine', 'magnet', 'mail_slot', 'mailbox_(at_home)',
'mallard', 'mallet', 'mammoth', 'manatee', 'mandarin_orange',
'manger', 'manhole', 'map', 'marker', 'martini', 'mascot',
'mashed_potato', 'masher', 'mask', 'mast', 'mat_(gym_equipment)',
'matchbox', 'mattress', 'measuring_cup', 'measuring_stick',
'meatball', 'medicine', 'melon', 'microphone', 'microscope',
'microwave_oven', 'milestone', 'milk', 'milk_can', 'milkshake',
'minivan', 'mint_candy', 'mirror', 'mitten', 'mixer_(kitchen_tool)',
'money', 'monitor_(computer_equipment) computer_monitor', 'monkey',
'motor', 'motor_scooter', 'motor_vehicle', 'motorcycle',
'mound_(baseball)', 'mouse_(computer_equipment)', 'mousepad',
'muffin', 'mug', 'mushroom', 'music_stool', 'musical_instrument',
'nailfile', 'napkin', 'neckerchief', 'necklace', 'necktie', 'needle',
'nest', 'newspaper', 'newsstand', 'nightshirt',
'nosebag_(for_animals)', 'noseband_(for_animals)', 'notebook',
'notepad', 'nut', 'nutcracker', 'oar', 'octopus_(food)',
'octopus_(animal)', 'oil_lamp', 'olive_oil', 'omelet', 'onion',
'orange_(fruit)', 'orange_juice', 'ostrich', 'ottoman', 'oven',
'overalls_(clothing)', 'owl', 'packet', 'inkpad', 'pad', 'paddle',
'padlock', 'paintbrush', 'painting', 'pajamas', 'palette',
'pan_(for_cooking)', 'pan_(metal_container)', 'pancake', 'pantyhose',
'papaya', 'paper_plate', 'paper_towel', 'paperback_book',
'paperweight', 'parachute', 'parakeet', 'parasail_(sports)',
'parasol', 'parchment', 'parka', 'parking_meter', 'parrot',
'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport',
'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter',
'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'wooden_leg',
'pegboard', 'pelican', 'pen', 'pencil', 'pencil_box',
'pencil_sharpener', 'pendulum', 'penguin', 'pennant', 'penny_(coin)',
'pepper', 'pepper_mill', 'perfume', 'persimmon', 'person', 'pet',
'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano',
'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow',
'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball',
'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)',
'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat',
'plate', 'platter', 'playpen', 'pliers', 'plow_(farm_equipment)',
'plume', 'pocket_watch', 'pocketknife', 'poker_(fire_stirring_tool)',
'pole', 'polo_shirt', 'poncho', 'pony', 'pool_table', 'pop_(soda)',
'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot',
'potato', 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn',
'pretzel', 'printer', 'projectile_(weapon)', 'projector', 'propeller',
'prune', 'pudding', 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin',
'puncher', 'puppet', 'puppy', 'quesadilla', 'quiche', 'quilt',
'rabbit', 'race_car', 'racket', 'radar', 'radiator', 'radio_receiver',
'radish', 'raft', 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry',
'rat', 'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt',
'recliner', 'record_player', 'reflector', 'remote_control',
'rhinoceros', 'rib_(food)', 'rifle', 'ring', 'river_boat', 'road_map',
'robe', 'rocking_chair', 'rodent', 'roller_skate', 'Rollerblade',
'rolling_pin', 'root_beer', 'router_(computer_equipment)',
'rubber_band', 'runner_(carpet)', 'plastic_bag',
'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag', 'safety_pin',
'sail', 'salad', 'salad_plate', 'salami', 'salmon_(fish)',
'salmon_(food)', 'salsa', 'saltshaker', 'sandal_(type_of_shoe)',
'sandwich', 'satchel', 'saucepan', 'saucer', 'sausage', 'sawhorse',
'saxophone', 'scale_(measuring_instrument)', 'scarecrow', 'scarf',
'school_bus', 'scissors', 'scoreboard', 'scraper', 'screwdriver',
'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane',
'seashell', 'sewing_machine', 'shaker', 'shampoo', 'shark',
'sharpener', 'Sharpie', 'shaver_(electric)', 'shaving_cream', 'shawl',
'shears', 'sheep', 'shepherd_dog', 'sherbert', 'shield', 'shirt',
'shoe', 'shopping_bag', 'shopping_cart', 'short_pants', 'shot_glass',
'shoulder_bag', 'shovel', 'shower_head', 'shower_cap',
'shower_curtain', 'shredder_(for_paper)', 'signboard', 'silo', 'sink',
'skateboard', 'skewer', 'ski', 'ski_boot', 'ski_parka', 'ski_pole',
'skirt', 'skullcap', 'sled', 'sleeping_bag', 'sling_(bandage)',
'slipper_(footwear)', 'smoothie', 'snake', 'snowboard', 'snowman',
'snowmobile', 'soap', 'soccer_ball', 'sock', 'sofa', 'softball',
'solar_array', 'sombrero', 'soup', 'soup_bowl', 'soupspoon',
'sour_cream', 'soya_milk', 'space_shuttle', 'sparkler_(fireworks)',
'spatula', 'spear', 'spectacles', 'spice_rack', 'spider', 'crawfish',
'sponge', 'spoon', 'sportswear', 'spotlight', 'squid_(food)',
'squirrel', 'stagecoach', 'stapler_(stapling_machine)', 'starfish',
'statue_(sculpture)', 'steak_(food)', 'steak_knife', 'steering_wheel',
'stepladder', 'step_stool', 'stereo_(sound_system)', 'stew',
'stirrer', 'stirrup', 'stool', 'stop_sign', 'brake_light', 'stove',
'strainer', 'strap', 'straw_(for_drinking)', 'strawberry',
'street_sign', 'streetlight', 'string_cheese', 'stylus', 'subwoofer',
'sugar_bowl', 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower',
'sunglasses', 'sunhat', 'surfboard', 'sushi', 'mop', 'sweat_pants',
'sweatband', 'sweater', 'sweatshirt', 'sweet_potato', 'swimsuit',
'sword', 'syringe', 'Tabasco_sauce', 'table-tennis_table', 'table',
'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag', 'taillight',
'tambourine', 'army_tank', 'tank_(storage_vessel)',
'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure',
'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup',
'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth',
'telephone_pole', 'telephoto_lens', 'television_camera',
'television_set', 'tennis_ball', 'tennis_racket', 'tequila',
'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread',
'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer',
'tinfoil', 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster',
'toaster_oven', 'toilet', 'toilet_tissue', 'tomato', 'tongs',
'toolbox', 'toothbrush', 'toothpaste', 'toothpick', 'cover',
'tortilla', 'tow_truck', 'towel', 'towel_rack', 'toy',
'tractor_(farm_equipment)', 'traffic_light', 'dirt_bike',
'trailer_truck', 'train_(railroad_vehicle)', 'trampoline', 'tray',
'trench_coat', 'triangle_(musical_instrument)', 'tricycle', 'tripod',
'trousers', 'truck', 'truffle_(chocolate)', 'trunk', 'vat', 'turban',
'turkey_(food)', 'turnip', 'turtle', 'turtleneck_(clothing)',
'typewriter', 'umbrella', 'underwear', 'unicycle', 'urinal', 'urn',
'vacuum_cleaner', 'vase', 'vending_machine', 'vent', 'vest',
'videotape', 'vinegar', 'violin', 'vodka', 'volleyball', 'vulture',
'waffle', 'waffle_iron', 'wagon', 'wagon_wheel', 'walking_stick',
'wall_clock', 'wall_socket', 'wallet', 'walrus', 'wardrobe',
'washbasin', 'automatic_washer', 'watch', 'water_bottle',
'water_cooler', 'water_faucet', 'water_heater', 'water_jug',
'water_gun', 'water_scooter', 'water_ski', 'water_tower',
'watering_can', 'watermelon', 'weathervane', 'webcam', 'wedding_cake',
'wedding_ring', 'wet_suit', 'wheel', 'wheelchair', 'whipped_cream',
'whistle', 'wig', 'wind_chime', 'windmill', 'window_box_(for_plants)',
'windshield_wiper', 'windsock', 'wine_bottle', 'wine_bucket',
'wineglass', 'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon',
'wreath', 'wrench', 'wristband', 'wristlet', 'yacht', 'yogurt',
'yoke_(animal_equipment)', 'zebra', 'zucchini'),
'palette':
None
}
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
""" # noqa: E501
try:
import lvis
if getattr(lvis, '__version__', '0') >= '10.5.3':
warnings.warn(
'mmlvis is deprecated, please install official lvis-api by "pip install git+https://github.com/lvis-dataset/lvis-api.git"', # noqa: E501
UserWarning)
from lvis import LVIS
except ImportError:
raise ImportError(
'Package lvis is not installed. Please run "pip install git+https://github.com/lvis-dataset/lvis-api.git".' # noqa: E501
)
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
self.lvis = LVIS(local_path)
self.cat_ids = self.lvis.get_cat_ids()
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.cat_img_map = copy.deepcopy(self.lvis.cat_img_map)
img_ids = self.lvis.get_img_ids()
data_list = []
total_ann_ids = []
for img_id in img_ids:
raw_img_info = self.lvis.load_imgs([img_id])[0]
raw_img_info['img_id'] = img_id
# coco_url is used in LVISv1 instead of file_name
# e.g. http://images.cocodataset.org/train2017/000000391895.jpg
# train/val split in specified in url
raw_img_info['file_name'] = raw_img_info['coco_url'].replace(
'http://images.cocodataset.org/', '')
ann_ids = self.lvis.get_ann_ids(img_ids=[img_id])
raw_ann_info = self.lvis.load_anns(ann_ids)
total_ann_ids.extend(ann_ids)
parsed_data_info = self.parse_data_info({
'raw_ann_info':
raw_ann_info,
'raw_img_info':
raw_img_info
})
data_list.append(parsed_data_info)
if self.ANN_ID_UNIQUE:
assert len(set(total_ann_ids)) == len(
total_ann_ids
), f"Annotation ids in '{self.ann_file}' are not unique!"
del self.lvis
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List
from mmengine.fileio import get_local_path
from mmdet.datasets import BaseDetDataset
from mmdet.registry import DATASETS
from .api_wrappers import COCO
@DATASETS.register_module()
class MDETRStyleRefCocoDataset(BaseDetDataset):
"""RefCOCO dataset.
Only support evaluation now.
"""
def load_data_list(self) -> List[dict]:
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
coco = COCO(local_path)
img_ids = coco.get_img_ids()
data_infos = []
for img_id in img_ids:
raw_img_info = coco.load_imgs([img_id])[0]
ann_ids = coco.get_ann_ids(img_ids=[img_id])
raw_ann_info = coco.load_anns(ann_ids)
data_info = {}
img_path = osp.join(self.data_prefix['img'],
raw_img_info['file_name'])
data_info['img_path'] = img_path
data_info['img_id'] = img_id
data_info['height'] = raw_img_info['height']
data_info['width'] = raw_img_info['width']
data_info['dataset_mode'] = raw_img_info['dataset_name']
data_info['text'] = raw_img_info['caption']
data_info['custom_entities'] = False
data_info['tokens_positive'] = -1
instances = []
for i, ann in enumerate(raw_ann_info):
instance = {}
x1, y1, w, h = ann['bbox']
bbox = [x1, y1, x1 + w, y1 + h]
instance['bbox'] = bbox
instance['bbox_label'] = ann['category_id']
instance['ignore_flag'] = 0
instances.append(instance)
data_info['instances'] = instances
data_infos.append(data_info)
return data_infos
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List, Union
from mmdet.registry import DATASETS
from .base_video_dataset import BaseVideoDataset
@DATASETS.register_module()
class MOTChallengeDataset(BaseVideoDataset):
"""Dataset for MOTChallenge.
Args:
visibility_thr (float, optional): The minimum visibility
for the objects during training. Default to -1.
"""
METAINFO = {
'classes':
('pedestrian', 'person_on_vehicle', 'car', 'bicycle', 'motorbike',
'non_mot_vehicle', 'static_person', 'distractor', 'occluder',
'occluder_on_ground', 'occluder_full', 'reflection', 'crowd')
}
def __init__(self, visibility_thr: float = -1, *args, **kwargs):
self.visibility_thr = visibility_thr
super().__init__(*args, **kwargs)
def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]:
"""Parse raw annotation to target format. The difference between this
function and the one in ``BaseVideoDataset`` is that the parsing here
adds ``visibility`` and ``mot_conf``.
Args:
raw_data_info (dict): Raw data information load from ``ann_file``
Returns:
Union[dict, List[dict]]: Parsed annotation.
"""
img_info = raw_data_info['raw_img_info']
ann_info = raw_data_info['raw_ann_info']
data_info = {}
data_info.update(img_info)
if self.data_prefix.get('img_path', None) is not None:
img_path = osp.join(self.data_prefix['img_path'],
img_info['file_name'])
else:
img_path = img_info['file_name']
data_info['img_path'] = img_path
instances = []
for i, ann in enumerate(ann_info):
instance = {}
if (not self.test_mode) and (ann['visibility'] <
self.visibility_thr):
continue
if ann.get('ignore', False):
continue
x1, y1, w, h = ann['bbox']
inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
if inter_w * inter_h == 0:
continue
if ann['area'] <= 0 or w < 1 or h < 1:
continue
if ann['category_id'] not in self.cat_ids:
continue
bbox = [x1, y1, x1 + w, y1 + h]
if ann.get('iscrowd', False):
instance['ignore_flag'] = 1
else:
instance['ignore_flag'] = 0
instance['bbox'] = bbox
instance['bbox_label'] = self.cat2label[ann['category_id']]
instance['instance_id'] = ann['instance_id']
instance['category_id'] = ann['category_id']
instance['mot_conf'] = ann['mot_conf']
instance['visibility'] = ann['visibility']
if len(instance) > 0:
instances.append(instance)
if not self.test_mode:
assert len(instances) > 0, f'No valid instances found in ' \
f'image {data_info["img_path"]}!'
data_info['instances'] = instances
return data_info
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from typing import List
from mmengine.fileio import get_local_path
from mmdet.registry import DATASETS
from .api_wrappers import COCO
from .coco import CocoDataset
# images exist in annotations but not in image folder.
objv2_ignore_list = [
osp.join('patch16', 'objects365_v2_00908726.jpg'),
osp.join('patch6', 'objects365_v1_00320532.jpg'),
osp.join('patch6', 'objects365_v1_00320534.jpg'),
]
@DATASETS.register_module()
class Objects365V1Dataset(CocoDataset):
"""Objects365 v1 dataset for detection."""
METAINFO = {
'classes':
('person', 'sneakers', 'chair', 'hat', 'lamp', 'bottle',
'cabinet/shelf', 'cup', 'car', 'glasses', 'picture/frame', 'desk',
'handbag', 'street lights', 'book', 'plate', 'helmet',
'leather shoes', 'pillow', 'glove', 'potted plant', 'bracelet',
'flower', 'tv', 'storage box', 'vase', 'bench', 'wine glass', 'boots',
'bowl', 'dining table', 'umbrella', 'boat', 'flag', 'speaker',
'trash bin/can', 'stool', 'backpack', 'couch', 'belt', 'carpet',
'basket', 'towel/napkin', 'slippers', 'barrel/bucket', 'coffee table',
'suv', 'toy', 'tie', 'bed', 'traffic light', 'pen/pencil',
'microphone', 'sandals', 'canned', 'necklace', 'mirror', 'faucet',
'bicycle', 'bread', 'high heels', 'ring', 'van', 'watch', 'sink',
'horse', 'fish', 'apple', 'camera', 'candle', 'teddy bear', 'cake',
'motorcycle', 'wild bird', 'laptop', 'knife', 'traffic sign',
'cell phone', 'paddle', 'truck', 'cow', 'power outlet', 'clock',
'drum', 'fork', 'bus', 'hanger', 'nightstand', 'pot/pan', 'sheep',
'guitar', 'traffic cone', 'tea pot', 'keyboard', 'tripod', 'hockey',
'fan', 'dog', 'spoon', 'blackboard/whiteboard', 'balloon',
'air conditioner', 'cymbal', 'mouse', 'telephone', 'pickup truck',
'orange', 'banana', 'airplane', 'luggage', 'skis', 'soccer',
'trolley', 'oven', 'remote', 'baseball glove', 'paper towel',
'refrigerator', 'train', 'tomato', 'machinery vehicle', 'tent',
'shampoo/shower gel', 'head phone', 'lantern', 'donut',
'cleaning products', 'sailboat', 'tangerine', 'pizza', 'kite',
'computer box', 'elephant', 'toiletries', 'gas stove', 'broccoli',
'toilet', 'stroller', 'shovel', 'baseball bat', 'microwave',
'skateboard', 'surfboard', 'surveillance camera', 'gun', 'life saver',
'cat', 'lemon', 'liquid soap', 'zebra', 'duck', 'sports car',
'giraffe', 'pumpkin', 'piano', 'stop sign', 'radiator', 'converter',
'tissue ', 'carrot', 'washing machine', 'vent', 'cookies',
'cutting/chopping board', 'tennis racket', 'candy',
'skating and skiing shoes', 'scissors', 'folder', 'baseball',
'strawberry', 'bow tie', 'pigeon', 'pepper', 'coffee machine',
'bathtub', 'snowboard', 'suitcase', 'grapes', 'ladder', 'pear',
'american football', 'basketball', 'potato', 'paint brush', 'printer',
'billiards', 'fire hydrant', 'goose', 'projector', 'sausage',
'fire extinguisher', 'extension cord', 'facial mask', 'tennis ball',
'chopsticks', 'electronic stove and gas stove', 'pie', 'frisbee',
'kettle', 'hamburger', 'golf club', 'cucumber', 'clutch', 'blender',
'tong', 'slide', 'hot dog', 'toothbrush', 'facial cleanser', 'mango',
'deer', 'egg', 'violin', 'marker', 'ship', 'chicken', 'onion',
'ice cream', 'tape', 'wheelchair', 'plum', 'bar soap', 'scale',
'watermelon', 'cabbage', 'router/modem', 'golf ball', 'pine apple',
'crane', 'fire truck', 'peach', 'cello', 'notepaper', 'tricycle',
'toaster', 'helicopter', 'green beans', 'brush', 'carriage', 'cigar',
'earphone', 'penguin', 'hurdle', 'swing', 'radio', 'CD',
'parking meter', 'swan', 'garlic', 'french fries', 'horn', 'avocado',
'saxophone', 'trumpet', 'sandwich', 'cue', 'kiwi fruit', 'bear',
'fishing rod', 'cherry', 'tablet', 'green vegetables', 'nuts', 'corn',
'key', 'screwdriver', 'globe', 'broom', 'pliers', 'volleyball',
'hammer', 'eggplant', 'trophy', 'dates', 'board eraser', 'rice',
'tape measure/ruler', 'dumbbell', 'hamimelon', 'stapler', 'camel',
'lettuce', 'goldfish', 'meat balls', 'medal', 'toothpaste',
'antelope', 'shrimp', 'rickshaw', 'trombone', 'pomegranate',
'coconut', 'jellyfish', 'mushroom', 'calculator', 'treadmill',
'butterfly', 'egg tart', 'cheese', 'pig', 'pomelo', 'race car',
'rice cooker', 'tuba', 'crosswalk sign', 'papaya', 'hair drier',
'green onion', 'chips', 'dolphin', 'sushi', 'urinal', 'donkey',
'electric drill', 'spring rolls', 'tortoise/turtle', 'parrot',
'flute', 'measuring cup', 'shark', 'steak', 'poker card',
'binoculars', 'llama', 'radish', 'noodles', 'yak', 'mop', 'crab',
'microscope', 'barbell', 'bread/bun', 'baozi', 'lion', 'red cabbage',
'polar bear', 'lighter', 'seal', 'mangosteen', 'comb', 'eraser',
'pitaya', 'scallop', 'pencil case', 'saw', 'table tennis paddle',
'okra', 'starfish', 'eagle', 'monkey', 'durian', 'game board',
'rabbit', 'french horn', 'ambulance', 'asparagus', 'hoverboard',
'pasta', 'target', 'hotair balloon', 'chainsaw', 'lobster', 'iron',
'flashlight'),
'palette':
None
}
COCOAPI = COCO
# ann_id is unique in coco dataset.
ANN_ID_UNIQUE = True
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
""" # noqa: E501
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
self.coco = self.COCOAPI(local_path)
# 'categories' list in objects365_train.json and objects365_val.json
# is inconsistent, need sort list(or dict) before get cat_ids.
cats = self.coco.cats
sorted_cats = {i: cats[i] for i in sorted(cats)}
self.coco.cats = sorted_cats
categories = self.coco.dataset['categories']
sorted_categories = sorted(categories, key=lambda i: i['id'])
self.coco.dataset['categories'] = sorted_categories
# The order of returned `cat_ids` will not
# change with the order of the `classes`
self.cat_ids = self.coco.get_cat_ids(
cat_names=self.metainfo['classes'])
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.cat_img_map = copy.deepcopy(self.coco.cat_img_map)
img_ids = self.coco.get_img_ids()
data_list = []
total_ann_ids = []
for img_id in img_ids:
raw_img_info = self.coco.load_imgs([img_id])[0]
raw_img_info['img_id'] = img_id
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
raw_ann_info = self.coco.load_anns(ann_ids)
total_ann_ids.extend(ann_ids)
parsed_data_info = self.parse_data_info({
'raw_ann_info':
raw_ann_info,
'raw_img_info':
raw_img_info
})
data_list.append(parsed_data_info)
if self.ANN_ID_UNIQUE:
assert len(set(total_ann_ids)) == len(
total_ann_ids
), f"Annotation ids in '{self.ann_file}' are not unique!"
del self.coco
return data_list
@DATASETS.register_module()
class Objects365V2Dataset(CocoDataset):
"""Objects365 v2 dataset for detection."""
METAINFO = {
'classes':
('Person', 'Sneakers', 'Chair', 'Other Shoes', 'Hat', 'Car', 'Lamp',
'Glasses', 'Bottle', 'Desk', 'Cup', 'Street Lights', 'Cabinet/shelf',
'Handbag/Satchel', 'Bracelet', 'Plate', 'Picture/Frame', 'Helmet',
'Book', 'Gloves', 'Storage box', 'Boat', 'Leather Shoes', 'Flower',
'Bench', 'Potted Plant', 'Bowl/Basin', 'Flag', 'Pillow', 'Boots',
'Vase', 'Microphone', 'Necklace', 'Ring', 'SUV', 'Wine Glass', 'Belt',
'Moniter/TV', 'Backpack', 'Umbrella', 'Traffic Light', 'Speaker',
'Watch', 'Tie', 'Trash bin Can', 'Slippers', 'Bicycle', 'Stool',
'Barrel/bucket', 'Van', 'Couch', 'Sandals', 'Bakset', 'Drum',
'Pen/Pencil', 'Bus', 'Wild Bird', 'High Heels', 'Motorcycle',
'Guitar', 'Carpet', 'Cell Phone', 'Bread', 'Camera', 'Canned',
'Truck', 'Traffic cone', 'Cymbal', 'Lifesaver', 'Towel',
'Stuffed Toy', 'Candle', 'Sailboat', 'Laptop', 'Awning', 'Bed',
'Faucet', 'Tent', 'Horse', 'Mirror', 'Power outlet', 'Sink', 'Apple',
'Air Conditioner', 'Knife', 'Hockey Stick', 'Paddle', 'Pickup Truck',
'Fork', 'Traffic Sign', 'Ballon', 'Tripod', 'Dog', 'Spoon', 'Clock',
'Pot', 'Cow', 'Cake', 'Dinning Table', 'Sheep', 'Hanger',
'Blackboard/Whiteboard', 'Napkin', 'Other Fish', 'Orange/Tangerine',
'Toiletry', 'Keyboard', 'Tomato', 'Lantern', 'Machinery Vehicle',
'Fan', 'Green Vegetables', 'Banana', 'Baseball Glove', 'Airplane',
'Mouse', 'Train', 'Pumpkin', 'Soccer', 'Skiboard', 'Luggage',
'Nightstand', 'Tea pot', 'Telephone', 'Trolley', 'Head Phone',
'Sports Car', 'Stop Sign', 'Dessert', 'Scooter', 'Stroller', 'Crane',
'Remote', 'Refrigerator', 'Oven', 'Lemon', 'Duck', 'Baseball Bat',
'Surveillance Camera', 'Cat', 'Jug', 'Broccoli', 'Piano', 'Pizza',
'Elephant', 'Skateboard', 'Surfboard', 'Gun',
'Skating and Skiing shoes', 'Gas stove', 'Donut', 'Bow Tie', 'Carrot',
'Toilet', 'Kite', 'Strawberry', 'Other Balls', 'Shovel', 'Pepper',
'Computer Box', 'Toilet Paper', 'Cleaning Products', 'Chopsticks',
'Microwave', 'Pigeon', 'Baseball', 'Cutting/chopping Board',
'Coffee Table', 'Side Table', 'Scissors', 'Marker', 'Pie', 'Ladder',
'Snowboard', 'Cookies', 'Radiator', 'Fire Hydrant', 'Basketball',
'Zebra', 'Grape', 'Giraffe', 'Potato', 'Sausage', 'Tricycle',
'Violin', 'Egg', 'Fire Extinguisher', 'Candy', 'Fire Truck',
'Billards', 'Converter', 'Bathtub', 'Wheelchair', 'Golf Club',
'Briefcase', 'Cucumber', 'Cigar/Cigarette ', 'Paint Brush', 'Pear',
'Heavy Truck', 'Hamburger', 'Extractor', 'Extention Cord', 'Tong',
'Tennis Racket', 'Folder', 'American Football', 'earphone', 'Mask',
'Kettle', 'Tennis', 'Ship', 'Swing', 'Coffee Machine', 'Slide',
'Carriage', 'Onion', 'Green beans', 'Projector', 'Frisbee',
'Washing Machine/Drying Machine', 'Chicken', 'Printer', 'Watermelon',
'Saxophone', 'Tissue', 'Toothbrush', 'Ice cream', 'Hotair ballon',
'Cello', 'French Fries', 'Scale', 'Trophy', 'Cabbage', 'Hot dog',
'Blender', 'Peach', 'Rice', 'Wallet/Purse', 'Volleyball', 'Deer',
'Goose', 'Tape', 'Tablet', 'Cosmetics', 'Trumpet', 'Pineapple',
'Golf Ball', 'Ambulance', 'Parking meter', 'Mango', 'Key', 'Hurdle',
'Fishing Rod', 'Medal', 'Flute', 'Brush', 'Penguin', 'Megaphone',
'Corn', 'Lettuce', 'Garlic', 'Swan', 'Helicopter', 'Green Onion',
'Sandwich', 'Nuts', 'Speed Limit Sign', 'Induction Cooker', 'Broom',
'Trombone', 'Plum', 'Rickshaw', 'Goldfish', 'Kiwi fruit',
'Router/modem', 'Poker Card', 'Toaster', 'Shrimp', 'Sushi', 'Cheese',
'Notepaper', 'Cherry', 'Pliers', 'CD', 'Pasta', 'Hammer', 'Cue',
'Avocado', 'Hamimelon', 'Flask', 'Mushroon', 'Screwdriver', 'Soap',
'Recorder', 'Bear', 'Eggplant', 'Board Eraser', 'Coconut',
'Tape Measur/ Ruler', 'Pig', 'Showerhead', 'Globe', 'Chips', 'Steak',
'Crosswalk Sign', 'Stapler', 'Campel', 'Formula 1 ', 'Pomegranate',
'Dishwasher', 'Crab', 'Hoverboard', 'Meat ball', 'Rice Cooker',
'Tuba', 'Calculator', 'Papaya', 'Antelope', 'Parrot', 'Seal',
'Buttefly', 'Dumbbell', 'Donkey', 'Lion', 'Urinal', 'Dolphin',
'Electric Drill', 'Hair Dryer', 'Egg tart', 'Jellyfish', 'Treadmill',
'Lighter', 'Grapefruit', 'Game board', 'Mop', 'Radish', 'Baozi',
'Target', 'French', 'Spring Rolls', 'Monkey', 'Rabbit', 'Pencil Case',
'Yak', 'Red Cabbage', 'Binoculars', 'Asparagus', 'Barbell', 'Scallop',
'Noddles', 'Comb', 'Dumpling', 'Oyster', 'Table Teniis paddle',
'Cosmetics Brush/Eyeliner Pencil', 'Chainsaw', 'Eraser', 'Lobster',
'Durian', 'Okra', 'Lipstick', 'Cosmetics Mirror', 'Curling',
'Table Tennis '),
'palette':
None
}
COCOAPI = COCO
# ann_id is unique in coco dataset.
ANN_ID_UNIQUE = True
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
""" # noqa: E501
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
self.coco = self.COCOAPI(local_path)
# The order of returned `cat_ids` will not
# change with the order of the `classes`
self.cat_ids = self.coco.get_cat_ids(
cat_names=self.metainfo['classes'])
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.cat_img_map = copy.deepcopy(self.coco.cat_img_map)
img_ids = self.coco.get_img_ids()
data_list = []
total_ann_ids = []
for img_id in img_ids:
raw_img_info = self.coco.load_imgs([img_id])[0]
raw_img_info['img_id'] = img_id
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
raw_ann_info = self.coco.load_anns(ann_ids)
total_ann_ids.extend(ann_ids)
# file_name should be `patchX/xxx.jpg`
file_name = osp.join(
osp.split(osp.split(raw_img_info['file_name'])[0])[-1],
osp.split(raw_img_info['file_name'])[-1])
if file_name in objv2_ignore_list:
continue
raw_img_info['file_name'] = file_name
parsed_data_info = self.parse_data_info({
'raw_ann_info':
raw_ann_info,
'raw_img_info':
raw_img_info
})
data_list.append(parsed_data_info)
if self.ANN_ID_UNIQUE:
assert len(set(total_ann_ids)) == len(
total_ann_ids
), f"Annotation ids in '{self.ann_file}' are not unique!"
del self.coco
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp
from typing import List, Optional
from mmengine.fileio import get_local_path
from mmdet.registry import DATASETS
from .base_det_dataset import BaseDetDataset
@DATASETS.register_module()
class ODVGDataset(BaseDetDataset):
"""object detection and visual grounding dataset."""
def __init__(self,
*args,
data_root: str = '',
label_map_file: Optional[str] = None,
need_text: bool = True,
**kwargs) -> None:
self.dataset_mode = 'VG'
self.need_text = need_text
if label_map_file:
label_map_file = osp.join(data_root, label_map_file)
with open(label_map_file, 'r') as file:
self.label_map = json.load(file)
self.dataset_mode = 'OD'
super().__init__(*args, data_root=data_root, **kwargs)
assert self.return_classes is True
def load_data_list(self) -> List[dict]:
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
data_list = [json.loads(line) for line in f]
out_data_list = []
for data in data_list:
data_info = {}
img_path = osp.join(self.data_prefix['img'], data['filename'])
data_info['img_path'] = img_path
data_info['height'] = data['height']
data_info['width'] = data['width']
if self.dataset_mode == 'OD':
if self.need_text:
data_info['text'] = self.label_map
anno = data.get('detection', {})
instances = [obj for obj in anno.get('instances', [])]
bboxes = [obj['bbox'] for obj in instances]
bbox_labels = [str(obj['label']) for obj in instances]
instances = []
for bbox, label in zip(bboxes, bbox_labels):
instance = {}
x1, y1, x2, y2 = bbox
inter_w = max(0, min(x2, data['width']) - max(x1, 0))
inter_h = max(0, min(y2, data['height']) - max(y1, 0))
if inter_w * inter_h == 0:
continue
if (x2 - x1) < 1 or (y2 - y1) < 1:
continue
instance['ignore_flag'] = 0
instance['bbox'] = bbox
instance['bbox_label'] = int(label)
instances.append(instance)
data_info['instances'] = instances
data_info['dataset_mode'] = self.dataset_mode
out_data_list.append(data_info)
else:
anno = data['grounding']
data_info['text'] = anno['caption']
regions = anno['regions']
instances = []
phrases = {}
for i, region in enumerate(regions):
bbox = region['bbox']
phrase = region['phrase']
tokens_positive = region['tokens_positive']
if not isinstance(bbox[0], list):
bbox = [bbox]
for box in bbox:
instance = {}
x1, y1, x2, y2 = box
inter_w = max(0, min(x2, data['width']) - max(x1, 0))
inter_h = max(0, min(y2, data['height']) - max(y1, 0))
if inter_w * inter_h == 0:
continue
if (x2 - x1) < 1 or (y2 - y1) < 1:
continue
instance['ignore_flag'] = 0
instance['bbox'] = box
instance['bbox_label'] = i
phrases[i] = {
'phrase': phrase,
'tokens_positive': tokens_positive
}
instances.append(instance)
data_info['instances'] = instances
data_info['phrases'] = phrases
data_info['dataset_mode'] = self.dataset_mode
out_data_list.append(data_info)
del data_list
return out_data_list
# Copyright (c) OpenMMLab. All rights reserved.
import csv
import os.path as osp
from collections import defaultdict
from typing import Dict, List, Optional
import numpy as np
from mmengine.fileio import get_local_path, load
from mmengine.utils import is_abs
from mmdet.registry import DATASETS
from .base_det_dataset import BaseDetDataset
@DATASETS.register_module()
class OpenImagesDataset(BaseDetDataset):
"""Open Images dataset for detection.
Args:
ann_file (str): Annotation file path.
label_file (str): File path of the label description file that
maps the classes names in MID format to their short
descriptions.
meta_file (str): File path to get image metas.
hierarchy_file (str): The file path of the class hierarchy.
image_level_ann_file (str): Human-verified image level annotation,
which is used in evaluation.
backend_args (dict, optional): Arguments to instantiate the
corresponding backend. Defaults to None.
"""
METAINFO: dict = dict(dataset_type='oid_v6')
def __init__(self,
label_file: str,
meta_file: str,
hierarchy_file: str,
image_level_ann_file: Optional[str] = None,
**kwargs) -> None:
self.label_file = label_file
self.meta_file = meta_file
self.hierarchy_file = hierarchy_file
self.image_level_ann_file = image_level_ann_file
super().__init__(**kwargs)
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
"""
classes_names, label_id_mapping = self._parse_label_file(
self.label_file)
self._metainfo['classes'] = classes_names
self.label_id_mapping = label_id_mapping
if self.image_level_ann_file is not None:
img_level_anns = self._parse_img_level_ann(
self.image_level_ann_file)
else:
img_level_anns = None
# OpenImagesMetric can get the relation matrix from the dataset meta
relation_matrix = self._get_relation_matrix(self.hierarchy_file)
self._metainfo['RELATION_MATRIX'] = relation_matrix
data_list = []
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
last_img_id = None
instances = []
for i, line in enumerate(reader):
if i == 0:
continue
img_id = line[0]
if last_img_id is None:
last_img_id = img_id
label_id = line[2]
assert label_id in self.label_id_mapping
label = int(self.label_id_mapping[label_id])
bbox = [
float(line[4]), # xmin
float(line[6]), # ymin
float(line[5]), # xmax
float(line[7]) # ymax
]
is_occluded = True if int(line[8]) == 1 else False
is_truncated = True if int(line[9]) == 1 else False
is_group_of = True if int(line[10]) == 1 else False
is_depiction = True if int(line[11]) == 1 else False
is_inside = True if int(line[12]) == 1 else False
instance = dict(
bbox=bbox,
bbox_label=label,
ignore_flag=0,
is_occluded=is_occluded,
is_truncated=is_truncated,
is_group_of=is_group_of,
is_depiction=is_depiction,
is_inside=is_inside)
last_img_path = osp.join(self.data_prefix['img'],
f'{last_img_id}.jpg')
if img_id != last_img_id:
# switch to a new image, record previous image's data.
data_info = dict(
img_path=last_img_path,
img_id=last_img_id,
instances=instances,
)
data_list.append(data_info)
instances = []
instances.append(instance)
last_img_id = img_id
data_list.append(
dict(
img_path=last_img_path,
img_id=last_img_id,
instances=instances,
))
# add image metas to data list
img_metas = load(
self.meta_file, file_format='pkl', backend_args=self.backend_args)
assert len(img_metas) == len(data_list)
for i, meta in enumerate(img_metas):
img_id = data_list[i]['img_id']
assert f'{img_id}.jpg' == osp.split(meta['filename'])[-1]
h, w = meta['ori_shape'][:2]
data_list[i]['height'] = h
data_list[i]['width'] = w
# denormalize bboxes
for j in range(len(data_list[i]['instances'])):
data_list[i]['instances'][j]['bbox'][0] *= w
data_list[i]['instances'][j]['bbox'][2] *= w
data_list[i]['instances'][j]['bbox'][1] *= h
data_list[i]['instances'][j]['bbox'][3] *= h
# add image-level annotation
if img_level_anns is not None:
img_labels = []
confidences = []
img_ann_list = img_level_anns.get(img_id, [])
for ann in img_ann_list:
img_labels.append(int(ann['image_level_label']))
confidences.append(float(ann['confidence']))
data_list[i]['image_level_labels'] = np.array(
img_labels, dtype=np.int64)
data_list[i]['confidences'] = np.array(
confidences, dtype=np.float32)
return data_list
def _parse_label_file(self, label_file: str) -> tuple:
"""Get classes name and index mapping from cls-label-description file.
Args:
label_file (str): File path of the label description file that
maps the classes names in MID format to their short
descriptions.
Returns:
tuple: Class name of OpenImages.
"""
index_list = []
classes_names = []
with get_local_path(
label_file, backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
for line in reader:
# self.cat2label[line[0]] = line[1]
classes_names.append(line[1])
index_list.append(line[0])
index_mapping = {index: i for i, index in enumerate(index_list)}
return classes_names, index_mapping
def _parse_img_level_ann(self,
img_level_ann_file: str) -> Dict[str, List[dict]]:
"""Parse image level annotations from csv style ann_file.
Args:
img_level_ann_file (str): CSV style image level annotation
file path.
Returns:
Dict[str, List[dict]]: Annotations where item of the defaultdict
indicates an image, each of which has (n) dicts.
Keys of dicts are:
- `image_level_label` (int): Label id.
- `confidence` (float): Labels that are human-verified to be
present in an image have confidence = 1 (positive labels).
Labels that are human-verified to be absent from an image
have confidence = 0 (negative labels). Machine-generated
labels have fractional confidences, generally >= 0.5.
The higher the confidence, the smaller the chance for
the label to be a false positive.
"""
item_lists = defaultdict(list)
with get_local_path(
img_level_ann_file,
backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
for i, line in enumerate(reader):
if i == 0:
continue
img_id = line[0]
item_lists[img_id].append(
dict(
image_level_label=int(
self.label_id_mapping[line[2]]),
confidence=float(line[3])))
return item_lists
def _get_relation_matrix(self, hierarchy_file: str) -> np.ndarray:
"""Get the matrix of class hierarchy from the hierarchy file. Hierarchy
for 600 classes can be found at https://storage.googleapis.com/openimag
es/2018_04/bbox_labels_600_hierarchy_visualizer/circle.html.
Args:
hierarchy_file (str): File path to the hierarchy for classes.
Returns:
np.ndarray: The matrix of the corresponding relationship between
the parent class and the child class, of shape
(class_num, class_num).
""" # noqa
hierarchy = load(
hierarchy_file, file_format='json', backend_args=self.backend_args)
class_num = len(self._metainfo['classes'])
relation_matrix = np.eye(class_num, class_num)
relation_matrix = self._convert_hierarchy_tree(hierarchy,
relation_matrix)
return relation_matrix
def _convert_hierarchy_tree(self,
hierarchy_map: dict,
relation_matrix: np.ndarray,
parents: list = [],
get_all_parents: bool = True) -> np.ndarray:
"""Get matrix of the corresponding relationship between the parent
class and the child class.
Args:
hierarchy_map (dict): Including label name and corresponding
subcategory. Keys of dicts are:
- `LabeName` (str): Name of the label.
- `Subcategory` (dict | list): Corresponding subcategory(ies).
relation_matrix (ndarray): The matrix of the corresponding
relationship between the parent class and the child class,
of shape (class_num, class_num).
parents (list): Corresponding parent class.
get_all_parents (bool): Whether get all parent names.
Default: True
Returns:
ndarray: The matrix of the corresponding relationship between
the parent class and the child class, of shape
(class_num, class_num).
"""
if 'Subcategory' in hierarchy_map:
for node in hierarchy_map['Subcategory']:
if 'LabelName' in node:
children_name = node['LabelName']
children_index = self.label_id_mapping[children_name]
children = [children_index]
else:
continue
if len(parents) > 0:
for parent_index in parents:
if get_all_parents:
children.append(parent_index)
relation_matrix[children_index, parent_index] = 1
relation_matrix = self._convert_hierarchy_tree(
node, relation_matrix, parents=children)
return relation_matrix
def _join_prefix(self):
"""Join ``self.data_root`` with annotation path."""
super()._join_prefix()
if not is_abs(self.label_file) and self.label_file:
self.label_file = osp.join(self.data_root, self.label_file)
if not is_abs(self.meta_file) and self.meta_file:
self.meta_file = osp.join(self.data_root, self.meta_file)
if not is_abs(self.hierarchy_file) and self.hierarchy_file:
self.hierarchy_file = osp.join(self.data_root, self.hierarchy_file)
if self.image_level_ann_file and not is_abs(self.image_level_ann_file):
self.image_level_ann_file = osp.join(self.data_root,
self.image_level_ann_file)
@DATASETS.register_module()
class OpenImagesChallengeDataset(OpenImagesDataset):
"""Open Images Challenge dataset for detection.
Args:
ann_file (str): Open Images Challenge box annotation in txt format.
"""
METAINFO: dict = dict(dataset_type='oid_challenge')
def __init__(self, ann_file: str, **kwargs) -> None:
if not ann_file.endswith('txt'):
raise TypeError('The annotation file of Open Images Challenge '
'should be a txt file.')
super().__init__(ann_file=ann_file, **kwargs)
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
"""
classes_names, label_id_mapping = self._parse_label_file(
self.label_file)
self._metainfo['classes'] = classes_names
self.label_id_mapping = label_id_mapping
if self.image_level_ann_file is not None:
img_level_anns = self._parse_img_level_ann(
self.image_level_ann_file)
else:
img_level_anns = None
# OpenImagesMetric can get the relation matrix from the dataset meta
relation_matrix = self._get_relation_matrix(self.hierarchy_file)
self._metainfo['RELATION_MATRIX'] = relation_matrix
data_list = []
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
lines = f.readlines()
i = 0
while i < len(lines):
instances = []
filename = lines[i].rstrip()
i += 2
img_gt_size = int(lines[i])
i += 1
for j in range(img_gt_size):
sp = lines[i + j].split()
instances.append(
dict(
bbox=[
float(sp[1]),
float(sp[2]),
float(sp[3]),
float(sp[4])
],
bbox_label=int(sp[0]) - 1, # labels begin from 1
ignore_flag=0,
is_group_ofs=True if int(sp[5]) == 1 else False))
i += img_gt_size
data_list.append(
dict(
img_path=osp.join(self.data_prefix['img'], filename),
instances=instances,
))
# add image metas to data list
img_metas = load(
self.meta_file, file_format='pkl', backend_args=self.backend_args)
assert len(img_metas) == len(data_list)
for i, meta in enumerate(img_metas):
img_id = osp.split(data_list[i]['img_path'])[-1][:-4]
assert img_id == osp.split(meta['filename'])[-1][:-4]
h, w = meta['ori_shape'][:2]
data_list[i]['height'] = h
data_list[i]['width'] = w
data_list[i]['img_id'] = img_id
# denormalize bboxes
for j in range(len(data_list[i]['instances'])):
data_list[i]['instances'][j]['bbox'][0] *= w
data_list[i]['instances'][j]['bbox'][2] *= w
data_list[i]['instances'][j]['bbox'][1] *= h
data_list[i]['instances'][j]['bbox'][3] *= h
# add image-level annotation
if img_level_anns is not None:
img_labels = []
confidences = []
img_ann_list = img_level_anns.get(img_id, [])
for ann in img_ann_list:
img_labels.append(int(ann['image_level_label']))
confidences.append(float(ann['confidence']))
data_list[i]['image_level_labels'] = np.array(
img_labels, dtype=np.int64)
data_list[i]['confidences'] = np.array(
confidences, dtype=np.float32)
return data_list
def _parse_label_file(self, label_file: str) -> tuple:
"""Get classes name and index mapping from cls-label-description file.
Args:
label_file (str): File path of the label description file that
maps the classes names in MID format to their short
descriptions.
Returns:
tuple: Class name of OpenImages.
"""
label_list = []
id_list = []
index_mapping = {}
with get_local_path(
label_file, backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
for line in reader:
label_name = line[0]
label_id = int(line[2])
label_list.append(line[1])
id_list.append(label_id)
index_mapping[label_name] = label_id - 1
indexes = np.argsort(id_list)
classes_names = []
for index in indexes:
classes_names.append(label_list[index])
return classes_names, index_mapping
def _parse_img_level_ann(self, image_level_ann_file):
"""Parse image level annotations from csv style ann_file.
Args:
image_level_ann_file (str): CSV style image level annotation
file path.
Returns:
defaultdict[list[dict]]: Annotations where item of the defaultdict
indicates an image, each of which has (n) dicts.
Keys of dicts are:
- `image_level_label` (int): of shape 1.
- `confidence` (float): of shape 1.
"""
item_lists = defaultdict(list)
with get_local_path(
image_level_ann_file,
backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
i = -1
for line in reader:
i += 1
if i == 0:
continue
else:
img_id = line[0]
label_id = line[1]
assert label_id in self.label_id_mapping
image_level_label = int(
self.label_id_mapping[label_id])
confidence = float(line[2])
item_lists[img_id].append(
dict(
image_level_label=image_level_label,
confidence=confidence))
return item_lists
def _get_relation_matrix(self, hierarchy_file: str) -> np.ndarray:
"""Get the matrix of class hierarchy from the hierarchy file.
Args:
hierarchy_file (str): File path to the hierarchy for classes.
Returns:
np.ndarray: The matrix of the corresponding
relationship between the parent class and the child class,
of shape (class_num, class_num).
"""
with get_local_path(
hierarchy_file, backend_args=self.backend_args) as local_path:
class_label_tree = np.load(local_path, allow_pickle=True)
return class_label_tree[1:, 1:]
# Copyright (c) OpenMMLab. All rights reserved.
import collections
import os.path as osp
import random
from typing import Dict, List
import mmengine
from mmengine.dataset import BaseDataset
from mmdet.registry import DATASETS
@DATASETS.register_module()
class RefCocoDataset(BaseDataset):
"""RefCOCO dataset.
The `Refcoco` and `Refcoco+` dataset is based on
`ReferItGame: Referring to Objects in Photographs of Natural Scenes
<http://tamaraberg.com/papers/referit.pdf>`_.
The `Refcocog` dataset is based on
`Generation and Comprehension of Unambiguous Object Descriptions
<https://arxiv.org/abs/1511.02283>`_.
Args:
ann_file (str): Annotation file path.
data_root (str): The root directory for ``data_prefix`` and
``ann_file``. Defaults to ''.
data_prefix (str): Prefix for training data.
split_file (str): Split file path.
split (str): Split name. Defaults to 'train'.
text_mode (str): Text mode. Defaults to 'random'.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def __init__(self,
data_root: str,
ann_file: str,
split_file: str,
data_prefix: Dict,
split: str = 'train',
text_mode: str = 'random',
**kwargs):
self.split_file = split_file
self.split = split
assert text_mode in ['original', 'random', 'concat', 'select_first']
self.text_mode = text_mode
super().__init__(
data_root=data_root,
data_prefix=data_prefix,
ann_file=ann_file,
**kwargs,
)
def _join_prefix(self):
if not mmengine.is_abs(self.split_file) and self.split_file:
self.split_file = osp.join(self.data_root, self.split_file)
return super()._join_prefix()
def _init_refs(self):
"""Initialize the refs for RefCOCO."""
anns, imgs = {}, {}
for ann in self.instances['annotations']:
anns[ann['id']] = ann
for img in self.instances['images']:
imgs[img['id']] = img
refs, ref_to_ann = {}, {}
for ref in self.splits:
# ids
ref_id = ref['ref_id']
ann_id = ref['ann_id']
# add mapping related to ref
refs[ref_id] = ref
ref_to_ann[ref_id] = anns[ann_id]
self.refs = refs
self.ref_to_ann = ref_to_ann
def load_data_list(self) -> List[dict]:
"""Load data list."""
self.splits = mmengine.load(self.split_file, file_format='pkl')
self.instances = mmengine.load(self.ann_file, file_format='json')
self._init_refs()
img_prefix = self.data_prefix['img_path']
ref_ids = [
ref['ref_id'] for ref in self.splits if ref['split'] == self.split
]
full_anno = []
for ref_id in ref_ids:
ref = self.refs[ref_id]
ann = self.ref_to_ann[ref_id]
ann.update(ref)
full_anno.append(ann)
image_id_list = []
final_anno = {}
for anno in full_anno:
image_id_list.append(anno['image_id'])
final_anno[anno['ann_id']] = anno
annotations = [value for key, value in final_anno.items()]
coco_train_id = []
image_annot = {}
for i in range(len(self.instances['images'])):
coco_train_id.append(self.instances['images'][i]['id'])
image_annot[self.instances['images'][i]
['id']] = self.instances['images'][i]
images = []
for image_id in list(set(image_id_list)):
images += [image_annot[image_id]]
data_list = []
grounding_dict = collections.defaultdict(list)
for anno in annotations:
image_id = int(anno['image_id'])
grounding_dict[image_id].append(anno)
join_path = mmengine.fileio.get_file_backend(img_prefix).join_path
for image in images:
img_id = image['id']
instances = []
sentences = []
for grounding_anno in grounding_dict[img_id]:
texts = [x['raw'].lower() for x in grounding_anno['sentences']]
# random select one text
if self.text_mode == 'random':
idx = random.randint(0, len(texts) - 1)
text = [texts[idx]]
# concat all texts
elif self.text_mode == 'concat':
text = [''.join(texts)]
# select the first text
elif self.text_mode == 'select_first':
text = [texts[0]]
# use all texts
elif self.text_mode == 'original':
text = texts
else:
raise ValueError(f'Invalid text mode "{self.text_mode}".')
ins = [{
'mask': grounding_anno['segmentation'],
'ignore_flag': 0
}] * len(text)
instances.extend(ins)
sentences.extend(text)
data_info = {
'img_path': join_path(img_prefix, image['file_name']),
'img_id': img_id,
'instances': instances,
'text': sentences
}
data_list.append(data_info)
if len(data_list) == 0:
raise ValueError(f'No sample in split "{self.split}".')
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from collections import defaultdict
from typing import Any, Dict, List
import numpy as np
from mmengine.dataset import BaseDataset
from mmengine.utils import check_file_exist
from mmdet.registry import DATASETS
@DATASETS.register_module()
class ReIDDataset(BaseDataset):
"""Dataset for ReID.
Args:
triplet_sampler (dict, optional): The sampler for hard mining
triplet loss. Defaults to None.
keys: num_ids (int): The number of person ids.
ins_per_id (int): The number of image for each person.
"""
def __init__(self, triplet_sampler: dict = None, *args, **kwargs):
self.triplet_sampler = triplet_sampler
super().__init__(*args, **kwargs)
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ''self.ann_file''.
Returns:
list[dict]: A list of annotation.
"""
assert isinstance(self.ann_file, str)
check_file_exist(self.ann_file)
data_list = []
with open(self.ann_file) as f:
samples = [x.strip().split(' ') for x in f.readlines()]
for filename, gt_label in samples:
info = dict(img_prefix=self.data_prefix)
if self.data_prefix['img_path'] is not None:
info['img_path'] = osp.join(self.data_prefix['img_path'],
filename)
else:
info['img_path'] = filename
info['gt_label'] = np.array(gt_label, dtype=np.int64)
data_list.append(info)
self._parse_ann_info(data_list)
return data_list
def _parse_ann_info(self, data_list: List[dict]):
"""Parse person id annotations."""
index_tmp_dic = defaultdict(list) # pid->[idx1,...,idxN]
self.index_dic = dict() # pid->array([idx1,...,idxN])
for idx, info in enumerate(data_list):
pid = info['gt_label']
index_tmp_dic[int(pid)].append(idx)
for pid, idxs in index_tmp_dic.items():
self.index_dic[pid] = np.asarray(idxs, dtype=np.int64)
self.pids = np.asarray(list(self.index_dic.keys()), dtype=np.int64)
def prepare_data(self, idx: int) -> Any:
"""Get data processed by ''self.pipeline''.
Args:
idx (int): The index of ''data_info''
Returns:
Any: Depends on ''self.pipeline''
"""
data_info = self.get_data_info(idx)
if self.triplet_sampler is not None:
img_info = self.triplet_sampling(data_info['gt_label'],
**self.triplet_sampler)
data_info = copy.deepcopy(img_info) # triplet -> list
else:
data_info = copy.deepcopy(data_info) # no triplet -> dict
return self.pipeline(data_info)
def triplet_sampling(self,
pos_pid,
num_ids: int = 8,
ins_per_id: int = 4) -> Dict:
"""Triplet sampler for hard mining triplet loss. First, for one
pos_pid, random sample ins_per_id images with same person id.
Then, random sample num_ids - 1 images for each negative id.
Finally, random sample ins_per_id images for each negative id.
Args:
pos_pid (ndarray): The person id of the anchor.
num_ids (int): The number of person ids.
ins_per_id (int): The number of images for each person.
Returns:
Dict: Annotation information of num_ids X ins_per_id images.
"""
assert len(self.pids) >= num_ids, \
'The number of person ids in the training set must ' \
'be greater than the number of person ids in the sample.'
pos_idxs = self.index_dic[int(
pos_pid)] # all positive idxs for pos_pid
idxs_list = []
# select positive samplers
idxs_list.extend(pos_idxs[np.random.choice(
pos_idxs.shape[0], ins_per_id, replace=True)])
# select negative ids
neg_pids = np.random.choice(
[i for i, _ in enumerate(self.pids) if i != pos_pid],
num_ids - 1,
replace=False)
# select negative samplers for each negative id
for neg_pid in neg_pids:
neg_idxs = self.index_dic[neg_pid]
idxs_list.extend(neg_idxs[np.random.choice(
neg_idxs.shape[0], ins_per_id, replace=True)])
# return the final triplet batch
triplet_img_infos = []
for idx in idxs_list:
triplet_img_infos.append(copy.deepcopy(self.get_data_info(idx)))
# Collect data_list scatters (list of dict -> dict of list)
out = dict()
for key in triplet_img_infos[0].keys():
out[key] = [_info[key] for _info in triplet_img_infos]
return out
# Copyright (c) OpenMMLab. All rights reserved.
from .batch_sampler import (AspectRatioBatchSampler,
MultiDataAspectRatioBatchSampler,
TrackAspectRatioBatchSampler)
from .class_aware_sampler import ClassAwareSampler
from .custom_sample_size_sampler import CustomSampleSizeSampler
from .multi_data_sampler import MultiDataSampler
from .multi_source_sampler import GroupMultiSourceSampler, MultiSourceSampler
from .track_img_sampler import TrackImgSampler
__all__ = [
'ClassAwareSampler', 'AspectRatioBatchSampler', 'MultiSourceSampler',
'GroupMultiSourceSampler', 'TrackImgSampler',
'TrackAspectRatioBatchSampler', 'MultiDataSampler',
'MultiDataAspectRatioBatchSampler', 'CustomSampleSizeSampler'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
from torch.utils.data import BatchSampler, Sampler
from mmdet.datasets.samplers.track_img_sampler import TrackImgSampler
from mmdet.registry import DATA_SAMPLERS
# TODO: maybe replace with a data_loader wrapper
@DATA_SAMPLERS.register_module()
class AspectRatioBatchSampler(BatchSampler):
"""A sampler wrapper for grouping images with similar aspect ratio (< 1 or.
>= 1) into a same batch.
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``.
"""
def __init__(self,
sampler: Sampler,
batch_size: int,
drop_last: bool = False) -> None:
if not isinstance(sampler, Sampler):
raise TypeError('sampler should be an instance of ``Sampler``, '
f'but got {sampler}')
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError('batch_size should be a positive integer value, '
f'but got batch_size={batch_size}')
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
# two groups for w < h and w >= h
self._aspect_ratio_buckets = [[] for _ in range(2)]
def __iter__(self) -> Sequence[int]:
for idx in self.sampler:
data_info = self.sampler.dataset.get_data_info(idx)
width, height = data_info['width'], data_info['height']
bucket_id = 0 if width < height else 1
bucket = self._aspect_ratio_buckets[bucket_id]
bucket.append(idx)
# yield a batch of indices in the same aspect ratio group
if len(bucket) == self.batch_size:
yield bucket[:]
del bucket[:]
# yield the rest data and reset the bucket
left_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[
1]
self._aspect_ratio_buckets = [[] for _ in range(2)]
while len(left_data) > 0:
if len(left_data) <= self.batch_size:
if not self.drop_last:
yield left_data[:]
left_data = []
else:
yield left_data[:self.batch_size]
left_data = left_data[self.batch_size:]
def __len__(self) -> int:
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
@DATA_SAMPLERS.register_module()
class TrackAspectRatioBatchSampler(AspectRatioBatchSampler):
"""A sampler wrapper for grouping images with similar aspect ratio (< 1 or.
>= 1) into a same batch.
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``.
"""
def __iter__(self) -> Sequence[int]:
for idx in self.sampler:
# hard code to solve TrackImgSampler
if isinstance(self.sampler, TrackImgSampler):
video_idx, _ = idx
else:
video_idx = idx
# video_idx
data_info = self.sampler.dataset.get_data_info(video_idx)
# data_info {video_id, images, video_length}
img_data_info = data_info['images'][0]
width, height = img_data_info['width'], img_data_info['height']
bucket_id = 0 if width < height else 1
bucket = self._aspect_ratio_buckets[bucket_id]
bucket.append(idx)
# yield a batch of indices in the same aspect ratio group
if len(bucket) == self.batch_size:
yield bucket[:]
del bucket[:]
# yield the rest data and reset the bucket
left_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[
1]
self._aspect_ratio_buckets = [[] for _ in range(2)]
while len(left_data) > 0:
if len(left_data) <= self.batch_size:
if not self.drop_last:
yield left_data[:]
left_data = []
else:
yield left_data[:self.batch_size]
left_data = left_data[self.batch_size:]
@DATA_SAMPLERS.register_module()
class MultiDataAspectRatioBatchSampler(BatchSampler):
"""A sampler wrapper for grouping images with similar aspect ratio (< 1 or.
>= 1) into a same batch for multi-source datasets.
Args:
sampler (Sampler): Base sampler.
batch_size (Sequence(int)): Size of mini-batch for multi-source
datasets.
num_datasets(int): Number of multi-source datasets.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``.
"""
def __init__(self,
sampler: Sampler,
batch_size: Sequence[int],
num_datasets: int,
drop_last: bool = True) -> None:
if not isinstance(sampler, Sampler):
raise TypeError('sampler should be an instance of ``Sampler``, '
f'but got {sampler}')
self.sampler = sampler
self.batch_size = batch_size
self.num_datasets = num_datasets
self.drop_last = drop_last
# two groups for w < h and w >= h for each dataset --> 2 * num_datasets
self._buckets = [[] for _ in range(2 * self.num_datasets)]
def __iter__(self) -> Sequence[int]:
for idx in self.sampler:
data_info = self.sampler.dataset.get_data_info(idx)
width, height = data_info['width'], data_info['height']
dataset_source_idx = self.sampler.dataset.get_dataset_source(idx)
aspect_ratio_bucket_id = 0 if width < height else 1
bucket_id = dataset_source_idx * 2 + aspect_ratio_bucket_id
bucket = self._buckets[bucket_id]
bucket.append(idx)
# yield a batch of indices in the same aspect ratio group
if len(bucket) == self.batch_size[dataset_source_idx]:
yield bucket[:]
del bucket[:]
# yield the rest data and reset the bucket
for i in range(self.num_datasets):
left_data = self._buckets[i * 2 + 0] + self._buckets[i * 2 + 1]
while len(left_data) > 0:
if len(left_data) <= self.batch_size[i]:
if not self.drop_last:
yield left_data[:]
left_data = []
else:
yield left_data[:self.batch_size[i]]
left_data = left_data[self.batch_size[i]:]
self._buckets = [[] for _ in range(2 * self.num_datasets)]
def __len__(self) -> int:
sizes = [0 for _ in range(self.num_datasets)]
for idx in self.sampler:
dataset_source_idx = self.sampler.dataset.get_dataset_source(idx)
sizes[dataset_source_idx] += 1
if self.drop_last:
lens = 0
for i in range(self.num_datasets):
lens += sizes[i] // self.batch_size[i]
return lens
else:
lens = 0
for i in range(self.num_datasets):
lens += (sizes[i] + self.batch_size[i] -
1) // self.batch_size[i]
return lens
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Dict, Iterator, Optional, Union
import numpy as np
import torch
from mmengine.dataset import BaseDataset
from mmengine.dist import get_dist_info, sync_random_seed
from torch.utils.data import Sampler
from mmdet.registry import DATA_SAMPLERS
@DATA_SAMPLERS.register_module()
class ClassAwareSampler(Sampler):
r"""Sampler that restricts data loading to the label of the dataset.
A class-aware sampling strategy to effectively tackle the
non-uniform class distribution. The length of the training data is
consistent with source data. Simple improvements based on `Relay
Backpropagation for Effective Learning of Deep Convolutional
Neural Networks <https://arxiv.org/abs/1512.05830>`_
The implementation logic is referred to
https://github.com/Sense-X/TSD/blob/master/mmdet/datasets/samplers/distributed_classaware_sampler.py
Args:
dataset: Dataset used for sampling.
seed (int, optional): random seed used to shuffle the sampler.
This number should be identical across all
processes in the distributed group. Defaults to None.
num_sample_class (int): The number of samples taken from each
per-label list. Defaults to 1.
"""
def __init__(self,
dataset: BaseDataset,
seed: Optional[int] = None,
num_sample_class: int = 1) -> None:
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
self.epoch = 0
# Must be the same across all workers. If None, will use a
# random seed shared among workers
# (require synchronization among all workers)
if seed is None:
seed = sync_random_seed()
self.seed = seed
# The number of samples taken from each per-label list
assert num_sample_class > 0 and isinstance(num_sample_class, int)
self.num_sample_class = num_sample_class
# Get per-label image list from dataset
self.cat_dict = self.get_cat2imgs()
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / world_size))
self.total_size = self.num_samples * self.world_size
# get number of images containing each category
self.num_cat_imgs = [len(x) for x in self.cat_dict.values()]
# filter labels without images
self.valid_cat_inds = [
i for i, length in enumerate(self.num_cat_imgs) if length != 0
]
self.num_classes = len(self.valid_cat_inds)
def get_cat2imgs(self) -> Dict[int, list]:
"""Get a dict with class as key and img_ids as values.
Returns:
dict[int, list]: A dict of per-label image list,
the item of the dict indicates a label index,
corresponds to the image index that contains the label.
"""
classes = self.dataset.metainfo.get('classes', None)
if classes is None:
raise ValueError('dataset metainfo must contain `classes`')
# sort the label index
cat2imgs = {i: [] for i in range(len(classes))}
for i in range(len(self.dataset)):
cat_ids = set(self.dataset.get_cat_ids(i))
for cat in cat_ids:
cat2imgs[cat].append(i)
return cat2imgs
def __iter__(self) -> Iterator[int]:
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch + self.seed)
# initialize label list
label_iter_list = RandomCycleIter(self.valid_cat_inds, generator=g)
# initialize each per-label image list
data_iter_dict = dict()
for i in self.valid_cat_inds:
data_iter_dict[i] = RandomCycleIter(self.cat_dict[i], generator=g)
def gen_cat_img_inds(cls_list, data_dict, num_sample_cls):
"""Traverse the categories and extract `num_sample_cls` image
indexes of the corresponding categories one by one."""
id_indices = []
for _ in range(len(cls_list)):
cls_idx = next(cls_list)
for _ in range(num_sample_cls):
id = next(data_dict[cls_idx])
id_indices.append(id)
return id_indices
# deterministically shuffle based on epoch
num_bins = int(
math.ceil(self.total_size * 1.0 / self.num_classes /
self.num_sample_class))
indices = []
for i in range(num_bins):
indices += gen_cat_img_inds(label_iter_list, data_iter_dict,
self.num_sample_class)
# fix extra samples to make it evenly divisible
if len(indices) >= self.total_size:
indices = indices[:self.total_size]
else:
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
offset = self.num_samples * self.rank
indices = indices[offset:offset + self.num_samples]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self) -> int:
"""The number of samples in this rank."""
return self.num_samples
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all replicas use a different
random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
class RandomCycleIter:
"""Shuffle the list and do it again after the list have traversed.
The implementation logic is referred to
https://github.com/wutong16/DistributionBalancedLoss/blob/master/mllt/datasets/loader/sampler.py
Example:
>>> label_list = [0, 1, 2, 4, 5]
>>> g = torch.Generator()
>>> g.manual_seed(0)
>>> label_iter_list = RandomCycleIter(label_list, generator=g)
>>> index = next(label_iter_list)
Args:
data (list or ndarray): The data that needs to be shuffled.
generator: An torch.Generator object, which is used in setting the seed
for generating random numbers.
""" # noqa: W605
def __init__(self,
data: Union[list, np.ndarray],
generator: torch.Generator = None) -> None:
self.data = data
self.length = len(data)
self.index = torch.randperm(self.length, generator=generator).numpy()
self.i = 0
self.generator = generator
def __iter__(self) -> Iterator:
return self
def __len__(self) -> int:
return len(self.data)
def __next__(self):
if self.i == self.length:
self.index = torch.randperm(
self.length, generator=self.generator).numpy()
self.i = 0
idx = self.data[self.index[self.i]]
self.i += 1
return idx
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Iterator, Optional, Sequence, Sized
import torch
from mmengine.dist import get_dist_info, sync_random_seed
from torch.utils.data import Sampler
from mmdet.registry import DATA_SAMPLERS
from .class_aware_sampler import RandomCycleIter
@DATA_SAMPLERS.register_module()
class CustomSampleSizeSampler(Sampler):
def __init__(self,
dataset: Sized,
dataset_size: Sequence[int],
ratio_mode: bool = False,
seed: Optional[int] = None,
round_up: bool = True) -> None:
assert len(dataset.datasets) == len(dataset_size)
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
if seed is None:
seed = sync_random_seed()
self.seed = seed
self.epoch = 0
self.round_up = round_up
total_size = 0
total_size_fake = 0
self.dataset_index = []
self.dataset_cycle_iter = []
new_dataset_size = []
for dataset, size in zip(dataset.datasets, dataset_size):
self.dataset_index.append(
list(range(total_size_fake,
len(dataset) + total_size_fake)))
total_size_fake += len(dataset)
if size == -1:
total_size += len(dataset)
self.dataset_cycle_iter.append(None)
new_dataset_size.append(-1)
else:
if ratio_mode:
size = int(size * len(dataset))
assert size <= len(
dataset
), f'dataset size {size} is larger than ' \
f'dataset length {len(dataset)}'
total_size += size
new_dataset_size.append(size)
g = torch.Generator()
g.manual_seed(self.seed)
self.dataset_cycle_iter.append(
RandomCycleIter(self.dataset_index[-1], generator=g))
self.dataset_size = new_dataset_size
if self.round_up:
self.num_samples = math.ceil(total_size / world_size)
self.total_size = self.num_samples * self.world_size
else:
self.num_samples = math.ceil((total_size - rank) / world_size)
self.total_size = total_size
def __iter__(self) -> Iterator[int]:
"""Iterate the indices."""
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
out_index = []
for data_size, data_index, cycle_iter in zip(self.dataset_size,
self.dataset_index,
self.dataset_cycle_iter):
if data_size == -1:
out_index += data_index
else:
index = [next(cycle_iter) for _ in range(data_size)]
out_index += index
index = torch.randperm(len(out_index), generator=g).numpy().tolist()
indices = [out_index[i] for i in index]
if self.round_up:
indices = (
indices *
int(self.total_size / len(indices) + 1))[:self.total_size]
indices = indices[self.rank:self.total_size:self.world_size]
return iter(indices)
def __len__(self) -> int:
"""The number of samples in this rank."""
return self.num_samples
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all replicas use a different
random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Iterator, Optional, Sequence, Sized
import torch
from mmengine.dist import get_dist_info, sync_random_seed
from mmengine.registry import DATA_SAMPLERS
from torch.utils.data import Sampler
@DATA_SAMPLERS.register_module()
class MultiDataSampler(Sampler):
"""The default data sampler for both distributed and non-distributed
environment.
It has several differences from the PyTorch ``DistributedSampler`` as
below:
1. This sampler supports non-distributed environment.
2. The round up behaviors are a little different.
- If ``round_up=True``, this sampler will add extra samples to make the
number of samples is evenly divisible by the world size. And
this behavior is the same as the ``DistributedSampler`` with
``drop_last=False``.
- If ``round_up=False``, this sampler won't remove or add any samples
while the ``DistributedSampler`` with ``drop_last=True`` will remove
tail samples.
Args:
dataset (Sized): The dataset.
dataset_ratio (Sequence(int)) The ratios of different datasets.
seed (int, optional): Random seed used to shuffle the sampler if
:attr:`shuffle=True`. This number should be identical across all
processes in the distributed group. Defaults to None.
round_up (bool): Whether to add extra samples to make the number of
samples evenly divisible by the world size. Defaults to True.
"""
def __init__(self,
dataset: Sized,
dataset_ratio: Sequence[int],
seed: Optional[int] = None,
round_up: bool = True) -> None:
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
self.dataset_ratio = dataset_ratio
if seed is None:
seed = sync_random_seed()
self.seed = seed
self.epoch = 0
self.round_up = round_up
if self.round_up:
self.num_samples = math.ceil(len(self.dataset) / world_size)
self.total_size = self.num_samples * self.world_size
else:
self.num_samples = math.ceil(
(len(self.dataset) - rank) / world_size)
self.total_size = len(self.dataset)
self.sizes = [len(dataset) for dataset in self.dataset.datasets]
dataset_weight = [
torch.ones(s) * max(self.sizes) / s * r / sum(self.dataset_ratio)
for i, (r, s) in enumerate(zip(self.dataset_ratio, self.sizes))
]
self.weights = torch.cat(dataset_weight)
def __iter__(self) -> Iterator[int]:
"""Iterate the indices."""
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.multinomial(
self.weights, len(self.weights), generator=g,
replacement=True).tolist()
# add extra samples to make it evenly divisible
if self.round_up:
indices = (
indices *
int(self.total_size / len(indices) + 1))[:self.total_size]
# subsample
indices = indices[self.rank:self.total_size:self.world_size]
return iter(indices)
def __len__(self) -> int:
"""The number of samples in this rank."""
return self.num_samples
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all replicas use a different
random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from typing import Iterator, List, Optional, Sized, Union
import numpy as np
import torch
from mmengine.dataset import BaseDataset
from mmengine.dist import get_dist_info, sync_random_seed
from torch.utils.data import Sampler
from mmdet.registry import DATA_SAMPLERS
@DATA_SAMPLERS.register_module()
class MultiSourceSampler(Sampler):
r"""Multi-Source Infinite Sampler.
According to the sampling ratio, sample data from different
datasets to form batches.
Args:
dataset (Sized): The dataset.
batch_size (int): Size of mini-batch.
source_ratio (list[int | float]): The sampling ratio of different
source datasets in a mini-batch.
shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
seed (int, optional): Random seed. If None, set a random seed.
Defaults to None.
Examples:
>>> dataset_type = 'ConcatDataset'
>>> sub_dataset_type = 'CocoDataset'
>>> data_root = 'data/coco/'
>>> sup_ann = '../coco_semi_annos/instances_train2017.1@10.json'
>>> unsup_ann = '../coco_semi_annos/' \
>>> 'instances_train2017.1@10-unlabeled.json'
>>> dataset = dict(type=dataset_type,
>>> datasets=[
>>> dict(
>>> type=sub_dataset_type,
>>> data_root=data_root,
>>> ann_file=sup_ann,
>>> data_prefix=dict(img='train2017/'),
>>> filter_cfg=dict(filter_empty_gt=True, min_size=32),
>>> pipeline=sup_pipeline),
>>> dict(
>>> type=sub_dataset_type,
>>> data_root=data_root,
>>> ann_file=unsup_ann,
>>> data_prefix=dict(img='train2017/'),
>>> filter_cfg=dict(filter_empty_gt=True, min_size=32),
>>> pipeline=unsup_pipeline),
>>> ])
>>> train_dataloader = dict(
>>> batch_size=5,
>>> num_workers=5,
>>> persistent_workers=True,
>>> sampler=dict(type='MultiSourceSampler',
>>> batch_size=5, source_ratio=[1, 4]),
>>> batch_sampler=None,
>>> dataset=dataset)
"""
def __init__(self,
dataset: Sized,
batch_size: int,
source_ratio: List[Union[int, float]],
shuffle: bool = True,
seed: Optional[int] = None) -> None:
assert hasattr(dataset, 'cumulative_sizes'),\
f'The dataset must be ConcatDataset, but get {dataset}'
assert isinstance(batch_size, int) and batch_size > 0, \
'batch_size must be a positive integer value, ' \
f'but got batch_size={batch_size}'
assert isinstance(source_ratio, list), \
f'source_ratio must be a list, but got source_ratio={source_ratio}'
assert len(source_ratio) == len(dataset.cumulative_sizes), \
'The length of source_ratio must be equal to ' \
f'the number of datasets, but got source_ratio={source_ratio}'
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
self.cumulative_sizes = [0] + dataset.cumulative_sizes
self.batch_size = batch_size
self.source_ratio = source_ratio
self.num_per_source = [
int(batch_size * sr / sum(source_ratio)) for sr in source_ratio
]
self.num_per_source[0] = batch_size - sum(self.num_per_source[1:])
assert sum(self.num_per_source) == batch_size, \
'The sum of num_per_source must be equal to ' \
f'batch_size, but get {self.num_per_source}'
self.seed = sync_random_seed() if seed is None else seed
self.shuffle = shuffle
self.source2inds = {
source: self._indices_of_rank(len(ds))
for source, ds in enumerate(dataset.datasets)
}
def _infinite_indices(self, sample_size: int) -> Iterator[int]:
"""Infinitely yield a sequence of indices."""
g = torch.Generator()
g.manual_seed(self.seed)
while True:
if self.shuffle:
yield from torch.randperm(sample_size, generator=g).tolist()
else:
yield from torch.arange(sample_size).tolist()
def _indices_of_rank(self, sample_size: int) -> Iterator[int]:
"""Slice the infinite indices by rank."""
yield from itertools.islice(
self._infinite_indices(sample_size), self.rank, None,
self.world_size)
def __iter__(self) -> Iterator[int]:
batch_buffer = []
while True:
for source, num in enumerate(self.num_per_source):
batch_buffer_per_source = []
for idx in self.source2inds[source]:
idx += self.cumulative_sizes[source]
batch_buffer_per_source.append(idx)
if len(batch_buffer_per_source) == num:
batch_buffer += batch_buffer_per_source
break
yield from batch_buffer
batch_buffer = []
def __len__(self) -> int:
return len(self.dataset)
def set_epoch(self, epoch: int) -> None:
"""Not supported in `epoch-based runner."""
pass
@DATA_SAMPLERS.register_module()
class GroupMultiSourceSampler(MultiSourceSampler):
r"""Group Multi-Source Infinite Sampler.
According to the sampling ratio, sample data from different
datasets but the same group to form batches.
Args:
dataset (Sized): The dataset.
batch_size (int): Size of mini-batch.
source_ratio (list[int | float]): The sampling ratio of different
source datasets in a mini-batch.
shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
seed (int, optional): Random seed. If None, set a random seed.
Defaults to None.
"""
def __init__(self,
dataset: BaseDataset,
batch_size: int,
source_ratio: List[Union[int, float]],
shuffle: bool = True,
seed: Optional[int] = None) -> None:
super().__init__(
dataset=dataset,
batch_size=batch_size,
source_ratio=source_ratio,
shuffle=shuffle,
seed=seed)
self._get_source_group_info()
self.group_source2inds = [{
source:
self._indices_of_rank(self.group2size_per_source[source][group])
for source in range(len(dataset.datasets))
} for group in range(len(self.group_ratio))]
def _get_source_group_info(self) -> None:
self.group2size_per_source = [{0: 0, 1: 0}, {0: 0, 1: 0}]
self.group2inds_per_source = [{0: [], 1: []}, {0: [], 1: []}]
for source, dataset in enumerate(self.dataset.datasets):
for idx in range(len(dataset)):
data_info = dataset.get_data_info(idx)
width, height = data_info['width'], data_info['height']
group = 0 if width < height else 1
self.group2size_per_source[source][group] += 1
self.group2inds_per_source[source][group].append(idx)
self.group_sizes = np.zeros(2, dtype=np.int64)
for group2size in self.group2size_per_source:
for group, size in group2size.items():
self.group_sizes[group] += size
self.group_ratio = self.group_sizes / sum(self.group_sizes)
def __iter__(self) -> Iterator[int]:
batch_buffer = []
while True:
group = np.random.choice(
list(range(len(self.group_ratio))), p=self.group_ratio)
for source, num in enumerate(self.num_per_source):
batch_buffer_per_source = []
for idx in self.group_source2inds[group][source]:
idx = self.group2inds_per_source[source][group][
idx] + self.cumulative_sizes[source]
batch_buffer_per_source.append(idx)
if len(batch_buffer_per_source) == num:
batch_buffer += batch_buffer_per_source
break
yield from batch_buffer
batch_buffer = []
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment