Commit b12850fe authored by dengjb's avatar dengjb
Browse files

update codes

parent 6515fb96
Pipeline #1046 failed with stages
in 0 seconds
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.registry import DATASETS
from .coco import CocoDataset
@DATASETS.register_module()
class iSAIDDataset(CocoDataset):
"""Dataset for iSAID instance segmentation.
iSAID: A Large-scale Dataset for Instance Segmentation
in Aerial Images.
For more detail, please refer to "projects/iSAID/README.md"
"""
METAINFO = dict(
classes=('background', 'ship', 'store_tank', 'baseball_diamond',
'tennis_court', 'basketball_court', 'Ground_Track_Field',
'Bridge', 'Large_Vehicle', 'Small_Vehicle', 'Helicopter',
'Swimming_pool', 'Roundabout', 'Soccer_ball_field', 'plane',
'Harbor'),
palette=[(0, 0, 0), (0, 0, 63), (0, 63, 63), (0, 63, 0), (0, 63, 127),
(0, 63, 191), (0, 63, 255), (0, 127, 63), (0, 127, 127),
(0, 0, 127), (0, 0, 191), (0, 0, 255), (0, 191, 127),
(0, 127, 191), (0, 127, 255), (0, 100, 155)])
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings
from typing import List
from mmengine.fileio import get_local_path
from mmdet.registry import DATASETS
from .coco import CocoDataset
@DATASETS.register_module()
class LVISV05Dataset(CocoDataset):
"""LVIS v0.5 dataset for detection."""
METAINFO = {
'classes':
('acorn', 'aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock',
'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet',
'antenna', 'apple', 'apple_juice', 'applesauce', 'apricot', 'apron',
'aquarium', 'armband', 'armchair', 'armoire', 'armor', 'artichoke',
'trash_can', 'ashtray', 'asparagus', 'atomizer', 'avocado', 'award',
'awning', 'ax', 'baby_buggy', 'basketball_backboard', 'backpack',
'handbag', 'suitcase', 'bagel', 'bagpipe', 'baguet', 'bait', 'ball',
'ballet_skirt', 'balloon', 'bamboo', 'banana', 'Band_Aid', 'bandage',
'bandanna', 'banjo', 'banner', 'barbell', 'barge', 'barrel',
'barrette', 'barrow', 'baseball_base', 'baseball', 'baseball_bat',
'baseball_cap', 'baseball_glove', 'basket', 'basketball_hoop',
'basketball', 'bass_horn', 'bat_(animal)', 'bath_mat', 'bath_towel',
'bathrobe', 'bathtub', 'batter_(food)', 'battery', 'beachball',
'bead', 'beaker', 'bean_curd', 'beanbag', 'beanie', 'bear', 'bed',
'bedspread', 'cow', 'beef_(food)', 'beeper', 'beer_bottle',
'beer_can', 'beetle', 'bell', 'bell_pepper', 'belt', 'belt_buckle',
'bench', 'beret', 'bib', 'Bible', 'bicycle', 'visor', 'binder',
'binoculars', 'bird', 'birdfeeder', 'birdbath', 'birdcage',
'birdhouse', 'birthday_cake', 'birthday_card', 'biscuit_(bread)',
'pirate_flag', 'black_sheep', 'blackboard', 'blanket', 'blazer',
'blender', 'blimp', 'blinker', 'blueberry', 'boar', 'gameboard',
'boat', 'bobbin', 'bobby_pin', 'boiled_egg', 'bolo_tie', 'deadbolt',
'bolt', 'bonnet', 'book', 'book_bag', 'bookcase', 'booklet',
'bookmark', 'boom_microphone', 'boot', 'bottle', 'bottle_opener',
'bouquet', 'bow_(weapon)', 'bow_(decorative_ribbons)', 'bow-tie',
'bowl', 'pipe_bowl', 'bowler_hat', 'bowling_ball', 'bowling_pin',
'boxing_glove', 'suspenders', 'bracelet', 'brass_plaque', 'brassiere',
'bread-bin', 'breechcloth', 'bridal_gown', 'briefcase',
'bristle_brush', 'broccoli', 'broach', 'broom', 'brownie',
'brussels_sprouts', 'bubble_gum', 'bucket', 'horse_buggy', 'bull',
'bulldog', 'bulldozer', 'bullet_train', 'bulletin_board',
'bulletproof_vest', 'bullhorn', 'corned_beef', 'bun', 'bunk_bed',
'buoy', 'burrito', 'bus_(vehicle)', 'business_card', 'butcher_knife',
'butter', 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car',
'cabinet', 'locker', 'cake', 'calculator', 'calendar', 'calf',
'camcorder', 'camel', 'camera', 'camera_lens', 'camper_(vehicle)',
'can', 'can_opener', 'candelabrum', 'candle', 'candle_holder',
'candy_bar', 'candy_cane', 'walking_cane', 'canister', 'cannon',
'canoe', 'cantaloup', 'canteen', 'cap_(headwear)', 'bottle_cap',
'cape', 'cappuccino', 'car_(automobile)', 'railcar_(part_of_a_train)',
'elevator_car', 'car_battery', 'identity_card', 'card', 'cardigan',
'cargo_ship', 'carnation', 'horse_carriage', 'carrot', 'tote_bag',
'cart', 'carton', 'cash_register', 'casserole', 'cassette', 'cast',
'cat', 'cauliflower', 'caviar', 'cayenne_(spice)', 'CD_player',
'celery', 'cellular_telephone', 'chain_mail', 'chair',
'chaise_longue', 'champagne', 'chandelier', 'chap', 'checkbook',
'checkerboard', 'cherry', 'chessboard',
'chest_of_drawers_(furniture)', 'chicken_(animal)', 'chicken_wire',
'chickpea', 'Chihuahua', 'chili_(vegetable)', 'chime', 'chinaware',
'crisp_(potato_chip)', 'poker_chip', 'chocolate_bar',
'chocolate_cake', 'chocolate_milk', 'chocolate_mousse', 'choker',
'chopping_board', 'chopstick', 'Christmas_tree', 'slide', 'cider',
'cigar_box', 'cigarette', 'cigarette_case', 'cistern', 'clarinet',
'clasp', 'cleansing_agent', 'clementine', 'clip', 'clipboard',
'clock', 'clock_tower', 'clothes_hamper', 'clothespin', 'clutch_bag',
'coaster', 'coat', 'coat_hanger', 'coatrack', 'cock', 'coconut',
'coffee_filter', 'coffee_maker', 'coffee_table', 'coffeepot', 'coil',
'coin', 'colander', 'coleslaw', 'coloring_material',
'combination_lock', 'pacifier', 'comic_book', 'computer_keyboard',
'concrete_mixer', 'cone', 'control', 'convertible_(automobile)',
'sofa_bed', 'cookie', 'cookie_jar', 'cooking_utensil',
'cooler_(for_food)', 'cork_(bottle_plug)', 'corkboard', 'corkscrew',
'edible_corn', 'cornbread', 'cornet', 'cornice', 'cornmeal', 'corset',
'romaine_lettuce', 'costume', 'cougar', 'coverall', 'cowbell',
'cowboy_hat', 'crab_(animal)', 'cracker', 'crape', 'crate', 'crayon',
'cream_pitcher', 'credit_card', 'crescent_roll', 'crib', 'crock_pot',
'crossbar', 'crouton', 'crow', 'crown', 'crucifix', 'cruise_ship',
'police_cruiser', 'crumb', 'crutch', 'cub_(animal)', 'cube',
'cucumber', 'cufflink', 'cup', 'trophy_cup', 'cupcake', 'hair_curler',
'curling_iron', 'curtain', 'cushion', 'custard', 'cutting_tool',
'cylinder', 'cymbal', 'dachshund', 'dagger', 'dartboard',
'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk',
'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table',
'tux', 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher',
'dishwasher_detergent', 'diskette', 'dispenser', 'Dixie_cup', 'dog',
'dog_collar', 'doll', 'dollar', 'dolphin', 'domestic_ass', 'eye_mask',
'doorbell', 'doorknob', 'doormat', 'doughnut', 'dove', 'dragonfly',
'drawer', 'underdrawers', 'dress', 'dress_hat', 'dress_suit',
'dresser', 'drill', 'drinking_fountain', 'drone', 'dropper',
'drum_(musical_instrument)', 'drumstick', 'duck', 'duckling',
'duct_tape', 'duffel_bag', 'dumbbell', 'dumpster', 'dustpan',
'Dutch_oven', 'eagle', 'earphone', 'earplug', 'earring', 'easel',
'eclair', 'eel', 'egg', 'egg_roll', 'egg_yolk', 'eggbeater',
'eggplant', 'electric_chair', 'refrigerator', 'elephant', 'elk',
'envelope', 'eraser', 'escargot', 'eyepatch', 'falcon', 'fan',
'faucet', 'fedora', 'ferret', 'Ferris_wheel', 'ferry', 'fig_(fruit)',
'fighter_jet', 'figurine', 'file_cabinet', 'file_(tool)',
'fire_alarm', 'fire_engine', 'fire_extinguisher', 'fire_hose',
'fireplace', 'fireplug', 'fish', 'fish_(food)', 'fishbowl',
'fishing_boat', 'fishing_rod', 'flag', 'flagpole', 'flamingo',
'flannel', 'flash', 'flashlight', 'fleece', 'flip-flop_(sandal)',
'flipper_(footwear)', 'flower_arrangement', 'flute_glass', 'foal',
'folding_chair', 'food_processor', 'football_(American)',
'football_helmet', 'footstool', 'fork', 'forklift', 'freight_car',
'French_toast', 'freshener', 'frisbee', 'frog', 'fruit_juice',
'fruit_salad', 'frying_pan', 'fudge', 'funnel', 'futon', 'gag',
'garbage', 'garbage_truck', 'garden_hose', 'gargle', 'gargoyle',
'garlic', 'gasmask', 'gazelle', 'gelatin', 'gemstone', 'giant_panda',
'gift_wrap', 'ginger', 'giraffe', 'cincture',
'glass_(drink_container)', 'globe', 'glove', 'goat', 'goggles',
'goldfish', 'golf_club', 'golfcart', 'gondola_(boat)', 'goose',
'gorilla', 'gourd', 'surgical_gown', 'grape', 'grasshopper', 'grater',
'gravestone', 'gravy_boat', 'green_bean', 'green_onion', 'griddle',
'grillroom', 'grinder_(tool)', 'grits', 'grizzly', 'grocery_bag',
'guacamole', 'guitar', 'gull', 'gun', 'hair_spray', 'hairbrush',
'hairnet', 'hairpin', 'ham', 'hamburger', 'hammer', 'hammock',
'hamper', 'hamster', 'hair_dryer', 'hand_glass', 'hand_towel',
'handcart', 'handcuff', 'handkerchief', 'handle', 'handsaw',
'hardback_book', 'harmonium', 'hat', 'hatbox', 'hatch', 'veil',
'headband', 'headboard', 'headlight', 'headscarf', 'headset',
'headstall_(for_horses)', 'hearing_aid', 'heart', 'heater',
'helicopter', 'helmet', 'heron', 'highchair', 'hinge', 'hippopotamus',
'hockey_stick', 'hog', 'home_plate_(baseball)', 'honey', 'fume_hood',
'hook', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce',
'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear',
'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate',
'ice_tea', 'igniter', 'incense', 'inhaler', 'iPod',
'iron_(for_clothing)', 'ironing_board', 'jacket', 'jam', 'jean',
'jeep', 'jelly_bean', 'jersey', 'jet_plane', 'jewelry', 'joystick',
'jumpsuit', 'kayak', 'keg', 'kennel', 'kettle', 'key', 'keycard',
'kilt', 'kimono', 'kitchen_sink', 'kitchen_table', 'kite', 'kitten',
'kiwi_fruit', 'knee_pad', 'knife', 'knight_(chess_piece)',
'knitting_needle', 'knob', 'knocker_(on_a_door)', 'koala', 'lab_coat',
'ladder', 'ladle', 'ladybug', 'lamb_(animal)', 'lamb-chop', 'lamp',
'lamppost', 'lampshade', 'lantern', 'lanyard', 'laptop_computer',
'lasagna', 'latch', 'lawn_mower', 'leather', 'legging_(clothing)',
'Lego', 'lemon', 'lemonade', 'lettuce', 'license_plate', 'life_buoy',
'life_jacket', 'lightbulb', 'lightning_rod', 'lime', 'limousine',
'linen_paper', 'lion', 'lip_balm', 'lipstick', 'liquor', 'lizard',
'Loafer_(type_of_shoe)', 'log', 'lollipop', 'lotion',
'speaker_(stereo_equipment)', 'loveseat', 'machine_gun', 'magazine',
'magnet', 'mail_slot', 'mailbox_(at_home)', 'mallet', 'mammoth',
'mandarin_orange', 'manger', 'manhole', 'map', 'marker', 'martini',
'mascot', 'mashed_potato', 'masher', 'mask', 'mast',
'mat_(gym_equipment)', 'matchbox', 'mattress', 'measuring_cup',
'measuring_stick', 'meatball', 'medicine', 'melon', 'microphone',
'microscope', 'microwave_oven', 'milestone', 'milk', 'minivan',
'mint_candy', 'mirror', 'mitten', 'mixer_(kitchen_tool)', 'money',
'monitor_(computer_equipment) computer_monitor', 'monkey', 'motor',
'motor_scooter', 'motor_vehicle', 'motorboat', 'motorcycle',
'mound_(baseball)', 'mouse_(animal_rodent)',
'mouse_(computer_equipment)', 'mousepad', 'muffin', 'mug', 'mushroom',
'music_stool', 'musical_instrument', 'nailfile', 'nameplate',
'napkin', 'neckerchief', 'necklace', 'necktie', 'needle', 'nest',
'newsstand', 'nightshirt', 'nosebag_(for_animals)',
'noseband_(for_animals)', 'notebook', 'notepad', 'nut', 'nutcracker',
'oar', 'octopus_(food)', 'octopus_(animal)', 'oil_lamp', 'olive_oil',
'omelet', 'onion', 'orange_(fruit)', 'orange_juice', 'oregano',
'ostrich', 'ottoman', 'overalls_(clothing)', 'owl', 'packet',
'inkpad', 'pad', 'paddle', 'padlock', 'paintbox', 'paintbrush',
'painting', 'pajamas', 'palette', 'pan_(for_cooking)',
'pan_(metal_container)', 'pancake', 'pantyhose', 'papaya',
'paperclip', 'paper_plate', 'paper_towel', 'paperback_book',
'paperweight', 'parachute', 'parakeet', 'parasail_(sports)',
'parchment', 'parka', 'parking_meter', 'parrot',
'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport',
'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter',
'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'pegboard',
'pelican', 'pen', 'pencil', 'pencil_box', 'pencil_sharpener',
'pendulum', 'penguin', 'pennant', 'penny_(coin)', 'pepper',
'pepper_mill', 'perfume', 'persimmon', 'baby', 'pet', 'petfood',
'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano',
'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow',
'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball',
'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)',
'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat',
'plate', 'platter', 'playing_card', 'playpen', 'pliers',
'plow_(farm_equipment)', 'pocket_watch', 'pocketknife',
'poker_(fire_stirring_tool)', 'pole', 'police_van', 'polo_shirt',
'poncho', 'pony', 'pool_table', 'pop_(soda)', 'portrait',
'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot',
'potato', 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn',
'printer', 'projectile_(weapon)', 'projector', 'propeller', 'prune',
'pudding', 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin', 'puncher',
'puppet', 'puppy', 'quesadilla', 'quiche', 'quilt', 'rabbit',
'race_car', 'racket', 'radar', 'radiator', 'radio_receiver', 'radish',
'raft', 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry', 'rat',
'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt',
'recliner', 'record_player', 'red_cabbage', 'reflector',
'remote_control', 'rhinoceros', 'rib_(food)', 'rifle', 'ring',
'river_boat', 'road_map', 'robe', 'rocking_chair', 'roller_skate',
'Rollerblade', 'rolling_pin', 'root_beer',
'router_(computer_equipment)', 'rubber_band', 'runner_(carpet)',
'plastic_bag', 'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag',
'safety_pin', 'sail', 'salad', 'salad_plate', 'salami',
'salmon_(fish)', 'salmon_(food)', 'salsa', 'saltshaker',
'sandal_(type_of_shoe)', 'sandwich', 'satchel', 'saucepan', 'saucer',
'sausage', 'sawhorse', 'saxophone', 'scale_(measuring_instrument)',
'scarecrow', 'scarf', 'school_bus', 'scissors', 'scoreboard',
'scrambled_eggs', 'scraper', 'scratcher', 'screwdriver',
'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane',
'seashell', 'seedling', 'serving_dish', 'sewing_machine', 'shaker',
'shampoo', 'shark', 'sharpener', 'Sharpie', 'shaver_(electric)',
'shaving_cream', 'shawl', 'shears', 'sheep', 'shepherd_dog',
'sherbert', 'shield', 'shirt', 'shoe', 'shopping_bag',
'shopping_cart', 'short_pants', 'shot_glass', 'shoulder_bag',
'shovel', 'shower_head', 'shower_curtain', 'shredder_(for_paper)',
'sieve', 'signboard', 'silo', 'sink', 'skateboard', 'skewer', 'ski',
'ski_boot', 'ski_parka', 'ski_pole', 'skirt', 'sled', 'sleeping_bag',
'sling_(bandage)', 'slipper_(footwear)', 'smoothie', 'snake',
'snowboard', 'snowman', 'snowmobile', 'soap', 'soccer_ball', 'sock',
'soda_fountain', 'carbonated_water', 'sofa', 'softball',
'solar_array', 'sombrero', 'soup', 'soup_bowl', 'soupspoon',
'sour_cream', 'soya_milk', 'space_shuttle', 'sparkler_(fireworks)',
'spatula', 'spear', 'spectacles', 'spice_rack', 'spider', 'sponge',
'spoon', 'sportswear', 'spotlight', 'squirrel',
'stapler_(stapling_machine)', 'starfish', 'statue_(sculpture)',
'steak_(food)', 'steak_knife', 'steamer_(kitchen_appliance)',
'steering_wheel', 'stencil', 'stepladder', 'step_stool',
'stereo_(sound_system)', 'stew', 'stirrer', 'stirrup',
'stockings_(leg_wear)', 'stool', 'stop_sign', 'brake_light', 'stove',
'strainer', 'strap', 'straw_(for_drinking)', 'strawberry',
'street_sign', 'streetlight', 'string_cheese', 'stylus', 'subwoofer',
'sugar_bowl', 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower',
'sunglasses', 'sunhat', 'sunscreen', 'surfboard', 'sushi', 'mop',
'sweat_pants', 'sweatband', 'sweater', 'sweatshirt', 'sweet_potato',
'swimsuit', 'sword', 'syringe', 'Tabasco_sauce', 'table-tennis_table',
'table', 'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag',
'taillight', 'tambourine', 'army_tank', 'tank_(storage_vessel)',
'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure',
'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup',
'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth',
'telephone_pole', 'telephoto_lens', 'television_camera',
'television_set', 'tennis_ball', 'tennis_racket', 'tequila',
'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread',
'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer',
'tinfoil', 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster',
'toaster_oven', 'toilet', 'toilet_tissue', 'tomato', 'tongs',
'toolbox', 'toothbrush', 'toothpaste', 'toothpick', 'cover',
'tortilla', 'tow_truck', 'towel', 'towel_rack', 'toy',
'tractor_(farm_equipment)', 'traffic_light', 'dirt_bike',
'trailer_truck', 'train_(railroad_vehicle)', 'trampoline', 'tray',
'tree_house', 'trench_coat', 'triangle_(musical_instrument)',
'tricycle', 'tripod', 'trousers', 'truck', 'truffle_(chocolate)',
'trunk', 'vat', 'turban', 'turkey_(bird)', 'turkey_(food)', 'turnip',
'turtle', 'turtleneck_(clothing)', 'typewriter', 'umbrella',
'underwear', 'unicycle', 'urinal', 'urn', 'vacuum_cleaner', 'valve',
'vase', 'vending_machine', 'vent', 'videotape', 'vinegar', 'violin',
'vodka', 'volleyball', 'vulture', 'waffle', 'waffle_iron', 'wagon',
'wagon_wheel', 'walking_stick', 'wall_clock', 'wall_socket', 'wallet',
'walrus', 'wardrobe', 'wasabi', 'automatic_washer', 'watch',
'water_bottle', 'water_cooler', 'water_faucet', 'water_filter',
'water_heater', 'water_jug', 'water_gun', 'water_scooter',
'water_ski', 'water_tower', 'watering_can', 'watermelon',
'weathervane', 'webcam', 'wedding_cake', 'wedding_ring', 'wet_suit',
'wheel', 'wheelchair', 'whipped_cream', 'whiskey', 'whistle', 'wick',
'wig', 'wind_chime', 'windmill', 'window_box_(for_plants)',
'windshield_wiper', 'windsock', 'wine_bottle', 'wine_bucket',
'wineglass', 'wing_chair', 'blinder_(for_horses)', 'wok', 'wolf',
'wooden_spoon', 'wreath', 'wrench', 'wristband', 'wristlet', 'yacht',
'yak', 'yogurt', 'yoke_(animal_equipment)', 'zebra', 'zucchini'),
'palette':
None
}
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
""" # noqa: E501
try:
import lvis
if getattr(lvis, '__version__', '0') >= '10.5.3':
warnings.warn(
'mmlvis is deprecated, please install official lvis-api by "pip install git+https://github.com/lvis-dataset/lvis-api.git"', # noqa: E501
UserWarning)
from lvis import LVIS
except ImportError:
raise ImportError(
'Package lvis is not installed. Please run "pip install git+https://github.com/lvis-dataset/lvis-api.git".' # noqa: E501
)
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
self.lvis = LVIS(local_path)
self.cat_ids = self.lvis.get_cat_ids()
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.cat_img_map = copy.deepcopy(self.lvis.cat_img_map)
img_ids = self.lvis.get_img_ids()
data_list = []
total_ann_ids = []
for img_id in img_ids:
raw_img_info = self.lvis.load_imgs([img_id])[0]
raw_img_info['img_id'] = img_id
if raw_img_info['file_name'].startswith('COCO'):
# Convert form the COCO 2014 file naming convention of
# COCO_[train/val/test]2014_000000000000.jpg to the 2017
# naming convention of 000000000000.jpg
# (LVIS v1 will fix this naming issue)
raw_img_info['file_name'] = raw_img_info['file_name'][-16:]
ann_ids = self.lvis.get_ann_ids(img_ids=[img_id])
raw_ann_info = self.lvis.load_anns(ann_ids)
total_ann_ids.extend(ann_ids)
parsed_data_info = self.parse_data_info({
'raw_ann_info':
raw_ann_info,
'raw_img_info':
raw_img_info
})
data_list.append(parsed_data_info)
if self.ANN_ID_UNIQUE:
assert len(set(total_ann_ids)) == len(
total_ann_ids
), f"Annotation ids in '{self.ann_file}' are not unique!"
del self.lvis
return data_list
LVISDataset = LVISV05Dataset
DATASETS.register_module(name='LVISDataset', module=LVISDataset)
@DATASETS.register_module()
class LVISV1Dataset(LVISDataset):
"""LVIS v1 dataset for detection."""
METAINFO = {
'classes':
('aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock',
'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet',
'antenna', 'apple', 'applesauce', 'apricot', 'apron', 'aquarium',
'arctic_(type_of_shoe)', 'armband', 'armchair', 'armoire', 'armor',
'artichoke', 'trash_can', 'ashtray', 'asparagus', 'atomizer',
'avocado', 'award', 'awning', 'ax', 'baboon', 'baby_buggy',
'basketball_backboard', 'backpack', 'handbag', 'suitcase', 'bagel',
'bagpipe', 'baguet', 'bait', 'ball', 'ballet_skirt', 'balloon',
'bamboo', 'banana', 'Band_Aid', 'bandage', 'bandanna', 'banjo',
'banner', 'barbell', 'barge', 'barrel', 'barrette', 'barrow',
'baseball_base', 'baseball', 'baseball_bat', 'baseball_cap',
'baseball_glove', 'basket', 'basketball', 'bass_horn', 'bat_(animal)',
'bath_mat', 'bath_towel', 'bathrobe', 'bathtub', 'batter_(food)',
'battery', 'beachball', 'bead', 'bean_curd', 'beanbag', 'beanie',
'bear', 'bed', 'bedpan', 'bedspread', 'cow', 'beef_(food)', 'beeper',
'beer_bottle', 'beer_can', 'beetle', 'bell', 'bell_pepper', 'belt',
'belt_buckle', 'bench', 'beret', 'bib', 'Bible', 'bicycle', 'visor',
'billboard', 'binder', 'binoculars', 'bird', 'birdfeeder', 'birdbath',
'birdcage', 'birdhouse', 'birthday_cake', 'birthday_card',
'pirate_flag', 'black_sheep', 'blackberry', 'blackboard', 'blanket',
'blazer', 'blender', 'blimp', 'blinker', 'blouse', 'blueberry',
'gameboard', 'boat', 'bob', 'bobbin', 'bobby_pin', 'boiled_egg',
'bolo_tie', 'deadbolt', 'bolt', 'bonnet', 'book', 'bookcase',
'booklet', 'bookmark', 'boom_microphone', 'boot', 'bottle',
'bottle_opener', 'bouquet', 'bow_(weapon)',
'bow_(decorative_ribbons)', 'bow-tie', 'bowl', 'pipe_bowl',
'bowler_hat', 'bowling_ball', 'box', 'boxing_glove', 'suspenders',
'bracelet', 'brass_plaque', 'brassiere', 'bread-bin', 'bread',
'breechcloth', 'bridal_gown', 'briefcase', 'broccoli', 'broach',
'broom', 'brownie', 'brussels_sprouts', 'bubble_gum', 'bucket',
'horse_buggy', 'bull', 'bulldog', 'bulldozer', 'bullet_train',
'bulletin_board', 'bulletproof_vest', 'bullhorn', 'bun', 'bunk_bed',
'buoy', 'burrito', 'bus_(vehicle)', 'business_card', 'butter',
'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car', 'cabinet',
'locker', 'cake', 'calculator', 'calendar', 'calf', 'camcorder',
'camel', 'camera', 'camera_lens', 'camper_(vehicle)', 'can',
'can_opener', 'candle', 'candle_holder', 'candy_bar', 'candy_cane',
'walking_cane', 'canister', 'canoe', 'cantaloup', 'canteen',
'cap_(headwear)', 'bottle_cap', 'cape', 'cappuccino',
'car_(automobile)', 'railcar_(part_of_a_train)', 'elevator_car',
'car_battery', 'identity_card', 'card', 'cardigan', 'cargo_ship',
'carnation', 'horse_carriage', 'carrot', 'tote_bag', 'cart', 'carton',
'cash_register', 'casserole', 'cassette', 'cast', 'cat',
'cauliflower', 'cayenne_(spice)', 'CD_player', 'celery',
'cellular_telephone', 'chain_mail', 'chair', 'chaise_longue',
'chalice', 'chandelier', 'chap', 'checkbook', 'checkerboard',
'cherry', 'chessboard', 'chicken_(animal)', 'chickpea',
'chili_(vegetable)', 'chime', 'chinaware', 'crisp_(potato_chip)',
'poker_chip', 'chocolate_bar', 'chocolate_cake', 'chocolate_milk',
'chocolate_mousse', 'choker', 'chopping_board', 'chopstick',
'Christmas_tree', 'slide', 'cider', 'cigar_box', 'cigarette',
'cigarette_case', 'cistern', 'clarinet', 'clasp', 'cleansing_agent',
'cleat_(for_securing_rope)', 'clementine', 'clip', 'clipboard',
'clippers_(for_plants)', 'cloak', 'clock', 'clock_tower',
'clothes_hamper', 'clothespin', 'clutch_bag', 'coaster', 'coat',
'coat_hanger', 'coatrack', 'cock', 'cockroach', 'cocoa_(beverage)',
'coconut', 'coffee_maker', 'coffee_table', 'coffeepot', 'coil',
'coin', 'colander', 'coleslaw', 'coloring_material',
'combination_lock', 'pacifier', 'comic_book', 'compass',
'computer_keyboard', 'condiment', 'cone', 'control',
'convertible_(automobile)', 'sofa_bed', 'cooker', 'cookie',
'cooking_utensil', 'cooler_(for_food)', 'cork_(bottle_plug)',
'corkboard', 'corkscrew', 'edible_corn', 'cornbread', 'cornet',
'cornice', 'cornmeal', 'corset', 'costume', 'cougar', 'coverall',
'cowbell', 'cowboy_hat', 'crab_(animal)', 'crabmeat', 'cracker',
'crape', 'crate', 'crayon', 'cream_pitcher', 'crescent_roll', 'crib',
'crock_pot', 'crossbar', 'crouton', 'crow', 'crowbar', 'crown',
'crucifix', 'cruise_ship', 'police_cruiser', 'crumb', 'crutch',
'cub_(animal)', 'cube', 'cucumber', 'cufflink', 'cup', 'trophy_cup',
'cupboard', 'cupcake', 'hair_curler', 'curling_iron', 'curtain',
'cushion', 'cylinder', 'cymbal', 'dagger', 'dalmatian', 'dartboard',
'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk',
'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table',
'tux', 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher',
'dishwasher_detergent', 'dispenser', 'diving_board', 'Dixie_cup',
'dog', 'dog_collar', 'doll', 'dollar', 'dollhouse', 'dolphin',
'domestic_ass', 'doorknob', 'doormat', 'doughnut', 'dove',
'dragonfly', 'drawer', 'underdrawers', 'dress', 'dress_hat',
'dress_suit', 'dresser', 'drill', 'drone', 'dropper',
'drum_(musical_instrument)', 'drumstick', 'duck', 'duckling',
'duct_tape', 'duffel_bag', 'dumbbell', 'dumpster', 'dustpan', 'eagle',
'earphone', 'earplug', 'earring', 'easel', 'eclair', 'eel', 'egg',
'egg_roll', 'egg_yolk', 'eggbeater', 'eggplant', 'electric_chair',
'refrigerator', 'elephant', 'elk', 'envelope', 'eraser', 'escargot',
'eyepatch', 'falcon', 'fan', 'faucet', 'fedora', 'ferret',
'Ferris_wheel', 'ferry', 'fig_(fruit)', 'fighter_jet', 'figurine',
'file_cabinet', 'file_(tool)', 'fire_alarm', 'fire_engine',
'fire_extinguisher', 'fire_hose', 'fireplace', 'fireplug',
'first-aid_kit', 'fish', 'fish_(food)', 'fishbowl', 'fishing_rod',
'flag', 'flagpole', 'flamingo', 'flannel', 'flap', 'flash',
'flashlight', 'fleece', 'flip-flop_(sandal)', 'flipper_(footwear)',
'flower_arrangement', 'flute_glass', 'foal', 'folding_chair',
'food_processor', 'football_(American)', 'football_helmet',
'footstool', 'fork', 'forklift', 'freight_car', 'French_toast',
'freshener', 'frisbee', 'frog', 'fruit_juice', 'frying_pan', 'fudge',
'funnel', 'futon', 'gag', 'garbage', 'garbage_truck', 'garden_hose',
'gargle', 'gargoyle', 'garlic', 'gasmask', 'gazelle', 'gelatin',
'gemstone', 'generator', 'giant_panda', 'gift_wrap', 'ginger',
'giraffe', 'cincture', 'glass_(drink_container)', 'globe', 'glove',
'goat', 'goggles', 'goldfish', 'golf_club', 'golfcart',
'gondola_(boat)', 'goose', 'gorilla', 'gourd', 'grape', 'grater',
'gravestone', 'gravy_boat', 'green_bean', 'green_onion', 'griddle',
'grill', 'grits', 'grizzly', 'grocery_bag', 'guitar', 'gull', 'gun',
'hairbrush', 'hairnet', 'hairpin', 'halter_top', 'ham', 'hamburger',
'hammer', 'hammock', 'hamper', 'hamster', 'hair_dryer', 'hand_glass',
'hand_towel', 'handcart', 'handcuff', 'handkerchief', 'handle',
'handsaw', 'hardback_book', 'harmonium', 'hat', 'hatbox', 'veil',
'headband', 'headboard', 'headlight', 'headscarf', 'headset',
'headstall_(for_horses)', 'heart', 'heater', 'helicopter', 'helmet',
'heron', 'highchair', 'hinge', 'hippopotamus', 'hockey_stick', 'hog',
'home_plate_(baseball)', 'honey', 'fume_hood', 'hook', 'hookah',
'hornet', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce',
'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear',
'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate',
'igniter', 'inhaler', 'iPod', 'iron_(for_clothing)', 'ironing_board',
'jacket', 'jam', 'jar', 'jean', 'jeep', 'jelly_bean', 'jersey',
'jet_plane', 'jewel', 'jewelry', 'joystick', 'jumpsuit', 'kayak',
'keg', 'kennel', 'kettle', 'key', 'keycard', 'kilt', 'kimono',
'kitchen_sink', 'kitchen_table', 'kite', 'kitten', 'kiwi_fruit',
'knee_pad', 'knife', 'knitting_needle', 'knob', 'knocker_(on_a_door)',
'koala', 'lab_coat', 'ladder', 'ladle', 'ladybug', 'lamb_(animal)',
'lamb-chop', 'lamp', 'lamppost', 'lampshade', 'lantern', 'lanyard',
'laptop_computer', 'lasagna', 'latch', 'lawn_mower', 'leather',
'legging_(clothing)', 'Lego', 'legume', 'lemon', 'lemonade',
'lettuce', 'license_plate', 'life_buoy', 'life_jacket', 'lightbulb',
'lightning_rod', 'lime', 'limousine', 'lion', 'lip_balm', 'liquor',
'lizard', 'log', 'lollipop', 'speaker_(stereo_equipment)', 'loveseat',
'machine_gun', 'magazine', 'magnet', 'mail_slot', 'mailbox_(at_home)',
'mallard', 'mallet', 'mammoth', 'manatee', 'mandarin_orange',
'manger', 'manhole', 'map', 'marker', 'martini', 'mascot',
'mashed_potato', 'masher', 'mask', 'mast', 'mat_(gym_equipment)',
'matchbox', 'mattress', 'measuring_cup', 'measuring_stick',
'meatball', 'medicine', 'melon', 'microphone', 'microscope',
'microwave_oven', 'milestone', 'milk', 'milk_can', 'milkshake',
'minivan', 'mint_candy', 'mirror', 'mitten', 'mixer_(kitchen_tool)',
'money', 'monitor_(computer_equipment) computer_monitor', 'monkey',
'motor', 'motor_scooter', 'motor_vehicle', 'motorcycle',
'mound_(baseball)', 'mouse_(computer_equipment)', 'mousepad',
'muffin', 'mug', 'mushroom', 'music_stool', 'musical_instrument',
'nailfile', 'napkin', 'neckerchief', 'necklace', 'necktie', 'needle',
'nest', 'newspaper', 'newsstand', 'nightshirt',
'nosebag_(for_animals)', 'noseband_(for_animals)', 'notebook',
'notepad', 'nut', 'nutcracker', 'oar', 'octopus_(food)',
'octopus_(animal)', 'oil_lamp', 'olive_oil', 'omelet', 'onion',
'orange_(fruit)', 'orange_juice', 'ostrich', 'ottoman', 'oven',
'overalls_(clothing)', 'owl', 'packet', 'inkpad', 'pad', 'paddle',
'padlock', 'paintbrush', 'painting', 'pajamas', 'palette',
'pan_(for_cooking)', 'pan_(metal_container)', 'pancake', 'pantyhose',
'papaya', 'paper_plate', 'paper_towel', 'paperback_book',
'paperweight', 'parachute', 'parakeet', 'parasail_(sports)',
'parasol', 'parchment', 'parka', 'parking_meter', 'parrot',
'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport',
'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter',
'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'wooden_leg',
'pegboard', 'pelican', 'pen', 'pencil', 'pencil_box',
'pencil_sharpener', 'pendulum', 'penguin', 'pennant', 'penny_(coin)',
'pepper', 'pepper_mill', 'perfume', 'persimmon', 'person', 'pet',
'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano',
'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow',
'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball',
'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)',
'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat',
'plate', 'platter', 'playpen', 'pliers', 'plow_(farm_equipment)',
'plume', 'pocket_watch', 'pocketknife', 'poker_(fire_stirring_tool)',
'pole', 'polo_shirt', 'poncho', 'pony', 'pool_table', 'pop_(soda)',
'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot',
'potato', 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn',
'pretzel', 'printer', 'projectile_(weapon)', 'projector', 'propeller',
'prune', 'pudding', 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin',
'puncher', 'puppet', 'puppy', 'quesadilla', 'quiche', 'quilt',
'rabbit', 'race_car', 'racket', 'radar', 'radiator', 'radio_receiver',
'radish', 'raft', 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry',
'rat', 'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt',
'recliner', 'record_player', 'reflector', 'remote_control',
'rhinoceros', 'rib_(food)', 'rifle', 'ring', 'river_boat', 'road_map',
'robe', 'rocking_chair', 'rodent', 'roller_skate', 'Rollerblade',
'rolling_pin', 'root_beer', 'router_(computer_equipment)',
'rubber_band', 'runner_(carpet)', 'plastic_bag',
'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag', 'safety_pin',
'sail', 'salad', 'salad_plate', 'salami', 'salmon_(fish)',
'salmon_(food)', 'salsa', 'saltshaker', 'sandal_(type_of_shoe)',
'sandwich', 'satchel', 'saucepan', 'saucer', 'sausage', 'sawhorse',
'saxophone', 'scale_(measuring_instrument)', 'scarecrow', 'scarf',
'school_bus', 'scissors', 'scoreboard', 'scraper', 'screwdriver',
'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane',
'seashell', 'sewing_machine', 'shaker', 'shampoo', 'shark',
'sharpener', 'Sharpie', 'shaver_(electric)', 'shaving_cream', 'shawl',
'shears', 'sheep', 'shepherd_dog', 'sherbert', 'shield', 'shirt',
'shoe', 'shopping_bag', 'shopping_cart', 'short_pants', 'shot_glass',
'shoulder_bag', 'shovel', 'shower_head', 'shower_cap',
'shower_curtain', 'shredder_(for_paper)', 'signboard', 'silo', 'sink',
'skateboard', 'skewer', 'ski', 'ski_boot', 'ski_parka', 'ski_pole',
'skirt', 'skullcap', 'sled', 'sleeping_bag', 'sling_(bandage)',
'slipper_(footwear)', 'smoothie', 'snake', 'snowboard', 'snowman',
'snowmobile', 'soap', 'soccer_ball', 'sock', 'sofa', 'softball',
'solar_array', 'sombrero', 'soup', 'soup_bowl', 'soupspoon',
'sour_cream', 'soya_milk', 'space_shuttle', 'sparkler_(fireworks)',
'spatula', 'spear', 'spectacles', 'spice_rack', 'spider', 'crawfish',
'sponge', 'spoon', 'sportswear', 'spotlight', 'squid_(food)',
'squirrel', 'stagecoach', 'stapler_(stapling_machine)', 'starfish',
'statue_(sculpture)', 'steak_(food)', 'steak_knife', 'steering_wheel',
'stepladder', 'step_stool', 'stereo_(sound_system)', 'stew',
'stirrer', 'stirrup', 'stool', 'stop_sign', 'brake_light', 'stove',
'strainer', 'strap', 'straw_(for_drinking)', 'strawberry',
'street_sign', 'streetlight', 'string_cheese', 'stylus', 'subwoofer',
'sugar_bowl', 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower',
'sunglasses', 'sunhat', 'surfboard', 'sushi', 'mop', 'sweat_pants',
'sweatband', 'sweater', 'sweatshirt', 'sweet_potato', 'swimsuit',
'sword', 'syringe', 'Tabasco_sauce', 'table-tennis_table', 'table',
'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag', 'taillight',
'tambourine', 'army_tank', 'tank_(storage_vessel)',
'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure',
'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup',
'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth',
'telephone_pole', 'telephoto_lens', 'television_camera',
'television_set', 'tennis_ball', 'tennis_racket', 'tequila',
'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread',
'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer',
'tinfoil', 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster',
'toaster_oven', 'toilet', 'toilet_tissue', 'tomato', 'tongs',
'toolbox', 'toothbrush', 'toothpaste', 'toothpick', 'cover',
'tortilla', 'tow_truck', 'towel', 'towel_rack', 'toy',
'tractor_(farm_equipment)', 'traffic_light', 'dirt_bike',
'trailer_truck', 'train_(railroad_vehicle)', 'trampoline', 'tray',
'trench_coat', 'triangle_(musical_instrument)', 'tricycle', 'tripod',
'trousers', 'truck', 'truffle_(chocolate)', 'trunk', 'vat', 'turban',
'turkey_(food)', 'turnip', 'turtle', 'turtleneck_(clothing)',
'typewriter', 'umbrella', 'underwear', 'unicycle', 'urinal', 'urn',
'vacuum_cleaner', 'vase', 'vending_machine', 'vent', 'vest',
'videotape', 'vinegar', 'violin', 'vodka', 'volleyball', 'vulture',
'waffle', 'waffle_iron', 'wagon', 'wagon_wheel', 'walking_stick',
'wall_clock', 'wall_socket', 'wallet', 'walrus', 'wardrobe',
'washbasin', 'automatic_washer', 'watch', 'water_bottle',
'water_cooler', 'water_faucet', 'water_heater', 'water_jug',
'water_gun', 'water_scooter', 'water_ski', 'water_tower',
'watering_can', 'watermelon', 'weathervane', 'webcam', 'wedding_cake',
'wedding_ring', 'wet_suit', 'wheel', 'wheelchair', 'whipped_cream',
'whistle', 'wig', 'wind_chime', 'windmill', 'window_box_(for_plants)',
'windshield_wiper', 'windsock', 'wine_bottle', 'wine_bucket',
'wineglass', 'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon',
'wreath', 'wrench', 'wristband', 'wristlet', 'yacht', 'yogurt',
'yoke_(animal_equipment)', 'zebra', 'zucchini'),
'palette':
None
}
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
""" # noqa: E501
try:
import lvis
if getattr(lvis, '__version__', '0') >= '10.5.3':
warnings.warn(
'mmlvis is deprecated, please install official lvis-api by "pip install git+https://github.com/lvis-dataset/lvis-api.git"', # noqa: E501
UserWarning)
from lvis import LVIS
except ImportError:
raise ImportError(
'Package lvis is not installed. Please run "pip install git+https://github.com/lvis-dataset/lvis-api.git".' # noqa: E501
)
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
self.lvis = LVIS(local_path)
self.cat_ids = self.lvis.get_cat_ids()
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.cat_img_map = copy.deepcopy(self.lvis.cat_img_map)
img_ids = self.lvis.get_img_ids()
data_list = []
total_ann_ids = []
for img_id in img_ids:
raw_img_info = self.lvis.load_imgs([img_id])[0]
raw_img_info['img_id'] = img_id
# coco_url is used in LVISv1 instead of file_name
# e.g. http://images.cocodataset.org/train2017/000000391895.jpg
# train/val split in specified in url
raw_img_info['file_name'] = raw_img_info['coco_url'].replace(
'http://images.cocodataset.org/', '')
ann_ids = self.lvis.get_ann_ids(img_ids=[img_id])
raw_ann_info = self.lvis.load_anns(ann_ids)
total_ann_ids.extend(ann_ids)
parsed_data_info = self.parse_data_info({
'raw_ann_info':
raw_ann_info,
'raw_img_info':
raw_img_info
})
data_list.append(parsed_data_info)
if self.ANN_ID_UNIQUE:
assert len(set(total_ann_ids)) == len(
total_ann_ids
), f"Annotation ids in '{self.ann_file}' are not unique!"
del self.lvis
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List
from mmengine.fileio import get_local_path
from mmdet.datasets import BaseDetDataset
from mmdet.registry import DATASETS
from .api_wrappers import COCO
@DATASETS.register_module()
class MDETRStyleRefCocoDataset(BaseDetDataset):
"""RefCOCO dataset.
Only support evaluation now.
"""
def load_data_list(self) -> List[dict]:
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
coco = COCO(local_path)
img_ids = coco.get_img_ids()
data_infos = []
for img_id in img_ids:
raw_img_info = coco.load_imgs([img_id])[0]
ann_ids = coco.get_ann_ids(img_ids=[img_id])
raw_ann_info = coco.load_anns(ann_ids)
data_info = {}
img_path = osp.join(self.data_prefix['img'],
raw_img_info['file_name'])
data_info['img_path'] = img_path
data_info['img_id'] = img_id
data_info['height'] = raw_img_info['height']
data_info['width'] = raw_img_info['width']
data_info['dataset_mode'] = raw_img_info['dataset_name']
data_info['text'] = raw_img_info['caption']
data_info['custom_entities'] = False
data_info['tokens_positive'] = -1
instances = []
for i, ann in enumerate(raw_ann_info):
instance = {}
x1, y1, w, h = ann['bbox']
bbox = [x1, y1, x1 + w, y1 + h]
instance['bbox'] = bbox
instance['bbox_label'] = ann['category_id']
instance['ignore_flag'] = 0
instances.append(instance)
data_info['instances'] = instances
data_infos.append(data_info)
return data_infos
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List, Union
from mmdet.registry import DATASETS
from .base_video_dataset import BaseVideoDataset
@DATASETS.register_module()
class MOTChallengeDataset(BaseVideoDataset):
"""Dataset for MOTChallenge.
Args:
visibility_thr (float, optional): The minimum visibility
for the objects during training. Default to -1.
"""
METAINFO = {
'classes':
('pedestrian', 'person_on_vehicle', 'car', 'bicycle', 'motorbike',
'non_mot_vehicle', 'static_person', 'distractor', 'occluder',
'occluder_on_ground', 'occluder_full', 'reflection', 'crowd')
}
def __init__(self, visibility_thr: float = -1, *args, **kwargs):
self.visibility_thr = visibility_thr
super().__init__(*args, **kwargs)
def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]:
"""Parse raw annotation to target format. The difference between this
function and the one in ``BaseVideoDataset`` is that the parsing here
adds ``visibility`` and ``mot_conf``.
Args:
raw_data_info (dict): Raw data information load from ``ann_file``
Returns:
Union[dict, List[dict]]: Parsed annotation.
"""
img_info = raw_data_info['raw_img_info']
ann_info = raw_data_info['raw_ann_info']
data_info = {}
data_info.update(img_info)
if self.data_prefix.get('img_path', None) is not None:
img_path = osp.join(self.data_prefix['img_path'],
img_info['file_name'])
else:
img_path = img_info['file_name']
data_info['img_path'] = img_path
instances = []
for i, ann in enumerate(ann_info):
instance = {}
if (not self.test_mode) and (ann['visibility'] <
self.visibility_thr):
continue
if ann.get('ignore', False):
continue
x1, y1, w, h = ann['bbox']
inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
if inter_w * inter_h == 0:
continue
if ann['area'] <= 0 or w < 1 or h < 1:
continue
if ann['category_id'] not in self.cat_ids:
continue
bbox = [x1, y1, x1 + w, y1 + h]
if ann.get('iscrowd', False):
instance['ignore_flag'] = 1
else:
instance['ignore_flag'] = 0
instance['bbox'] = bbox
instance['bbox_label'] = self.cat2label[ann['category_id']]
instance['instance_id'] = ann['instance_id']
instance['category_id'] = ann['category_id']
instance['mot_conf'] = ann['mot_conf']
instance['visibility'] = ann['visibility']
if len(instance) > 0:
instances.append(instance)
if not self.test_mode:
assert len(instances) > 0, f'No valid instances found in ' \
f'image {data_info["img_path"]}!'
data_info['instances'] = instances
return data_info
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from typing import List
from mmengine.fileio import get_local_path
from mmdet.registry import DATASETS
from .api_wrappers import COCO
from .coco import CocoDataset
# images exist in annotations but not in image folder.
objv2_ignore_list = [
osp.join('patch16', 'objects365_v2_00908726.jpg'),
osp.join('patch6', 'objects365_v1_00320532.jpg'),
osp.join('patch6', 'objects365_v1_00320534.jpg'),
]
@DATASETS.register_module()
class Objects365V1Dataset(CocoDataset):
"""Objects365 v1 dataset for detection."""
METAINFO = {
'classes':
('person', 'sneakers', 'chair', 'hat', 'lamp', 'bottle',
'cabinet/shelf', 'cup', 'car', 'glasses', 'picture/frame', 'desk',
'handbag', 'street lights', 'book', 'plate', 'helmet',
'leather shoes', 'pillow', 'glove', 'potted plant', 'bracelet',
'flower', 'tv', 'storage box', 'vase', 'bench', 'wine glass', 'boots',
'bowl', 'dining table', 'umbrella', 'boat', 'flag', 'speaker',
'trash bin/can', 'stool', 'backpack', 'couch', 'belt', 'carpet',
'basket', 'towel/napkin', 'slippers', 'barrel/bucket', 'coffee table',
'suv', 'toy', 'tie', 'bed', 'traffic light', 'pen/pencil',
'microphone', 'sandals', 'canned', 'necklace', 'mirror', 'faucet',
'bicycle', 'bread', 'high heels', 'ring', 'van', 'watch', 'sink',
'horse', 'fish', 'apple', 'camera', 'candle', 'teddy bear', 'cake',
'motorcycle', 'wild bird', 'laptop', 'knife', 'traffic sign',
'cell phone', 'paddle', 'truck', 'cow', 'power outlet', 'clock',
'drum', 'fork', 'bus', 'hanger', 'nightstand', 'pot/pan', 'sheep',
'guitar', 'traffic cone', 'tea pot', 'keyboard', 'tripod', 'hockey',
'fan', 'dog', 'spoon', 'blackboard/whiteboard', 'balloon',
'air conditioner', 'cymbal', 'mouse', 'telephone', 'pickup truck',
'orange', 'banana', 'airplane', 'luggage', 'skis', 'soccer',
'trolley', 'oven', 'remote', 'baseball glove', 'paper towel',
'refrigerator', 'train', 'tomato', 'machinery vehicle', 'tent',
'shampoo/shower gel', 'head phone', 'lantern', 'donut',
'cleaning products', 'sailboat', 'tangerine', 'pizza', 'kite',
'computer box', 'elephant', 'toiletries', 'gas stove', 'broccoli',
'toilet', 'stroller', 'shovel', 'baseball bat', 'microwave',
'skateboard', 'surfboard', 'surveillance camera', 'gun', 'life saver',
'cat', 'lemon', 'liquid soap', 'zebra', 'duck', 'sports car',
'giraffe', 'pumpkin', 'piano', 'stop sign', 'radiator', 'converter',
'tissue ', 'carrot', 'washing machine', 'vent', 'cookies',
'cutting/chopping board', 'tennis racket', 'candy',
'skating and skiing shoes', 'scissors', 'folder', 'baseball',
'strawberry', 'bow tie', 'pigeon', 'pepper', 'coffee machine',
'bathtub', 'snowboard', 'suitcase', 'grapes', 'ladder', 'pear',
'american football', 'basketball', 'potato', 'paint brush', 'printer',
'billiards', 'fire hydrant', 'goose', 'projector', 'sausage',
'fire extinguisher', 'extension cord', 'facial mask', 'tennis ball',
'chopsticks', 'electronic stove and gas stove', 'pie', 'frisbee',
'kettle', 'hamburger', 'golf club', 'cucumber', 'clutch', 'blender',
'tong', 'slide', 'hot dog', 'toothbrush', 'facial cleanser', 'mango',
'deer', 'egg', 'violin', 'marker', 'ship', 'chicken', 'onion',
'ice cream', 'tape', 'wheelchair', 'plum', 'bar soap', 'scale',
'watermelon', 'cabbage', 'router/modem', 'golf ball', 'pine apple',
'crane', 'fire truck', 'peach', 'cello', 'notepaper', 'tricycle',
'toaster', 'helicopter', 'green beans', 'brush', 'carriage', 'cigar',
'earphone', 'penguin', 'hurdle', 'swing', 'radio', 'CD',
'parking meter', 'swan', 'garlic', 'french fries', 'horn', 'avocado',
'saxophone', 'trumpet', 'sandwich', 'cue', 'kiwi fruit', 'bear',
'fishing rod', 'cherry', 'tablet', 'green vegetables', 'nuts', 'corn',
'key', 'screwdriver', 'globe', 'broom', 'pliers', 'volleyball',
'hammer', 'eggplant', 'trophy', 'dates', 'board eraser', 'rice',
'tape measure/ruler', 'dumbbell', 'hamimelon', 'stapler', 'camel',
'lettuce', 'goldfish', 'meat balls', 'medal', 'toothpaste',
'antelope', 'shrimp', 'rickshaw', 'trombone', 'pomegranate',
'coconut', 'jellyfish', 'mushroom', 'calculator', 'treadmill',
'butterfly', 'egg tart', 'cheese', 'pig', 'pomelo', 'race car',
'rice cooker', 'tuba', 'crosswalk sign', 'papaya', 'hair drier',
'green onion', 'chips', 'dolphin', 'sushi', 'urinal', 'donkey',
'electric drill', 'spring rolls', 'tortoise/turtle', 'parrot',
'flute', 'measuring cup', 'shark', 'steak', 'poker card',
'binoculars', 'llama', 'radish', 'noodles', 'yak', 'mop', 'crab',
'microscope', 'barbell', 'bread/bun', 'baozi', 'lion', 'red cabbage',
'polar bear', 'lighter', 'seal', 'mangosteen', 'comb', 'eraser',
'pitaya', 'scallop', 'pencil case', 'saw', 'table tennis paddle',
'okra', 'starfish', 'eagle', 'monkey', 'durian', 'game board',
'rabbit', 'french horn', 'ambulance', 'asparagus', 'hoverboard',
'pasta', 'target', 'hotair balloon', 'chainsaw', 'lobster', 'iron',
'flashlight'),
'palette':
None
}
COCOAPI = COCO
# ann_id is unique in coco dataset.
ANN_ID_UNIQUE = True
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
""" # noqa: E501
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
self.coco = self.COCOAPI(local_path)
# 'categories' list in objects365_train.json and objects365_val.json
# is inconsistent, need sort list(or dict) before get cat_ids.
cats = self.coco.cats
sorted_cats = {i: cats[i] for i in sorted(cats)}
self.coco.cats = sorted_cats
categories = self.coco.dataset['categories']
sorted_categories = sorted(categories, key=lambda i: i['id'])
self.coco.dataset['categories'] = sorted_categories
# The order of returned `cat_ids` will not
# change with the order of the `classes`
self.cat_ids = self.coco.get_cat_ids(
cat_names=self.metainfo['classes'])
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.cat_img_map = copy.deepcopy(self.coco.cat_img_map)
img_ids = self.coco.get_img_ids()
data_list = []
total_ann_ids = []
for img_id in img_ids:
raw_img_info = self.coco.load_imgs([img_id])[0]
raw_img_info['img_id'] = img_id
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
raw_ann_info = self.coco.load_anns(ann_ids)
total_ann_ids.extend(ann_ids)
parsed_data_info = self.parse_data_info({
'raw_ann_info':
raw_ann_info,
'raw_img_info':
raw_img_info
})
data_list.append(parsed_data_info)
if self.ANN_ID_UNIQUE:
assert len(set(total_ann_ids)) == len(
total_ann_ids
), f"Annotation ids in '{self.ann_file}' are not unique!"
del self.coco
return data_list
@DATASETS.register_module()
class Objects365V2Dataset(CocoDataset):
"""Objects365 v2 dataset for detection."""
METAINFO = {
'classes':
('Person', 'Sneakers', 'Chair', 'Other Shoes', 'Hat', 'Car', 'Lamp',
'Glasses', 'Bottle', 'Desk', 'Cup', 'Street Lights', 'Cabinet/shelf',
'Handbag/Satchel', 'Bracelet', 'Plate', 'Picture/Frame', 'Helmet',
'Book', 'Gloves', 'Storage box', 'Boat', 'Leather Shoes', 'Flower',
'Bench', 'Potted Plant', 'Bowl/Basin', 'Flag', 'Pillow', 'Boots',
'Vase', 'Microphone', 'Necklace', 'Ring', 'SUV', 'Wine Glass', 'Belt',
'Moniter/TV', 'Backpack', 'Umbrella', 'Traffic Light', 'Speaker',
'Watch', 'Tie', 'Trash bin Can', 'Slippers', 'Bicycle', 'Stool',
'Barrel/bucket', 'Van', 'Couch', 'Sandals', 'Bakset', 'Drum',
'Pen/Pencil', 'Bus', 'Wild Bird', 'High Heels', 'Motorcycle',
'Guitar', 'Carpet', 'Cell Phone', 'Bread', 'Camera', 'Canned',
'Truck', 'Traffic cone', 'Cymbal', 'Lifesaver', 'Towel',
'Stuffed Toy', 'Candle', 'Sailboat', 'Laptop', 'Awning', 'Bed',
'Faucet', 'Tent', 'Horse', 'Mirror', 'Power outlet', 'Sink', 'Apple',
'Air Conditioner', 'Knife', 'Hockey Stick', 'Paddle', 'Pickup Truck',
'Fork', 'Traffic Sign', 'Ballon', 'Tripod', 'Dog', 'Spoon', 'Clock',
'Pot', 'Cow', 'Cake', 'Dinning Table', 'Sheep', 'Hanger',
'Blackboard/Whiteboard', 'Napkin', 'Other Fish', 'Orange/Tangerine',
'Toiletry', 'Keyboard', 'Tomato', 'Lantern', 'Machinery Vehicle',
'Fan', 'Green Vegetables', 'Banana', 'Baseball Glove', 'Airplane',
'Mouse', 'Train', 'Pumpkin', 'Soccer', 'Skiboard', 'Luggage',
'Nightstand', 'Tea pot', 'Telephone', 'Trolley', 'Head Phone',
'Sports Car', 'Stop Sign', 'Dessert', 'Scooter', 'Stroller', 'Crane',
'Remote', 'Refrigerator', 'Oven', 'Lemon', 'Duck', 'Baseball Bat',
'Surveillance Camera', 'Cat', 'Jug', 'Broccoli', 'Piano', 'Pizza',
'Elephant', 'Skateboard', 'Surfboard', 'Gun',
'Skating and Skiing shoes', 'Gas stove', 'Donut', 'Bow Tie', 'Carrot',
'Toilet', 'Kite', 'Strawberry', 'Other Balls', 'Shovel', 'Pepper',
'Computer Box', 'Toilet Paper', 'Cleaning Products', 'Chopsticks',
'Microwave', 'Pigeon', 'Baseball', 'Cutting/chopping Board',
'Coffee Table', 'Side Table', 'Scissors', 'Marker', 'Pie', 'Ladder',
'Snowboard', 'Cookies', 'Radiator', 'Fire Hydrant', 'Basketball',
'Zebra', 'Grape', 'Giraffe', 'Potato', 'Sausage', 'Tricycle',
'Violin', 'Egg', 'Fire Extinguisher', 'Candy', 'Fire Truck',
'Billards', 'Converter', 'Bathtub', 'Wheelchair', 'Golf Club',
'Briefcase', 'Cucumber', 'Cigar/Cigarette ', 'Paint Brush', 'Pear',
'Heavy Truck', 'Hamburger', 'Extractor', 'Extention Cord', 'Tong',
'Tennis Racket', 'Folder', 'American Football', 'earphone', 'Mask',
'Kettle', 'Tennis', 'Ship', 'Swing', 'Coffee Machine', 'Slide',
'Carriage', 'Onion', 'Green beans', 'Projector', 'Frisbee',
'Washing Machine/Drying Machine', 'Chicken', 'Printer', 'Watermelon',
'Saxophone', 'Tissue', 'Toothbrush', 'Ice cream', 'Hotair ballon',
'Cello', 'French Fries', 'Scale', 'Trophy', 'Cabbage', 'Hot dog',
'Blender', 'Peach', 'Rice', 'Wallet/Purse', 'Volleyball', 'Deer',
'Goose', 'Tape', 'Tablet', 'Cosmetics', 'Trumpet', 'Pineapple',
'Golf Ball', 'Ambulance', 'Parking meter', 'Mango', 'Key', 'Hurdle',
'Fishing Rod', 'Medal', 'Flute', 'Brush', 'Penguin', 'Megaphone',
'Corn', 'Lettuce', 'Garlic', 'Swan', 'Helicopter', 'Green Onion',
'Sandwich', 'Nuts', 'Speed Limit Sign', 'Induction Cooker', 'Broom',
'Trombone', 'Plum', 'Rickshaw', 'Goldfish', 'Kiwi fruit',
'Router/modem', 'Poker Card', 'Toaster', 'Shrimp', 'Sushi', 'Cheese',
'Notepaper', 'Cherry', 'Pliers', 'CD', 'Pasta', 'Hammer', 'Cue',
'Avocado', 'Hamimelon', 'Flask', 'Mushroon', 'Screwdriver', 'Soap',
'Recorder', 'Bear', 'Eggplant', 'Board Eraser', 'Coconut',
'Tape Measur/ Ruler', 'Pig', 'Showerhead', 'Globe', 'Chips', 'Steak',
'Crosswalk Sign', 'Stapler', 'Campel', 'Formula 1 ', 'Pomegranate',
'Dishwasher', 'Crab', 'Hoverboard', 'Meat ball', 'Rice Cooker',
'Tuba', 'Calculator', 'Papaya', 'Antelope', 'Parrot', 'Seal',
'Buttefly', 'Dumbbell', 'Donkey', 'Lion', 'Urinal', 'Dolphin',
'Electric Drill', 'Hair Dryer', 'Egg tart', 'Jellyfish', 'Treadmill',
'Lighter', 'Grapefruit', 'Game board', 'Mop', 'Radish', 'Baozi',
'Target', 'French', 'Spring Rolls', 'Monkey', 'Rabbit', 'Pencil Case',
'Yak', 'Red Cabbage', 'Binoculars', 'Asparagus', 'Barbell', 'Scallop',
'Noddles', 'Comb', 'Dumpling', 'Oyster', 'Table Teniis paddle',
'Cosmetics Brush/Eyeliner Pencil', 'Chainsaw', 'Eraser', 'Lobster',
'Durian', 'Okra', 'Lipstick', 'Cosmetics Mirror', 'Curling',
'Table Tennis '),
'palette':
None
}
COCOAPI = COCO
# ann_id is unique in coco dataset.
ANN_ID_UNIQUE = True
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
""" # noqa: E501
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
self.coco = self.COCOAPI(local_path)
# The order of returned `cat_ids` will not
# change with the order of the `classes`
self.cat_ids = self.coco.get_cat_ids(
cat_names=self.metainfo['classes'])
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.cat_img_map = copy.deepcopy(self.coco.cat_img_map)
img_ids = self.coco.get_img_ids()
data_list = []
total_ann_ids = []
for img_id in img_ids:
raw_img_info = self.coco.load_imgs([img_id])[0]
raw_img_info['img_id'] = img_id
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
raw_ann_info = self.coco.load_anns(ann_ids)
total_ann_ids.extend(ann_ids)
# file_name should be `patchX/xxx.jpg`
file_name = osp.join(
osp.split(osp.split(raw_img_info['file_name'])[0])[-1],
osp.split(raw_img_info['file_name'])[-1])
if file_name in objv2_ignore_list:
continue
raw_img_info['file_name'] = file_name
parsed_data_info = self.parse_data_info({
'raw_ann_info':
raw_ann_info,
'raw_img_info':
raw_img_info
})
data_list.append(parsed_data_info)
if self.ANN_ID_UNIQUE:
assert len(set(total_ann_ids)) == len(
total_ann_ids
), f"Annotation ids in '{self.ann_file}' are not unique!"
del self.coco
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp
from typing import List, Optional
from mmengine.fileio import get_local_path
from mmdet.registry import DATASETS
from .base_det_dataset import BaseDetDataset
@DATASETS.register_module()
class ODVGDataset(BaseDetDataset):
"""object detection and visual grounding dataset."""
def __init__(self,
*args,
data_root: str = '',
label_map_file: Optional[str] = None,
need_text: bool = True,
**kwargs) -> None:
self.dataset_mode = 'VG'
self.need_text = need_text
if label_map_file:
label_map_file = osp.join(data_root, label_map_file)
with open(label_map_file, 'r') as file:
self.label_map = json.load(file)
self.dataset_mode = 'OD'
super().__init__(*args, data_root=data_root, **kwargs)
assert self.return_classes is True
def load_data_list(self) -> List[dict]:
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
data_list = [json.loads(line) for line in f]
out_data_list = []
for data in data_list:
data_info = {}
img_path = osp.join(self.data_prefix['img'], data['filename'])
data_info['img_path'] = img_path
data_info['height'] = data['height']
data_info['width'] = data['width']
if self.dataset_mode == 'OD':
if self.need_text:
data_info['text'] = self.label_map
anno = data.get('detection', {})
instances = [obj for obj in anno.get('instances', [])]
bboxes = [obj['bbox'] for obj in instances]
bbox_labels = [str(obj['label']) for obj in instances]
instances = []
for bbox, label in zip(bboxes, bbox_labels):
instance = {}
x1, y1, x2, y2 = bbox
inter_w = max(0, min(x2, data['width']) - max(x1, 0))
inter_h = max(0, min(y2, data['height']) - max(y1, 0))
if inter_w * inter_h == 0:
continue
if (x2 - x1) < 1 or (y2 - y1) < 1:
continue
instance['ignore_flag'] = 0
instance['bbox'] = bbox
instance['bbox_label'] = int(label)
instances.append(instance)
data_info['instances'] = instances
data_info['dataset_mode'] = self.dataset_mode
out_data_list.append(data_info)
else:
anno = data['grounding']
data_info['text'] = anno['caption']
regions = anno['regions']
instances = []
phrases = {}
for i, region in enumerate(regions):
bbox = region['bbox']
phrase = region['phrase']
tokens_positive = region['tokens_positive']
if not isinstance(bbox[0], list):
bbox = [bbox]
for box in bbox:
instance = {}
x1, y1, x2, y2 = box
inter_w = max(0, min(x2, data['width']) - max(x1, 0))
inter_h = max(0, min(y2, data['height']) - max(y1, 0))
if inter_w * inter_h == 0:
continue
if (x2 - x1) < 1 or (y2 - y1) < 1:
continue
instance['ignore_flag'] = 0
instance['bbox'] = box
instance['bbox_label'] = i
phrases[i] = {
'phrase': phrase,
'tokens_positive': tokens_positive
}
instances.append(instance)
data_info['instances'] = instances
data_info['phrases'] = phrases
data_info['dataset_mode'] = self.dataset_mode
out_data_list.append(data_info)
del data_list
return out_data_list
# Copyright (c) OpenMMLab. All rights reserved.
import csv
import os.path as osp
from collections import defaultdict
from typing import Dict, List, Optional
import numpy as np
from mmengine.fileio import get_local_path, load
from mmengine.utils import is_abs
from mmdet.registry import DATASETS
from .base_det_dataset import BaseDetDataset
@DATASETS.register_module()
class OpenImagesDataset(BaseDetDataset):
"""Open Images dataset for detection.
Args:
ann_file (str): Annotation file path.
label_file (str): File path of the label description file that
maps the classes names in MID format to their short
descriptions.
meta_file (str): File path to get image metas.
hierarchy_file (str): The file path of the class hierarchy.
image_level_ann_file (str): Human-verified image level annotation,
which is used in evaluation.
backend_args (dict, optional): Arguments to instantiate the
corresponding backend. Defaults to None.
"""
METAINFO: dict = dict(dataset_type='oid_v6')
def __init__(self,
label_file: str,
meta_file: str,
hierarchy_file: str,
image_level_ann_file: Optional[str] = None,
**kwargs) -> None:
self.label_file = label_file
self.meta_file = meta_file
self.hierarchy_file = hierarchy_file
self.image_level_ann_file = image_level_ann_file
super().__init__(**kwargs)
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
"""
classes_names, label_id_mapping = self._parse_label_file(
self.label_file)
self._metainfo['classes'] = classes_names
self.label_id_mapping = label_id_mapping
if self.image_level_ann_file is not None:
img_level_anns = self._parse_img_level_ann(
self.image_level_ann_file)
else:
img_level_anns = None
# OpenImagesMetric can get the relation matrix from the dataset meta
relation_matrix = self._get_relation_matrix(self.hierarchy_file)
self._metainfo['RELATION_MATRIX'] = relation_matrix
data_list = []
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
last_img_id = None
instances = []
for i, line in enumerate(reader):
if i == 0:
continue
img_id = line[0]
if last_img_id is None:
last_img_id = img_id
label_id = line[2]
assert label_id in self.label_id_mapping
label = int(self.label_id_mapping[label_id])
bbox = [
float(line[4]), # xmin
float(line[6]), # ymin
float(line[5]), # xmax
float(line[7]) # ymax
]
is_occluded = True if int(line[8]) == 1 else False
is_truncated = True if int(line[9]) == 1 else False
is_group_of = True if int(line[10]) == 1 else False
is_depiction = True if int(line[11]) == 1 else False
is_inside = True if int(line[12]) == 1 else False
instance = dict(
bbox=bbox,
bbox_label=label,
ignore_flag=0,
is_occluded=is_occluded,
is_truncated=is_truncated,
is_group_of=is_group_of,
is_depiction=is_depiction,
is_inside=is_inside)
last_img_path = osp.join(self.data_prefix['img'],
f'{last_img_id}.jpg')
if img_id != last_img_id:
# switch to a new image, record previous image's data.
data_info = dict(
img_path=last_img_path,
img_id=last_img_id,
instances=instances,
)
data_list.append(data_info)
instances = []
instances.append(instance)
last_img_id = img_id
data_list.append(
dict(
img_path=last_img_path,
img_id=last_img_id,
instances=instances,
))
# add image metas to data list
img_metas = load(
self.meta_file, file_format='pkl', backend_args=self.backend_args)
assert len(img_metas) == len(data_list)
for i, meta in enumerate(img_metas):
img_id = data_list[i]['img_id']
assert f'{img_id}.jpg' == osp.split(meta['filename'])[-1]
h, w = meta['ori_shape'][:2]
data_list[i]['height'] = h
data_list[i]['width'] = w
# denormalize bboxes
for j in range(len(data_list[i]['instances'])):
data_list[i]['instances'][j]['bbox'][0] *= w
data_list[i]['instances'][j]['bbox'][2] *= w
data_list[i]['instances'][j]['bbox'][1] *= h
data_list[i]['instances'][j]['bbox'][3] *= h
# add image-level annotation
if img_level_anns is not None:
img_labels = []
confidences = []
img_ann_list = img_level_anns.get(img_id, [])
for ann in img_ann_list:
img_labels.append(int(ann['image_level_label']))
confidences.append(float(ann['confidence']))
data_list[i]['image_level_labels'] = np.array(
img_labels, dtype=np.int64)
data_list[i]['confidences'] = np.array(
confidences, dtype=np.float32)
return data_list
def _parse_label_file(self, label_file: str) -> tuple:
"""Get classes name and index mapping from cls-label-description file.
Args:
label_file (str): File path of the label description file that
maps the classes names in MID format to their short
descriptions.
Returns:
tuple: Class name of OpenImages.
"""
index_list = []
classes_names = []
with get_local_path(
label_file, backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
for line in reader:
# self.cat2label[line[0]] = line[1]
classes_names.append(line[1])
index_list.append(line[0])
index_mapping = {index: i for i, index in enumerate(index_list)}
return classes_names, index_mapping
def _parse_img_level_ann(self,
img_level_ann_file: str) -> Dict[str, List[dict]]:
"""Parse image level annotations from csv style ann_file.
Args:
img_level_ann_file (str): CSV style image level annotation
file path.
Returns:
Dict[str, List[dict]]: Annotations where item of the defaultdict
indicates an image, each of which has (n) dicts.
Keys of dicts are:
- `image_level_label` (int): Label id.
- `confidence` (float): Labels that are human-verified to be
present in an image have confidence = 1 (positive labels).
Labels that are human-verified to be absent from an image
have confidence = 0 (negative labels). Machine-generated
labels have fractional confidences, generally >= 0.5.
The higher the confidence, the smaller the chance for
the label to be a false positive.
"""
item_lists = defaultdict(list)
with get_local_path(
img_level_ann_file,
backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
for i, line in enumerate(reader):
if i == 0:
continue
img_id = line[0]
item_lists[img_id].append(
dict(
image_level_label=int(
self.label_id_mapping[line[2]]),
confidence=float(line[3])))
return item_lists
def _get_relation_matrix(self, hierarchy_file: str) -> np.ndarray:
"""Get the matrix of class hierarchy from the hierarchy file. Hierarchy
for 600 classes can be found at https://storage.googleapis.com/openimag
es/2018_04/bbox_labels_600_hierarchy_visualizer/circle.html.
Args:
hierarchy_file (str): File path to the hierarchy for classes.
Returns:
np.ndarray: The matrix of the corresponding relationship between
the parent class and the child class, of shape
(class_num, class_num).
""" # noqa
hierarchy = load(
hierarchy_file, file_format='json', backend_args=self.backend_args)
class_num = len(self._metainfo['classes'])
relation_matrix = np.eye(class_num, class_num)
relation_matrix = self._convert_hierarchy_tree(hierarchy,
relation_matrix)
return relation_matrix
def _convert_hierarchy_tree(self,
hierarchy_map: dict,
relation_matrix: np.ndarray,
parents: list = [],
get_all_parents: bool = True) -> np.ndarray:
"""Get matrix of the corresponding relationship between the parent
class and the child class.
Args:
hierarchy_map (dict): Including label name and corresponding
subcategory. Keys of dicts are:
- `LabeName` (str): Name of the label.
- `Subcategory` (dict | list): Corresponding subcategory(ies).
relation_matrix (ndarray): The matrix of the corresponding
relationship between the parent class and the child class,
of shape (class_num, class_num).
parents (list): Corresponding parent class.
get_all_parents (bool): Whether get all parent names.
Default: True
Returns:
ndarray: The matrix of the corresponding relationship between
the parent class and the child class, of shape
(class_num, class_num).
"""
if 'Subcategory' in hierarchy_map:
for node in hierarchy_map['Subcategory']:
if 'LabelName' in node:
children_name = node['LabelName']
children_index = self.label_id_mapping[children_name]
children = [children_index]
else:
continue
if len(parents) > 0:
for parent_index in parents:
if get_all_parents:
children.append(parent_index)
relation_matrix[children_index, parent_index] = 1
relation_matrix = self._convert_hierarchy_tree(
node, relation_matrix, parents=children)
return relation_matrix
def _join_prefix(self):
"""Join ``self.data_root`` with annotation path."""
super()._join_prefix()
if not is_abs(self.label_file) and self.label_file:
self.label_file = osp.join(self.data_root, self.label_file)
if not is_abs(self.meta_file) and self.meta_file:
self.meta_file = osp.join(self.data_root, self.meta_file)
if not is_abs(self.hierarchy_file) and self.hierarchy_file:
self.hierarchy_file = osp.join(self.data_root, self.hierarchy_file)
if self.image_level_ann_file and not is_abs(self.image_level_ann_file):
self.image_level_ann_file = osp.join(self.data_root,
self.image_level_ann_file)
@DATASETS.register_module()
class OpenImagesChallengeDataset(OpenImagesDataset):
"""Open Images Challenge dataset for detection.
Args:
ann_file (str): Open Images Challenge box annotation in txt format.
"""
METAINFO: dict = dict(dataset_type='oid_challenge')
def __init__(self, ann_file: str, **kwargs) -> None:
if not ann_file.endswith('txt'):
raise TypeError('The annotation file of Open Images Challenge '
'should be a txt file.')
super().__init__(ann_file=ann_file, **kwargs)
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
"""
classes_names, label_id_mapping = self._parse_label_file(
self.label_file)
self._metainfo['classes'] = classes_names
self.label_id_mapping = label_id_mapping
if self.image_level_ann_file is not None:
img_level_anns = self._parse_img_level_ann(
self.image_level_ann_file)
else:
img_level_anns = None
# OpenImagesMetric can get the relation matrix from the dataset meta
relation_matrix = self._get_relation_matrix(self.hierarchy_file)
self._metainfo['RELATION_MATRIX'] = relation_matrix
data_list = []
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
lines = f.readlines()
i = 0
while i < len(lines):
instances = []
filename = lines[i].rstrip()
i += 2
img_gt_size = int(lines[i])
i += 1
for j in range(img_gt_size):
sp = lines[i + j].split()
instances.append(
dict(
bbox=[
float(sp[1]),
float(sp[2]),
float(sp[3]),
float(sp[4])
],
bbox_label=int(sp[0]) - 1, # labels begin from 1
ignore_flag=0,
is_group_ofs=True if int(sp[5]) == 1 else False))
i += img_gt_size
data_list.append(
dict(
img_path=osp.join(self.data_prefix['img'], filename),
instances=instances,
))
# add image metas to data list
img_metas = load(
self.meta_file, file_format='pkl', backend_args=self.backend_args)
assert len(img_metas) == len(data_list)
for i, meta in enumerate(img_metas):
img_id = osp.split(data_list[i]['img_path'])[-1][:-4]
assert img_id == osp.split(meta['filename'])[-1][:-4]
h, w = meta['ori_shape'][:2]
data_list[i]['height'] = h
data_list[i]['width'] = w
data_list[i]['img_id'] = img_id
# denormalize bboxes
for j in range(len(data_list[i]['instances'])):
data_list[i]['instances'][j]['bbox'][0] *= w
data_list[i]['instances'][j]['bbox'][2] *= w
data_list[i]['instances'][j]['bbox'][1] *= h
data_list[i]['instances'][j]['bbox'][3] *= h
# add image-level annotation
if img_level_anns is not None:
img_labels = []
confidences = []
img_ann_list = img_level_anns.get(img_id, [])
for ann in img_ann_list:
img_labels.append(int(ann['image_level_label']))
confidences.append(float(ann['confidence']))
data_list[i]['image_level_labels'] = np.array(
img_labels, dtype=np.int64)
data_list[i]['confidences'] = np.array(
confidences, dtype=np.float32)
return data_list
def _parse_label_file(self, label_file: str) -> tuple:
"""Get classes name and index mapping from cls-label-description file.
Args:
label_file (str): File path of the label description file that
maps the classes names in MID format to their short
descriptions.
Returns:
tuple: Class name of OpenImages.
"""
label_list = []
id_list = []
index_mapping = {}
with get_local_path(
label_file, backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
for line in reader:
label_name = line[0]
label_id = int(line[2])
label_list.append(line[1])
id_list.append(label_id)
index_mapping[label_name] = label_id - 1
indexes = np.argsort(id_list)
classes_names = []
for index in indexes:
classes_names.append(label_list[index])
return classes_names, index_mapping
def _parse_img_level_ann(self, image_level_ann_file):
"""Parse image level annotations from csv style ann_file.
Args:
image_level_ann_file (str): CSV style image level annotation
file path.
Returns:
defaultdict[list[dict]]: Annotations where item of the defaultdict
indicates an image, each of which has (n) dicts.
Keys of dicts are:
- `image_level_label` (int): of shape 1.
- `confidence` (float): of shape 1.
"""
item_lists = defaultdict(list)
with get_local_path(
image_level_ann_file,
backend_args=self.backend_args) as local_path:
with open(local_path, 'r') as f:
reader = csv.reader(f)
i = -1
for line in reader:
i += 1
if i == 0:
continue
else:
img_id = line[0]
label_id = line[1]
assert label_id in self.label_id_mapping
image_level_label = int(
self.label_id_mapping[label_id])
confidence = float(line[2])
item_lists[img_id].append(
dict(
image_level_label=image_level_label,
confidence=confidence))
return item_lists
def _get_relation_matrix(self, hierarchy_file: str) -> np.ndarray:
"""Get the matrix of class hierarchy from the hierarchy file.
Args:
hierarchy_file (str): File path to the hierarchy for classes.
Returns:
np.ndarray: The matrix of the corresponding
relationship between the parent class and the child class,
of shape (class_num, class_num).
"""
with get_local_path(
hierarchy_file, backend_args=self.backend_args) as local_path:
class_label_tree = np.load(local_path, allow_pickle=True)
return class_label_tree[1:, 1:]
# Copyright (c) OpenMMLab. All rights reserved.
import collections
import os.path as osp
import random
from typing import Dict, List
import mmengine
from mmengine.dataset import BaseDataset
from mmdet.registry import DATASETS
@DATASETS.register_module()
class RefCocoDataset(BaseDataset):
"""RefCOCO dataset.
The `Refcoco` and `Refcoco+` dataset is based on
`ReferItGame: Referring to Objects in Photographs of Natural Scenes
<http://tamaraberg.com/papers/referit.pdf>`_.
The `Refcocog` dataset is based on
`Generation and Comprehension of Unambiguous Object Descriptions
<https://arxiv.org/abs/1511.02283>`_.
Args:
ann_file (str): Annotation file path.
data_root (str): The root directory for ``data_prefix`` and
``ann_file``. Defaults to ''.
data_prefix (str): Prefix for training data.
split_file (str): Split file path.
split (str): Split name. Defaults to 'train'.
text_mode (str): Text mode. Defaults to 'random'.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def __init__(self,
data_root: str,
ann_file: str,
split_file: str,
data_prefix: Dict,
split: str = 'train',
text_mode: str = 'random',
**kwargs):
self.split_file = split_file
self.split = split
assert text_mode in ['original', 'random', 'concat', 'select_first']
self.text_mode = text_mode
super().__init__(
data_root=data_root,
data_prefix=data_prefix,
ann_file=ann_file,
**kwargs,
)
def _join_prefix(self):
if not mmengine.is_abs(self.split_file) and self.split_file:
self.split_file = osp.join(self.data_root, self.split_file)
return super()._join_prefix()
def _init_refs(self):
"""Initialize the refs for RefCOCO."""
anns, imgs = {}, {}
for ann in self.instances['annotations']:
anns[ann['id']] = ann
for img in self.instances['images']:
imgs[img['id']] = img
refs, ref_to_ann = {}, {}
for ref in self.splits:
# ids
ref_id = ref['ref_id']
ann_id = ref['ann_id']
# add mapping related to ref
refs[ref_id] = ref
ref_to_ann[ref_id] = anns[ann_id]
self.refs = refs
self.ref_to_ann = ref_to_ann
def load_data_list(self) -> List[dict]:
"""Load data list."""
self.splits = mmengine.load(self.split_file, file_format='pkl')
self.instances = mmengine.load(self.ann_file, file_format='json')
self._init_refs()
img_prefix = self.data_prefix['img_path']
ref_ids = [
ref['ref_id'] for ref in self.splits if ref['split'] == self.split
]
full_anno = []
for ref_id in ref_ids:
ref = self.refs[ref_id]
ann = self.ref_to_ann[ref_id]
ann.update(ref)
full_anno.append(ann)
image_id_list = []
final_anno = {}
for anno in full_anno:
image_id_list.append(anno['image_id'])
final_anno[anno['ann_id']] = anno
annotations = [value for key, value in final_anno.items()]
coco_train_id = []
image_annot = {}
for i in range(len(self.instances['images'])):
coco_train_id.append(self.instances['images'][i]['id'])
image_annot[self.instances['images'][i]
['id']] = self.instances['images'][i]
images = []
for image_id in list(set(image_id_list)):
images += [image_annot[image_id]]
data_list = []
grounding_dict = collections.defaultdict(list)
for anno in annotations:
image_id = int(anno['image_id'])
grounding_dict[image_id].append(anno)
join_path = mmengine.fileio.get_file_backend(img_prefix).join_path
for image in images:
img_id = image['id']
instances = []
sentences = []
for grounding_anno in grounding_dict[img_id]:
texts = [x['raw'].lower() for x in grounding_anno['sentences']]
# random select one text
if self.text_mode == 'random':
idx = random.randint(0, len(texts) - 1)
text = [texts[idx]]
# concat all texts
elif self.text_mode == 'concat':
text = [''.join(texts)]
# select the first text
elif self.text_mode == 'select_first':
text = [texts[0]]
# use all texts
elif self.text_mode == 'original':
text = texts
else:
raise ValueError(f'Invalid text mode "{self.text_mode}".')
ins = [{
'mask': grounding_anno['segmentation'],
'ignore_flag': 0
}] * len(text)
instances.extend(ins)
sentences.extend(text)
data_info = {
'img_path': join_path(img_prefix, image['file_name']),
'img_id': img_id,
'instances': instances,
'text': sentences
}
data_list.append(data_info)
if len(data_list) == 0:
raise ValueError(f'No sample in split "{self.split}".')
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from collections import defaultdict
from typing import Any, Dict, List
import numpy as np
from mmengine.dataset import BaseDataset
from mmengine.utils import check_file_exist
from mmdet.registry import DATASETS
@DATASETS.register_module()
class ReIDDataset(BaseDataset):
"""Dataset for ReID.
Args:
triplet_sampler (dict, optional): The sampler for hard mining
triplet loss. Defaults to None.
keys: num_ids (int): The number of person ids.
ins_per_id (int): The number of image for each person.
"""
def __init__(self, triplet_sampler: dict = None, *args, **kwargs):
self.triplet_sampler = triplet_sampler
super().__init__(*args, **kwargs)
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ''self.ann_file''.
Returns:
list[dict]: A list of annotation.
"""
assert isinstance(self.ann_file, str)
check_file_exist(self.ann_file)
data_list = []
with open(self.ann_file) as f:
samples = [x.strip().split(' ') for x in f.readlines()]
for filename, gt_label in samples:
info = dict(img_prefix=self.data_prefix)
if self.data_prefix['img_path'] is not None:
info['img_path'] = osp.join(self.data_prefix['img_path'],
filename)
else:
info['img_path'] = filename
info['gt_label'] = np.array(gt_label, dtype=np.int64)
data_list.append(info)
self._parse_ann_info(data_list)
return data_list
def _parse_ann_info(self, data_list: List[dict]):
"""Parse person id annotations."""
index_tmp_dic = defaultdict(list) # pid->[idx1,...,idxN]
self.index_dic = dict() # pid->array([idx1,...,idxN])
for idx, info in enumerate(data_list):
pid = info['gt_label']
index_tmp_dic[int(pid)].append(idx)
for pid, idxs in index_tmp_dic.items():
self.index_dic[pid] = np.asarray(idxs, dtype=np.int64)
self.pids = np.asarray(list(self.index_dic.keys()), dtype=np.int64)
def prepare_data(self, idx: int) -> Any:
"""Get data processed by ''self.pipeline''.
Args:
idx (int): The index of ''data_info''
Returns:
Any: Depends on ''self.pipeline''
"""
data_info = self.get_data_info(idx)
if self.triplet_sampler is not None:
img_info = self.triplet_sampling(data_info['gt_label'],
**self.triplet_sampler)
data_info = copy.deepcopy(img_info) # triplet -> list
else:
data_info = copy.deepcopy(data_info) # no triplet -> dict
return self.pipeline(data_info)
def triplet_sampling(self,
pos_pid,
num_ids: int = 8,
ins_per_id: int = 4) -> Dict:
"""Triplet sampler for hard mining triplet loss. First, for one
pos_pid, random sample ins_per_id images with same person id.
Then, random sample num_ids - 1 images for each negative id.
Finally, random sample ins_per_id images for each negative id.
Args:
pos_pid (ndarray): The person id of the anchor.
num_ids (int): The number of person ids.
ins_per_id (int): The number of images for each person.
Returns:
Dict: Annotation information of num_ids X ins_per_id images.
"""
assert len(self.pids) >= num_ids, \
'The number of person ids in the training set must ' \
'be greater than the number of person ids in the sample.'
pos_idxs = self.index_dic[int(
pos_pid)] # all positive idxs for pos_pid
idxs_list = []
# select positive samplers
idxs_list.extend(pos_idxs[np.random.choice(
pos_idxs.shape[0], ins_per_id, replace=True)])
# select negative ids
neg_pids = np.random.choice(
[i for i, _ in enumerate(self.pids) if i != pos_pid],
num_ids - 1,
replace=False)
# select negative samplers for each negative id
for neg_pid in neg_pids:
neg_idxs = self.index_dic[neg_pid]
idxs_list.extend(neg_idxs[np.random.choice(
neg_idxs.shape[0], ins_per_id, replace=True)])
# return the final triplet batch
triplet_img_infos = []
for idx in idxs_list:
triplet_img_infos.append(copy.deepcopy(self.get_data_info(idx)))
# Collect data_list scatters (list of dict -> dict of list)
out = dict()
for key in triplet_img_infos[0].keys():
out[key] = [_info[key] for _info in triplet_img_infos]
return out
# Copyright (c) OpenMMLab. All rights reserved.
from .batch_sampler import (AspectRatioBatchSampler,
MultiDataAspectRatioBatchSampler,
TrackAspectRatioBatchSampler)
from .class_aware_sampler import ClassAwareSampler
from .custom_sample_size_sampler import CustomSampleSizeSampler
from .multi_data_sampler import MultiDataSampler
from .multi_source_sampler import GroupMultiSourceSampler, MultiSourceSampler
from .track_img_sampler import TrackImgSampler
__all__ = [
'ClassAwareSampler', 'AspectRatioBatchSampler', 'MultiSourceSampler',
'GroupMultiSourceSampler', 'TrackImgSampler',
'TrackAspectRatioBatchSampler', 'MultiDataSampler',
'MultiDataAspectRatioBatchSampler', 'CustomSampleSizeSampler'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
from torch.utils.data import BatchSampler, Sampler
from mmdet.datasets.samplers.track_img_sampler import TrackImgSampler
from mmdet.registry import DATA_SAMPLERS
# TODO: maybe replace with a data_loader wrapper
@DATA_SAMPLERS.register_module()
class AspectRatioBatchSampler(BatchSampler):
"""A sampler wrapper for grouping images with similar aspect ratio (< 1 or.
>= 1) into a same batch.
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``.
"""
def __init__(self,
sampler: Sampler,
batch_size: int,
drop_last: bool = False) -> None:
if not isinstance(sampler, Sampler):
raise TypeError('sampler should be an instance of ``Sampler``, '
f'but got {sampler}')
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError('batch_size should be a positive integer value, '
f'but got batch_size={batch_size}')
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
# two groups for w < h and w >= h
self._aspect_ratio_buckets = [[] for _ in range(2)]
def __iter__(self) -> Sequence[int]:
for idx in self.sampler:
data_info = self.sampler.dataset.get_data_info(idx)
width, height = data_info['width'], data_info['height']
bucket_id = 0 if width < height else 1
bucket = self._aspect_ratio_buckets[bucket_id]
bucket.append(idx)
# yield a batch of indices in the same aspect ratio group
if len(bucket) == self.batch_size:
yield bucket[:]
del bucket[:]
# yield the rest data and reset the bucket
left_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[
1]
self._aspect_ratio_buckets = [[] for _ in range(2)]
while len(left_data) > 0:
if len(left_data) <= self.batch_size:
if not self.drop_last:
yield left_data[:]
left_data = []
else:
yield left_data[:self.batch_size]
left_data = left_data[self.batch_size:]
def __len__(self) -> int:
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
@DATA_SAMPLERS.register_module()
class TrackAspectRatioBatchSampler(AspectRatioBatchSampler):
"""A sampler wrapper for grouping images with similar aspect ratio (< 1 or.
>= 1) into a same batch.
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``.
"""
def __iter__(self) -> Sequence[int]:
for idx in self.sampler:
# hard code to solve TrackImgSampler
if isinstance(self.sampler, TrackImgSampler):
video_idx, _ = idx
else:
video_idx = idx
# video_idx
data_info = self.sampler.dataset.get_data_info(video_idx)
# data_info {video_id, images, video_length}
img_data_info = data_info['images'][0]
width, height = img_data_info['width'], img_data_info['height']
bucket_id = 0 if width < height else 1
bucket = self._aspect_ratio_buckets[bucket_id]
bucket.append(idx)
# yield a batch of indices in the same aspect ratio group
if len(bucket) == self.batch_size:
yield bucket[:]
del bucket[:]
# yield the rest data and reset the bucket
left_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[
1]
self._aspect_ratio_buckets = [[] for _ in range(2)]
while len(left_data) > 0:
if len(left_data) <= self.batch_size:
if not self.drop_last:
yield left_data[:]
left_data = []
else:
yield left_data[:self.batch_size]
left_data = left_data[self.batch_size:]
@DATA_SAMPLERS.register_module()
class MultiDataAspectRatioBatchSampler(BatchSampler):
"""A sampler wrapper for grouping images with similar aspect ratio (< 1 or.
>= 1) into a same batch for multi-source datasets.
Args:
sampler (Sampler): Base sampler.
batch_size (Sequence(int)): Size of mini-batch for multi-source
datasets.
num_datasets(int): Number of multi-source datasets.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``.
"""
def __init__(self,
sampler: Sampler,
batch_size: Sequence[int],
num_datasets: int,
drop_last: bool = True) -> None:
if not isinstance(sampler, Sampler):
raise TypeError('sampler should be an instance of ``Sampler``, '
f'but got {sampler}')
self.sampler = sampler
self.batch_size = batch_size
self.num_datasets = num_datasets
self.drop_last = drop_last
# two groups for w < h and w >= h for each dataset --> 2 * num_datasets
self._buckets = [[] for _ in range(2 * self.num_datasets)]
def __iter__(self) -> Sequence[int]:
for idx in self.sampler:
data_info = self.sampler.dataset.get_data_info(idx)
width, height = data_info['width'], data_info['height']
dataset_source_idx = self.sampler.dataset.get_dataset_source(idx)
aspect_ratio_bucket_id = 0 if width < height else 1
bucket_id = dataset_source_idx * 2 + aspect_ratio_bucket_id
bucket = self._buckets[bucket_id]
bucket.append(idx)
# yield a batch of indices in the same aspect ratio group
if len(bucket) == self.batch_size[dataset_source_idx]:
yield bucket[:]
del bucket[:]
# yield the rest data and reset the bucket
for i in range(self.num_datasets):
left_data = self._buckets[i * 2 + 0] + self._buckets[i * 2 + 1]
while len(left_data) > 0:
if len(left_data) <= self.batch_size[i]:
if not self.drop_last:
yield left_data[:]
left_data = []
else:
yield left_data[:self.batch_size[i]]
left_data = left_data[self.batch_size[i]:]
self._buckets = [[] for _ in range(2 * self.num_datasets)]
def __len__(self) -> int:
sizes = [0 for _ in range(self.num_datasets)]
for idx in self.sampler:
dataset_source_idx = self.sampler.dataset.get_dataset_source(idx)
sizes[dataset_source_idx] += 1
if self.drop_last:
lens = 0
for i in range(self.num_datasets):
lens += sizes[i] // self.batch_size[i]
return lens
else:
lens = 0
for i in range(self.num_datasets):
lens += (sizes[i] + self.batch_size[i] -
1) // self.batch_size[i]
return lens
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Dict, Iterator, Optional, Union
import numpy as np
import torch
from mmengine.dataset import BaseDataset
from mmengine.dist import get_dist_info, sync_random_seed
from torch.utils.data import Sampler
from mmdet.registry import DATA_SAMPLERS
@DATA_SAMPLERS.register_module()
class ClassAwareSampler(Sampler):
r"""Sampler that restricts data loading to the label of the dataset.
A class-aware sampling strategy to effectively tackle the
non-uniform class distribution. The length of the training data is
consistent with source data. Simple improvements based on `Relay
Backpropagation for Effective Learning of Deep Convolutional
Neural Networks <https://arxiv.org/abs/1512.05830>`_
The implementation logic is referred to
https://github.com/Sense-X/TSD/blob/master/mmdet/datasets/samplers/distributed_classaware_sampler.py
Args:
dataset: Dataset used for sampling.
seed (int, optional): random seed used to shuffle the sampler.
This number should be identical across all
processes in the distributed group. Defaults to None.
num_sample_class (int): The number of samples taken from each
per-label list. Defaults to 1.
"""
def __init__(self,
dataset: BaseDataset,
seed: Optional[int] = None,
num_sample_class: int = 1) -> None:
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
self.epoch = 0
# Must be the same across all workers. If None, will use a
# random seed shared among workers
# (require synchronization among all workers)
if seed is None:
seed = sync_random_seed()
self.seed = seed
# The number of samples taken from each per-label list
assert num_sample_class > 0 and isinstance(num_sample_class, int)
self.num_sample_class = num_sample_class
# Get per-label image list from dataset
self.cat_dict = self.get_cat2imgs()
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / world_size))
self.total_size = self.num_samples * self.world_size
# get number of images containing each category
self.num_cat_imgs = [len(x) for x in self.cat_dict.values()]
# filter labels without images
self.valid_cat_inds = [
i for i, length in enumerate(self.num_cat_imgs) if length != 0
]
self.num_classes = len(self.valid_cat_inds)
def get_cat2imgs(self) -> Dict[int, list]:
"""Get a dict with class as key and img_ids as values.
Returns:
dict[int, list]: A dict of per-label image list,
the item of the dict indicates a label index,
corresponds to the image index that contains the label.
"""
classes = self.dataset.metainfo.get('classes', None)
if classes is None:
raise ValueError('dataset metainfo must contain `classes`')
# sort the label index
cat2imgs = {i: [] for i in range(len(classes))}
for i in range(len(self.dataset)):
cat_ids = set(self.dataset.get_cat_ids(i))
for cat in cat_ids:
cat2imgs[cat].append(i)
return cat2imgs
def __iter__(self) -> Iterator[int]:
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch + self.seed)
# initialize label list
label_iter_list = RandomCycleIter(self.valid_cat_inds, generator=g)
# initialize each per-label image list
data_iter_dict = dict()
for i in self.valid_cat_inds:
data_iter_dict[i] = RandomCycleIter(self.cat_dict[i], generator=g)
def gen_cat_img_inds(cls_list, data_dict, num_sample_cls):
"""Traverse the categories and extract `num_sample_cls` image
indexes of the corresponding categories one by one."""
id_indices = []
for _ in range(len(cls_list)):
cls_idx = next(cls_list)
for _ in range(num_sample_cls):
id = next(data_dict[cls_idx])
id_indices.append(id)
return id_indices
# deterministically shuffle based on epoch
num_bins = int(
math.ceil(self.total_size * 1.0 / self.num_classes /
self.num_sample_class))
indices = []
for i in range(num_bins):
indices += gen_cat_img_inds(label_iter_list, data_iter_dict,
self.num_sample_class)
# fix extra samples to make it evenly divisible
if len(indices) >= self.total_size:
indices = indices[:self.total_size]
else:
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
offset = self.num_samples * self.rank
indices = indices[offset:offset + self.num_samples]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self) -> int:
"""The number of samples in this rank."""
return self.num_samples
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all replicas use a different
random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
class RandomCycleIter:
"""Shuffle the list and do it again after the list have traversed.
The implementation logic is referred to
https://github.com/wutong16/DistributionBalancedLoss/blob/master/mllt/datasets/loader/sampler.py
Example:
>>> label_list = [0, 1, 2, 4, 5]
>>> g = torch.Generator()
>>> g.manual_seed(0)
>>> label_iter_list = RandomCycleIter(label_list, generator=g)
>>> index = next(label_iter_list)
Args:
data (list or ndarray): The data that needs to be shuffled.
generator: An torch.Generator object, which is used in setting the seed
for generating random numbers.
""" # noqa: W605
def __init__(self,
data: Union[list, np.ndarray],
generator: torch.Generator = None) -> None:
self.data = data
self.length = len(data)
self.index = torch.randperm(self.length, generator=generator).numpy()
self.i = 0
self.generator = generator
def __iter__(self) -> Iterator:
return self
def __len__(self) -> int:
return len(self.data)
def __next__(self):
if self.i == self.length:
self.index = torch.randperm(
self.length, generator=self.generator).numpy()
self.i = 0
idx = self.data[self.index[self.i]]
self.i += 1
return idx
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Iterator, Optional, Sequence, Sized
import torch
from mmengine.dist import get_dist_info, sync_random_seed
from torch.utils.data import Sampler
from mmdet.registry import DATA_SAMPLERS
from .class_aware_sampler import RandomCycleIter
@DATA_SAMPLERS.register_module()
class CustomSampleSizeSampler(Sampler):
def __init__(self,
dataset: Sized,
dataset_size: Sequence[int],
ratio_mode: bool = False,
seed: Optional[int] = None,
round_up: bool = True) -> None:
assert len(dataset.datasets) == len(dataset_size)
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
if seed is None:
seed = sync_random_seed()
self.seed = seed
self.epoch = 0
self.round_up = round_up
total_size = 0
total_size_fake = 0
self.dataset_index = []
self.dataset_cycle_iter = []
new_dataset_size = []
for dataset, size in zip(dataset.datasets, dataset_size):
self.dataset_index.append(
list(range(total_size_fake,
len(dataset) + total_size_fake)))
total_size_fake += len(dataset)
if size == -1:
total_size += len(dataset)
self.dataset_cycle_iter.append(None)
new_dataset_size.append(-1)
else:
if ratio_mode:
size = int(size * len(dataset))
assert size <= len(
dataset
), f'dataset size {size} is larger than ' \
f'dataset length {len(dataset)}'
total_size += size
new_dataset_size.append(size)
g = torch.Generator()
g.manual_seed(self.seed)
self.dataset_cycle_iter.append(
RandomCycleIter(self.dataset_index[-1], generator=g))
self.dataset_size = new_dataset_size
if self.round_up:
self.num_samples = math.ceil(total_size / world_size)
self.total_size = self.num_samples * self.world_size
else:
self.num_samples = math.ceil((total_size - rank) / world_size)
self.total_size = total_size
def __iter__(self) -> Iterator[int]:
"""Iterate the indices."""
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
out_index = []
for data_size, data_index, cycle_iter in zip(self.dataset_size,
self.dataset_index,
self.dataset_cycle_iter):
if data_size == -1:
out_index += data_index
else:
index = [next(cycle_iter) for _ in range(data_size)]
out_index += index
index = torch.randperm(len(out_index), generator=g).numpy().tolist()
indices = [out_index[i] for i in index]
if self.round_up:
indices = (
indices *
int(self.total_size / len(indices) + 1))[:self.total_size]
indices = indices[self.rank:self.total_size:self.world_size]
return iter(indices)
def __len__(self) -> int:
"""The number of samples in this rank."""
return self.num_samples
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all replicas use a different
random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Iterator, Optional, Sequence, Sized
import torch
from mmengine.dist import get_dist_info, sync_random_seed
from mmengine.registry import DATA_SAMPLERS
from torch.utils.data import Sampler
@DATA_SAMPLERS.register_module()
class MultiDataSampler(Sampler):
"""The default data sampler for both distributed and non-distributed
environment.
It has several differences from the PyTorch ``DistributedSampler`` as
below:
1. This sampler supports non-distributed environment.
2. The round up behaviors are a little different.
- If ``round_up=True``, this sampler will add extra samples to make the
number of samples is evenly divisible by the world size. And
this behavior is the same as the ``DistributedSampler`` with
``drop_last=False``.
- If ``round_up=False``, this sampler won't remove or add any samples
while the ``DistributedSampler`` with ``drop_last=True`` will remove
tail samples.
Args:
dataset (Sized): The dataset.
dataset_ratio (Sequence(int)) The ratios of different datasets.
seed (int, optional): Random seed used to shuffle the sampler if
:attr:`shuffle=True`. This number should be identical across all
processes in the distributed group. Defaults to None.
round_up (bool): Whether to add extra samples to make the number of
samples evenly divisible by the world size. Defaults to True.
"""
def __init__(self,
dataset: Sized,
dataset_ratio: Sequence[int],
seed: Optional[int] = None,
round_up: bool = True) -> None:
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
self.dataset_ratio = dataset_ratio
if seed is None:
seed = sync_random_seed()
self.seed = seed
self.epoch = 0
self.round_up = round_up
if self.round_up:
self.num_samples = math.ceil(len(self.dataset) / world_size)
self.total_size = self.num_samples * self.world_size
else:
self.num_samples = math.ceil(
(len(self.dataset) - rank) / world_size)
self.total_size = len(self.dataset)
self.sizes = [len(dataset) for dataset in self.dataset.datasets]
dataset_weight = [
torch.ones(s) * max(self.sizes) / s * r / sum(self.dataset_ratio)
for i, (r, s) in enumerate(zip(self.dataset_ratio, self.sizes))
]
self.weights = torch.cat(dataset_weight)
def __iter__(self) -> Iterator[int]:
"""Iterate the indices."""
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.multinomial(
self.weights, len(self.weights), generator=g,
replacement=True).tolist()
# add extra samples to make it evenly divisible
if self.round_up:
indices = (
indices *
int(self.total_size / len(indices) + 1))[:self.total_size]
# subsample
indices = indices[self.rank:self.total_size:self.world_size]
return iter(indices)
def __len__(self) -> int:
"""The number of samples in this rank."""
return self.num_samples
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all replicas use a different
random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from typing import Iterator, List, Optional, Sized, Union
import numpy as np
import torch
from mmengine.dataset import BaseDataset
from mmengine.dist import get_dist_info, sync_random_seed
from torch.utils.data import Sampler
from mmdet.registry import DATA_SAMPLERS
@DATA_SAMPLERS.register_module()
class MultiSourceSampler(Sampler):
r"""Multi-Source Infinite Sampler.
According to the sampling ratio, sample data from different
datasets to form batches.
Args:
dataset (Sized): The dataset.
batch_size (int): Size of mini-batch.
source_ratio (list[int | float]): The sampling ratio of different
source datasets in a mini-batch.
shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
seed (int, optional): Random seed. If None, set a random seed.
Defaults to None.
Examples:
>>> dataset_type = 'ConcatDataset'
>>> sub_dataset_type = 'CocoDataset'
>>> data_root = 'data/coco/'
>>> sup_ann = '../coco_semi_annos/instances_train2017.1@10.json'
>>> unsup_ann = '../coco_semi_annos/' \
>>> 'instances_train2017.1@10-unlabeled.json'
>>> dataset = dict(type=dataset_type,
>>> datasets=[
>>> dict(
>>> type=sub_dataset_type,
>>> data_root=data_root,
>>> ann_file=sup_ann,
>>> data_prefix=dict(img='train2017/'),
>>> filter_cfg=dict(filter_empty_gt=True, min_size=32),
>>> pipeline=sup_pipeline),
>>> dict(
>>> type=sub_dataset_type,
>>> data_root=data_root,
>>> ann_file=unsup_ann,
>>> data_prefix=dict(img='train2017/'),
>>> filter_cfg=dict(filter_empty_gt=True, min_size=32),
>>> pipeline=unsup_pipeline),
>>> ])
>>> train_dataloader = dict(
>>> batch_size=5,
>>> num_workers=5,
>>> persistent_workers=True,
>>> sampler=dict(type='MultiSourceSampler',
>>> batch_size=5, source_ratio=[1, 4]),
>>> batch_sampler=None,
>>> dataset=dataset)
"""
def __init__(self,
dataset: Sized,
batch_size: int,
source_ratio: List[Union[int, float]],
shuffle: bool = True,
seed: Optional[int] = None) -> None:
assert hasattr(dataset, 'cumulative_sizes'),\
f'The dataset must be ConcatDataset, but get {dataset}'
assert isinstance(batch_size, int) and batch_size > 0, \
'batch_size must be a positive integer value, ' \
f'but got batch_size={batch_size}'
assert isinstance(source_ratio, list), \
f'source_ratio must be a list, but got source_ratio={source_ratio}'
assert len(source_ratio) == len(dataset.cumulative_sizes), \
'The length of source_ratio must be equal to ' \
f'the number of datasets, but got source_ratio={source_ratio}'
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
self.cumulative_sizes = [0] + dataset.cumulative_sizes
self.batch_size = batch_size
self.source_ratio = source_ratio
self.num_per_source = [
int(batch_size * sr / sum(source_ratio)) for sr in source_ratio
]
self.num_per_source[0] = batch_size - sum(self.num_per_source[1:])
assert sum(self.num_per_source) == batch_size, \
'The sum of num_per_source must be equal to ' \
f'batch_size, but get {self.num_per_source}'
self.seed = sync_random_seed() if seed is None else seed
self.shuffle = shuffle
self.source2inds = {
source: self._indices_of_rank(len(ds))
for source, ds in enumerate(dataset.datasets)
}
def _infinite_indices(self, sample_size: int) -> Iterator[int]:
"""Infinitely yield a sequence of indices."""
g = torch.Generator()
g.manual_seed(self.seed)
while True:
if self.shuffle:
yield from torch.randperm(sample_size, generator=g).tolist()
else:
yield from torch.arange(sample_size).tolist()
def _indices_of_rank(self, sample_size: int) -> Iterator[int]:
"""Slice the infinite indices by rank."""
yield from itertools.islice(
self._infinite_indices(sample_size), self.rank, None,
self.world_size)
def __iter__(self) -> Iterator[int]:
batch_buffer = []
while True:
for source, num in enumerate(self.num_per_source):
batch_buffer_per_source = []
for idx in self.source2inds[source]:
idx += self.cumulative_sizes[source]
batch_buffer_per_source.append(idx)
if len(batch_buffer_per_source) == num:
batch_buffer += batch_buffer_per_source
break
yield from batch_buffer
batch_buffer = []
def __len__(self) -> int:
return len(self.dataset)
def set_epoch(self, epoch: int) -> None:
"""Not supported in `epoch-based runner."""
pass
@DATA_SAMPLERS.register_module()
class GroupMultiSourceSampler(MultiSourceSampler):
r"""Group Multi-Source Infinite Sampler.
According to the sampling ratio, sample data from different
datasets but the same group to form batches.
Args:
dataset (Sized): The dataset.
batch_size (int): Size of mini-batch.
source_ratio (list[int | float]): The sampling ratio of different
source datasets in a mini-batch.
shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
seed (int, optional): Random seed. If None, set a random seed.
Defaults to None.
"""
def __init__(self,
dataset: BaseDataset,
batch_size: int,
source_ratio: List[Union[int, float]],
shuffle: bool = True,
seed: Optional[int] = None) -> None:
super().__init__(
dataset=dataset,
batch_size=batch_size,
source_ratio=source_ratio,
shuffle=shuffle,
seed=seed)
self._get_source_group_info()
self.group_source2inds = [{
source:
self._indices_of_rank(self.group2size_per_source[source][group])
for source in range(len(dataset.datasets))
} for group in range(len(self.group_ratio))]
def _get_source_group_info(self) -> None:
self.group2size_per_source = [{0: 0, 1: 0}, {0: 0, 1: 0}]
self.group2inds_per_source = [{0: [], 1: []}, {0: [], 1: []}]
for source, dataset in enumerate(self.dataset.datasets):
for idx in range(len(dataset)):
data_info = dataset.get_data_info(idx)
width, height = data_info['width'], data_info['height']
group = 0 if width < height else 1
self.group2size_per_source[source][group] += 1
self.group2inds_per_source[source][group].append(idx)
self.group_sizes = np.zeros(2, dtype=np.int64)
for group2size in self.group2size_per_source:
for group, size in group2size.items():
self.group_sizes[group] += size
self.group_ratio = self.group_sizes / sum(self.group_sizes)
def __iter__(self) -> Iterator[int]:
batch_buffer = []
while True:
group = np.random.choice(
list(range(len(self.group_ratio))), p=self.group_ratio)
for source, num in enumerate(self.num_per_source):
batch_buffer_per_source = []
for idx in self.group_source2inds[group][source]:
idx = self.group2inds_per_source[source][group][
idx] + self.cumulative_sizes[source]
batch_buffer_per_source.append(idx)
if len(batch_buffer_per_source) == num:
batch_buffer += batch_buffer_per_source
break
yield from batch_buffer
batch_buffer = []
# Copyright (c) OpenMMLab. All rights reserved.
import math
import random
from typing import Iterator, Optional, Sized
import numpy as np
from mmengine.dataset import ClassBalancedDataset, ConcatDataset
from mmengine.dist import get_dist_info, sync_random_seed
from torch.utils.data import Sampler
from mmdet.registry import DATA_SAMPLERS
from ..base_video_dataset import BaseVideoDataset
@DATA_SAMPLERS.register_module()
class TrackImgSampler(Sampler):
"""Sampler that providing image-level sampling outputs for video datasets
in tracking tasks. It could be both used in both distributed and
non-distributed environment.
If using the default sampler in pytorch, the subsequent data receiver will
get one video, which is not desired in some cases:
(Take a non-distributed environment as an example)
1. In test mode, we want only one image is fed into the data pipeline. This
is in consideration of memory usage since feeding the whole video commonly
requires a large amount of memory (>=20G on MOTChallenge17 dataset), which
is not available in some machines.
2. In training mode, we may want to make sure all the images in one video
are randomly sampled once in one epoch and this can not be guaranteed in
the default sampler in pytorch.
Args:
dataset (Sized): Dataset used for sampling.
seed (int, optional): random seed used to shuffle the sampler. This
number should be identical across all processes in the distributed
group. Defaults to None.
"""
def __init__(
self,
dataset: Sized,
seed: Optional[int] = None,
) -> None:
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.epoch = 0
if seed is None:
self.seed = sync_random_seed()
else:
self.seed = seed
self.dataset = dataset
self.indices = []
# Hard code here to handle different dataset wrapper
if isinstance(self.dataset, ConcatDataset):
cat_datasets = self.dataset.datasets
assert isinstance(
cat_datasets[0], BaseVideoDataset
), f'expected BaseVideoDataset, but got {type(cat_datasets[0])}'
self.test_mode = cat_datasets[0].test_mode
assert not self.test_mode, "'ConcatDataset' should not exist in "
'test mode'
for dataset in cat_datasets:
num_videos = len(dataset)
for video_ind in range(num_videos):
self.indices.extend([
(video_ind, frame_ind) for frame_ind in range(
dataset.get_len_per_video(video_ind))
])
elif isinstance(self.dataset, ClassBalancedDataset):
ori_dataset = self.dataset.dataset
assert isinstance(
ori_dataset, BaseVideoDataset
), f'expected BaseVideoDataset, but got {type(ori_dataset)}'
self.test_mode = ori_dataset.test_mode
assert not self.test_mode, "'ClassBalancedDataset' should not "
'exist in test mode'
video_indices = self.dataset.repeat_indices
for index in video_indices:
self.indices.extend([(index, frame_ind) for frame_ind in range(
ori_dataset.get_len_per_video(index))])
else:
assert isinstance(
self.dataset, BaseVideoDataset
), 'TrackImgSampler is only supported in BaseVideoDataset or '
'dataset wrapper: ClassBalancedDataset and ConcatDataset, but '
f'got {type(self.dataset)} '
self.test_mode = self.dataset.test_mode
num_videos = len(self.dataset)
if self.test_mode:
# in test mode, the images belong to the same video must be put
# on the same device.
if num_videos < self.world_size:
raise ValueError(f'only {num_videos} videos loaded,'
f'but {self.world_size} gpus were given.')
chunks = np.array_split(
list(range(num_videos)), self.world_size)
for videos_inds in chunks:
indices_chunk = []
for video_ind in videos_inds:
indices_chunk.extend([
(video_ind, frame_ind) for frame_ind in range(
self.dataset.get_len_per_video(video_ind))
])
self.indices.append(indices_chunk)
else:
for video_ind in range(num_videos):
self.indices.extend([
(video_ind, frame_ind) for frame_ind in range(
self.dataset.get_len_per_video(video_ind))
])
if self.test_mode:
self.num_samples = len(self.indices[self.rank])
self.total_size = sum(
[len(index_list) for index_list in self.indices])
else:
self.num_samples = int(
math.ceil(len(self.indices) * 1.0 / self.world_size))
self.total_size = self.num_samples * self.world_size
def __iter__(self) -> Iterator:
if self.test_mode:
# in test mode, the order of frames can not be shuffled.
indices = self.indices[self.rank]
else:
# deterministically shuffle based on epoch
rng = random.Random(self.epoch + self.seed)
indices = rng.sample(self.indices, len(self.indices))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.world_size]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
# Copyright (c) OpenMMLab. All rights reserved.
from .augment_wrappers import AutoAugment, RandAugment
from .colorspace import (AutoContrast, Brightness, Color, ColorTransform,
Contrast, Equalize, Invert, Posterize, Sharpness,
Solarize, SolarizeAdd)
from .formatting import (ImageToTensor, PackDetInputs, PackReIDInputs,
PackTrackInputs, ToTensor, Transpose)
from .frame_sampling import BaseFrameSample, UniformRefFrameSample
from .geometric import (GeomTransform, Rotate, ShearX, ShearY, TranslateX,
TranslateY)
from .instaboost import InstaBoost
from .loading import (FilterAnnotations, InferencerLoader, LoadAnnotations,
LoadEmptyAnnotations, LoadImageFromNDArray,
LoadMultiChannelImageFromFiles, LoadPanopticAnnotations,
LoadProposals, LoadTrackAnnotations)
from .text_transformers import LoadTextAnnotations, RandomSamplingNegPos
from .transformers_glip import GTBoxSubOne_GLIP, RandomFlip_GLIP
from .transforms import (Albu, CachedMixUp, CachedMosaic, CopyPaste, CutOut,
Expand, FixScaleResize, FixShapeResize,
MinIoURandomCrop, MixUp, Mosaic, Pad,
PhotoMetricDistortion, RandomAffine,
RandomCenterCropPad, RandomCrop, RandomErasing,
RandomFlip, RandomShift, Resize, ResizeShortestEdge,
SegRescale, YOLOXHSVRandomAug)
from .wrappers import MultiBranch, ProposalBroadcaster, RandomOrder
__all__ = [
'PackDetInputs', 'ToTensor', 'ImageToTensor', 'Transpose',
'LoadImageFromNDArray', 'LoadAnnotations', 'LoadPanopticAnnotations',
'LoadMultiChannelImageFromFiles', 'LoadProposals', 'Resize', 'RandomFlip',
'RandomCrop', 'SegRescale', 'MinIoURandomCrop', 'Expand',
'PhotoMetricDistortion', 'Albu', 'InstaBoost', 'RandomCenterCropPad',
'AutoAugment', 'CutOut', 'ShearX', 'ShearY', 'Rotate', 'Color', 'Equalize',
'Brightness', 'Contrast', 'TranslateX', 'TranslateY', 'RandomShift',
'Mosaic', 'MixUp', 'RandomAffine', 'YOLOXHSVRandomAug', 'CopyPaste',
'FilterAnnotations', 'Pad', 'GeomTransform', 'ColorTransform',
'RandAugment', 'Sharpness', 'Solarize', 'SolarizeAdd', 'Posterize',
'AutoContrast', 'Invert', 'MultiBranch', 'RandomErasing',
'LoadEmptyAnnotations', 'RandomOrder', 'CachedMosaic', 'CachedMixUp',
'FixShapeResize', 'ProposalBroadcaster', 'InferencerLoader',
'LoadTrackAnnotations', 'BaseFrameSample', 'UniformRefFrameSample',
'PackTrackInputs', 'PackReIDInputs', 'FixScaleResize',
'ResizeShortestEdge', 'GTBoxSubOne_GLIP', 'RandomFlip_GLIP',
'RandomSamplingNegPos', 'LoadTextAnnotations'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import numpy as np
from mmcv.transforms import RandomChoice
from mmcv.transforms.utils import cache_randomness
from mmengine.config import ConfigDict
from mmdet.registry import TRANSFORMS
# AutoAugment uses reinforcement learning to search for
# some widely useful data augmentation strategies,
# here we provide AUTOAUG_POLICIES_V0.
# For AUTOAUG_POLICIES_V0, each tuple is an augmentation
# operation of the form (operation, probability, magnitude).
# Each element in policies is a policy that will be applied
# sequentially on the image.
# RandAugment defines a data augmentation search space, RANDAUG_SPACE,
# sampling 1~3 data augmentations each time, and
# setting the magnitude of each data augmentation randomly,
# which will be applied sequentially on the image.
_MAX_LEVEL = 10
AUTOAUG_POLICIES_V0 = [
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
[('Rotate', 1.0, 7), ('TranslateY', 0.8, 9)],
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
]
def policies_v0():
"""Autoaugment policies that was used in AutoAugment Paper."""
policies = list()
for policy_args in AUTOAUG_POLICIES_V0:
policy = list()
for args in policy_args:
policy.append(dict(type=args[0], prob=args[1], level=args[2]))
policies.append(policy)
return policies
RANDAUG_SPACE = [[dict(type='AutoContrast')], [dict(type='Equalize')],
[dict(type='Invert')], [dict(type='Rotate')],
[dict(type='Posterize')], [dict(type='Solarize')],
[dict(type='SolarizeAdd')], [dict(type='Color')],
[dict(type='Contrast')], [dict(type='Brightness')],
[dict(type='Sharpness')], [dict(type='ShearX')],
[dict(type='ShearY')], [dict(type='TranslateX')],
[dict(type='TranslateY')]]
def level_to_mag(level: Optional[int], min_mag: float,
max_mag: float) -> float:
"""Map from level to magnitude."""
if level is None:
return round(np.random.rand() * (max_mag - min_mag) + min_mag, 1)
else:
return round(level / _MAX_LEVEL * (max_mag - min_mag) + min_mag, 1)
@TRANSFORMS.register_module()
class AutoAugment(RandomChoice):
"""Auto augmentation.
This data augmentation is proposed in `AutoAugment: Learning
Augmentation Policies from Data <https://arxiv.org/abs/1805.09501>`_
and in `Learning Data Augmentation Strategies for Object Detection
<https://arxiv.org/pdf/1906.11172>`_.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_bboxes_labels (np.int64) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_ignore_flags (bool) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- img_shape
- gt_bboxes
- gt_bboxes_labels
- gt_masks
- gt_ignore_flags
- gt_seg_map
Added Keys:
- homography_matrix
Args:
policies (List[List[Union[dict, ConfigDict]]]):
The policies of auto augmentation.Each policy in ``policies``
is a specific augmentation policy, and is composed by several
augmentations. When AutoAugment is called, a random policy in
``policies`` will be selected to augment images.
Defaults to policy_v0().
prob (list[float], optional): The probabilities associated
with each policy. The length should be equal to the policy
number and the sum should be 1. If not given, a uniform
distribution will be assumed. Defaults to None.
Examples:
>>> policies = [
>>> [
>>> dict(type='Sharpness', prob=0.0, level=8),
>>> dict(type='ShearX', prob=0.4, level=0,)
>>> ],
>>> [
>>> dict(type='Rotate', prob=0.6, level=10),
>>> dict(type='Color', prob=1.0, level=6)
>>> ]
>>> ]
>>> augmentation = AutoAugment(policies)
>>> img = np.ones(100, 100, 3)
>>> gt_bboxes = np.ones(10, 4)
>>> results = dict(img=img, gt_bboxes=gt_bboxes)
>>> results = augmentation(results)
"""
def __init__(self,
policies: List[List[Union[dict, ConfigDict]]] = policies_v0(),
prob: Optional[List[float]] = None) -> None:
assert isinstance(policies, list) and len(policies) > 0, \
'Policies must be a non-empty list.'
for policy in policies:
assert isinstance(policy, list) and len(policy) > 0, \
'Each policy in policies must be a non-empty list.'
for augment in policy:
assert isinstance(augment, dict) and 'type' in augment, \
'Each specific augmentation must be a dict with key' \
' "type".'
super().__init__(transforms=policies, prob=prob)
self.policies = policies
def __repr__(self) -> str:
return f'{self.__class__.__name__}(policies={self.policies}, ' \
f'prob={self.prob})'
@TRANSFORMS.register_module()
class RandAugment(RandomChoice):
"""Rand augmentation.
This data augmentation is proposed in `RandAugment:
Practical automated data augmentation with a reduced
search space <https://arxiv.org/abs/1909.13719>`_.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_bboxes_labels (np.int64) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_ignore_flags (bool) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- img_shape
- gt_bboxes
- gt_bboxes_labels
- gt_masks
- gt_ignore_flags
- gt_seg_map
Added Keys:
- homography_matrix
Args:
aug_space (List[List[Union[dict, ConfigDict]]]): The augmentation space
of rand augmentation. Each augmentation transform in ``aug_space``
is a specific transform, and is composed by several augmentations.
When RandAugment is called, a random transform in ``aug_space``
will be selected to augment images. Defaults to aug_space.
aug_num (int): Number of augmentation to apply equentially.
Defaults to 2.
prob (list[float], optional): The probabilities associated with
each augmentation. The length should be equal to the
augmentation space and the sum should be 1. If not given,
a uniform distribution will be assumed. Defaults to None.
Examples:
>>> aug_space = [
>>> dict(type='Sharpness'),
>>> dict(type='ShearX'),
>>> dict(type='Color'),
>>> ],
>>> augmentation = RandAugment(aug_space)
>>> img = np.ones(100, 100, 3)
>>> gt_bboxes = np.ones(10, 4)
>>> results = dict(img=img, gt_bboxes=gt_bboxes)
>>> results = augmentation(results)
"""
def __init__(self,
aug_space: List[Union[dict, ConfigDict]] = RANDAUG_SPACE,
aug_num: int = 2,
prob: Optional[List[float]] = None) -> None:
assert isinstance(aug_space, list) and len(aug_space) > 0, \
'Augmentation space must be a non-empty list.'
for aug in aug_space:
assert isinstance(aug, list) and len(aug) == 1, \
'Each augmentation in aug_space must be a list.'
for transform in aug:
assert isinstance(transform, dict) and 'type' in transform, \
'Each specific transform must be a dict with key' \
' "type".'
super().__init__(transforms=aug_space, prob=prob)
self.aug_space = aug_space
self.aug_num = aug_num
@cache_randomness
def random_pipeline_index(self):
indices = np.arange(len(self.transforms))
return np.random.choice(
indices, self.aug_num, p=self.prob, replace=False)
def transform(self, results: dict) -> dict:
"""Transform function to use RandAugment.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with RandAugment.
"""
for idx in self.random_pipeline_index():
results = self.transforms[idx](results)
return results
def __repr__(self) -> str:
return f'{self.__class__.__name__}(' \
f'aug_space={self.aug_space}, '\
f'aug_num={self.aug_num}, ' \
f'prob={self.prob})'
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Optional
import mmcv
import numpy as np
from mmcv.transforms import BaseTransform
from mmcv.transforms.utils import cache_randomness
from mmdet.registry import TRANSFORMS
from .augment_wrappers import _MAX_LEVEL, level_to_mag
@TRANSFORMS.register_module()
class ColorTransform(BaseTransform):
"""Base class for color transformations. All color transformations need to
inherit from this base class. ``ColorTransform`` unifies the class
attributes and class functions of color transformations (Color, Brightness,
Contrast, Sharpness, Solarize, SolarizeAdd, Equalize, AutoContrast, Invert,
and Posterize), and only distort color channels, without impacting the
locations of the instances.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing the geometric
transformation and should be in range [0, 1]. Defaults to 1.0.
level (int, optional): The level should be in range [0, _MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for color transformation.
Defaults to 0.1.
max_mag (float): The maximum magnitude for color transformation.
Defaults to 1.9.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.1,
max_mag: float = 1.9) -> None:
assert 0 <= prob <= 1.0, f'The probability of the transformation ' \
f'should be in range [0,1], got {prob}.'
assert level is None or isinstance(level, int), \
f'The level should be None or type int, got {type(level)}.'
assert level is None or 0 <= level <= _MAX_LEVEL, \
f'The level should be in range [0,{_MAX_LEVEL}], got {level}.'
assert isinstance(min_mag, float), \
f'min_mag should be type float, got {type(min_mag)}.'
assert isinstance(max_mag, float), \
f'max_mag should be type float, got {type(max_mag)}.'
assert min_mag <= max_mag, \
f'min_mag should smaller than max_mag, ' \
f'got min_mag={min_mag} and max_mag={max_mag}'
self.prob = prob
self.level = level
self.min_mag = min_mag
self.max_mag = max_mag
def _transform_img(self, results: dict, mag: float) -> None:
"""Transform the image."""
pass
@cache_randomness
def _random_disable(self):
"""Randomly disable the transform."""
return np.random.rand() > self.prob
@cache_randomness
def _get_mag(self):
"""Get the magnitude of the transform."""
return level_to_mag(self.level, self.min_mag, self.max_mag)
def transform(self, results: dict) -> dict:
"""Transform function for images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Transformed results.
"""
if self._random_disable():
return results
mag = self._get_mag()
self._transform_img(results, mag)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, '
repr_str += f'level={self.level}, '
repr_str += f'min_mag={self.min_mag}, '
repr_str += f'max_mag={self.max_mag})'
return repr_str
@TRANSFORMS.register_module()
class Color(ColorTransform):
"""Adjust the color balance of the image, in a manner similar to the
controls on a colour TV set. A magnitude=0 gives a black & white image,
whereas magnitude=1 gives the original image. The bboxes, masks and
segmentations are not modified.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing Color transformation.
Defaults to 1.0.
level (int, optional): Should be in range [0,_MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for Color transformation.
Defaults to 0.1.
max_mag (float): The maximum magnitude for Color transformation.
Defaults to 1.9.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.1,
max_mag: float = 1.9) -> None:
assert 0. <= min_mag <= 2.0, \
f'min_mag for Color should be in range [0,2], got {min_mag}.'
assert 0. <= max_mag <= 2.0, \
f'max_mag for Color should be in range [0,2], got {max_mag}.'
super().__init__(
prob=prob, level=level, min_mag=min_mag, max_mag=max_mag)
def _transform_img(self, results: dict, mag: float) -> None:
"""Apply Color transformation to image."""
# NOTE defaultly the image should be BGR format
img = results['img']
results['img'] = mmcv.adjust_color(img, mag).astype(img.dtype)
@TRANSFORMS.register_module()
class Brightness(ColorTransform):
"""Adjust the brightness of the image. A magnitude=0 gives a black image,
whereas magnitude=1 gives the original image. The bboxes, masks and
segmentations are not modified.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing Brightness transformation.
Defaults to 1.0.
level (int, optional): Should be in range [0,_MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for Brightness transformation.
Defaults to 0.1.
max_mag (float): The maximum magnitude for Brightness transformation.
Defaults to 1.9.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.1,
max_mag: float = 1.9) -> None:
assert 0. <= min_mag <= 2.0, \
f'min_mag for Brightness should be in range [0,2], got {min_mag}.'
assert 0. <= max_mag <= 2.0, \
f'max_mag for Brightness should be in range [0,2], got {max_mag}.'
super().__init__(
prob=prob, level=level, min_mag=min_mag, max_mag=max_mag)
def _transform_img(self, results: dict, mag: float) -> None:
"""Adjust the brightness of image."""
img = results['img']
results['img'] = mmcv.adjust_brightness(img, mag).astype(img.dtype)
@TRANSFORMS.register_module()
class Contrast(ColorTransform):
"""Control the contrast of the image. A magnitude=0 gives a gray image,
whereas magnitude=1 gives the original imageThe bboxes, masks and
segmentations are not modified.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing Contrast transformation.
Defaults to 1.0.
level (int, optional): Should be in range [0,_MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for Contrast transformation.
Defaults to 0.1.
max_mag (float): The maximum magnitude for Contrast transformation.
Defaults to 1.9.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.1,
max_mag: float = 1.9) -> None:
assert 0. <= min_mag <= 2.0, \
f'min_mag for Contrast should be in range [0,2], got {min_mag}.'
assert 0. <= max_mag <= 2.0, \
f'max_mag for Contrast should be in range [0,2], got {max_mag}.'
super().__init__(
prob=prob, level=level, min_mag=min_mag, max_mag=max_mag)
def _transform_img(self, results: dict, mag: float) -> None:
"""Adjust the image contrast."""
img = results['img']
results['img'] = mmcv.adjust_contrast(img, mag).astype(img.dtype)
@TRANSFORMS.register_module()
class Sharpness(ColorTransform):
"""Adjust images sharpness. A positive magnitude would enhance the
sharpness and a negative magnitude would make the image blurry. A
magnitude=0 gives the origin img.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing Sharpness transformation.
Defaults to 1.0.
level (int, optional): Should be in range [0,_MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for Sharpness transformation.
Defaults to 0.1.
max_mag (float): The maximum magnitude for Sharpness transformation.
Defaults to 1.9.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.1,
max_mag: float = 1.9) -> None:
assert 0. <= min_mag <= 2.0, \
f'min_mag for Sharpness should be in range [0,2], got {min_mag}.'
assert 0. <= max_mag <= 2.0, \
f'max_mag for Sharpness should be in range [0,2], got {max_mag}.'
super().__init__(
prob=prob, level=level, min_mag=min_mag, max_mag=max_mag)
def _transform_img(self, results: dict, mag: float) -> None:
"""Adjust the image sharpness."""
img = results['img']
results['img'] = mmcv.adjust_sharpness(img, mag).astype(img.dtype)
@TRANSFORMS.register_module()
class Solarize(ColorTransform):
"""Solarize images (Invert all pixels above a threshold value of
magnitude.).
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing Solarize transformation.
Defaults to 1.0.
level (int, optional): Should be in range [0,_MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for Solarize transformation.
Defaults to 0.0.
max_mag (float): The maximum magnitude for Solarize transformation.
Defaults to 256.0.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.0,
max_mag: float = 256.0) -> None:
assert 0. <= min_mag <= 256.0, f'min_mag for Solarize should be ' \
f'in range [0, 256], got {min_mag}.'
assert 0. <= max_mag <= 256.0, f'max_mag for Solarize should be ' \
f'in range [0, 256], got {max_mag}.'
super().__init__(
prob=prob, level=level, min_mag=min_mag, max_mag=max_mag)
def _transform_img(self, results: dict, mag: float) -> None:
"""Invert all pixel values above magnitude."""
img = results['img']
results['img'] = mmcv.solarize(img, mag).astype(img.dtype)
@TRANSFORMS.register_module()
class SolarizeAdd(ColorTransform):
"""SolarizeAdd images. For each pixel in the image that is less than 128,
add an additional amount to it decided by the magnitude.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing SolarizeAdd
transformation. Defaults to 1.0.
level (int, optional): Should be in range [0,_MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for SolarizeAdd transformation.
Defaults to 0.0.
max_mag (float): The maximum magnitude for SolarizeAdd transformation.
Defaults to 110.0.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.0,
max_mag: float = 110.0) -> None:
assert 0. <= min_mag <= 110.0, f'min_mag for SolarizeAdd should be ' \
f'in range [0, 110], got {min_mag}.'
assert 0. <= max_mag <= 110.0, f'max_mag for SolarizeAdd should be ' \
f'in range [0, 110], got {max_mag}.'
super().__init__(
prob=prob, level=level, min_mag=min_mag, max_mag=max_mag)
def _transform_img(self, results: dict, mag: float) -> None:
"""SolarizeAdd the image."""
img = results['img']
img_solarized = np.where(img < 128, np.minimum(img + mag, 255), img)
results['img'] = img_solarized.astype(img.dtype)
@TRANSFORMS.register_module()
class Posterize(ColorTransform):
"""Posterize images (reduce the number of bits for each color channel).
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing Posterize
transformation. Defaults to 1.0.
level (int, optional): Should be in range [0,_MAX_LEVEL].
If level is None, it will generate from [0, _MAX_LEVEL] randomly.
Defaults to None.
min_mag (float): The minimum magnitude for Posterize transformation.
Defaults to 0.0.
max_mag (float): The maximum magnitude for Posterize transformation.
Defaults to 4.0.
"""
def __init__(self,
prob: float = 1.0,
level: Optional[int] = None,
min_mag: float = 0.0,
max_mag: float = 4.0) -> None:
assert 0. <= min_mag <= 8.0, f'min_mag for Posterize should be ' \
f'in range [0, 8], got {min_mag}.'
assert 0. <= max_mag <= 8.0, f'max_mag for Posterize should be ' \
f'in range [0, 8], got {max_mag}.'
super().__init__(
prob=prob, level=level, min_mag=min_mag, max_mag=max_mag)
def _transform_img(self, results: dict, mag: float) -> None:
"""Posterize the image."""
img = results['img']
results['img'] = mmcv.posterize(img, math.ceil(mag)).astype(img.dtype)
@TRANSFORMS.register_module()
class Equalize(ColorTransform):
"""Equalize the image histogram. The bboxes, masks and segmentations are
not modified.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing Equalize transformation.
Defaults to 1.0.
level (int, optional): No use for Equalize transformation.
Defaults to None.
min_mag (float): No use for Equalize transformation. Defaults to 0.1.
max_mag (float): No use for Equalize transformation. Defaults to 1.9.
"""
def _transform_img(self, results: dict, mag: float) -> None:
"""Equalizes the histogram of one image."""
img = results['img']
results['img'] = mmcv.imequalize(img).astype(img.dtype)
@TRANSFORMS.register_module()
class AutoContrast(ColorTransform):
"""Auto adjust image contrast.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing AutoContrast should
be in range [0, 1]. Defaults to 1.0.
level (int, optional): No use for AutoContrast transformation.
Defaults to None.
min_mag (float): No use for AutoContrast transformation.
Defaults to 0.1.
max_mag (float): No use for AutoContrast transformation.
Defaults to 1.9.
"""
def _transform_img(self, results: dict, mag: float) -> None:
"""Auto adjust image contrast."""
img = results['img']
results['img'] = mmcv.auto_contrast(img).astype(img.dtype)
@TRANSFORMS.register_module()
class Invert(ColorTransform):
"""Invert images.
Required Keys:
- img
Modified Keys:
- img
Args:
prob (float): The probability for performing invert therefore should
be in range [0, 1]. Defaults to 1.0.
level (int, optional): No use for Invert transformation.
Defaults to None.
min_mag (float): No use for Invert transformation. Defaults to 0.1.
max_mag (float): No use for Invert transformation. Defaults to 1.9.
"""
def _transform_img(self, results: dict, mag: float) -> None:
"""Invert the image."""
img = results['img']
results['img'] = mmcv.iminvert(img).astype(img.dtype)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence
import numpy as np
from mmcv.transforms import to_tensor
from mmcv.transforms.base import BaseTransform
from mmengine.structures import InstanceData, PixelData
from mmdet.registry import TRANSFORMS
from mmdet.structures import DetDataSample, ReIDDataSample, TrackDataSample
from mmdet.structures.bbox import BaseBoxes
@TRANSFORMS.register_module()
class PackDetInputs(BaseTransform):
"""Pack the inputs data for the detection / semantic segmentation /
panoptic segmentation.
The ``img_meta`` item is always populated. The contents of the
``img_meta`` dictionary depends on ``meta_keys``. By default this includes:
- ``img_id``: id of the image
- ``img_path``: path to the image file
- ``ori_shape``: original shape of the image as a tuple (h, w)
- ``img_shape``: shape of the image input to the network as a tuple \
(h, w). Note that images may be zero padded on the \
bottom/right if the batch tensor is larger than this shape.
- ``scale_factor``: a float indicating the preprocessing scale
- ``flip``: a boolean indicating if image flip transform was used
- ``flip_direction``: the flipping direction
Args:
meta_keys (Sequence[str], optional): Meta keys to be converted to
``mmcv.DataContainer`` and collected in ``data[img_metas]``.
Default: ``('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction')``
"""
mapping_table = {
'gt_bboxes': 'bboxes',
'gt_bboxes_labels': 'labels',
'gt_masks': 'masks'
}
def __init__(self,
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction')):
self.meta_keys = meta_keys
def transform(self, results: dict) -> dict:
"""Method to pack the input data.
Args:
results (dict): Result dict from the data pipeline.
Returns:
dict:
- 'inputs' (obj:`torch.Tensor`): The forward data of models.
- 'data_sample' (obj:`DetDataSample`): The annotation info of the
sample.
"""
packed_results = dict()
if 'img' in results:
img = results['img']
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
# To improve the computational speed by by 3-5 times, apply:
# If image is not contiguous, use
# `numpy.transpose()` followed by `numpy.ascontiguousarray()`
# If image is already contiguous, use
# `torch.permute()` followed by `torch.contiguous()`
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533
# for more details
if not img.flags.c_contiguous:
img = np.ascontiguousarray(img.transpose(2, 0, 1))
img = to_tensor(img)
else:
img = to_tensor(img).permute(2, 0, 1).contiguous()
packed_results['inputs'] = img
if 'gt_ignore_flags' in results:
valid_idx = np.where(results['gt_ignore_flags'] == 0)[0]
ignore_idx = np.where(results['gt_ignore_flags'] == 1)[0]
data_sample = DetDataSample()
instance_data = InstanceData()
ignore_instance_data = InstanceData()
for key in self.mapping_table.keys():
if key not in results:
continue
if key == 'gt_masks' or isinstance(results[key], BaseBoxes):
if 'gt_ignore_flags' in results:
instance_data[
self.mapping_table[key]] = results[key][valid_idx]
ignore_instance_data[
self.mapping_table[key]] = results[key][ignore_idx]
else:
instance_data[self.mapping_table[key]] = results[key]
else:
if 'gt_ignore_flags' in results:
instance_data[self.mapping_table[key]] = to_tensor(
results[key][valid_idx])
ignore_instance_data[self.mapping_table[key]] = to_tensor(
results[key][ignore_idx])
else:
instance_data[self.mapping_table[key]] = to_tensor(
results[key])
data_sample.gt_instances = instance_data
data_sample.ignored_instances = ignore_instance_data
if 'proposals' in results:
proposals = InstanceData(
bboxes=to_tensor(results['proposals']),
scores=to_tensor(results['proposals_scores']))
data_sample.proposals = proposals
if 'gt_seg_map' in results:
gt_sem_seg_data = dict(
sem_seg=to_tensor(results['gt_seg_map'][None, ...].copy()))
gt_sem_seg_data = PixelData(**gt_sem_seg_data)
if 'ignore_index' in results:
metainfo = dict(ignore_index=results['ignore_index'])
gt_sem_seg_data.set_metainfo(metainfo)
data_sample.gt_sem_seg = gt_sem_seg_data
img_meta = {}
for key in self.meta_keys:
if key in results:
img_meta[key] = results[key]
data_sample.set_metainfo(img_meta)
packed_results['data_samples'] = data_sample
return packed_results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(meta_keys={self.meta_keys})'
return repr_str
@TRANSFORMS.register_module()
class ToTensor:
"""Convert some results to :obj:`torch.Tensor` by given keys.
Args:
keys (Sequence[str]): Keys that need to be converted to Tensor.
"""
def __init__(self, keys):
self.keys = keys
def __call__(self, results):
"""Call function to convert data in results to :obj:`torch.Tensor`.
Args:
results (dict): Result dict contains the data to convert.
Returns:
dict: The result dict contains the data converted
to :obj:`torch.Tensor`.
"""
for key in self.keys:
results[key] = to_tensor(results[key])
return results
def __repr__(self):
return self.__class__.__name__ + f'(keys={self.keys})'
@TRANSFORMS.register_module()
class ImageToTensor:
"""Convert image to :obj:`torch.Tensor` by given keys.
The dimension order of input image is (H, W, C). The pipeline will convert
it to (C, H, W). If only 2 dimension (H, W) is given, the output would be
(1, H, W).
Args:
keys (Sequence[str]): Key of images to be converted to Tensor.
"""
def __init__(self, keys):
self.keys = keys
def __call__(self, results):
"""Call function to convert image in results to :obj:`torch.Tensor` and
transpose the channel order.
Args:
results (dict): Result dict contains the image data to convert.
Returns:
dict: The result dict contains the image converted
to :obj:`torch.Tensor` and permuted to (C, H, W) order.
"""
for key in self.keys:
img = results[key]
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
results[key] = to_tensor(img).permute(2, 0, 1).contiguous()
return results
def __repr__(self):
return self.__class__.__name__ + f'(keys={self.keys})'
@TRANSFORMS.register_module()
class Transpose:
"""Transpose some results by given keys.
Args:
keys (Sequence[str]): Keys of results to be transposed.
order (Sequence[int]): Order of transpose.
"""
def __init__(self, keys, order):
self.keys = keys
self.order = order
def __call__(self, results):
"""Call function to transpose the channel order of data in results.
Args:
results (dict): Result dict contains the data to transpose.
Returns:
dict: The result dict contains the data transposed to \
``self.order``.
"""
for key in self.keys:
results[key] = results[key].transpose(self.order)
return results
def __repr__(self):
return self.__class__.__name__ + \
f'(keys={self.keys}, order={self.order})'
@TRANSFORMS.register_module()
class WrapFieldsToLists:
"""Wrap fields of the data dictionary into lists for evaluation.
This class can be used as a last step of a test or validation
pipeline for single image evaluation or inference.
Example:
>>> test_pipeline = [
>>> dict(type='LoadImageFromFile'),
>>> dict(type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
>>> dict(type='Pad', size_divisor=32),
>>> dict(type='ImageToTensor', keys=['img']),
>>> dict(type='Collect', keys=['img']),
>>> dict(type='WrapFieldsToLists')
>>> ]
"""
def __call__(self, results):
"""Call function to wrap fields into lists.
Args:
results (dict): Result dict contains the data to wrap.
Returns:
dict: The result dict where value of ``self.keys`` are wrapped \
into list.
"""
# Wrap dict fields into lists
for key, val in results.items():
results[key] = [val]
return results
def __repr__(self):
return f'{self.__class__.__name__}()'
@TRANSFORMS.register_module()
class PackTrackInputs(BaseTransform):
"""Pack the inputs data for the multi object tracking and video instance
segmentation. All the information of images are packed to ``inputs``. All
the information except images are packed to ``data_samples``. In order to
get the original annotaiton and meta info, we add `instances` key into meta
keys.
Args:
meta_keys (Sequence[str]): Meta keys to be collected in
``data_sample.metainfo``. Defaults to None.
default_meta_keys (tuple): Default meta keys. Defaults to ('img_id',
'img_path', 'ori_shape', 'img_shape', 'scale_factor',
'flip', 'flip_direction', 'frame_id', 'is_video_data',
'video_id', 'video_length', 'instances').
"""
mapping_table = {
'gt_bboxes': 'bboxes',
'gt_bboxes_labels': 'labels',
'gt_masks': 'masks',
'gt_instances_ids': 'instances_ids'
}
def __init__(self,
meta_keys: Optional[dict] = None,
default_meta_keys: tuple = ('img_id', 'img_path', 'ori_shape',
'img_shape', 'scale_factor',
'flip', 'flip_direction',
'frame_id', 'video_id',
'video_length',
'ori_video_length', 'instances')):
self.meta_keys = default_meta_keys
if meta_keys is not None:
if isinstance(meta_keys, str):
meta_keys = (meta_keys, )
else:
assert isinstance(meta_keys, tuple), \
'meta_keys must be str or tuple'
self.meta_keys += meta_keys
def transform(self, results: dict) -> dict:
"""Method to pack the input data.
Args:
results (dict): Result dict from the data pipeline.
Returns:
dict:
- 'inputs' (dict[Tensor]): The forward data of models.
- 'data_samples' (obj:`TrackDataSample`): The annotation info of
the samples.
"""
packed_results = dict()
packed_results['inputs'] = dict()
# 1. Pack images
if 'img' in results:
imgs = results['img']
imgs = np.stack(imgs, axis=0)
imgs = imgs.transpose(0, 3, 1, 2)
packed_results['inputs'] = to_tensor(imgs)
# 2. Pack InstanceData
if 'gt_ignore_flags' in results:
gt_ignore_flags_list = results['gt_ignore_flags']
valid_idx_list, ignore_idx_list = [], []
for gt_ignore_flags in gt_ignore_flags_list:
valid_idx = np.where(gt_ignore_flags == 0)[0]
ignore_idx = np.where(gt_ignore_flags == 1)[0]
valid_idx_list.append(valid_idx)
ignore_idx_list.append(ignore_idx)
assert 'img_id' in results, "'img_id' must contained in the results "
'for counting the number of images'
num_imgs = len(results['img_id'])
instance_data_list = [InstanceData() for _ in range(num_imgs)]
ignore_instance_data_list = [InstanceData() for _ in range(num_imgs)]
for key in self.mapping_table.keys():
if key not in results:
continue
if key == 'gt_masks':
mapped_key = self.mapping_table[key]
gt_masks_list = results[key]
if 'gt_ignore_flags' in results:
for i, gt_mask in enumerate(gt_masks_list):
valid_idx, ignore_idx = valid_idx_list[
i], ignore_idx_list[i]
instance_data_list[i][mapped_key] = gt_mask[valid_idx]
ignore_instance_data_list[i][mapped_key] = gt_mask[
ignore_idx]
else:
for i, gt_mask in enumerate(gt_masks_list):
instance_data_list[i][mapped_key] = gt_mask
else:
anns_list = results[key]
if 'gt_ignore_flags' in results:
for i, ann in enumerate(anns_list):
valid_idx, ignore_idx = valid_idx_list[
i], ignore_idx_list[i]
instance_data_list[i][
self.mapping_table[key]] = to_tensor(
ann[valid_idx])
ignore_instance_data_list[i][
self.mapping_table[key]] = to_tensor(
ann[ignore_idx])
else:
for i, ann in enumerate(anns_list):
instance_data_list[i][
self.mapping_table[key]] = to_tensor(ann)
det_data_samples_list = []
for i in range(num_imgs):
det_data_sample = DetDataSample()
det_data_sample.gt_instances = instance_data_list[i]
det_data_sample.ignored_instances = ignore_instance_data_list[i]
det_data_samples_list.append(det_data_sample)
# 3. Pack metainfo
for key in self.meta_keys:
if key not in results:
continue
img_metas_list = results[key]
for i, img_meta in enumerate(img_metas_list):
det_data_samples_list[i].set_metainfo({f'{key}': img_meta})
track_data_sample = TrackDataSample()
track_data_sample.video_data_samples = det_data_samples_list
if 'key_frame_flags' in results:
key_frame_flags = np.asarray(results['key_frame_flags'])
key_frames_inds = np.where(key_frame_flags)[0].tolist()
ref_frames_inds = np.where(~key_frame_flags)[0].tolist()
track_data_sample.set_metainfo(
dict(key_frames_inds=key_frames_inds))
track_data_sample.set_metainfo(
dict(ref_frames_inds=ref_frames_inds))
packed_results['data_samples'] = track_data_sample
return packed_results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'meta_keys={self.meta_keys}, '
repr_str += f'default_meta_keys={self.default_meta_keys})'
return repr_str
@TRANSFORMS.register_module()
class PackReIDInputs(BaseTransform):
"""Pack the inputs data for the ReID. The ``meta_info`` item is always
populated. The contents of the ``meta_info`` dictionary depends on
``meta_keys``. By default this includes:
- ``img_path``: path to the image file.
- ``ori_shape``: original shape of the image as a tuple (H, W).
- ``img_shape``: shape of the image input to the network as a tuple
(H, W). Note that images may be zero padded on the bottom/right
if the batch tensor is larger than this shape.
- ``scale``: scale of the image as a tuple (W, H).
- ``scale_factor``: a float indicating the pre-processing scale.
- ``flip``: a boolean indicating if image flip transform was used.
- ``flip_direction``: the flipping direction.
Args:
meta_keys (Sequence[str], optional): The meta keys to saved in the
``metainfo`` of the packed ``data_sample``.
"""
default_meta_keys = ('img_path', 'ori_shape', 'img_shape', 'scale',
'scale_factor')
def __init__(self, meta_keys: Sequence[str] = ()) -> None:
self.meta_keys = self.default_meta_keys
if meta_keys is not None:
if isinstance(meta_keys, str):
meta_keys = (meta_keys, )
else:
assert isinstance(meta_keys, tuple), \
'meta_keys must be str or tuple.'
self.meta_keys += meta_keys
def transform(self, results: dict) -> dict:
"""Method to pack the input data.
Args:
results (dict): Result dict from the data pipeline.
Returns:
dict:
- 'inputs' (dict[Tensor]): The forward data of models.
- 'data_samples' (obj:`ReIDDataSample`): The meta info of the
sample.
"""
packed_results = dict(inputs=dict(), data_samples=None)
assert 'img' in results, 'Missing the key ``img``.'
_type = type(results['img'])
label = results['gt_label']
if _type == list:
img = results['img']
label = np.stack(label, axis=0) # (N,)
assert all([type(v) == _type for v in results.values()]), \
'All items in the results must have the same type.'
else:
img = [results['img']]
img = np.stack(img, axis=3) # (H, W, C, N)
img = img.transpose(3, 2, 0, 1) # (N, C, H, W)
img = np.ascontiguousarray(img)
packed_results['inputs'] = to_tensor(img)
data_sample = ReIDDataSample()
data_sample.set_gt_label(label)
meta_info = dict()
for key in self.meta_keys:
meta_info[key] = results[key]
data_sample.set_metainfo(meta_info)
packed_results['data_samples'] = data_sample
return packed_results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(meta_keys={self.meta_keys})'
return repr_str
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment