Commit a8562a56 authored by luopl's avatar luopl
Browse files

Initial commit

parents
Pipeline #1564 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.registry import DATASETS
from .coco import CocoDataset
@DATASETS.register_module()
class DeepFashionDataset(CocoDataset):
"""Dataset for DeepFashion."""
METAINFO = {
'classes': ('top', 'skirt', 'leggings', 'dress', 'outer', 'pants',
'bag', 'neckwear', 'headwear', 'eyeglass', 'belt',
'footwear', 'hair', 'skin', 'face'),
# palette is a list of color tuples, which is used for visualization.
'palette': [(0, 192, 64), (0, 64, 96), (128, 192, 192), (0, 64, 64),
(0, 192, 224), (0, 192, 192), (128, 192, 64), (0, 192, 96),
(128, 32, 192), (0, 0, 224), (0, 0, 64), (0, 160, 192),
(128, 0, 96), (128, 0, 192), (0, 32, 192)]
}
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List, Optional
import numpy as np
from mmdet.registry import DATASETS
from .base_det_dataset import BaseDetDataset
try:
from d_cube import D3
except ImportError:
D3 = None
from .api_wrappers import COCO
@DATASETS.register_module()
class DODDataset(BaseDetDataset):
def __init__(self,
*args,
data_root: Optional[str] = '',
data_prefix: dict = dict(img_path=''),
**kwargs) -> None:
if D3 is None:
raise ImportError(
'Please install d3 by `pip install ddd-dataset`.')
pkl_anno_path = osp.join(data_root, data_prefix['anno'])
self.img_root = osp.join(data_root, data_prefix['img'])
self.d3 = D3(self.img_root, pkl_anno_path)
sent_infos = self.d3.load_sents()
classes = tuple([sent_info['raw_sent'] for sent_info in sent_infos])
super().__init__(
*args,
data_root=data_root,
data_prefix=data_prefix,
metainfo={'classes': classes},
**kwargs)
def load_data_list(self) -> List[dict]:
coco = COCO(self.ann_file)
data_list = []
img_ids = self.d3.get_img_ids()
for img_id in img_ids:
data_info = {}
img_info = self.d3.load_imgs(img_id)[0]
file_name = img_info['file_name']
img_path = osp.join(self.img_root, file_name)
data_info['img_path'] = img_path
data_info['img_id'] = img_id
data_info['height'] = img_info['height']
data_info['width'] = img_info['width']
group_ids = self.d3.get_group_ids(img_ids=[img_id])
sent_ids = self.d3.get_sent_ids(group_ids=group_ids)
sent_list = self.d3.load_sents(sent_ids=sent_ids)
text_list = [sent['raw_sent'] for sent in sent_list]
ann_ids = coco.get_ann_ids(img_ids=[img_id])
anno = coco.load_anns(ann_ids)
data_info['text'] = text_list
data_info['sent_ids'] = np.array([s for s in sent_ids])
data_info['custom_entities'] = True
instances = []
for i, ann in enumerate(anno):
instance = {}
x1, y1, w, h = ann['bbox']
bbox = [x1, y1, x1 + w, y1 + h]
instance['ignore_flag'] = 0
instance['bbox'] = bbox
instance['bbox_label'] = ann['category_id'] - 1
instances.append(instance)
data_info['instances'] = instances
data_list.append(data_info)
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import List
from mmdet.registry import DATASETS
from .base_det_dataset import BaseDetDataset
try:
from dsdl.dataset import DSDLDataset
except ImportError:
DSDLDataset = None
@DATASETS.register_module()
class DSDLDetDataset(BaseDetDataset):
"""Dataset for dsdl detection.
Args:
with_bbox(bool): Load bbox or not, defaults to be True.
with_polygon(bool): Load polygon or not, defaults to be False.
with_mask(bool): Load seg map mask or not, defaults to be False.
with_imagelevel_label(bool): Load image level label or not,
defaults to be False.
with_hierarchy(bool): Load hierarchy information or not,
defaults to be False.
specific_key_path(dict): Path of specific key which can not
be loaded by it's field name.
pre_transform(dict): pre-transform functions before loading.
"""
METAINFO = {}
def __init__(self,
with_bbox: bool = True,
with_polygon: bool = False,
with_mask: bool = False,
with_imagelevel_label: bool = False,
with_hierarchy: bool = False,
specific_key_path: dict = {},
pre_transform: dict = {},
**kwargs) -> None:
if DSDLDataset is None:
raise RuntimeError(
'Package dsdl is not installed. Please run "pip install dsdl".'
)
self.with_hierarchy = with_hierarchy
self.specific_key_path = specific_key_path
loc_config = dict(type='LocalFileReader', working_dir='')
if kwargs.get('data_root'):
kwargs['ann_file'] = os.path.join(kwargs['data_root'],
kwargs['ann_file'])
self.required_fields = ['Image', 'ImageShape', 'Label', 'ignore_flag']
if with_bbox:
self.required_fields.append('Bbox')
if with_polygon:
self.required_fields.append('Polygon')
if with_mask:
self.required_fields.append('LabelMap')
if with_imagelevel_label:
self.required_fields.append('image_level_labels')
assert 'image_level_labels' in specific_key_path.keys(
), '`image_level_labels` not specified in `specific_key_path` !'
self.extra_keys = [
key for key in self.specific_key_path.keys()
if key not in self.required_fields
]
self.dsdldataset = DSDLDataset(
dsdl_yaml=kwargs['ann_file'],
location_config=loc_config,
required_fields=self.required_fields,
specific_key_path=specific_key_path,
transform=pre_transform,
)
BaseDetDataset.__init__(self, **kwargs)
def load_data_list(self) -> List[dict]:
"""Load data info from an dsdl yaml file named as ``self.ann_file``
Returns:
List[dict]: A list of data info.
"""
if self.with_hierarchy:
# get classes_names and relation_matrix
classes_names, relation_matrix = \
self.dsdldataset.class_dom.get_hierarchy_info()
self._metainfo['classes'] = tuple(classes_names)
self._metainfo['RELATION_MATRIX'] = relation_matrix
else:
self._metainfo['classes'] = tuple(self.dsdldataset.class_names)
data_list = []
for i, data in enumerate(self.dsdldataset):
# basic image info, including image id, path and size.
datainfo = dict(
img_id=i,
img_path=os.path.join(self.data_prefix['img_path'],
data['Image'][0].location),
width=data['ImageShape'][0].width,
height=data['ImageShape'][0].height,
)
# get image label info
if 'image_level_labels' in data.keys():
if self.with_hierarchy:
# get leaf node name when using hierarchy classes
datainfo['image_level_labels'] = [
self._metainfo['classes'].index(i.leaf_node_name)
for i in data['image_level_labels']
]
else:
datainfo['image_level_labels'] = [
self._metainfo['classes'].index(i.name)
for i in data['image_level_labels']
]
# get semantic segmentation info
if 'LabelMap' in data.keys():
datainfo['seg_map_path'] = data['LabelMap']
# load instance info
instances = []
if 'Bbox' in data.keys():
for idx in range(len(data['Bbox'])):
bbox = data['Bbox'][idx]
if self.with_hierarchy:
# get leaf node name when using hierarchy classes
label = data['Label'][idx].leaf_node_name
label_index = self._metainfo['classes'].index(label)
else:
label = data['Label'][idx].name
label_index = self._metainfo['classes'].index(label)
instance = {}
instance['bbox'] = bbox.xyxy
instance['bbox_label'] = label_index
if 'ignore_flag' in data.keys():
# get ignore flag
instance['ignore_flag'] = data['ignore_flag'][idx]
else:
instance['ignore_flag'] = 0
if 'Polygon' in data.keys():
# get polygon info
polygon = data['Polygon'][idx]
instance['mask'] = polygon.openmmlabformat
for key in self.extra_keys:
# load extra instance info
instance[key] = data[key][idx]
instances.append(instance)
datainfo['instances'] = instances
# append a standard sample in data list
if len(datainfo['instances']) > 0:
data_list.append(datainfo)
return data_list
def filter_data(self) -> List[dict]:
"""Filter annotations according to filter_cfg.
Returns:
List[dict]: Filtered results.
"""
if self.test_mode:
return self.data_list
filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) \
if self.filter_cfg is not None else False
min_size = self.filter_cfg.get('min_size', 0) \
if self.filter_cfg is not None else 0
valid_data_list = []
for i, data_info in enumerate(self.data_list):
width = data_info['width']
height = data_info['height']
if filter_empty_gt and len(data_info['instances']) == 0:
continue
if min(width, height) >= min_size:
valid_data_list.append(data_info)
return valid_data_list
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List
from pycocotools.coco import COCO
from mmdet.registry import DATASETS
from .base_det_dataset import BaseDetDataset
def convert_phrase_ids(phrase_ids: list) -> list:
unique_elements = sorted(set(phrase_ids))
element_to_new_label = {
element: label
for label, element in enumerate(unique_elements)
}
phrase_ids = [element_to_new_label[element] for element in phrase_ids]
return phrase_ids
@DATASETS.register_module()
class Flickr30kDataset(BaseDetDataset):
"""Flickr30K Dataset."""
def load_data_list(self) -> List[dict]:
self.coco = COCO(self.ann_file)
self.ids = sorted(list(self.coco.imgs.keys()))
data_list = []
for img_id in self.ids:
if isinstance(img_id, str):
ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None)
else:
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
coco_img = self.coco.loadImgs(img_id)[0]
caption = coco_img['caption']
file_name = coco_img['file_name']
img_path = osp.join(self.data_prefix['img'], file_name)
width = coco_img['width']
height = coco_img['height']
tokens_positive = coco_img['tokens_positive_eval']
phrases = [caption[i[0][0]:i[0][1]] for i in tokens_positive]
phrase_ids = []
instances = []
annos = self.coco.loadAnns(ann_ids)
for anno in annos:
instance = {
'bbox': [
anno['bbox'][0], anno['bbox'][1],
anno['bbox'][0] + anno['bbox'][2],
anno['bbox'][1] + anno['bbox'][3]
],
'bbox_label':
anno['category_id'],
'ignore_flag':
anno['iscrowd']
}
phrase_ids.append(anno['phrase_ids'])
instances.append(instance)
phrase_ids = convert_phrase_ids(phrase_ids)
data_list.append(
dict(
img_path=img_path,
img_id=img_id,
height=height,
width=width,
instances=instances,
text=caption,
phrase_ids=phrase_ids,
tokens_positive=tokens_positive,
phrases=phrases,
))
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.registry import DATASETS
from .coco import CocoDataset
@DATASETS.register_module()
class iSAIDDataset(CocoDataset):
"""Dataset for iSAID instance segmentation.
iSAID: A Large-scale Dataset for Instance Segmentation
in Aerial Images.
For more detail, please refer to "projects/iSAID/README.md"
"""
METAINFO = dict(
classes=('background', 'ship', 'store_tank', 'baseball_diamond',
'tennis_court', 'basketball_court', 'Ground_Track_Field',
'Bridge', 'Large_Vehicle', 'Small_Vehicle', 'Helicopter',
'Swimming_pool', 'Roundabout', 'Soccer_ball_field', 'plane',
'Harbor'),
palette=[(0, 0, 0), (0, 0, 63), (0, 63, 63), (0, 63, 0), (0, 63, 127),
(0, 63, 191), (0, 63, 255), (0, 127, 63), (0, 127, 127),
(0, 0, 127), (0, 0, 191), (0, 0, 255), (0, 191, 127),
(0, 127, 191), (0, 127, 255), (0, 100, 155)])
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings
from typing import List
from mmengine.fileio import get_local_path
from mmdet.registry import DATASETS
from .coco import CocoDataset
@DATASETS.register_module()
class LVISV05Dataset(CocoDataset):
"""LVIS v0.5 dataset for detection."""
METAINFO = {
'classes':
('acorn', 'aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock',
'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet',
'antenna', 'apple', 'apple_juice', 'applesauce', 'apricot', 'apron',
'aquarium', 'armband', 'armchair', 'armoire', 'armor', 'artichoke',
'trash_can', 'ashtray', 'asparagus', 'atomizer', 'avocado', 'award',
'awning', 'ax', 'baby_buggy', 'basketball_backboard', 'backpack',
'handbag', 'suitcase', 'bagel', 'bagpipe', 'baguet', 'bait', 'ball',
'ballet_skirt', 'balloon', 'bamboo', 'banana', 'Band_Aid', 'bandage',
'bandanna', 'banjo', 'banner', 'barbell', 'barge', 'barrel',
'barrette', 'barrow', 'baseball_base', 'baseball', 'baseball_bat',
'baseball_cap', 'baseball_glove', 'basket', 'basketball_hoop',
'basketball', 'bass_horn', 'bat_(animal)', 'bath_mat', 'bath_towel',
'bathrobe', 'bathtub', 'batter_(food)', 'battery', 'beachball',
'bead', 'beaker', 'bean_curd', 'beanbag', 'beanie', 'bear', 'bed',
'bedspread', 'cow', 'beef_(food)', 'beeper', 'beer_bottle',
'beer_can', 'beetle', 'bell', 'bell_pepper', 'belt', 'belt_buckle',
'bench', 'beret', 'bib', 'Bible', 'bicycle', 'visor', 'binder',
'binoculars', 'bird', 'birdfeeder', 'birdbath', 'birdcage',
'birdhouse', 'birthday_cake', 'birthday_card', 'biscuit_(bread)',
'pirate_flag', 'black_sheep', 'blackboard', 'blanket', 'blazer',
'blender', 'blimp', 'blinker', 'blueberry', 'boar', 'gameboard',
'boat', 'bobbin', 'bobby_pin', 'boiled_egg', 'bolo_tie', 'deadbolt',
'bolt', 'bonnet', 'book', 'book_bag', 'bookcase', 'booklet',
'bookmark', 'boom_microphone', 'boot', 'bottle', 'bottle_opener',
'bouquet', 'bow_(weapon)', 'bow_(decorative_ribbons)', 'bow-tie',
'bowl', 'pipe_bowl', 'bowler_hat', 'bowling_ball', 'bowling_pin',
'boxing_glove', 'suspenders', 'bracelet', 'brass_plaque', 'brassiere',
'bread-bin', 'breechcloth', 'bridal_gown', 'briefcase',
'bristle_brush', 'broccoli', 'broach', 'broom', 'brownie',
'brussels_sprouts', 'bubble_gum', 'bucket', 'horse_buggy', 'bull',
'bulldog', 'bulldozer', 'bullet_train', 'bulletin_board',
'bulletproof_vest', 'bullhorn', 'corned_beef', 'bun', 'bunk_bed',
'buoy', 'burrito', 'bus_(vehicle)', 'business_card', 'butcher_knife',
'butter', 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car',
'cabinet', 'locker', 'cake', 'calculator', 'calendar', 'calf',
'camcorder', 'camel', 'camera', 'camera_lens', 'camper_(vehicle)',
'can', 'can_opener', 'candelabrum', 'candle', 'candle_holder',
'candy_bar', 'candy_cane', 'walking_cane', 'canister', 'cannon',
'canoe', 'cantaloup', 'canteen', 'cap_(headwear)', 'bottle_cap',
'cape', 'cappuccino', 'car_(automobile)', 'railcar_(part_of_a_train)',
'elevator_car', 'car_battery', 'identity_card', 'card', 'cardigan',
'cargo_ship', 'carnation', 'horse_carriage', 'carrot', 'tote_bag',
'cart', 'carton', 'cash_register', 'casserole', 'cassette', 'cast',
'cat', 'cauliflower', 'caviar', 'cayenne_(spice)', 'CD_player',
'celery', 'cellular_telephone', 'chain_mail', 'chair',
'chaise_longue', 'champagne', 'chandelier', 'chap', 'checkbook',
'checkerboard', 'cherry', 'chessboard',
'chest_of_drawers_(furniture)', 'chicken_(animal)', 'chicken_wire',
'chickpea', 'Chihuahua', 'chili_(vegetable)', 'chime', 'chinaware',
'crisp_(potato_chip)', 'poker_chip', 'chocolate_bar',
'chocolate_cake', 'chocolate_milk', 'chocolate_mousse', 'choker',
'chopping_board', 'chopstick', 'Christmas_tree', 'slide', 'cider',
'cigar_box', 'cigarette', 'cigarette_case', 'cistern', 'clarinet',
'clasp', 'cleansing_agent', 'clementine', 'clip', 'clipboard',
'clock', 'clock_tower', 'clothes_hamper', 'clothespin', 'clutch_bag',
'coaster', 'coat', 'coat_hanger', 'coatrack', 'cock', 'coconut',
'coffee_filter', 'coffee_maker', 'coffee_table', 'coffeepot', 'coil',
'coin', 'colander', 'coleslaw', 'coloring_material',
'combination_lock', 'pacifier', 'comic_book', 'computer_keyboard',
'concrete_mixer', 'cone', 'control', 'convertible_(automobile)',
'sofa_bed', 'cookie', 'cookie_jar', 'cooking_utensil',
'cooler_(for_food)', 'cork_(bottle_plug)', 'corkboard', 'corkscrew',
'edible_corn', 'cornbread', 'cornet', 'cornice', 'cornmeal', 'corset',
'romaine_lettuce', 'costume', 'cougar', 'coverall', 'cowbell',
'cowboy_hat', 'crab_(animal)', 'cracker', 'crape', 'crate', 'crayon',
'cream_pitcher', 'credit_card', 'crescent_roll', 'crib', 'crock_pot',
'crossbar', 'crouton', 'crow', 'crown', 'crucifix', 'cruise_ship',
'police_cruiser', 'crumb', 'crutch', 'cub_(animal)', 'cube',
'cucumber', 'cufflink', 'cup', 'trophy_cup', 'cupcake', 'hair_curler',
'curling_iron', 'curtain', 'cushion', 'custard', 'cutting_tool',
'cylinder', 'cymbal', 'dachshund', 'dagger', 'dartboard',
'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk',
'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table',
'tux', 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher',
'dishwasher_detergent', 'diskette', 'dispenser', 'Dixie_cup', 'dog',
'dog_collar', 'doll', 'dollar', 'dolphin', 'domestic_ass', 'eye_mask',
'doorbell', 'doorknob', 'doormat', 'doughnut', 'dove', 'dragonfly',
'drawer', 'underdrawers', 'dress', 'dress_hat', 'dress_suit',
'dresser', 'drill', 'drinking_fountain', 'drone', 'dropper',
'drum_(musical_instrument)', 'drumstick', 'duck', 'duckling',
'duct_tape', 'duffel_bag', 'dumbbell', 'dumpster', 'dustpan',
'Dutch_oven', 'eagle', 'earphone', 'earplug', 'earring', 'easel',
'eclair', 'eel', 'egg', 'egg_roll', 'egg_yolk', 'eggbeater',
'eggplant', 'electric_chair', 'refrigerator', 'elephant', 'elk',
'envelope', 'eraser', 'escargot', 'eyepatch', 'falcon', 'fan',
'faucet', 'fedora', 'ferret', 'Ferris_wheel', 'ferry', 'fig_(fruit)',
'fighter_jet', 'figurine', 'file_cabinet', 'file_(tool)',
'fire_alarm', 'fire_engine', 'fire_extinguisher', 'fire_hose',
'fireplace', 'fireplug', 'fish', 'fish_(food)', 'fishbowl',
'fishing_boat', 'fishing_rod', 'flag', 'flagpole', 'flamingo',
'flannel', 'flash', 'flashlight', 'fleece', 'flip-flop_(sandal)',
'flipper_(footwear)', 'flower_arrangement', 'flute_glass', 'foal',
'folding_chair', 'food_processor', 'football_(American)',
'football_helmet', 'footstool', 'fork', 'forklift', 'freight_car',
'French_toast', 'freshener', 'frisbee', 'frog', 'fruit_juice',
'fruit_salad', 'frying_pan', 'fudge', 'funnel', 'futon', 'gag',
'garbage', 'garbage_truck', 'garden_hose', 'gargle', 'gargoyle',
'garlic', 'gasmask', 'gazelle', 'gelatin', 'gemstone', 'giant_panda',
'gift_wrap', 'ginger', 'giraffe', 'cincture',
'glass_(drink_container)', 'globe', 'glove', 'goat', 'goggles',
'goldfish', 'golf_club', 'golfcart', 'gondola_(boat)', 'goose',
'gorilla', 'gourd', 'surgical_gown', 'grape', 'grasshopper', 'grater',
'gravestone', 'gravy_boat', 'green_bean', 'green_onion', 'griddle',
'grillroom', 'grinder_(tool)', 'grits', 'grizzly', 'grocery_bag',
'guacamole', 'guitar', 'gull', 'gun', 'hair_spray', 'hairbrush',
'hairnet', 'hairpin', 'ham', 'hamburger', 'hammer', 'hammock',
'hamper', 'hamster', 'hair_dryer', 'hand_glass', 'hand_towel',
'handcart', 'handcuff', 'handkerchief', 'handle', 'handsaw',
'hardback_book', 'harmonium', 'hat', 'hatbox', 'hatch', 'veil',
'headband', 'headboard', 'headlight', 'headscarf', 'headset',
'headstall_(for_horses)', 'hearing_aid', 'heart', 'heater',
'helicopter', 'helmet', 'heron', 'highchair', 'hinge', 'hippopotamus',
'hockey_stick', 'hog', 'home_plate_(baseball)', 'honey', 'fume_hood',
'hook', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce',
'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear',
'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate',
'ice_tea', 'igniter', 'incense', 'inhaler', 'iPod',
'iron_(for_clothing)', 'ironing_board', 'jacket', 'jam', 'jean',
'jeep', 'jelly_bean', 'jersey', 'jet_plane', 'jewelry', 'joystick',
'jumpsuit', 'kayak', 'keg', 'kennel', 'kettle', 'key', 'keycard',
'kilt', 'kimono', 'kitchen_sink', 'kitchen_table', 'kite', 'kitten',
'kiwi_fruit', 'knee_pad', 'knife', 'knight_(chess_piece)',
'knitting_needle', 'knob', 'knocker_(on_a_door)', 'koala', 'lab_coat',
'ladder', 'ladle', 'ladybug', 'lamb_(animal)', 'lamb-chop', 'lamp',
'lamppost', 'lampshade', 'lantern', 'lanyard', 'laptop_computer',
'lasagna', 'latch', 'lawn_mower', 'leather', 'legging_(clothing)',
'Lego', 'lemon', 'lemonade', 'lettuce', 'license_plate', 'life_buoy',
'life_jacket', 'lightbulb', 'lightning_rod', 'lime', 'limousine',
'linen_paper', 'lion', 'lip_balm', 'lipstick', 'liquor', 'lizard',
'Loafer_(type_of_shoe)', 'log', 'lollipop', 'lotion',
'speaker_(stereo_equipment)', 'loveseat', 'machine_gun', 'magazine',
'magnet', 'mail_slot', 'mailbox_(at_home)', 'mallet', 'mammoth',
'mandarin_orange', 'manger', 'manhole', 'map', 'marker', 'martini',
'mascot', 'mashed_potato', 'masher', 'mask', 'mast',
'mat_(gym_equipment)', 'matchbox', 'mattress', 'measuring_cup',
'measuring_stick', 'meatball', 'medicine', 'melon', 'microphone',
'microscope', 'microwave_oven', 'milestone', 'milk', 'minivan',
'mint_candy', 'mirror', 'mitten', 'mixer_(kitchen_tool)', 'money',
'monitor_(computer_equipment) computer_monitor', 'monkey', 'motor',
'motor_scooter', 'motor_vehicle', 'motorboat', 'motorcycle',
'mound_(baseball)', 'mouse_(animal_rodent)',
'mouse_(computer_equipment)', 'mousepad', 'muffin', 'mug', 'mushroom',
'music_stool', 'musical_instrument', 'nailfile', 'nameplate',
'napkin', 'neckerchief', 'necklace', 'necktie', 'needle', 'nest',
'newsstand', 'nightshirt', 'nosebag_(for_animals)',
'noseband_(for_animals)', 'notebook', 'notepad', 'nut', 'nutcracker',
'oar', 'octopus_(food)', 'octopus_(animal)', 'oil_lamp', 'olive_oil',
'omelet', 'onion', 'orange_(fruit)', 'orange_juice', 'oregano',
'ostrich', 'ottoman', 'overalls_(clothing)', 'owl', 'packet',
'inkpad', 'pad', 'paddle', 'padlock', 'paintbox', 'paintbrush',
'painting', 'pajamas', 'palette', 'pan_(for_cooking)',
'pan_(metal_container)', 'pancake', 'pantyhose', 'papaya',
'paperclip', 'paper_plate', 'paper_towel', 'paperback_book',
'paperweight', 'parachute', 'parakeet', 'parasail_(sports)',
'parchment', 'parka', 'parking_meter', 'parrot',
'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport',
'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter',
'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'pegboard',
'pelican', 'pen', 'pencil', 'pencil_box', 'pencil_sharpener',
'pendulum', 'penguin', 'pennant', 'penny_(coin)', 'pepper',
'pepper_mill', 'perfume', 'persimmon', 'baby', 'pet', 'petfood',
'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano',
'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow',
'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball',
'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)',
'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat',
'plate', 'platter', 'playing_card', 'playpen', 'pliers',
'plow_(farm_equipment)', 'pocket_watch', 'pocketknife',
'poker_(fire_stirring_tool)', 'pole', 'police_van', 'polo_shirt',
'poncho', 'pony', 'pool_table', 'pop_(soda)', 'portrait',
'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot',
'potato', 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn',
'printer', 'projectile_(weapon)', 'projector', 'propeller', 'prune',
'pudding', 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin', 'puncher',
'puppet', 'puppy', 'quesadilla', 'quiche', 'quilt', 'rabbit',
'race_car', 'racket', 'radar', 'radiator', 'radio_receiver', 'radish',
'raft', 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry', 'rat',
'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt',
'recliner', 'record_player', 'red_cabbage', 'reflector',
'remote_control', 'rhinoceros', 'rib_(food)', 'rifle', 'ring',
'river_boat', 'road_map', 'robe', 'rocking_chair', 'roller_skate',
'Rollerblade', 'rolling_pin', 'root_beer',
'router_(computer_equipment)', 'rubber_band', 'runner_(carpet)',
'plastic_bag', 'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag',
'safety_pin', 'sail', 'salad', 'salad_plate', 'salami',
'salmon_(fish)', 'salmon_(food)', 'salsa', 'saltshaker',
'sandal_(type_of_shoe)', 'sandwich', 'satchel', 'saucepan', 'saucer',
'sausage', 'sawhorse', 'saxophone', 'scale_(measuring_instrument)',
'scarecrow', 'scarf', 'school_bus', 'scissors', 'scoreboard',
'scrambled_eggs', 'scraper', 'scratcher', 'screwdriver',
'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane',
'seashell', 'seedling', 'serving_dish', 'sewing_machine', 'shaker',
'shampoo', 'shark', 'sharpener', 'Sharpie', 'shaver_(electric)',
'shaving_cream', 'shawl', 'shears', 'sheep', 'shepherd_dog',
'sherbert', 'shield', 'shirt', 'shoe', 'shopping_bag',
'shopping_cart', 'short_pants', 'shot_glass', 'shoulder_bag',
'shovel', 'shower_head', 'shower_curtain', 'shredder_(for_paper)',
'sieve', 'signboard', 'silo', 'sink', 'skateboard', 'skewer', 'ski',
'ski_boot', 'ski_parka', 'ski_pole', 'skirt', 'sled', 'sleeping_bag',
'sling_(bandage)', 'slipper_(footwear)', 'smoothie', 'snake',
'snowboard', 'snowman', 'snowmobile', 'soap', 'soccer_ball', 'sock',
'soda_fountain', 'carbonated_water', 'sofa', 'softball',
'solar_array', 'sombrero', 'soup', 'soup_bowl', 'soupspoon',
'sour_cream', 'soya_milk', 'space_shuttle', 'sparkler_(fireworks)',
'spatula', 'spear', 'spectacles', 'spice_rack', 'spider', 'sponge',
'spoon', 'sportswear', 'spotlight', 'squirrel',
'stapler_(stapling_machine)', 'starfish', 'statue_(sculpture)',
'steak_(food)', 'steak_knife', 'steamer_(kitchen_appliance)',
'steering_wheel', 'stencil', 'stepladder', 'step_stool',
'stereo_(sound_system)', 'stew', 'stirrer', 'stirrup',
'stockings_(leg_wear)', 'stool', 'stop_sign', 'brake_light', 'stove',
'strainer', 'strap', 'straw_(for_drinking)', 'strawberry',
'street_sign', 'streetlight', 'string_cheese', 'stylus', 'subwoofer',
'sugar_bowl', 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower',
'sunglasses', 'sunhat', 'sunscreen', 'surfboard', 'sushi', 'mop',
'sweat_pants', 'sweatband', 'sweater', 'sweatshirt', 'sweet_potato',
'swimsuit', 'sword', 'syringe', 'Tabasco_sauce', 'table-tennis_table',
'table', 'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag',
'taillight', 'tambourine', 'army_tank', 'tank_(storage_vessel)',
'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure',
'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup',
'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth',
'telephone_pole', 'telephoto_lens', 'television_camera',
'television_set', 'tennis_ball', 'tennis_racket', 'tequila',
'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread',
'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer',
'tinfoil', 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster',
'toaster_oven', 'toilet', 'toilet_tissue', 'tomato', 'tongs',
'toolbox', 'toothbrush', 'toothpaste', 'toothpick', 'cover',
'tortilla', 'tow_truck', 'towel', 'towel_rack', 'toy',
'tractor_(farm_equipment)', 'traffic_light', 'dirt_bike',
'trailer_truck', 'train_(railroad_vehicle)', 'trampoline', 'tray',
'tree_house', 'trench_coat', 'triangle_(musical_instrument)',
'tricycle', 'tripod', 'trousers', 'truck', 'truffle_(chocolate)',
'trunk', 'vat', 'turban', 'turkey_(bird)', 'turkey_(food)', 'turnip',
'turtle', 'turtleneck_(clothing)', 'typewriter', 'umbrella',
'underwear', 'unicycle', 'urinal', 'urn', 'vacuum_cleaner', 'valve',
'vase', 'vending_machine', 'vent', 'videotape', 'vinegar', 'violin',
'vodka', 'volleyball', 'vulture', 'waffle', 'waffle_iron', 'wagon',
'wagon_wheel', 'walking_stick', 'wall_clock', 'wall_socket', 'wallet',
'walrus', 'wardrobe', 'wasabi', 'automatic_washer', 'watch',
'water_bottle', 'water_cooler', 'water_faucet', 'water_filter',
'water_heater', 'water_jug', 'water_gun', 'water_scooter',
'water_ski', 'water_tower', 'watering_can', 'watermelon',
'weathervane', 'webcam', 'wedding_cake', 'wedding_ring', 'wet_suit',
'wheel', 'wheelchair', 'whipped_cream', 'whiskey', 'whistle', 'wick',
'wig', 'wind_chime', 'windmill', 'window_box_(for_plants)',
'windshield_wiper', 'windsock', 'wine_bottle', 'wine_bucket',
'wineglass', 'wing_chair', 'blinder_(for_horses)', 'wok', 'wolf',
'wooden_spoon', 'wreath', 'wrench', 'wristband', 'wristlet', 'yacht',
'yak', 'yogurt', 'yoke_(animal_equipment)', 'zebra', 'zucchini'),
'palette':
None
}
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
""" # noqa: E501
try:
import lvis
if getattr(lvis, '__version__', '0') >= '10.5.3':
warnings.warn(
'mmlvis is deprecated, please install official lvis-api by "pip install git+https://github.com/lvis-dataset/lvis-api.git"', # noqa: E501
UserWarning)
from lvis import LVIS
except ImportError:
raise ImportError(
'Package lvis is not installed. Please run "pip install git+https://github.com/lvis-dataset/lvis-api.git".' # noqa: E501
)
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
self.lvis = LVIS(local_path)
self.cat_ids = self.lvis.get_cat_ids()
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.cat_img_map = copy.deepcopy(self.lvis.cat_img_map)
img_ids = self.lvis.get_img_ids()
data_list = []
total_ann_ids = []
for img_id in img_ids:
raw_img_info = self.lvis.load_imgs([img_id])[0]
raw_img_info['img_id'] = img_id
if raw_img_info['file_name'].startswith('COCO'):
# Convert form the COCO 2014 file naming convention of
# COCO_[train/val/test]2014_000000000000.jpg to the 2017
# naming convention of 000000000000.jpg
# (LVIS v1 will fix this naming issue)
raw_img_info['file_name'] = raw_img_info['file_name'][-16:]
ann_ids = self.lvis.get_ann_ids(img_ids=[img_id])
raw_ann_info = self.lvis.load_anns(ann_ids)
total_ann_ids.extend(ann_ids)
parsed_data_info = self.parse_data_info({
'raw_ann_info':
raw_ann_info,
'raw_img_info':
raw_img_info
})
data_list.append(parsed_data_info)
if self.ANN_ID_UNIQUE:
assert len(set(total_ann_ids)) == len(
total_ann_ids
), f"Annotation ids in '{self.ann_file}' are not unique!"
del self.lvis
return data_list
LVISDataset = LVISV05Dataset
DATASETS.register_module(name='LVISDataset', module=LVISDataset)
@DATASETS.register_module()
class LVISV1Dataset(LVISDataset):
"""LVIS v1 dataset for detection."""
METAINFO = {
'classes':
('aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock',
'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet',
'antenna', 'apple', 'applesauce', 'apricot', 'apron', 'aquarium',
'arctic_(type_of_shoe)', 'armband', 'armchair', 'armoire', 'armor',
'artichoke', 'trash_can', 'ashtray', 'asparagus', 'atomizer',
'avocado', 'award', 'awning', 'ax', 'baboon', 'baby_buggy',
'basketball_backboard', 'backpack', 'handbag', 'suitcase', 'bagel',
'bagpipe', 'baguet', 'bait', 'ball', 'ballet_skirt', 'balloon',
'bamboo', 'banana', 'Band_Aid', 'bandage', 'bandanna', 'banjo',
'banner', 'barbell', 'barge', 'barrel', 'barrette', 'barrow',
'baseball_base', 'baseball', 'baseball_bat', 'baseball_cap',
'baseball_glove', 'basket', 'basketball', 'bass_horn', 'bat_(animal)',
'bath_mat', 'bath_towel', 'bathrobe', 'bathtub', 'batter_(food)',
'battery', 'beachball', 'bead', 'bean_curd', 'beanbag', 'beanie',
'bear', 'bed', 'bedpan', 'bedspread', 'cow', 'beef_(food)', 'beeper',
'beer_bottle', 'beer_can', 'beetle', 'bell', 'bell_pepper', 'belt',
'belt_buckle', 'bench', 'beret', 'bib', 'Bible', 'bicycle', 'visor',
'billboard', 'binder', 'binoculars', 'bird', 'birdfeeder', 'birdbath',
'birdcage', 'birdhouse', 'birthday_cake', 'birthday_card',
'pirate_flag', 'black_sheep', 'blackberry', 'blackboard', 'blanket',
'blazer', 'blender', 'blimp', 'blinker', 'blouse', 'blueberry',
'gameboard', 'boat', 'bob', 'bobbin', 'bobby_pin', 'boiled_egg',
'bolo_tie', 'deadbolt', 'bolt', 'bonnet', 'book', 'bookcase',
'booklet', 'bookmark', 'boom_microphone', 'boot', 'bottle',
'bottle_opener', 'bouquet', 'bow_(weapon)',
'bow_(decorative_ribbons)', 'bow-tie', 'bowl', 'pipe_bowl',
'bowler_hat', 'bowling_ball', 'box', 'boxing_glove', 'suspenders',
'bracelet', 'brass_plaque', 'brassiere', 'bread-bin', 'bread',
'breechcloth', 'bridal_gown', 'briefcase', 'broccoli', 'broach',
'broom', 'brownie', 'brussels_sprouts', 'bubble_gum', 'bucket',
'horse_buggy', 'bull', 'bulldog', 'bulldozer', 'bullet_train',
'bulletin_board', 'bulletproof_vest', 'bullhorn', 'bun', 'bunk_bed',
'buoy', 'burrito', 'bus_(vehicle)', 'business_card', 'butter',
'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car', 'cabinet',
'locker', 'cake', 'calculator', 'calendar', 'calf', 'camcorder',
'camel', 'camera', 'camera_lens', 'camper_(vehicle)', 'can',
'can_opener', 'candle', 'candle_holder', 'candy_bar', 'candy_cane',
'walking_cane', 'canister', 'canoe', 'cantaloup', 'canteen',
'cap_(headwear)', 'bottle_cap', 'cape', 'cappuccino',
'car_(automobile)', 'railcar_(part_of_a_train)', 'elevator_car',
'car_battery', 'identity_card', 'card', 'cardigan', 'cargo_ship',
'carnation', 'horse_carriage', 'carrot', 'tote_bag', 'cart', 'carton',
'cash_register', 'casserole', 'cassette', 'cast', 'cat',
'cauliflower', 'cayenne_(spice)', 'CD_player', 'celery',
'cellular_telephone', 'chain_mail', 'chair', 'chaise_longue',
'chalice', 'chandelier', 'chap', 'checkbook', 'checkerboard',
'cherry', 'chessboard', 'chicken_(animal)', 'chickpea',
'chili_(vegetable)', 'chime', 'chinaware', 'crisp_(potato_chip)',
'poker_chip', 'chocolate_bar', 'chocolate_cake', 'chocolate_milk',
'chocolate_mousse', 'choker', 'chopping_board', 'chopstick',
'Christmas_tree', 'slide', 'cider', 'cigar_box', 'cigarette',
'cigarette_case', 'cistern', 'clarinet', 'clasp', 'cleansing_agent',
'cleat_(for_securing_rope)', 'clementine', 'clip', 'clipboard',
'clippers_(for_plants)', 'cloak', 'clock', 'clock_tower',
'clothes_hamper', 'clothespin', 'clutch_bag', 'coaster', 'coat',
'coat_hanger', 'coatrack', 'cock', 'cockroach', 'cocoa_(beverage)',
'coconut', 'coffee_maker', 'coffee_table', 'coffeepot', 'coil',
'coin', 'colander', 'coleslaw', 'coloring_material',
'combination_lock', 'pacifier', 'comic_book', 'compass',
'computer_keyboard', 'condiment', 'cone', 'control',
'convertible_(automobile)', 'sofa_bed', 'cooker', 'cookie',
'cooking_utensil', 'cooler_(for_food)', 'cork_(bottle_plug)',
'corkboard', 'corkscrew', 'edible_corn', 'cornbread', 'cornet',
'cornice', 'cornmeal', 'corset', 'costume', 'cougar', 'coverall',
'cowbell', 'cowboy_hat', 'crab_(animal)', 'crabmeat', 'cracker',
'crape', 'crate', 'crayon', 'cream_pitcher', 'crescent_roll', 'crib',
'crock_pot', 'crossbar', 'crouton', 'crow', 'crowbar', 'crown',
'crucifix', 'cruise_ship', 'police_cruiser', 'crumb', 'crutch',
'cub_(animal)', 'cube', 'cucumber', 'cufflink', 'cup', 'trophy_cup',
'cupboard', 'cupcake', 'hair_curler', 'curling_iron', 'curtain',
'cushion', 'cylinder', 'cymbal', 'dagger', 'dalmatian', 'dartboard',
'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk',
'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table',
'tux', 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher',
'dishwasher_detergent', 'dispenser', 'diving_board', 'Dixie_cup',
'dog', 'dog_collar', 'doll', 'dollar', 'dollhouse', 'dolphin',
'domestic_ass', 'doorknob', 'doormat', 'doughnut', 'dove',
'dragonfly', 'drawer', 'underdrawers', 'dress', 'dress_hat',
'dress_suit', 'dresser', 'drill', 'drone', 'dropper',
'drum_(musical_instrument)', 'drumstick', 'duck', 'duckling',
'duct_tape', 'duffel_bag', 'dumbbell', 'dumpster', 'dustpan', 'eagle',
'earphone', 'earplug', 'earring', 'easel', 'eclair', 'eel', 'egg',
'egg_roll', 'egg_yolk', 'eggbeater', 'eggplant', 'electric_chair',
'refrigerator', 'elephant', 'elk', 'envelope', 'eraser', 'escargot',
'eyepatch', 'falcon', 'fan', 'faucet', 'fedora', 'ferret',
'Ferris_wheel', 'ferry', 'fig_(fruit)', 'fighter_jet', 'figurine',
'file_cabinet', 'file_(tool)', 'fire_alarm', 'fire_engine',
'fire_extinguisher', 'fire_hose', 'fireplace', 'fireplug',
'first-aid_kit', 'fish', 'fish_(food)', 'fishbowl', 'fishing_rod',
'flag', 'flagpole', 'flamingo', 'flannel', 'flap', 'flash',
'flashlight', 'fleece', 'flip-flop_(sandal)', 'flipper_(footwear)',
'flower_arrangement', 'flute_glass', 'foal', 'folding_chair',
'food_processor', 'football_(American)', 'football_helmet',
'footstool', 'fork', 'forklift', 'freight_car', 'French_toast',
'freshener', 'frisbee', 'frog', 'fruit_juice', 'frying_pan', 'fudge',
'funnel', 'futon', 'gag', 'garbage', 'garbage_truck', 'garden_hose',
'gargle', 'gargoyle', 'garlic', 'gasmask', 'gazelle', 'gelatin',
'gemstone', 'generator', 'giant_panda', 'gift_wrap', 'ginger',
'giraffe', 'cincture', 'glass_(drink_container)', 'globe', 'glove',
'goat', 'goggles', 'goldfish', 'golf_club', 'golfcart',
'gondola_(boat)', 'goose', 'gorilla', 'gourd', 'grape', 'grater',
'gravestone', 'gravy_boat', 'green_bean', 'green_onion', 'griddle',
'grill', 'grits', 'grizzly', 'grocery_bag', 'guitar', 'gull', 'gun',
'hairbrush', 'hairnet', 'hairpin', 'halter_top', 'ham', 'hamburger',
'hammer', 'hammock', 'hamper', 'hamster', 'hair_dryer', 'hand_glass',
'hand_towel', 'handcart', 'handcuff', 'handkerchief', 'handle',
'handsaw', 'hardback_book', 'harmonium', 'hat', 'hatbox', 'veil',
'headband', 'headboard', 'headlight', 'headscarf', 'headset',
'headstall_(for_horses)', 'heart', 'heater', 'helicopter', 'helmet',
'heron', 'highchair', 'hinge', 'hippopotamus', 'hockey_stick', 'hog',
'home_plate_(baseball)', 'honey', 'fume_hood', 'hook', 'hookah',
'hornet', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce',
'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear',
'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate',
'igniter', 'inhaler', 'iPod', 'iron_(for_clothing)', 'ironing_board',
'jacket', 'jam', 'jar', 'jean', 'jeep', 'jelly_bean', 'jersey',
'jet_plane', 'jewel', 'jewelry', 'joystick', 'jumpsuit', 'kayak',
'keg', 'kennel', 'kettle', 'key', 'keycard', 'kilt', 'kimono',
'kitchen_sink', 'kitchen_table', 'kite', 'kitten', 'kiwi_fruit',
'knee_pad', 'knife', 'knitting_needle', 'knob', 'knocker_(on_a_door)',
'koala', 'lab_coat', 'ladder', 'ladle', 'ladybug', 'lamb_(animal)',
'lamb-chop', 'lamp', 'lamppost', 'lampshade', 'lantern', 'lanyard',
'laptop_computer', 'lasagna', 'latch', 'lawn_mower', 'leather',
'legging_(clothing)', 'Lego', 'legume', 'lemon', 'lemonade',
'lettuce', 'license_plate', 'life_buoy', 'life_jacket', 'lightbulb',
'lightning_rod', 'lime', 'limousine', 'lion', 'lip_balm', 'liquor',
'lizard', 'log', 'lollipop', 'speaker_(stereo_equipment)', 'loveseat',
'machine_gun', 'magazine', 'magnet', 'mail_slot', 'mailbox_(at_home)',
'mallard', 'mallet', 'mammoth', 'manatee', 'mandarin_orange',
'manger', 'manhole', 'map', 'marker', 'martini', 'mascot',
'mashed_potato', 'masher', 'mask', 'mast', 'mat_(gym_equipment)',
'matchbox', 'mattress', 'measuring_cup', 'measuring_stick',
'meatball', 'medicine', 'melon', 'microphone', 'microscope',
'microwave_oven', 'milestone', 'milk', 'milk_can', 'milkshake',
'minivan', 'mint_candy', 'mirror', 'mitten', 'mixer_(kitchen_tool)',
'money', 'monitor_(computer_equipment) computer_monitor', 'monkey',
'motor', 'motor_scooter', 'motor_vehicle', 'motorcycle',
'mound_(baseball)', 'mouse_(computer_equipment)', 'mousepad',
'muffin', 'mug', 'mushroom', 'music_stool', 'musical_instrument',
'nailfile', 'napkin', 'neckerchief', 'necklace', 'necktie', 'needle',
'nest', 'newspaper', 'newsstand', 'nightshirt',
'nosebag_(for_animals)', 'noseband_(for_animals)', 'notebook',
'notepad', 'nut', 'nutcracker', 'oar', 'octopus_(food)',
'octopus_(animal)', 'oil_lamp', 'olive_oil', 'omelet', 'onion',
'orange_(fruit)', 'orange_juice', 'ostrich', 'ottoman', 'oven',
'overalls_(clothing)', 'owl', 'packet', 'inkpad', 'pad', 'paddle',
'padlock', 'paintbrush', 'painting', 'pajamas', 'palette',
'pan_(for_cooking)', 'pan_(metal_container)', 'pancake', 'pantyhose',
'papaya', 'paper_plate', 'paper_towel', 'paperback_book',
'paperweight', 'parachute', 'parakeet', 'parasail_(sports)',
'parasol', 'parchment', 'parka', 'parking_meter', 'parrot',
'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport',
'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter',
'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'wooden_leg',
'pegboard', 'pelican', 'pen', 'pencil', 'pencil_box',
'pencil_sharpener', 'pendulum', 'penguin', 'pennant', 'penny_(coin)',
'pepper', 'pepper_mill', 'perfume', 'persimmon', 'person', 'pet',
'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano',
'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow',
'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball',
'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)',
'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat',
'plate', 'platter', 'playpen', 'pliers', 'plow_(farm_equipment)',
'plume', 'pocket_watch', 'pocketknife', 'poker_(fire_stirring_tool)',
'pole', 'polo_shirt', 'poncho', 'pony', 'pool_table', 'pop_(soda)',
'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot',
'potato', 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn',
'pretzel', 'printer', 'projectile_(weapon)', 'projector', 'propeller',
'prune', 'pudding', 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin',
'puncher', 'puppet', 'puppy', 'quesadilla', 'quiche', 'quilt',
'rabbit', 'race_car', 'racket', 'radar', 'radiator', 'radio_receiver',
'radish', 'raft', 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry',
'rat', 'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt',
'recliner', 'record_player', 'reflector', 'remote_control',
'rhinoceros', 'rib_(food)', 'rifle', 'ring', 'river_boat', 'road_map',
'robe', 'rocking_chair', 'rodent', 'roller_skate', 'Rollerblade',
'rolling_pin', 'root_beer', 'router_(computer_equipment)',
'rubber_band', 'runner_(carpet)', 'plastic_bag',
'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag', 'safety_pin',
'sail', 'salad', 'salad_plate', 'salami', 'salmon_(fish)',
'salmon_(food)', 'salsa', 'saltshaker', 'sandal_(type_of_shoe)',
'sandwich', 'satchel', 'saucepan', 'saucer', 'sausage', 'sawhorse',
'saxophone', 'scale_(measuring_instrument)', 'scarecrow', 'scarf',
'school_bus', 'scissors', 'scoreboard', 'scraper', 'screwdriver',
'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane',
'seashell', 'sewing_machine', 'shaker', 'shampoo', 'shark',
'sharpener', 'Sharpie', 'shaver_(electric)', 'shaving_cream', 'shawl',
'shears', 'sheep', 'shepherd_dog', 'sherbert', 'shield', 'shirt',
'shoe', 'shopping_bag', 'shopping_cart', 'short_pants', 'shot_glass',
'shoulder_bag', 'shovel', 'shower_head', 'shower_cap',
'shower_curtain', 'shredder_(for_paper)', 'signboard', 'silo', 'sink',
'skateboard', 'skewer', 'ski', 'ski_boot', 'ski_parka', 'ski_pole',
'skirt', 'skullcap', 'sled', 'sleeping_bag', 'sling_(bandage)',
'slipper_(footwear)', 'smoothie', 'snake', 'snowboard', 'snowman',
'snowmobile', 'soap', 'soccer_ball', 'sock', 'sofa', 'softball',
'solar_array', 'sombrero', 'soup', 'soup_bowl', 'soupspoon',
'sour_cream', 'soya_milk', 'space_shuttle', 'sparkler_(fireworks)',
'spatula', 'spear', 'spectacles', 'spice_rack', 'spider', 'crawfish',
'sponge', 'spoon', 'sportswear', 'spotlight', 'squid_(food)',
'squirrel', 'stagecoach', 'stapler_(stapling_machine)', 'starfish',
'statue_(sculpture)', 'steak_(food)', 'steak_knife', 'steering_wheel',
'stepladder', 'step_stool', 'stereo_(sound_system)', 'stew',
'stirrer', 'stirrup', 'stool', 'stop_sign', 'brake_light', 'stove',
'strainer', 'strap', 'straw_(for_drinking)', 'strawberry',
'street_sign', 'streetlight', 'string_cheese', 'stylus', 'subwoofer',
'sugar_bowl', 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower',
'sunglasses', 'sunhat', 'surfboard', 'sushi', 'mop', 'sweat_pants',
'sweatband', 'sweater', 'sweatshirt', 'sweet_potato', 'swimsuit',
'sword', 'syringe', 'Tabasco_sauce', 'table-tennis_table', 'table',
'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag', 'taillight',
'tambourine', 'army_tank', 'tank_(storage_vessel)',
'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure',
'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup',
'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth',
'telephone_pole', 'telephoto_lens', 'television_camera',
'television_set', 'tennis_ball', 'tennis_racket', 'tequila',
'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread',
'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer',
'tinfoil', 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster',
'toaster_oven', 'toilet', 'toilet_tissue', 'tomato', 'tongs',
'toolbox', 'toothbrush', 'toothpaste', 'toothpick', 'cover',
'tortilla', 'tow_truck', 'towel', 'towel_rack', 'toy',
'tractor_(farm_equipment)', 'traffic_light', 'dirt_bike',
'trailer_truck', 'train_(railroad_vehicle)', 'trampoline', 'tray',
'trench_coat', 'triangle_(musical_instrument)', 'tricycle', 'tripod',
'trousers', 'truck', 'truffle_(chocolate)', 'trunk', 'vat', 'turban',
'turkey_(food)', 'turnip', 'turtle', 'turtleneck_(clothing)',
'typewriter', 'umbrella', 'underwear', 'unicycle', 'urinal', 'urn',
'vacuum_cleaner', 'vase', 'vending_machine', 'vent', 'vest',
'videotape', 'vinegar', 'violin', 'vodka', 'volleyball', 'vulture',
'waffle', 'waffle_iron', 'wagon', 'wagon_wheel', 'walking_stick',
'wall_clock', 'wall_socket', 'wallet', 'walrus', 'wardrobe',
'washbasin', 'automatic_washer', 'watch', 'water_bottle',
'water_cooler', 'water_faucet', 'water_heater', 'water_jug',
'water_gun', 'water_scooter', 'water_ski', 'water_tower',
'watering_can', 'watermelon', 'weathervane', 'webcam', 'wedding_cake',
'wedding_ring', 'wet_suit', 'wheel', 'wheelchair', 'whipped_cream',
'whistle', 'wig', 'wind_chime', 'windmill', 'window_box_(for_plants)',
'windshield_wiper', 'windsock', 'wine_bottle', 'wine_bucket',
'wineglass', 'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon',
'wreath', 'wrench', 'wristband', 'wristlet', 'yacht', 'yogurt',
'yoke_(animal_equipment)', 'zebra', 'zucchini'),
'palette':
None
}
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
""" # noqa: E501
try:
import lvis
if getattr(lvis, '__version__', '0') >= '10.5.3':
warnings.warn(
'mmlvis is deprecated, please install official lvis-api by "pip install git+https://github.com/lvis-dataset/lvis-api.git"', # noqa: E501
UserWarning)
from lvis import LVIS
except ImportError:
raise ImportError(
'Package lvis is not installed. Please run "pip install git+https://github.com/lvis-dataset/lvis-api.git".' # noqa: E501
)
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
self.lvis = LVIS(local_path)
self.cat_ids = self.lvis.get_cat_ids()
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.cat_img_map = copy.deepcopy(self.lvis.cat_img_map)
img_ids = self.lvis.get_img_ids()
data_list = []
total_ann_ids = []
for img_id in img_ids:
raw_img_info = self.lvis.load_imgs([img_id])[0]
raw_img_info['img_id'] = img_id
# coco_url is used in LVISv1 instead of file_name
# e.g. http://images.cocodataset.org/train2017/000000391895.jpg
# train/val split in specified in url
raw_img_info['file_name'] = raw_img_info['coco_url'].replace(
'http://images.cocodataset.org/', '')
ann_ids = self.lvis.get_ann_ids(img_ids=[img_id])
raw_ann_info = self.lvis.load_anns(ann_ids)
total_ann_ids.extend(ann_ids)
parsed_data_info = self.parse_data_info({
'raw_ann_info':
raw_ann_info,
'raw_img_info':
raw_img_info
})
data_list.append(parsed_data_info)
if self.ANN_ID_UNIQUE:
assert len(set(total_ann_ids)) == len(
total_ann_ids
), f"Annotation ids in '{self.ann_file}' are not unique!"
del self.lvis
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List
from mmengine.fileio import get_local_path
from mmdet.datasets import BaseDetDataset
from mmdet.registry import DATASETS
from .api_wrappers import COCO
@DATASETS.register_module()
class MDETRStyleRefCocoDataset(BaseDetDataset):
"""RefCOCO dataset.
Only support evaluation now.
"""
def load_data_list(self) -> List[dict]:
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
coco = COCO(local_path)
img_ids = coco.get_img_ids()
data_infos = []
for img_id in img_ids:
raw_img_info = coco.load_imgs([img_id])[0]
ann_ids = coco.get_ann_ids(img_ids=[img_id])
raw_ann_info = coco.load_anns(ann_ids)
data_info = {}
img_path = osp.join(self.data_prefix['img'],
raw_img_info['file_name'])
data_info['img_path'] = img_path
data_info['img_id'] = img_id
data_info['height'] = raw_img_info['height']
data_info['width'] = raw_img_info['width']
data_info['dataset_mode'] = raw_img_info['dataset_name']
data_info['text'] = raw_img_info['caption']
data_info['custom_entities'] = False
data_info['tokens_positive'] = -1
instances = []
for i, ann in enumerate(raw_ann_info):
instance = {}
x1, y1, w, h = ann['bbox']
bbox = [x1, y1, x1 + w, y1 + h]
instance['bbox'] = bbox
instance['bbox_label'] = ann['category_id']
instance['ignore_flag'] = 0
instances.append(instance)
data_info['instances'] = instances
data_infos.append(data_info)
return data_infos
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List, Union
from mmdet.registry import DATASETS
from .base_video_dataset import BaseVideoDataset
@DATASETS.register_module()
class MOTChallengeDataset(BaseVideoDataset):
"""Dataset for MOTChallenge.
Args:
visibility_thr (float, optional): The minimum visibility
for the objects during training. Default to -1.
"""
METAINFO = {
'classes':
('pedestrian', 'person_on_vehicle', 'car', 'bicycle', 'motorbike',
'non_mot_vehicle', 'static_person', 'distractor', 'occluder',
'occluder_on_ground', 'occluder_full', 'reflection', 'crowd')
}
def __init__(self, visibility_thr: float = -1, *args, **kwargs):
self.visibility_thr = visibility_thr
super().__init__(*args, **kwargs)
def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]:
"""Parse raw annotation to target format. The difference between this
function and the one in ``BaseVideoDataset`` is that the parsing here
adds ``visibility`` and ``mot_conf``.
Args:
raw_data_info (dict): Raw data information load from ``ann_file``
Returns:
Union[dict, List[dict]]: Parsed annotation.
"""
img_info = raw_data_info['raw_img_info']
ann_info = raw_data_info['raw_ann_info']
data_info = {}
data_info.update(img_info)
if self.data_prefix.get('img_path', None) is not None:
img_path = osp.join(self.data_prefix['img_path'],
img_info['file_name'])
else:
img_path = img_info['file_name']
data_info['img_path'] = img_path
instances = []
for i, ann in enumerate(ann_info):
instance = {}
if (not self.test_mode) and (ann['visibility'] <
self.visibility_thr):
continue
if ann.get('ignore', False):
continue
x1, y1, w, h = ann['bbox']
inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
if inter_w * inter_h == 0:
continue
if ann['area'] <= 0 or w < 1 or h < 1:
continue
if ann['category_id'] not in self.cat_ids:
continue
bbox = [x1, y1, x1 + w, y1 + h]
if ann.get('iscrowd', False):
instance['ignore_flag'] = 1
else:
instance['ignore_flag'] = 0
instance['bbox'] = bbox
instance['bbox_label'] = self.cat2label[ann['category_id']]
instance['instance_id'] = ann['instance_id']
instance['category_id'] = ann['category_id']
instance['mot_conf'] = ann['mot_conf']
instance['visibility'] = ann['visibility']
if len(instance) > 0:
instances.append(instance)
if not self.test_mode:
assert len(instances) > 0, f'No valid instances found in ' \
f'image {data_info["img_path"]}!'
data_info['instances'] = instances
return data_info
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from typing import List
from mmengine.fileio import get_local_path
from mmdet.registry import DATASETS
from .api_wrappers import COCO
from .coco import CocoDataset
# images exist in annotations but not in image folder.
objv2_ignore_list = [
osp.join('patch16', 'objects365_v2_00908726.jpg'),
osp.join('patch6', 'objects365_v1_00320532.jpg'),
osp.join('patch6', 'objects365_v1_00320534.jpg'),
]
@DATASETS.register_module()
class Objects365V1Dataset(CocoDataset):
"""Objects365 v1 dataset for detection."""
METAINFO = {
'classes':
('person', 'sneakers', 'chair', 'hat', 'lamp', 'bottle',
'cabinet/shelf', 'cup', 'car', 'glasses', 'picture/frame', 'desk',
'handbag', 'street lights', 'book', 'plate', 'helmet',
'leather shoes', 'pillow', 'glove', 'potted plant', 'bracelet',
'flower', 'tv', 'storage box', 'vase', 'bench', 'wine glass', 'boots',
'bowl', 'dining table', 'umbrella', 'boat', 'flag', 'speaker',
'trash bin/can', 'stool', 'backpack', 'couch', 'belt', 'carpet',
'basket', 'towel/napkin', 'slippers', 'barrel/bucket', 'coffee table',
'suv', 'toy', 'tie', 'bed', 'traffic light', 'pen/pencil',
'microphone', 'sandals', 'canned', 'necklace', 'mirror', 'faucet',
'bicycle', 'bread', 'high heels', 'ring', 'van', 'watch', 'sink',
'horse', 'fish', 'apple', 'camera', 'candle', 'teddy bear', 'cake',
'motorcycle', 'wild bird', 'laptop', 'knife', 'traffic sign',
'cell phone', 'paddle', 'truck', 'cow', 'power outlet', 'clock',
'drum', 'fork', 'bus', 'hanger', 'nightstand', 'pot/pan', 'sheep',
'guitar', 'traffic cone', 'tea pot', 'keyboard', 'tripod', 'hockey',
'fan', 'dog', 'spoon', 'blackboard/whiteboard', 'balloon',
'air conditioner', 'cymbal', 'mouse', 'telephone', 'pickup truck',
'orange', 'banana', 'airplane', 'luggage', 'skis', 'soccer',
'trolley', 'oven', 'remote', 'baseball glove', 'paper towel',
'refrigerator', 'train', 'tomato', 'machinery vehicle', 'tent',
'shampoo/shower gel', 'head phone', 'lantern', 'donut',
'cleaning products', 'sailboat', 'tangerine', 'pizza', 'kite',
'computer box', 'elephant', 'toiletries', 'gas stove', 'broccoli',
'toilet', 'stroller', 'shovel', 'baseball bat', 'microwave',
'skateboard', 'surfboard', 'surveillance camera', 'gun', 'life saver',
'cat', 'lemon', 'liquid soap', 'zebra', 'duck', 'sports car',
'giraffe', 'pumpkin', 'piano', 'stop sign', 'radiator', 'converter',
'tissue ', 'carrot', 'washing machine', 'vent', 'cookies',
'cutting/chopping board', 'tennis racket', 'candy',
'skating and skiing shoes', 'scissors', 'folder', 'baseball',
'strawberry', 'bow tie', 'pigeon', 'pepper', 'coffee machine',
'bathtub', 'snowboard', 'suitcase', 'grapes', 'ladder', 'pear',
'american football', 'basketball', 'potato', 'paint brush', 'printer',
'billiards', 'fire hydrant', 'goose', 'projector', 'sausage',
'fire extinguisher', 'extension cord', 'facial mask', 'tennis ball',
'chopsticks', 'electronic stove and gas stove', 'pie', 'frisbee',
'kettle', 'hamburger', 'golf club', 'cucumber', 'clutch', 'blender',
'tong', 'slide', 'hot dog', 'toothbrush', 'facial cleanser', 'mango',
'deer', 'egg', 'violin', 'marker', 'ship', 'chicken', 'onion',
'ice cream', 'tape', 'wheelchair', 'plum', 'bar soap', 'scale',
'watermelon', 'cabbage', 'router/modem', 'golf ball', 'pine apple',
'crane', 'fire truck', 'peach', 'cello', 'notepaper', 'tricycle',
'toaster', 'helicopter', 'green beans', 'brush', 'carriage', 'cigar',
'earphone', 'penguin', 'hurdle', 'swing', 'radio', 'CD',
'parking meter', 'swan', 'garlic', 'french fries', 'horn', 'avocado',
'saxophone', 'trumpet', 'sandwich', 'cue', 'kiwi fruit', 'bear',
'fishing rod', 'cherry', 'tablet', 'green vegetables', 'nuts', 'corn',
'key', 'screwdriver', 'globe', 'broom', 'pliers', 'volleyball',
'hammer', 'eggplant', 'trophy', 'dates', 'board eraser', 'rice',
'tape measure/ruler', 'dumbbell', 'hamimelon', 'stapler', 'camel',
'lettuce', 'goldfish', 'meat balls', 'medal', 'toothpaste',
'antelope', 'shrimp', 'rickshaw', 'trombone', 'pomegranate',
'coconut', 'jellyfish', 'mushroom', 'calculator', 'treadmill',
'butterfly', 'egg tart', 'cheese', 'pig', 'pomelo', 'race car',
'rice cooker', 'tuba', 'crosswalk sign', 'papaya', 'hair drier',
'green onion', 'chips', 'dolphin', 'sushi', 'urinal', 'donkey',
'electric drill', 'spring rolls', 'tortoise/turtle', 'parrot',
'flute', 'measuring cup', 'shark', 'steak', 'poker card',
'binoculars', 'llama', 'radish', 'noodles', 'yak', 'mop', 'crab',
'microscope', 'barbell', 'bread/bun', 'baozi', 'lion', 'red cabbage',
'polar bear', 'lighter', 'seal', 'mangosteen', 'comb', 'eraser',
'pitaya', 'scallop', 'pencil case', 'saw', 'table tennis paddle',
'okra', 'starfish', 'eagle', 'monkey', 'durian', 'game board',
'rabbit', 'french horn', 'ambulance', 'asparagus', 'hoverboard',
'pasta', 'target', 'hotair balloon', 'chainsaw', 'lobster', 'iron',
'flashlight'),
'palette':
None
}
COCOAPI = COCO
# ann_id is unique in coco dataset.
ANN_ID_UNIQUE = True
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
""" # noqa: E501
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
self.coco = self.COCOAPI(local_path)
# 'categories' list in objects365_train.json and objects365_val.json
# is inconsistent, need sort list(or dict) before get cat_ids.
cats = self.coco.cats
sorted_cats = {i: cats[i] for i in sorted(cats)}
self.coco.cats = sorted_cats
categories = self.coco.dataset['categories']
sorted_categories = sorted(categories, key=lambda i: i['id'])
self.coco.dataset['categories'] = sorted_categories
# The order of returned `cat_ids` will not
# change with the order of the `classes`
self.cat_ids = self.coco.get_cat_ids(
cat_names=self.metainfo['classes'])
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.cat_img_map = copy.deepcopy(self.coco.cat_img_map)
img_ids = self.coco.get_img_ids()
data_list = []
total_ann_ids = []
for img_id in img_ids:
raw_img_info = self.coco.load_imgs([img_id])[0]
raw_img_info['img_id'] = img_id
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
raw_ann_info = self.coco.load_anns(ann_ids)
total_ann_ids.extend(ann_ids)
parsed_data_info = self.parse_data_info({
'raw_ann_info':
raw_ann_info,
'raw_img_info':
raw_img_info
})
data_list.append(parsed_data_info)
if self.ANN_ID_UNIQUE:
assert len(set(total_ann_ids)) == len(
total_ann_ids
), f"Annotation ids in '{self.ann_file}' are not unique!"
del self.coco
return data_list
@DATASETS.register_module()
class Objects365V2Dataset(CocoDataset):
"""Objects365 v2 dataset for detection."""
METAINFO = {
'classes':
('Person', 'Sneakers', 'Chair', 'Other Shoes', 'Hat', 'Car', 'Lamp',
'Glasses', 'Bottle', 'Desk', 'Cup', 'Street Lights', 'Cabinet/shelf',
'Handbag/Satchel', 'Bracelet', 'Plate', 'Picture/Frame', 'Helmet',
'Book', 'Gloves', 'Storage box', 'Boat', 'Leather Shoes', 'Flower',
'Bench', 'Potted Plant', 'Bowl/Basin', 'Flag', 'Pillow', 'Boots',
'Vase', 'Microphone', 'Necklace', 'Ring', 'SUV', 'Wine Glass', 'Belt',
'Moniter/TV', 'Backpack', 'Umbrella', 'Traffic Light', 'Speaker',
'Watch', 'Tie', 'Trash bin Can', 'Slippers', 'Bicycle', 'Stool',
'Barrel/bucket', 'Van', 'Couch', 'Sandals', 'Bakset', 'Drum',
'Pen/Pencil', 'Bus', 'Wild Bird', 'High Heels', 'Motorcycle',
'Guitar', 'Carpet', 'Cell Phone', 'Bread', 'Camera', 'Canned',
'Truck', 'Traffic cone', 'Cymbal', 'Lifesaver', 'Towel',
'Stuffed Toy', 'Candle', 'Sailboat', 'Laptop', 'Awning', 'Bed',
'Faucet', 'Tent', 'Horse', 'Mirror', 'Power outlet', 'Sink', 'Apple',
'Air Conditioner', 'Knife', 'Hockey Stick', 'Paddle', 'Pickup Truck',
'Fork', 'Traffic Sign', 'Ballon', 'Tripod', 'Dog', 'Spoon', 'Clock',
'Pot', 'Cow', 'Cake', 'Dinning Table', 'Sheep', 'Hanger',
'Blackboard/Whiteboard', 'Napkin', 'Other Fish', 'Orange/Tangerine',
'Toiletry', 'Keyboard', 'Tomato', 'Lantern', 'Machinery Vehicle',
'Fan', 'Green Vegetables', 'Banana', 'Baseball Glove', 'Airplane',
'Mouse', 'Train', 'Pumpkin', 'Soccer', 'Skiboard', 'Luggage',
'Nightstand', 'Tea pot', 'Telephone', 'Trolley', 'Head Phone',
'Sports Car', 'Stop Sign', 'Dessert', 'Scooter', 'Stroller', 'Crane',
'Remote', 'Refrigerator', 'Oven', 'Lemon', 'Duck', 'Baseball Bat',
'Surveillance Camera', 'Cat', 'Jug', 'Broccoli', 'Piano', 'Pizza',
'Elephant', 'Skateboard', 'Surfboard', 'Gun',
'Skating and Skiing shoes', 'Gas stove', 'Donut', 'Bow Tie', 'Carrot',
'Toilet', 'Kite', 'Strawberry', 'Other Balls', 'Shovel', 'Pepper',
'Computer Box', 'Toilet Paper', 'Cleaning Products', 'Chopsticks',
'Microwave', 'Pigeon', 'Baseball', 'Cutting/chopping Board',
'Coffee Table', 'Side Table', 'Scissors', 'Marker', 'Pie', 'Ladder',
'Snowboard', 'Cookies', 'Radiator', 'Fire Hydrant', 'Basketball',
'Zebra', 'Grape', 'Giraffe', 'Potato', 'Sausage', 'Tricycle',
'Violin', 'Egg', 'Fire Extinguisher', 'Candy', 'Fire Truck',
'Billards', 'Converter', 'Bathtub', 'Wheelchair', 'Golf Club',
'Briefcase', 'Cucumber', 'Cigar/Cigarette ', 'Paint Brush', 'Pear',
'Heavy Truck', 'Hamburger', 'Extractor', 'Extention Cord', 'Tong',
'Tennis Racket', 'Folder', 'American Football', 'earphone', 'Mask',
'Kettle', 'Tennis', 'Ship', 'Swing', 'Coffee Machine', 'Slide',
'Carriage', 'Onion', 'Green beans', 'Projector', 'Frisbee',
'Washing Machine/Drying Machine', 'Chicken', 'Printer', 'Watermelon',
'Saxophone', 'Tissue', 'Toothbrush', 'Ice cream', 'Hotair ballon',
'Cello', 'French Fries', 'Scale', 'Trophy', 'Cabbage', 'Hot dog',
'Blender', 'Peach', 'Rice', 'Wallet/Purse', 'Volleyball', 'Deer',
'Goose', 'Tape', 'Tablet', 'Cosmetics', 'Trumpet', 'Pineapple',
'Golf Ball', 'Ambulance', 'Parking meter', 'Mango', 'Key', 'Hurdle',
'Fishing Rod', 'Medal', 'Flute', 'Brush', 'Penguin', 'Megaphone',
'Corn', 'Lettuce', 'Garlic', 'Swan', 'Helicopter', 'Green Onion',
'Sandwich', 'Nuts', 'Speed Limit Sign', 'Induction Cooker', 'Broom',
'Trombone', 'Plum', 'Rickshaw', 'Goldfish', 'Kiwi fruit',
'Router/modem', 'Poker Card', 'Toaster', 'Shrimp', 'Sushi', 'Cheese',
'Notepaper', 'Cherry', 'Pliers', 'CD', 'Pasta', 'Hammer', 'Cue',
'Avocado', 'Hamimelon', 'Flask', 'Mushroon', 'Screwdriver', 'Soap',
'Recorder', 'Bear', 'Eggplant', 'Board Eraser', 'Coconut',
'Tape Measur/ Ruler', 'Pig', 'Showerhead', 'Globe', 'Chips', 'Steak',
'Crosswalk Sign', 'Stapler', 'Campel', 'Formula 1 ', 'Pomegranate',
'Dishwasher', 'Crab', 'Hoverboard', 'Meat ball', 'Rice Cooker',
'Tuba', 'Calculator', 'Papaya', 'Antelope', 'Parrot', 'Seal',
'Buttefly', 'Dumbbell', 'Donkey', 'Lion', 'Urinal', 'Dolphin',
'Electric Drill', 'Hair Dryer', 'Egg tart', 'Jellyfish', 'Treadmill',
'Lighter', 'Grapefruit', 'Game board', 'Mop', 'Radish', 'Baozi',
'Target', 'French', 'Spring Rolls', 'Monkey', 'Rabbit', 'Pencil Case',
'Yak', 'Red Cabbage', 'Binoculars', 'Asparagus', 'Barbell', 'Scallop',
'Noddles', 'Comb', 'Dumpling', 'Oyster', 'Table Teniis paddle',
'Cosmetics Brush/Eyeliner Pencil', 'Chainsaw', 'Eraser', 'Lobster',
'Durian', 'Okra', 'Lipstick', 'Cosmetics Mirror', 'Curling',
'Table Tennis '),
'palette':
None
}
COCOAPI = COCO
# ann_id is unique in coco dataset.
ANN_ID_UNIQUE = True
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
""" # noqa: E501
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
self.coco = self.COCOAPI(local_path)
# The order of returned `cat_ids` will not
# change with the order of the `classes`
self.cat_ids = self.coco.get_cat_ids(
cat_names=self.metainfo['classes'])
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.cat_img_map = copy.deepcopy(self.coco.cat_img_map)
img_ids = self.coco.get_img_ids()
data_list = []
total_ann_ids = []
for img_id in img_ids:
raw_img_info = self.coco.load_imgs([img_id])[0]
raw_img_info['img_id'] = img_id
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
raw_ann_info = self.coco.load_anns(ann_ids)
total_ann_ids.extend(ann_ids)
# file_name should be `patchX/xxx.jpg`
file_name = osp.join(
osp.split(osp.split(raw_img_info['file_name'])[0])[-1],
osp.split(raw_img_info['file_name'])[-1])
if file_name in objv2_ignore_list:
continue
raw_img_info['file_name'] = file_name
parsed_data_info = self.parse_data_info({
'raw_ann_info':
raw_ann_info,
'raw_img_info':
raw_img_info
})
data_list.append(parsed_data_info)
if self.ANN_ID_UNIQUE:
assert len(set(total_ann_ids)) == len(
total_ann_ids
), f"Annotation ids in '{self.ann_file}' are not unique!"
del self.coco
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp
from typing import List, Optional
from mmengine.fileio import get_local_path
from mmdet.registry import DATASETS
from .base_det_dataset import BaseDetDataset
@DATASETS.register_module()
class ODVGDataset(BaseDetDataset):
"""object detection and visual grounding dataset."""
def __init__(self,
*args,
data_root: str = '',
label_map_file: Optional[str] = None,
need_text: bool = True,
**kwargs) -> None:
self.dataset_mode = 'VG'
self.need_text = need_text
if label_map_file:
label_map_file = osp.join(data_root, label_map_file)
with open(label_map_file, 'r') as file:
self.label_map = json.load(file)
self.dataset_mode = 'OD'
super().__init__(*args, data_root=data_root, **kwargs)
assert self.return_classes is True
def load_data_list(self) -> List[dict]:
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
data_list = [json.loads(line) for line in f]
out_data_list = []
for data in data_list:
data_info = {}
img_path = osp.join(self.data_prefix['img'], data['filename'])
data_info['img_path'] = img_path
data_info['height'] = data['height']
data_info['width'] = data['width']
if self.dataset_mode == 'OD':
if self.need_text:
data_info['text'] = self.label_map
anno = data.get('detection', {})
instances = [obj for obj in anno.get('instances', [])]
bboxes = [obj['bbox'] for obj in instances]
bbox_labels = [str(obj['label']) for obj in instances]
instances = []
for bbox, label in zip(bboxes, bbox_labels):
instance = {}
x1, y1, x2, y2 = bbox
inter_w = max(0, min(x2, data['width']) - max(x1, 0))
inter_h = max(0, min(y2, data['height']) - max(y1, 0))
if inter_w * inter_h == 0:
continue
if (x2 - x1) < 1 or (y2 - y1) < 1:
continue
instance['ignore_flag'] = 0
instance['bbox'] = bbox
instance['bbox_label'] = int(label)
instances.append(instance)
data_info['instances'] = instances
data_info['dataset_mode'] = self.dataset_mode
out_data_list.append(data_info)
else:
anno = data['grounding']
data_info['text'] = anno['caption']
regions = anno['regions']
instances = []
phrases = {}
for i, region in enumerate(regions):
bbox = region['bbox']
phrase = region['phrase']
tokens_positive = region['tokens_positive']
if not isinstance(bbox[0], list):
bbox = [bbox]
for box in bbox:
instance = {}
x1, y1, x2, y2 = box
inter_w = max(0, min(x2, data['width']) - max(x1, 0))
inter_h = max(0, min(y2, data['height']) - max(y1, 0))
if inter_w * inter_h == 0:
continue
if (x2 - x1) < 1 or (y2 - y1) < 1:
continue
instance['ignore_flag'] = 0
instance['bbox'] = box
instance['bbox_label'] = i
phrases[i] = {
'phrase': phrase,
'tokens_positive': tokens_positive
}
instances.append(instance)
data_info['instances'] = instances
data_info['phrases'] = phrases
data_info['dataset_mode'] = self.dataset_mode
out_data_list.append(data_info)
del data_list
return out_data_list
# Copyright (c) OpenMMLab. All rights reserved.
import csv
import os.path as osp
from collections import defaultdict
from typing import Dict, List, Optional
import numpy as np
from mmengine.fileio import get_local_path, load
from mmengine.utils import is_abs
from mmdet.registry import DATASETS
from .base_det_dataset import BaseDetDataset
@DATASETS.register_module()
class OpenImagesDataset(BaseDetDataset):
"""Open Images dataset for detection.
Args:
ann_file (str): Annotation file path.
label_file (str): File path of the label description file that
maps the classes names in MID format to their short
descriptions.
meta_file (str): File path to get image metas.
hierarchy_file (str): The file path of the class hierarchy.
image_level_ann_file (str): Human-verified image level annotation,
which is used in evaluation.
backend_args (dict, optional): Arguments to instantiate the
corresponding backend. Defaults to None.
"""
METAINFO: dict = dict(dataset_type='oid_v6')
def __init__(self,
label_file: str,
meta_file: str,
hierarchy_file: str,
image_level_ann_file: Optional[str] = None,
**kwargs) -> None:
self.label_file = label_file
self.meta_file = meta_file
self.hierarchy_file = hierarchy_file
self.image_level_ann_file = image_level_ann_file
super().__init__(**kwargs)
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
"""
classes_names, label_id_mapping = self._parse_label_file(
self.label_file)
self._metainfo['classes'] = classes_names
self.label_id_mapping = label_id_mapping
if self.image_level_ann_file is not None:
img_level_anns = self._parse_img_level_ann(
self.image_level_ann_file)
else:
img_level_anns = None
# OpenImagesMetric can get the relation matrix from the dataset meta
relation_matrix = self._get_relation_matrix(self.hierarchy_file)
self._metainfo['RELATION_MATRIX'] = relation_matrix
data_list = []
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
last_img_id = None
instances = []
for i, line in enumerate(reader):
if i == 0:
continue
img_id = line[0]
if last_img_id is None:
last_img_id = img_id
label_id = line[2]
assert label_id in self.label_id_mapping
label = int(self.label_id_mapping[label_id])
bbox = [
float(line[4]), # xmin
float(line[6]), # ymin
float(line[5]), # xmax
float(line[7]) # ymax
]
is_occluded = True if int(line[8]) == 1 else False
is_truncated = True if int(line[9]) == 1 else False
is_group_of = True if int(line[10]) == 1 else False
is_depiction = True if int(line[11]) == 1 else False
is_inside = True if int(line[12]) == 1 else False
instance = dict(
bbox=bbox,
bbox_label=label,
ignore_flag=0,
is_occluded=is_occluded,
is_truncated=is_truncated,
is_group_of=is_group_of,
is_depiction=is_depiction,
is_inside=is_inside)
last_img_path = osp.join(self.data_prefix['img'],
f'{last_img_id}.jpg')
if img_id != last_img_id:
# switch to a new image, record previous image's data.
data_info = dict(
img_path=last_img_path,
img_id=last_img_id,
instances=instances,
)
data_list.append(data_info)
instances = []
instances.append(instance)
last_img_id = img_id
data_list.append(
dict(
img_path=last_img_path,
img_id=last_img_id,
instances=instances,
))
# add image metas to data list
img_metas = load(
self.meta_file, file_format='pkl', backend_args=self.backend_args)
assert len(img_metas) == len(data_list)
for i, meta in enumerate(img_metas):
img_id = data_list[i]['img_id']
assert f'{img_id}.jpg' == osp.split(meta['filename'])[-1]
h, w = meta['ori_shape'][:2]
data_list[i]['height'] = h
data_list[i]['width'] = w
# denormalize bboxes
for j in range(len(data_list[i]['instances'])):
data_list[i]['instances'][j]['bbox'][0] *= w
data_list[i]['instances'][j]['bbox'][2] *= w
data_list[i]['instances'][j]['bbox'][1] *= h
data_list[i]['instances'][j]['bbox'][3] *= h
# add image-level annotation
if img_level_anns is not None:
img_labels = []
confidences = []
img_ann_list = img_level_anns.get(img_id, [])
for ann in img_ann_list:
img_labels.append(int(ann['image_level_label']))
confidences.append(float(ann['confidence']))
data_list[i]['image_level_labels'] = np.array(
img_labels, dtype=np.int64)
data_list[i]['confidences'] = np.array(
confidences, dtype=np.float32)
return data_list
def _parse_label_file(self, label_file: str) -> tuple:
"""Get classes name and index mapping from cls-label-description file.
Args:
label_file (str): File path of the label description file that
maps the classes names in MID format to their short
descriptions.
Returns:
tuple: Class name of OpenImages.
"""
index_list = []
classes_names = []
with get_local_path(
label_file, backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
for line in reader:
# self.cat2label[line[0]] = line[1]
classes_names.append(line[1])
index_list.append(line[0])
index_mapping = {index: i for i, index in enumerate(index_list)}
return classes_names, index_mapping
def _parse_img_level_ann(self,
img_level_ann_file: str) -> Dict[str, List[dict]]:
"""Parse image level annotations from csv style ann_file.
Args:
img_level_ann_file (str): CSV style image level annotation
file path.
Returns:
Dict[str, List[dict]]: Annotations where item of the defaultdict
indicates an image, each of which has (n) dicts.
Keys of dicts are:
- `image_level_label` (int): Label id.
- `confidence` (float): Labels that are human-verified to be
present in an image have confidence = 1 (positive labels).
Labels that are human-verified to be absent from an image
have confidence = 0 (negative labels). Machine-generated
labels have fractional confidences, generally >= 0.5.
The higher the confidence, the smaller the chance for
the label to be a false positive.
"""
item_lists = defaultdict(list)
with get_local_path(
img_level_ann_file,
backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
for i, line in enumerate(reader):
if i == 0:
continue
img_id = line[0]
item_lists[img_id].append(
dict(
image_level_label=int(
self.label_id_mapping[line[2]]),
confidence=float(line[3])))
return item_lists
def _get_relation_matrix(self, hierarchy_file: str) -> np.ndarray:
"""Get the matrix of class hierarchy from the hierarchy file. Hierarchy
for 600 classes can be found at https://storage.googleapis.com/openimag
es/2018_04/bbox_labels_600_hierarchy_visualizer/circle.html.
Args:
hierarchy_file (str): File path to the hierarchy for classes.
Returns:
np.ndarray: The matrix of the corresponding relationship between
the parent class and the child class, of shape
(class_num, class_num).
""" # noqa
hierarchy = load(
hierarchy_file, file_format='json', backend_args=self.backend_args)
class_num = len(self._metainfo['classes'])
relation_matrix = np.eye(class_num, class_num)
relation_matrix = self._convert_hierarchy_tree(hierarchy,
relation_matrix)
return relation_matrix
def _convert_hierarchy_tree(self,
hierarchy_map: dict,
relation_matrix: np.ndarray,
parents: list = [],
get_all_parents: bool = True) -> np.ndarray:
"""Get matrix of the corresponding relationship between the parent
class and the child class.
Args:
hierarchy_map (dict): Including label name and corresponding
subcategory. Keys of dicts are:
- `LabeName` (str): Name of the label.
- `Subcategory` (dict | list): Corresponding subcategory(ies).
relation_matrix (ndarray): The matrix of the corresponding
relationship between the parent class and the child class,
of shape (class_num, class_num).
parents (list): Corresponding parent class.
get_all_parents (bool): Whether get all parent names.
Default: True
Returns:
ndarray: The matrix of the corresponding relationship between
the parent class and the child class, of shape
(class_num, class_num).
"""
if 'Subcategory' in hierarchy_map:
for node in hierarchy_map['Subcategory']:
if 'LabelName' in node:
children_name = node['LabelName']
children_index = self.label_id_mapping[children_name]
children = [children_index]
else:
continue
if len(parents) > 0:
for parent_index in parents:
if get_all_parents:
children.append(parent_index)
relation_matrix[children_index, parent_index] = 1
relation_matrix = self._convert_hierarchy_tree(
node, relation_matrix, parents=children)
return relation_matrix
def _join_prefix(self):
"""Join ``self.data_root`` with annotation path."""
super()._join_prefix()
if not is_abs(self.label_file) and self.label_file:
self.label_file = osp.join(self.data_root, self.label_file)
if not is_abs(self.meta_file) and self.meta_file:
self.meta_file = osp.join(self.data_root, self.meta_file)
if not is_abs(self.hierarchy_file) and self.hierarchy_file:
self.hierarchy_file = osp.join(self.data_root, self.hierarchy_file)
if self.image_level_ann_file and not is_abs(self.image_level_ann_file):
self.image_level_ann_file = osp.join(self.data_root,
self.image_level_ann_file)
@DATASETS.register_module()
class OpenImagesChallengeDataset(OpenImagesDataset):
"""Open Images Challenge dataset for detection.
Args:
ann_file (str): Open Images Challenge box annotation in txt format.
"""
METAINFO: dict = dict(dataset_type='oid_challenge')
def __init__(self, ann_file: str, **kwargs) -> None:
if not ann_file.endswith('txt'):
raise TypeError('The annotation file of Open Images Challenge '
'should be a txt file.')
super().__init__(ann_file=ann_file, **kwargs)
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
"""
classes_names, label_id_mapping = self._parse_label_file(
self.label_file)
self._metainfo['classes'] = classes_names
self.label_id_mapping = label_id_mapping
if self.image_level_ann_file is not None:
img_level_anns = self._parse_img_level_ann(
self.image_level_ann_file)
else:
img_level_anns = None
# OpenImagesMetric can get the relation matrix from the dataset meta
relation_matrix = self._get_relation_matrix(self.hierarchy_file)
self._metainfo['RELATION_MATRIX'] = relation_matrix
data_list = []
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
lines = f.readlines()
i = 0
while i < len(lines):
instances = []
filename = lines[i].rstrip()
i += 2
img_gt_size = int(lines[i])
i += 1
for j in range(img_gt_size):
sp = lines[i + j].split()
instances.append(
dict(
bbox=[
float(sp[1]),
float(sp[2]),
float(sp[3]),
float(sp[4])
],
bbox_label=int(sp[0]) - 1, # labels begin from 1
ignore_flag=0,
is_group_ofs=True if int(sp[5]) == 1 else False))
i += img_gt_size
data_list.append(
dict(
img_path=osp.join(self.data_prefix['img'], filename),
instances=instances,
))
# add image metas to data list
img_metas = load(
self.meta_file, file_format='pkl', backend_args=self.backend_args)
assert len(img_metas) == len(data_list)
for i, meta in enumerate(img_metas):
img_id = osp.split(data_list[i]['img_path'])[-1][:-4]
assert img_id == osp.split(meta['filename'])[-1][:-4]
h, w = meta['ori_shape'][:2]
data_list[i]['height'] = h
data_list[i]['width'] = w
data_list[i]['img_id'] = img_id
# denormalize bboxes
for j in range(len(data_list[i]['instances'])):
data_list[i]['instances'][j]['bbox'][0] *= w
data_list[i]['instances'][j]['bbox'][2] *= w
data_list[i]['instances'][j]['bbox'][1] *= h
data_list[i]['instances'][j]['bbox'][3] *= h
# add image-level annotation
if img_level_anns is not None:
img_labels = []
confidences = []
img_ann_list = img_level_anns.get(img_id, [])
for ann in img_ann_list:
img_labels.append(int(ann['image_level_label']))
confidences.append(float(ann['confidence']))
data_list[i]['image_level_labels'] = np.array(
img_labels, dtype=np.int64)
data_list[i]['confidences'] = np.array(
confidences, dtype=np.float32)
return data_list
def _parse_label_file(self, label_file: str) -> tuple:
"""Get classes name and index mapping from cls-label-description file.
Args:
label_file (str): File path of the label description file that
maps the classes names in MID format to their short
descriptions.
Returns:
tuple: Class name of OpenImages.
"""
label_list = []
id_list = []
index_mapping = {}
with get_local_path(
label_file, backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
for line in reader:
label_name = line[0]
label_id = int(line[2])
label_list.append(line[1])
id_list.append(label_id)
index_mapping[label_name] = label_id - 1
indexes = np.argsort(id_list)
classes_names = []
for index in indexes:
classes_names.append(label_list[index])
return classes_names, index_mapping
def _parse_img_level_ann(self, image_level_ann_file):
"""Parse image level annotations from csv style ann_file.
Args:
image_level_ann_file (str): CSV style image level annotation
file path.
Returns:
defaultdict[list[dict]]: Annotations where item of the defaultdict
indicates an image, each of which has (n) dicts.
Keys of dicts are:
- `image_level_label` (int): of shape 1.
- `confidence` (float): of shape 1.
"""
item_lists = defaultdict(list)
with get_local_path(
image_level_ann_file,
backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
i = -1
for line in reader:
i += 1
if i == 0:
continue
else:
img_id = line[0]
label_id = line[1]
assert label_id in self.label_id_mapping
image_level_label = int(
self.label_id_mapping[label_id])
confidence = float(line[2])
item_lists[img_id].append(
dict(
image_level_label=image_level_label,
confidence=confidence))
return item_lists
def _get_relation_matrix(self, hierarchy_file: str) -> np.ndarray:
"""Get the matrix of class hierarchy from the hierarchy file.
Args:
hierarchy_file (str): File path to the hierarchy for classes.
Returns:
np.ndarray: The matrix of the corresponding
relationship between the parent class and the child class,
of shape (class_num, class_num).
"""
with get_local_path(
hierarchy_file, backend_args=self.backend_args) as local_path:
class_label_tree = np.load(local_path, allow_pickle=True)
return class_label_tree[1:, 1:]
# Copyright (c) OpenMMLab. All rights reserved.
import collections
import os.path as osp
import random
from typing import Dict, List
import mmengine
from mmengine.dataset import BaseDataset
from mmdet.registry import DATASETS
@DATASETS.register_module()
class RefCocoDataset(BaseDataset):
"""RefCOCO dataset.
The `Refcoco` and `Refcoco+` dataset is based on
`ReferItGame: Referring to Objects in Photographs of Natural Scenes
<http://tamaraberg.com/papers/referit.pdf>`_.
The `Refcocog` dataset is based on
`Generation and Comprehension of Unambiguous Object Descriptions
<https://arxiv.org/abs/1511.02283>`_.
Args:
ann_file (str): Annotation file path.
data_root (str): The root directory for ``data_prefix`` and
``ann_file``. Defaults to ''.
data_prefix (str): Prefix for training data.
split_file (str): Split file path.
split (str): Split name. Defaults to 'train'.
text_mode (str): Text mode. Defaults to 'random'.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def __init__(self,
data_root: str,
ann_file: str,
split_file: str,
data_prefix: Dict,
split: str = 'train',
text_mode: str = 'random',
**kwargs):
self.split_file = split_file
self.split = split
assert text_mode in ['original', 'random', 'concat', 'select_first']
self.text_mode = text_mode
super().__init__(
data_root=data_root,
data_prefix=data_prefix,
ann_file=ann_file,
**kwargs,
)
def _join_prefix(self):
if not mmengine.is_abs(self.split_file) and self.split_file:
self.split_file = osp.join(self.data_root, self.split_file)
return super()._join_prefix()
def _init_refs(self):
"""Initialize the refs for RefCOCO."""
anns, imgs = {}, {}
for ann in self.instances['annotations']:
anns[ann['id']] = ann
for img in self.instances['images']:
imgs[img['id']] = img
refs, ref_to_ann = {}, {}
for ref in self.splits:
# ids
ref_id = ref['ref_id']
ann_id = ref['ann_id']
# add mapping related to ref
refs[ref_id] = ref
ref_to_ann[ref_id] = anns[ann_id]
self.refs = refs
self.ref_to_ann = ref_to_ann
def load_data_list(self) -> List[dict]:
"""Load data list."""
self.splits = mmengine.load(self.split_file, file_format='pkl')
self.instances = mmengine.load(self.ann_file, file_format='json')
self._init_refs()
img_prefix = self.data_prefix['img_path']
ref_ids = [
ref['ref_id'] for ref in self.splits if ref['split'] == self.split
]
full_anno = []
for ref_id in ref_ids:
ref = self.refs[ref_id]
ann = self.ref_to_ann[ref_id]
ann.update(ref)
full_anno.append(ann)
image_id_list = []
final_anno = {}
for anno in full_anno:
image_id_list.append(anno['image_id'])
final_anno[anno['ann_id']] = anno
annotations = [value for key, value in final_anno.items()]
coco_train_id = []
image_annot = {}
for i in range(len(self.instances['images'])):
coco_train_id.append(self.instances['images'][i]['id'])
image_annot[self.instances['images'][i]
['id']] = self.instances['images'][i]
images = []
for image_id in list(set(image_id_list)):
images += [image_annot[image_id]]
data_list = []
grounding_dict = collections.defaultdict(list)
for anno in annotations:
image_id = int(anno['image_id'])
grounding_dict[image_id].append(anno)
join_path = mmengine.fileio.get_file_backend(img_prefix).join_path
for image in images:
img_id = image['id']
instances = []
sentences = []
for grounding_anno in grounding_dict[img_id]:
texts = [x['raw'].lower() for x in grounding_anno['sentences']]
# random select one text
if self.text_mode == 'random':
idx = random.randint(0, len(texts) - 1)
text = [texts[idx]]
# concat all texts
elif self.text_mode == 'concat':
text = [''.join(texts)]
# select the first text
elif self.text_mode == 'select_first':
text = [texts[0]]
# use all texts
elif self.text_mode == 'original':
text = texts
else:
raise ValueError(f'Invalid text mode "{self.text_mode}".')
ins = [{
'mask': grounding_anno['segmentation'],
'ignore_flag': 0
}] * len(text)
instances.extend(ins)
sentences.extend(text)
data_info = {
'img_path': join_path(img_prefix, image['file_name']),
'img_id': img_id,
'instances': instances,
'text': sentences
}
data_list.append(data_info)
if len(data_list) == 0:
raise ValueError(f'No sample in split "{self.split}".')
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from collections import defaultdict
from typing import Any, Dict, List
import numpy as np
from mmengine.dataset import BaseDataset
from mmengine.utils import check_file_exist
from mmdet.registry import DATASETS
@DATASETS.register_module()
class ReIDDataset(BaseDataset):
"""Dataset for ReID.
Args:
triplet_sampler (dict, optional): The sampler for hard mining
triplet loss. Defaults to None.
keys: num_ids (int): The number of person ids.
ins_per_id (int): The number of image for each person.
"""
def __init__(self, triplet_sampler: dict = None, *args, **kwargs):
self.triplet_sampler = triplet_sampler
super().__init__(*args, **kwargs)
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ''self.ann_file''.
Returns:
list[dict]: A list of annotation.
"""
assert isinstance(self.ann_file, str)
check_file_exist(self.ann_file)
data_list = []
with open(self.ann_file) as f:
samples = [x.strip().split(' ') for x in f.readlines()]
for filename, gt_label in samples:
info = dict(img_prefix=self.data_prefix)
if self.data_prefix['img_path'] is not None:
info['img_path'] = osp.join(self.data_prefix['img_path'],
filename)
else:
info['img_path'] = filename
info['gt_label'] = np.array(gt_label, dtype=np.int64)
data_list.append(info)
self._parse_ann_info(data_list)
return data_list
def _parse_ann_info(self, data_list: List[dict]):
"""Parse person id annotations."""
index_tmp_dic = defaultdict(list) # pid->[idx1,...,idxN]
self.index_dic = dict() # pid->array([idx1,...,idxN])
for idx, info in enumerate(data_list):
pid = info['gt_label']
index_tmp_dic[int(pid)].append(idx)
for pid, idxs in index_tmp_dic.items():
self.index_dic[pid] = np.asarray(idxs, dtype=np.int64)
self.pids = np.asarray(list(self.index_dic.keys()), dtype=np.int64)
def prepare_data(self, idx: int) -> Any:
"""Get data processed by ''self.pipeline''.
Args:
idx (int): The index of ''data_info''
Returns:
Any: Depends on ''self.pipeline''
"""
data_info = self.get_data_info(idx)
if self.triplet_sampler is not None:
img_info = self.triplet_sampling(data_info['gt_label'],
**self.triplet_sampler)
data_info = copy.deepcopy(img_info) # triplet -> list
else:
data_info = copy.deepcopy(data_info) # no triplet -> dict
return self.pipeline(data_info)
def triplet_sampling(self,
pos_pid,
num_ids: int = 8,
ins_per_id: int = 4) -> Dict:
"""Triplet sampler for hard mining triplet loss. First, for one
pos_pid, random sample ins_per_id images with same person id.
Then, random sample num_ids - 1 images for each negative id.
Finally, random sample ins_per_id images for each negative id.
Args:
pos_pid (ndarray): The person id of the anchor.
num_ids (int): The number of person ids.
ins_per_id (int): The number of images for each person.
Returns:
Dict: Annotation information of num_ids X ins_per_id images.
"""
assert len(self.pids) >= num_ids, \
'The number of person ids in the training set must ' \
'be greater than the number of person ids in the sample.'
pos_idxs = self.index_dic[int(
pos_pid)] # all positive idxs for pos_pid
idxs_list = []
# select positive samplers
idxs_list.extend(pos_idxs[np.random.choice(
pos_idxs.shape[0], ins_per_id, replace=True)])
# select negative ids
neg_pids = np.random.choice(
[i for i, _ in enumerate(self.pids) if i != pos_pid],
num_ids - 1,
replace=False)
# select negative samplers for each negative id
for neg_pid in neg_pids:
neg_idxs = self.index_dic[neg_pid]
idxs_list.extend(neg_idxs[np.random.choice(
neg_idxs.shape[0], ins_per_id, replace=True)])
# return the final triplet batch
triplet_img_infos = []
for idx in idxs_list:
triplet_img_infos.append(copy.deepcopy(self.get_data_info(idx)))
# Collect data_list scatters (list of dict -> dict of list)
out = dict()
for key in triplet_img_infos[0].keys():
out[key] = [_info[key] for _info in triplet_img_infos]
return out
# Copyright (c) OpenMMLab. All rights reserved.
from .batch_sampler import (AspectRatioBatchSampler,
MultiDataAspectRatioBatchSampler,
TrackAspectRatioBatchSampler)
from .class_aware_sampler import ClassAwareSampler
from .custom_sample_size_sampler import CustomSampleSizeSampler
from .multi_data_sampler import MultiDataSampler
from .multi_source_sampler import GroupMultiSourceSampler, MultiSourceSampler
from .track_img_sampler import TrackImgSampler
__all__ = [
'ClassAwareSampler', 'AspectRatioBatchSampler', 'MultiSourceSampler',
'GroupMultiSourceSampler', 'TrackImgSampler',
'TrackAspectRatioBatchSampler', 'MultiDataSampler',
'MultiDataAspectRatioBatchSampler', 'CustomSampleSizeSampler'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
from torch.utils.data import BatchSampler, Sampler
from mmdet.datasets.samplers.track_img_sampler import TrackImgSampler
from mmdet.registry import DATA_SAMPLERS
# TODO: maybe replace with a data_loader wrapper
@DATA_SAMPLERS.register_module()
class AspectRatioBatchSampler(BatchSampler):
"""A sampler wrapper for grouping images with similar aspect ratio (< 1 or.
>= 1) into a same batch.
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``.
"""
def __init__(self,
sampler: Sampler,
batch_size: int,
drop_last: bool = False) -> None:
if not isinstance(sampler, Sampler):
raise TypeError('sampler should be an instance of ``Sampler``, '
f'but got {sampler}')
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError('batch_size should be a positive integer value, '
f'but got batch_size={batch_size}')
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
# two groups for w < h and w >= h
self._aspect_ratio_buckets = [[] for _ in range(2)]
def __iter__(self) -> Sequence[int]:
for idx in self.sampler:
data_info = self.sampler.dataset.get_data_info(idx)
width, height = data_info['width'], data_info['height']
bucket_id = 0 if width < height else 1
bucket = self._aspect_ratio_buckets[bucket_id]
bucket.append(idx)
# yield a batch of indices in the same aspect ratio group
if len(bucket) == self.batch_size:
yield bucket[:]
del bucket[:]
# yield the rest data and reset the bucket
left_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[
1]
self._aspect_ratio_buckets = [[] for _ in range(2)]
while len(left_data) > 0:
if len(left_data) <= self.batch_size:
if not self.drop_last:
yield left_data[:]
left_data = []
else:
yield left_data[:self.batch_size]
left_data = left_data[self.batch_size:]
def __len__(self) -> int:
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
@DATA_SAMPLERS.register_module()
class TrackAspectRatioBatchSampler(AspectRatioBatchSampler):
"""A sampler wrapper for grouping images with similar aspect ratio (< 1 or.
>= 1) into a same batch.
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``.
"""
def __iter__(self) -> Sequence[int]:
for idx in self.sampler:
# hard code to solve TrackImgSampler
if isinstance(self.sampler, TrackImgSampler):
video_idx, _ = idx
else:
video_idx = idx
# video_idx
data_info = self.sampler.dataset.get_data_info(video_idx)
# data_info {video_id, images, video_length}
img_data_info = data_info['images'][0]
width, height = img_data_info['width'], img_data_info['height']
bucket_id = 0 if width < height else 1
bucket = self._aspect_ratio_buckets[bucket_id]
bucket.append(idx)
# yield a batch of indices in the same aspect ratio group
if len(bucket) == self.batch_size:
yield bucket[:]
del bucket[:]
# yield the rest data and reset the bucket
left_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[
1]
self._aspect_ratio_buckets = [[] for _ in range(2)]
while len(left_data) > 0:
if len(left_data) <= self.batch_size:
if not self.drop_last:
yield left_data[:]
left_data = []
else:
yield left_data[:self.batch_size]
left_data = left_data[self.batch_size:]
@DATA_SAMPLERS.register_module()
class MultiDataAspectRatioBatchSampler(BatchSampler):
"""A sampler wrapper for grouping images with similar aspect ratio (< 1 or.
>= 1) into a same batch for multi-source datasets.
Args:
sampler (Sampler): Base sampler.
batch_size (Sequence(int)): Size of mini-batch for multi-source
datasets.
num_datasets(int): Number of multi-source datasets.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``.
"""
def __init__(self,
sampler: Sampler,
batch_size: Sequence[int],
num_datasets: int,
drop_last: bool = True) -> None:
if not isinstance(sampler, Sampler):
raise TypeError('sampler should be an instance of ``Sampler``, '
f'but got {sampler}')
self.sampler = sampler
self.batch_size = batch_size
self.num_datasets = num_datasets
self.drop_last = drop_last
# two groups for w < h and w >= h for each dataset --> 2 * num_datasets
self._buckets = [[] for _ in range(2 * self.num_datasets)]
def __iter__(self) -> Sequence[int]:
for idx in self.sampler:
data_info = self.sampler.dataset.get_data_info(idx)
width, height = data_info['width'], data_info['height']
dataset_source_idx = self.sampler.dataset.get_dataset_source(idx)
aspect_ratio_bucket_id = 0 if width < height else 1
bucket_id = dataset_source_idx * 2 + aspect_ratio_bucket_id
bucket = self._buckets[bucket_id]
bucket.append(idx)
# yield a batch of indices in the same aspect ratio group
if len(bucket) == self.batch_size[dataset_source_idx]:
yield bucket[:]
del bucket[:]
# yield the rest data and reset the bucket
for i in range(self.num_datasets):
left_data = self._buckets[i * 2 + 0] + self._buckets[i * 2 + 1]
while len(left_data) > 0:
if len(left_data) <= self.batch_size[i]:
if not self.drop_last:
yield left_data[:]
left_data = []
else:
yield left_data[:self.batch_size[i]]
left_data = left_data[self.batch_size[i]:]
self._buckets = [[] for _ in range(2 * self.num_datasets)]
def __len__(self) -> int:
sizes = [0 for _ in range(self.num_datasets)]
for idx in self.sampler:
dataset_source_idx = self.sampler.dataset.get_dataset_source(idx)
sizes[dataset_source_idx] += 1
if self.drop_last:
lens = 0
for i in range(self.num_datasets):
lens += sizes[i] // self.batch_size[i]
return lens
else:
lens = 0
for i in range(self.num_datasets):
lens += (sizes[i] + self.batch_size[i] -
1) // self.batch_size[i]
return lens
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Dict, Iterator, Optional, Union
import numpy as np
import torch
from mmengine.dataset import BaseDataset
from mmengine.dist import get_dist_info, sync_random_seed
from torch.utils.data import Sampler
from mmdet.registry import DATA_SAMPLERS
@DATA_SAMPLERS.register_module()
class ClassAwareSampler(Sampler):
r"""Sampler that restricts data loading to the label of the dataset.
A class-aware sampling strategy to effectively tackle the
non-uniform class distribution. The length of the training data is
consistent with source data. Simple improvements based on `Relay
Backpropagation for Effective Learning of Deep Convolutional
Neural Networks <https://arxiv.org/abs/1512.05830>`_
The implementation logic is referred to
https://github.com/Sense-X/TSD/blob/master/mmdet/datasets/samplers/distributed_classaware_sampler.py
Args:
dataset: Dataset used for sampling.
seed (int, optional): random seed used to shuffle the sampler.
This number should be identical across all
processes in the distributed group. Defaults to None.
num_sample_class (int): The number of samples taken from each
per-label list. Defaults to 1.
"""
def __init__(self,
dataset: BaseDataset,
seed: Optional[int] = None,
num_sample_class: int = 1) -> None:
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
self.epoch = 0
# Must be the same across all workers. If None, will use a
# random seed shared among workers
# (require synchronization among all workers)
if seed is None:
seed = sync_random_seed()
self.seed = seed
# The number of samples taken from each per-label list
assert num_sample_class > 0 and isinstance(num_sample_class, int)
self.num_sample_class = num_sample_class
# Get per-label image list from dataset
self.cat_dict = self.get_cat2imgs()
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / world_size))
self.total_size = self.num_samples * self.world_size
# get number of images containing each category
self.num_cat_imgs = [len(x) for x in self.cat_dict.values()]
# filter labels without images
self.valid_cat_inds = [
i for i, length in enumerate(self.num_cat_imgs) if length != 0
]
self.num_classes = len(self.valid_cat_inds)
def get_cat2imgs(self) -> Dict[int, list]:
"""Get a dict with class as key and img_ids as values.
Returns:
dict[int, list]: A dict of per-label image list,
the item of the dict indicates a label index,
corresponds to the image index that contains the label.
"""
classes = self.dataset.metainfo.get('classes', None)
if classes is None:
raise ValueError('dataset metainfo must contain `classes`')
# sort the label index
cat2imgs = {i: [] for i in range(len(classes))}
for i in range(len(self.dataset)):
cat_ids = set(self.dataset.get_cat_ids(i))
for cat in cat_ids:
cat2imgs[cat].append(i)
return cat2imgs
def __iter__(self) -> Iterator[int]:
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch + self.seed)
# initialize label list
label_iter_list = RandomCycleIter(self.valid_cat_inds, generator=g)
# initialize each per-label image list
data_iter_dict = dict()
for i in self.valid_cat_inds:
data_iter_dict[i] = RandomCycleIter(self.cat_dict[i], generator=g)
def gen_cat_img_inds(cls_list, data_dict, num_sample_cls):
"""Traverse the categories and extract `num_sample_cls` image
indexes of the corresponding categories one by one."""
id_indices = []
for _ in range(len(cls_list)):
cls_idx = next(cls_list)
for _ in range(num_sample_cls):
id = next(data_dict[cls_idx])
id_indices.append(id)
return id_indices
# deterministically shuffle based on epoch
num_bins = int(
math.ceil(self.total_size * 1.0 / self.num_classes /
self.num_sample_class))
indices = []
for i in range(num_bins):
indices += gen_cat_img_inds(label_iter_list, data_iter_dict,
self.num_sample_class)
# fix extra samples to make it evenly divisible
if len(indices) >= self.total_size:
indices = indices[:self.total_size]
else:
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
offset = self.num_samples * self.rank
indices = indices[offset:offset + self.num_samples]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self) -> int:
"""The number of samples in this rank."""
return self.num_samples
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all replicas use a different
random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
class RandomCycleIter:
"""Shuffle the list and do it again after the list have traversed.
The implementation logic is referred to
https://github.com/wutong16/DistributionBalancedLoss/blob/master/mllt/datasets/loader/sampler.py
Example:
>>> label_list = [0, 1, 2, 4, 5]
>>> g = torch.Generator()
>>> g.manual_seed(0)
>>> label_iter_list = RandomCycleIter(label_list, generator=g)
>>> index = next(label_iter_list)
Args:
data (list or ndarray): The data that needs to be shuffled.
generator: An torch.Generator object, which is used in setting the seed
for generating random numbers.
""" # noqa: W605
def __init__(self,
data: Union[list, np.ndarray],
generator: torch.Generator = None) -> None:
self.data = data
self.length = len(data)
self.index = torch.randperm(self.length, generator=generator).numpy()
self.i = 0
self.generator = generator
def __iter__(self) -> Iterator:
return self
def __len__(self) -> int:
return len(self.data)
def __next__(self):
if self.i == self.length:
self.index = torch.randperm(
self.length, generator=self.generator).numpy()
self.i = 0
idx = self.data[self.index[self.i]]
self.i += 1
return idx
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Iterator, Optional, Sequence, Sized
import torch
from mmengine.dist import get_dist_info, sync_random_seed
from torch.utils.data import Sampler
from mmdet.registry import DATA_SAMPLERS
from .class_aware_sampler import RandomCycleIter
@DATA_SAMPLERS.register_module()
class CustomSampleSizeSampler(Sampler):
def __init__(self,
dataset: Sized,
dataset_size: Sequence[int],
ratio_mode: bool = False,
seed: Optional[int] = None,
round_up: bool = True) -> None:
assert len(dataset.datasets) == len(dataset_size)
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
if seed is None:
seed = sync_random_seed()
self.seed = seed
self.epoch = 0
self.round_up = round_up
total_size = 0
total_size_fake = 0
self.dataset_index = []
self.dataset_cycle_iter = []
new_dataset_size = []
for dataset, size in zip(dataset.datasets, dataset_size):
self.dataset_index.append(
list(range(total_size_fake,
len(dataset) + total_size_fake)))
total_size_fake += len(dataset)
if size == -1:
total_size += len(dataset)
self.dataset_cycle_iter.append(None)
new_dataset_size.append(-1)
else:
if ratio_mode:
size = int(size * len(dataset))
assert size <= len(
dataset
), f'dataset size {size} is larger than ' \
f'dataset length {len(dataset)}'
total_size += size
new_dataset_size.append(size)
g = torch.Generator()
g.manual_seed(self.seed)
self.dataset_cycle_iter.append(
RandomCycleIter(self.dataset_index[-1], generator=g))
self.dataset_size = new_dataset_size
if self.round_up:
self.num_samples = math.ceil(total_size / world_size)
self.total_size = self.num_samples * self.world_size
else:
self.num_samples = math.ceil((total_size - rank) / world_size)
self.total_size = total_size
def __iter__(self) -> Iterator[int]:
"""Iterate the indices."""
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
out_index = []
for data_size, data_index, cycle_iter in zip(self.dataset_size,
self.dataset_index,
self.dataset_cycle_iter):
if data_size == -1:
out_index += data_index
else:
index = [next(cycle_iter) for _ in range(data_size)]
out_index += index
index = torch.randperm(len(out_index), generator=g).numpy().tolist()
indices = [out_index[i] for i in index]
if self.round_up:
indices = (
indices *
int(self.total_size / len(indices) + 1))[:self.total_size]
indices = indices[self.rank:self.total_size:self.world_size]
return iter(indices)
def __len__(self) -> int:
"""The number of samples in this rank."""
return self.num_samples
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all replicas use a different
random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Iterator, Optional, Sequence, Sized
import torch
from mmengine.dist import get_dist_info, sync_random_seed
from mmengine.registry import DATA_SAMPLERS
from torch.utils.data import Sampler
@DATA_SAMPLERS.register_module()
class MultiDataSampler(Sampler):
"""The default data sampler for both distributed and non-distributed
environment.
It has several differences from the PyTorch ``DistributedSampler`` as
below:
1. This sampler supports non-distributed environment.
2. The round up behaviors are a little different.
- If ``round_up=True``, this sampler will add extra samples to make the
number of samples is evenly divisible by the world size. And
this behavior is the same as the ``DistributedSampler`` with
``drop_last=False``.
- If ``round_up=False``, this sampler won't remove or add any samples
while the ``DistributedSampler`` with ``drop_last=True`` will remove
tail samples.
Args:
dataset (Sized): The dataset.
dataset_ratio (Sequence(int)) The ratios of different datasets.
seed (int, optional): Random seed used to shuffle the sampler if
:attr:`shuffle=True`. This number should be identical across all
processes in the distributed group. Defaults to None.
round_up (bool): Whether to add extra samples to make the number of
samples evenly divisible by the world size. Defaults to True.
"""
def __init__(self,
dataset: Sized,
dataset_ratio: Sequence[int],
seed: Optional[int] = None,
round_up: bool = True) -> None:
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
self.dataset_ratio = dataset_ratio
if seed is None:
seed = sync_random_seed()
self.seed = seed
self.epoch = 0
self.round_up = round_up
if self.round_up:
self.num_samples = math.ceil(len(self.dataset) / world_size)
self.total_size = self.num_samples * self.world_size
else:
self.num_samples = math.ceil(
(len(self.dataset) - rank) / world_size)
self.total_size = len(self.dataset)
self.sizes = [len(dataset) for dataset in self.dataset.datasets]
dataset_weight = [
torch.ones(s) * max(self.sizes) / s * r / sum(self.dataset_ratio)
for i, (r, s) in enumerate(zip(self.dataset_ratio, self.sizes))
]
self.weights = torch.cat(dataset_weight)
def __iter__(self) -> Iterator[int]:
"""Iterate the indices."""
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.multinomial(
self.weights, len(self.weights), generator=g,
replacement=True).tolist()
# add extra samples to make it evenly divisible
if self.round_up:
indices = (
indices *
int(self.total_size / len(indices) + 1))[:self.total_size]
# subsample
indices = indices[self.rank:self.total_size:self.world_size]
return iter(indices)
def __len__(self) -> int:
"""The number of samples in this rank."""
return self.num_samples
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all replicas use a different
random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from typing import Iterator, List, Optional, Sized, Union
import numpy as np
import torch
from mmengine.dataset import BaseDataset
from mmengine.dist import get_dist_info, sync_random_seed
from torch.utils.data import Sampler
from mmdet.registry import DATA_SAMPLERS
@DATA_SAMPLERS.register_module()
class MultiSourceSampler(Sampler):
r"""Multi-Source Infinite Sampler.
According to the sampling ratio, sample data from different
datasets to form batches.
Args:
dataset (Sized): The dataset.
batch_size (int): Size of mini-batch.
source_ratio (list[int | float]): The sampling ratio of different
source datasets in a mini-batch.
shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
seed (int, optional): Random seed. If None, set a random seed.
Defaults to None.
Examples:
>>> dataset_type = 'ConcatDataset'
>>> sub_dataset_type = 'CocoDataset'
>>> data_root = 'data/coco/'
>>> sup_ann = '../coco_semi_annos/instances_train2017.1@10.json'
>>> unsup_ann = '../coco_semi_annos/' \
>>> 'instances_train2017.1@10-unlabeled.json'
>>> dataset = dict(type=dataset_type,
>>> datasets=[
>>> dict(
>>> type=sub_dataset_type,
>>> data_root=data_root,
>>> ann_file=sup_ann,
>>> data_prefix=dict(img='train2017/'),
>>> filter_cfg=dict(filter_empty_gt=True, min_size=32),
>>> pipeline=sup_pipeline),
>>> dict(
>>> type=sub_dataset_type,
>>> data_root=data_root,
>>> ann_file=unsup_ann,
>>> data_prefix=dict(img='train2017/'),
>>> filter_cfg=dict(filter_empty_gt=True, min_size=32),
>>> pipeline=unsup_pipeline),
>>> ])
>>> train_dataloader = dict(
>>> batch_size=5,
>>> num_workers=5,
>>> persistent_workers=True,
>>> sampler=dict(type='MultiSourceSampler',
>>> batch_size=5, source_ratio=[1, 4]),
>>> batch_sampler=None,
>>> dataset=dataset)
"""
def __init__(self,
dataset: Sized,
batch_size: int,
source_ratio: List[Union[int, float]],
shuffle: bool = True,
seed: Optional[int] = None) -> None:
assert hasattr(dataset, 'cumulative_sizes'),\
f'The dataset must be ConcatDataset, but get {dataset}'
assert isinstance(batch_size, int) and batch_size > 0, \
'batch_size must be a positive integer value, ' \
f'but got batch_size={batch_size}'
assert isinstance(source_ratio, list), \
f'source_ratio must be a list, but got source_ratio={source_ratio}'
assert len(source_ratio) == len(dataset.cumulative_sizes), \
'The length of source_ratio must be equal to ' \
f'the number of datasets, but got source_ratio={source_ratio}'
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
self.cumulative_sizes = [0] + dataset.cumulative_sizes
self.batch_size = batch_size
self.source_ratio = source_ratio
self.num_per_source = [
int(batch_size * sr / sum(source_ratio)) for sr in source_ratio
]
self.num_per_source[0] = batch_size - sum(self.num_per_source[1:])
assert sum(self.num_per_source) == batch_size, \
'The sum of num_per_source must be equal to ' \
f'batch_size, but get {self.num_per_source}'
self.seed = sync_random_seed() if seed is None else seed
self.shuffle = shuffle
self.source2inds = {
source: self._indices_of_rank(len(ds))
for source, ds in enumerate(dataset.datasets)
}
def _infinite_indices(self, sample_size: int) -> Iterator[int]:
"""Infinitely yield a sequence of indices."""
g = torch.Generator()
g.manual_seed(self.seed)
while True:
if self.shuffle:
yield from torch.randperm(sample_size, generator=g).tolist()
else:
yield from torch.arange(sample_size).tolist()
def _indices_of_rank(self, sample_size: int) -> Iterator[int]:
"""Slice the infinite indices by rank."""
yield from itertools.islice(
self._infinite_indices(sample_size), self.rank, None,
self.world_size)
def __iter__(self) -> Iterator[int]:
batch_buffer = []
while True:
for source, num in enumerate(self.num_per_source):
batch_buffer_per_source = []
for idx in self.source2inds[source]:
idx += self.cumulative_sizes[source]
batch_buffer_per_source.append(idx)
if len(batch_buffer_per_source) == num:
batch_buffer += batch_buffer_per_source
break
yield from batch_buffer
batch_buffer = []
def __len__(self) -> int:
return len(self.dataset)
def set_epoch(self, epoch: int) -> None:
"""Not supported in `epoch-based runner."""
pass
@DATA_SAMPLERS.register_module()
class GroupMultiSourceSampler(MultiSourceSampler):
r"""Group Multi-Source Infinite Sampler.
According to the sampling ratio, sample data from different
datasets but the same group to form batches.
Args:
dataset (Sized): The dataset.
batch_size (int): Size of mini-batch.
source_ratio (list[int | float]): The sampling ratio of different
source datasets in a mini-batch.
shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
seed (int, optional): Random seed. If None, set a random seed.
Defaults to None.
"""
def __init__(self,
dataset: BaseDataset,
batch_size: int,
source_ratio: List[Union[int, float]],
shuffle: bool = True,
seed: Optional[int] = None) -> None:
super().__init__(
dataset=dataset,
batch_size=batch_size,
source_ratio=source_ratio,
shuffle=shuffle,
seed=seed)
self._get_source_group_info()
self.group_source2inds = [{
source:
self._indices_of_rank(self.group2size_per_source[source][group])
for source in range(len(dataset.datasets))
} for group in range(len(self.group_ratio))]
def _get_source_group_info(self) -> None:
self.group2size_per_source = [{0: 0, 1: 0}, {0: 0, 1: 0}]
self.group2inds_per_source = [{0: [], 1: []}, {0: [], 1: []}]
for source, dataset in enumerate(self.dataset.datasets):
for idx in range(len(dataset)):
data_info = dataset.get_data_info(idx)
width, height = data_info['width'], data_info['height']
group = 0 if width < height else 1
self.group2size_per_source[source][group] += 1
self.group2inds_per_source[source][group].append(idx)
self.group_sizes = np.zeros(2, dtype=np.int64)
for group2size in self.group2size_per_source:
for group, size in group2size.items():
self.group_sizes[group] += size
self.group_ratio = self.group_sizes / sum(self.group_sizes)
def __iter__(self) -> Iterator[int]:
batch_buffer = []
while True:
group = np.random.choice(
list(range(len(self.group_ratio))), p=self.group_ratio)
for source, num in enumerate(self.num_per_source):
batch_buffer_per_source = []
for idx in self.group_source2inds[group][source]:
idx = self.group2inds_per_source[source][group][
idx] + self.cumulative_sizes[source]
batch_buffer_per_source.append(idx)
if len(batch_buffer_per_source) == num:
batch_buffer += batch_buffer_per_source
break
yield from batch_buffer
batch_buffer = []
# Copyright (c) OpenMMLab. All rights reserved.
import math
import random
from typing import Iterator, Optional, Sized
import numpy as np
from mmengine.dataset import ClassBalancedDataset, ConcatDataset
from mmengine.dist import get_dist_info, sync_random_seed
from torch.utils.data import Sampler
from mmdet.registry import DATA_SAMPLERS
from ..base_video_dataset import BaseVideoDataset
@DATA_SAMPLERS.register_module()
class TrackImgSampler(Sampler):
"""Sampler that providing image-level sampling outputs for video datasets
in tracking tasks. It could be both used in both distributed and
non-distributed environment.
If using the default sampler in pytorch, the subsequent data receiver will
get one video, which is not desired in some cases:
(Take a non-distributed environment as an example)
1. In test mode, we want only one image is fed into the data pipeline. This
is in consideration of memory usage since feeding the whole video commonly
requires a large amount of memory (>=20G on MOTChallenge17 dataset), which
is not available in some machines.
2. In training mode, we may want to make sure all the images in one video
are randomly sampled once in one epoch and this can not be guaranteed in
the default sampler in pytorch.
Args:
dataset (Sized): Dataset used for sampling.
seed (int, optional): random seed used to shuffle the sampler. This
number should be identical across all processes in the distributed
group. Defaults to None.
"""
def __init__(
self,
dataset: Sized,
seed: Optional[int] = None,
) -> None:
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.epoch = 0
if seed is None:
self.seed = sync_random_seed()
else:
self.seed = seed
self.dataset = dataset
self.indices = []
# Hard code here to handle different dataset wrapper
if isinstance(self.dataset, ConcatDataset):
cat_datasets = self.dataset.datasets
assert isinstance(
cat_datasets[0], BaseVideoDataset
), f'expected BaseVideoDataset, but got {type(cat_datasets[0])}'
self.test_mode = cat_datasets[0].test_mode
assert not self.test_mode, "'ConcatDataset' should not exist in "
'test mode'
for dataset in cat_datasets:
num_videos = len(dataset)
for video_ind in range(num_videos):
self.indices.extend([
(video_ind, frame_ind) for frame_ind in range(
dataset.get_len_per_video(video_ind))
])
elif isinstance(self.dataset, ClassBalancedDataset):
ori_dataset = self.dataset.dataset
assert isinstance(
ori_dataset, BaseVideoDataset
), f'expected BaseVideoDataset, but got {type(ori_dataset)}'
self.test_mode = ori_dataset.test_mode
assert not self.test_mode, "'ClassBalancedDataset' should not "
'exist in test mode'
video_indices = self.dataset.repeat_indices
for index in video_indices:
self.indices.extend([(index, frame_ind) for frame_ind in range(
ori_dataset.get_len_per_video(index))])
else:
assert isinstance(
self.dataset, BaseVideoDataset
), 'TrackImgSampler is only supported in BaseVideoDataset or '
'dataset wrapper: ClassBalancedDataset and ConcatDataset, but '
f'got {type(self.dataset)} '
self.test_mode = self.dataset.test_mode
num_videos = len(self.dataset)
if self.test_mode:
# in test mode, the images belong to the same video must be put
# on the same device.
if num_videos < self.world_size:
raise ValueError(f'only {num_videos} videos loaded,'
f'but {self.world_size} gpus were given.')
chunks = np.array_split(
list(range(num_videos)), self.world_size)
for videos_inds in chunks:
indices_chunk = []
for video_ind in videos_inds:
indices_chunk.extend([
(video_ind, frame_ind) for frame_ind in range(
self.dataset.get_len_per_video(video_ind))
])
self.indices.append(indices_chunk)
else:
for video_ind in range(num_videos):
self.indices.extend([
(video_ind, frame_ind) for frame_ind in range(
self.dataset.get_len_per_video(video_ind))
])
if self.test_mode:
self.num_samples = len(self.indices[self.rank])
self.total_size = sum(
[len(index_list) for index_list in self.indices])
else:
self.num_samples = int(
math.ceil(len(self.indices) * 1.0 / self.world_size))
self.total_size = self.num_samples * self.world_size
def __iter__(self) -> Iterator:
if self.test_mode:
# in test mode, the order of frames can not be shuffled.
indices = self.indices[self.rank]
else:
# deterministically shuffle based on epoch
rng = random.Random(self.epoch + self.seed)
indices = rng.sample(self.indices, len(self.indices))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.world_size]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment