Commit a8562a56 authored by luopl's avatar luopl
Browse files

Initial commit

parents
Pipeline #1564 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa
# mmcv >= 2.0.1
# mmengine >= 0.8.0
from mmengine.config import read_base
with read_base():
from .rtmdet_l_8xb32_300e_coco import *
model.update(
dict(
backbone=dict(deepen_factor=0.67, widen_factor=0.75),
neck=dict(
in_channels=[192, 384, 768], out_channels=192, num_csp_blocks=2),
bbox_head=dict(in_channels=192, feat_channels=192)))
# Copyright (c) OpenMMLab. All rights reserved.
# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa
# mmcv >= 2.0.1
# mmengine >= 0.8.0
from mmengine.config import read_base
with read_base():
from .rtmdet_l_8xb32_300e_coco import *
from mmcv.transforms.loading import LoadImageFromFile
from mmcv.transforms.processing import RandomResize
from mmengine.hooks.ema_hook import EMAHook
from mmdet.datasets.transforms.formatting import PackDetInputs
from mmdet.datasets.transforms.loading import LoadAnnotations
from mmdet.datasets.transforms.transforms import (CachedMixUp, CachedMosaic,
Pad, RandomCrop, RandomFlip,
Resize, YOLOXHSVRandomAug)
from mmdet.engine.hooks.pipeline_switch_hook import PipelineSwitchHook
from mmdet.models.layers.ema import ExpMomentumEMA
checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-s_imagenet_600e.pth' # noqa
model.update(
dict(
backbone=dict(
deepen_factor=0.33,
widen_factor=0.5,
init_cfg=dict(
type='Pretrained', prefix='backbone.', checkpoint=checkpoint)),
neck=dict(
in_channels=[128, 256, 512], out_channels=128, num_csp_blocks=1),
bbox_head=dict(in_channels=128, feat_channels=128, exp_on_reg=False)))
train_pipeline = [
dict(type=LoadImageFromFile, backend_args=backend_args),
dict(type=LoadAnnotations, with_bbox=True),
dict(type=CachedMosaic, img_scale=(640, 640), pad_val=114.0),
dict(
type=RandomResize,
scale=(1280, 1280),
ratio_range=(0.5, 2.0),
resize_type=Resize,
keep_ratio=True),
dict(type=RandomCrop, crop_size=(640, 640)),
dict(type=YOLOXHSVRandomAug),
dict(type=RandomFlip, prob=0.5),
dict(type=Pad, size=(640, 640), pad_val=dict(img=(114, 114, 114))),
dict(
type=CachedMixUp,
img_scale=(640, 640),
ratio_range=(1.0, 1.0),
max_cached_images=20,
pad_val=(114, 114, 114)),
dict(type=PackDetInputs)
]
train_pipeline_stage2 = [
dict(type=LoadImageFromFile, backend_args=backend_args),
dict(type=LoadAnnotations, with_bbox=True),
dict(
type=RandomResize,
scale=(640, 640),
ratio_range=(0.5, 2.0),
resize_type=Resize,
keep_ratio=True),
dict(type=RandomCrop, crop_size=(640, 640)),
dict(type=YOLOXHSVRandomAug),
dict(type=RandomFlip, prob=0.5),
dict(type=Pad, size=(640, 640), pad_val=dict(img=(114, 114, 114))),
dict(type=PackDetInputs)
]
train_dataloader.update(dict(dataset=dict(pipeline=train_pipeline)))
custom_hooks = [
dict(
type=EMAHook,
ema_type=ExpMomentumEMA,
momentum=0.0002,
update_buffers=True,
priority=49),
dict(
type=PipelineSwitchHook,
switch_epoch=280,
switch_pipeline=train_pipeline_stage2)
]
# Copyright (c) OpenMMLab. All rights reserved.
# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa
# mmcv >= 2.0.1
# mmengine >= 0.8.0
from mmengine.config import read_base
with read_base():
from .rtmdet_s_8xb32_300e_coco import *
from mmcv.transforms.loading import LoadImageFromFile
from mmcv.transforms.processing import RandomResize
from mmdet.datasets.transforms.formatting import PackDetInputs
from mmdet.datasets.transforms.loading import LoadAnnotations
from mmdet.datasets.transforms.transforms import (CachedMixUp, CachedMosaic,
Pad, RandomCrop, RandomFlip,
Resize, YOLOXHSVRandomAug)
checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-tiny_imagenet_600e.pth' # noqa
model.update(
dict(
backbone=dict(
deepen_factor=0.167,
widen_factor=0.375,
init_cfg=dict(
type='Pretrained', prefix='backbone.', checkpoint=checkpoint)),
neck=dict(
in_channels=[96, 192, 384], out_channels=96, num_csp_blocks=1),
bbox_head=dict(in_channels=96, feat_channels=96, exp_on_reg=False)))
train_pipeline = [
dict(type=LoadImageFromFile, backend_args=backend_args),
dict(type=LoadAnnotations, with_bbox=True),
dict(
type=CachedMosaic,
img_scale=(640, 640),
pad_val=114.0,
max_cached_images=20,
random_pop=False),
dict(
type=RandomResize,
scale=(1280, 1280),
ratio_range=(0.5, 2.0),
resize_type=Resize,
keep_ratio=True),
dict(type=RandomCrop, crop_size=(640, 640)),
dict(type=YOLOXHSVRandomAug),
dict(type=RandomFlip, prob=0.5),
dict(type=Pad, size=(640, 640), pad_val=dict(img=(114, 114, 114))),
dict(
type=CachedMixUp,
img_scale=(640, 640),
ratio_range=(1.0, 1.0),
max_cached_images=10,
random_pop=False,
pad_val=(114, 114, 114),
prob=0.5),
dict(type=PackDetInputs)
]
train_dataloader.update(dict(dataset=dict(pipeline=train_pipeline)))
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.transforms.loading import LoadImageFromFile
from mmcv.transforms.processing import TestTimeAug
from mmdet.datasets.transforms.formatting import PackDetInputs
from mmdet.datasets.transforms.loading import LoadAnnotations
from mmdet.datasets.transforms.transforms import Pad, RandomFlip, Resize
from mmdet.models.test_time_augs.det_tta import DetTTAModel
tta_model = dict(
type=DetTTAModel,
tta_cfg=dict(nms=dict(type='nms', iou_threshold=0.6), max_per_img=100))
img_scales = [(640, 640), (320, 320), (960, 960)]
tta_pipeline = [
dict(type=LoadImageFromFile, backend_args=None),
dict(
type=TestTimeAug,
transforms=[
[dict(type=Resize, scale=s, keep_ratio=True) for s in img_scales],
[
# ``RandomFlip`` must be placed before ``Pad``, otherwise
# bounding box coordinates after flipping cannot be
# recovered correctly.
dict(type=RandomFlip, prob=1.),
dict(type=RandomFlip, prob=0.)
],
[
dict(
type=Pad,
size=(960, 960),
pad_val=dict(img=(114, 114, 114))),
],
[dict(type=LoadAnnotations, with_bbox=True)],
[
dict(
type=PackDetInputs,
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction'))
]
])
]
# Copyright (c) OpenMMLab. All rights reserved.
# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa
# mmcv >= 2.0.1
# mmengine >= 0.8.0
from mmengine.config import read_base
with read_base():
from .rtmdet_l_8xb32_300e_coco import *
model.update(
dict(
backbone=dict(deepen_factor=1.33, widen_factor=1.25),
neck=dict(
in_channels=[320, 640, 1280], out_channels=320, num_csp_blocks=4),
bbox_head=dict(in_channels=320, feat_channels=320)))
# Copyright (c) OpenMMLab. All rights reserved.
from .ade20k import (ADE20KInstanceDataset, ADE20KPanopticDataset,
ADE20KSegDataset)
from .base_det_dataset import BaseDetDataset
from .base_semseg_dataset import BaseSegDataset
from .base_video_dataset import BaseVideoDataset
from .cityscapes import CityscapesDataset
from .coco import CocoDataset
from .coco_caption import CocoCaptionDataset
from .coco_panoptic import CocoPanopticDataset
from .coco_semantic import CocoSegDataset
from .crowdhuman import CrowdHumanDataset
from .dataset_wrappers import ConcatDataset, MultiImageMixDataset
from .deepfashion import DeepFashionDataset
from .dod import DODDataset
from .dsdl import DSDLDetDataset
from .flickr30k import Flickr30kDataset
from .isaid import iSAIDDataset
from .lvis import LVISDataset, LVISV1Dataset, LVISV05Dataset
from .mdetr_style_refcoco import MDETRStyleRefCocoDataset
from .mot_challenge_dataset import MOTChallengeDataset
from .objects365 import Objects365V1Dataset, Objects365V2Dataset
from .odvg import ODVGDataset
from .openimages import OpenImagesChallengeDataset, OpenImagesDataset
from .refcoco import RefCocoDataset
from .reid_dataset import ReIDDataset
from .samplers import (AspectRatioBatchSampler, ClassAwareSampler,
CustomSampleSizeSampler, GroupMultiSourceSampler,
MultiSourceSampler, TrackAspectRatioBatchSampler,
TrackImgSampler)
from .utils import get_loading_pipeline
from .v3det import V3DetDataset
from .voc import VOCDataset
from .wider_face import WIDERFaceDataset
from .xml_style import XMLDataset
from .youtube_vis_dataset import YouTubeVISDataset
__all__ = [
'XMLDataset', 'CocoDataset', 'DeepFashionDataset', 'VOCDataset',
'CityscapesDataset', 'LVISDataset', 'LVISV05Dataset', 'LVISV1Dataset',
'WIDERFaceDataset', 'get_loading_pipeline', 'CocoPanopticDataset',
'MultiImageMixDataset', 'OpenImagesDataset', 'OpenImagesChallengeDataset',
'AspectRatioBatchSampler', 'ClassAwareSampler', 'MultiSourceSampler',
'GroupMultiSourceSampler', 'BaseDetDataset', 'CrowdHumanDataset',
'Objects365V1Dataset', 'Objects365V2Dataset', 'DSDLDetDataset',
'BaseVideoDataset', 'MOTChallengeDataset', 'TrackImgSampler',
'ReIDDataset', 'YouTubeVISDataset', 'TrackAspectRatioBatchSampler',
'ADE20KPanopticDataset', 'CocoCaptionDataset', 'RefCocoDataset',
'BaseSegDataset', 'ADE20KSegDataset', 'CocoSegDataset',
'ADE20KInstanceDataset', 'iSAIDDataset', 'V3DetDataset', 'ConcatDataset',
'ODVGDataset', 'MDETRStyleRefCocoDataset', 'DODDataset',
'CustomSampleSizeSampler', 'Flickr30kDataset'
]
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List
from mmengine import fileio
from mmdet.registry import DATASETS
from .base_semseg_dataset import BaseSegDataset
from .coco import CocoDataset
from .coco_panoptic import CocoPanopticDataset
ADE_PALETTE = [(120, 120, 120), (180, 120, 120), (6, 230, 230), (80, 50, 50),
(4, 200, 3), (120, 120, 80), (140, 140, 140), (204, 5, 255),
(230, 230, 230), (4, 250, 7), (224, 5, 255), (235, 255, 7),
(150, 5, 61), (120, 120, 70), (8, 255, 51), (255, 6, 82),
(143, 255, 140), (204, 255, 4), (255, 51, 7), (204, 70, 3),
(0, 102, 200), (61, 230, 250), (255, 6, 51), (11, 102, 255),
(255, 7, 71), (255, 9, 224), (9, 7, 230), (220, 220, 220),
(255, 9, 92), (112, 9, 255), (8, 255, 214), (7, 255, 224),
(255, 184, 6), (10, 255, 71), (255, 41, 10), (7, 255, 255),
(224, 255, 8), (102, 8, 255), (255, 61, 6), (255, 194, 7),
(255, 122, 8), (0, 255, 20), (255, 8, 41), (255, 5, 153),
(6, 51, 255), (235, 12, 255), (160, 150, 20), (0, 163, 255),
(140, 140, 140), (250, 10, 15), (20, 255, 0), (31, 255, 0),
(255, 31, 0), (255, 224, 0), (153, 255, 0), (0, 0, 255),
(255, 71, 0), (0, 235, 255), (0, 173, 255), (31, 0, 255),
(11, 200, 200), (255, 82, 0), (0, 255, 245), (0, 61, 255),
(0, 255, 112), (0, 255, 133), (255, 0, 0), (255, 163, 0),
(255, 102, 0), (194, 255, 0), (0, 143, 255), (51, 255, 0),
(0, 82, 255), (0, 255, 41), (0, 255, 173), (10, 0, 255),
(173, 255, 0), (0, 255, 153), (255, 92, 0), (255, 0, 255),
(255, 0, 245), (255, 0, 102), (255, 173, 0), (255, 0, 20),
(255, 184, 184), (0, 31, 255), (0, 255, 61), (0, 71, 255),
(255, 0, 204), (0, 255, 194), (0, 255, 82), (0, 10, 255),
(0, 112, 255), (51, 0, 255), (0, 194, 255), (0, 122, 255),
(0, 255, 163), (255, 153, 0), (0, 255, 10), (255, 112, 0),
(143, 255, 0), (82, 0, 255), (163, 255, 0), (255, 235, 0),
(8, 184, 170), (133, 0, 255), (0, 255, 92), (184, 0, 255),
(255, 0, 31), (0, 184, 255), (0, 214, 255), (255, 0, 112),
(92, 255, 0), (0, 224, 255), (112, 224, 255), (70, 184, 160),
(163, 0, 255), (153, 0, 255), (71, 255, 0), (255, 0, 163),
(255, 204, 0), (255, 0, 143), (0, 255, 235), (133, 255, 0),
(255, 0, 235), (245, 0, 255), (255, 0, 122), (255, 245, 0),
(10, 190, 212), (214, 255, 0), (0, 204, 255), (20, 0, 255),
(255, 255, 0), (0, 153, 255), (0, 41, 255), (0, 255, 204),
(41, 0, 255), (41, 255, 0), (173, 0, 255), (0, 245, 255),
(71, 0, 255), (122, 0, 255), (0, 255, 184), (0, 92, 255),
(184, 255, 0), (0, 133, 255), (255, 214, 0), (25, 194, 194),
(102, 255, 0), (92, 0, 255)]
@DATASETS.register_module()
class ADE20KPanopticDataset(CocoPanopticDataset):
METAINFO = {
'classes':
('bed', 'window', 'cabinet', 'person', 'door', 'table', 'curtain',
'chair', 'car', 'painting, picture', 'sofa', 'shelf', 'mirror',
'armchair', 'seat', 'fence', 'desk', 'wardrobe, closet, press',
'lamp', 'tub', 'rail', 'cushion', 'box', 'column, pillar',
'signboard, sign', 'chest of drawers, chest, bureau, dresser',
'counter', 'sink', 'fireplace', 'refrigerator, icebox', 'stairs',
'case, display case, showcase, vitrine',
'pool table, billiard table, snooker table', 'pillow',
'screen door, screen', 'bookcase', 'coffee table',
'toilet, can, commode, crapper, pot, potty, stool, throne', 'flower',
'book', 'bench', 'countertop', 'stove', 'palm, palm tree',
'kitchen island', 'computer', 'swivel chair', 'boat',
'arcade machine', 'bus', 'towel', 'light', 'truck', 'chandelier',
'awning, sunshade, sunblind', 'street lamp', 'booth', 'tv',
'airplane', 'clothes', 'pole',
'bannister, banister, balustrade, balusters, handrail',
'ottoman, pouf, pouffe, puff, hassock', 'bottle', 'van', 'ship',
'fountain', 'washer, automatic washer, washing machine',
'plaything, toy', 'stool', 'barrel, cask', 'basket, handbasket',
'bag', 'minibike, motorbike', 'oven', 'ball', 'food, solid food',
'step, stair', 'trade name', 'microwave', 'pot', 'animal', 'bicycle',
'dishwasher', 'screen', 'sculpture', 'hood, exhaust hood', 'sconce',
'vase', 'traffic light', 'tray', 'trash can', 'fan', 'plate',
'monitor', 'bulletin board', 'radiator', 'glass, drinking glass',
'clock', 'flag', 'wall', 'building', 'sky', 'floor', 'tree',
'ceiling', 'road, route', 'grass', 'sidewalk, pavement',
'earth, ground', 'mountain, mount', 'plant', 'water', 'house', 'sea',
'rug', 'field', 'rock, stone', 'base, pedestal, stand', 'sand',
'skyscraper', 'grandstand, covered stand', 'path', 'runway',
'stairway, staircase', 'river', 'bridge, span', 'blind, screen',
'hill', 'bar', 'hovel, hut, hutch, shack, shanty', 'tower',
'dirt track', 'land, ground, soil',
'escalator, moving staircase, moving stairway',
'buffet, counter, sideboard',
'poster, posting, placard, notice, bill, card', 'stage',
'conveyer belt, conveyor belt, conveyer, conveyor, transporter',
'canopy', 'pool', 'falls', 'tent', 'cradle', 'tank, storage tank',
'lake', 'blanket, cover', 'pier', 'crt screen', 'shower'),
'thing_classes':
('bed', 'window', 'cabinet', 'person', 'door', 'table', 'curtain',
'chair', 'car', 'painting, picture', 'sofa', 'shelf', 'mirror',
'armchair', 'seat', 'fence', 'desk', 'wardrobe, closet, press',
'lamp', 'tub', 'rail', 'cushion', 'box', 'column, pillar',
'signboard, sign', 'chest of drawers, chest, bureau, dresser',
'counter', 'sink', 'fireplace', 'refrigerator, icebox', 'stairs',
'case, display case, showcase, vitrine',
'pool table, billiard table, snooker table', 'pillow',
'screen door, screen', 'bookcase', 'coffee table',
'toilet, can, commode, crapper, pot, potty, stool, throne', 'flower',
'book', 'bench', 'countertop', 'stove', 'palm, palm tree',
'kitchen island', 'computer', 'swivel chair', 'boat',
'arcade machine', 'bus', 'towel', 'light', 'truck', 'chandelier',
'awning, sunshade, sunblind', 'street lamp', 'booth', 'tv',
'airplane', 'clothes', 'pole',
'bannister, banister, balustrade, balusters, handrail',
'ottoman, pouf, pouffe, puff, hassock', 'bottle', 'van', 'ship',
'fountain', 'washer, automatic washer, washing machine',
'plaything, toy', 'stool', 'barrel, cask', 'basket, handbasket',
'bag', 'minibike, motorbike', 'oven', 'ball', 'food, solid food',
'step, stair', 'trade name', 'microwave', 'pot', 'animal', 'bicycle',
'dishwasher', 'screen', 'sculpture', 'hood, exhaust hood', 'sconce',
'vase', 'traffic light', 'tray', 'trash can', 'fan', 'plate',
'monitor', 'bulletin board', 'radiator', 'glass, drinking glass',
'clock', 'flag'),
'stuff_classes':
('wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road, route',
'grass', 'sidewalk, pavement', 'earth, ground', 'mountain, mount',
'plant', 'water', 'house', 'sea', 'rug', 'field', 'rock, stone',
'base, pedestal, stand', 'sand', 'skyscraper',
'grandstand, covered stand', 'path', 'runway', 'stairway, staircase',
'river', 'bridge, span', 'blind, screen', 'hill', 'bar',
'hovel, hut, hutch, shack, shanty', 'tower', 'dirt track',
'land, ground, soil', 'escalator, moving staircase, moving stairway',
'buffet, counter, sideboard',
'poster, posting, placard, notice, bill, card', 'stage',
'conveyer belt, conveyor belt, conveyer, conveyor, transporter',
'canopy', 'pool', 'falls', 'tent', 'cradle', 'tank, storage tank',
'lake', 'blanket, cover', 'pier', 'crt screen', 'shower'),
'palette':
ADE_PALETTE
}
@DATASETS.register_module()
class ADE20KInstanceDataset(CocoDataset):
METAINFO = {
'classes':
('bed', 'windowpane', 'cabinet', 'person', 'door', 'table', 'curtain',
'chair', 'car', 'painting', 'sofa', 'shelf', 'mirror', 'armchair',
'seat', 'fence', 'desk', 'wardrobe', 'lamp', 'bathtub', 'railing',
'cushion', 'box', 'column', 'signboard', 'chest of drawers',
'counter', 'sink', 'fireplace', 'refrigerator', 'stairs', 'case',
'pool table', 'pillow', 'screen door', 'bookcase', 'coffee table',
'toilet', 'flower', 'book', 'bench', 'countertop', 'stove', 'palm',
'kitchen island', 'computer', 'swivel chair', 'boat',
'arcade machine', 'bus', 'towel', 'light', 'truck', 'chandelier',
'awning', 'streetlight', 'booth', 'television receiver', 'airplane',
'apparel', 'pole', 'bannister', 'ottoman', 'bottle', 'van', 'ship',
'fountain', 'washer', 'plaything', 'stool', 'barrel', 'basket', 'bag',
'minibike', 'oven', 'ball', 'food', 'step', 'trade name', 'microwave',
'pot', 'animal', 'bicycle', 'dishwasher', 'screen', 'sculpture',
'hood', 'sconce', 'vase', 'traffic light', 'tray', 'ashcan', 'fan',
'plate', 'monitor', 'bulletin board', 'radiator', 'glass', 'clock',
'flag'),
'palette': [(204, 5, 255), (230, 230, 230), (224, 5, 255),
(150, 5, 61), (8, 255, 51), (255, 6, 82), (255, 51, 7),
(204, 70, 3), (0, 102, 200), (255, 6, 51), (11, 102, 255),
(255, 7, 71), (220, 220, 220), (8, 255, 214),
(7, 255, 224), (255, 184, 6), (10, 255, 71), (7, 255, 255),
(224, 255, 8), (102, 8, 255), (255, 61, 6), (255, 194, 7),
(0, 255, 20), (255, 8, 41), (255, 5, 153), (6, 51, 255),
(235, 12, 255), (0, 163, 255), (250, 10, 15), (20, 255, 0),
(255, 224, 0), (0, 0, 255), (255, 71, 0), (0, 235, 255),
(0, 173, 255), (0, 255, 245), (0, 255, 112), (0, 255, 133),
(255, 0, 0), (255, 163, 0), (194, 255, 0), (0, 143, 255),
(51, 255, 0), (0, 82, 255), (0, 255, 41), (0, 255, 173),
(10, 0, 255), (173, 255, 0), (255, 92, 0), (255, 0, 245),
(255, 0, 102), (255, 173, 0), (255, 0, 20), (0, 31, 255),
(0, 255, 61), (0, 71, 255), (255, 0, 204), (0, 255, 194),
(0, 255, 82), (0, 112, 255), (51, 0, 255), (0, 122, 255),
(255, 153, 0), (0, 255, 10), (163, 255, 0), (255, 235, 0),
(8, 184, 170), (184, 0, 255), (255, 0, 31), (0, 214, 255),
(255, 0, 112), (92, 255, 0), (70, 184, 160), (163, 0, 255),
(71, 255, 0), (255, 0, 163), (255, 204, 0), (255, 0, 143),
(133, 255, 0), (255, 0, 235), (245, 0, 255), (255, 0, 122),
(255, 245, 0), (214, 255, 0), (0, 204, 255), (255, 255, 0),
(0, 153, 255), (0, 41, 255), (0, 255, 204), (41, 0, 255),
(41, 255, 0), (173, 0, 255), (0, 245, 255), (0, 255, 184),
(0, 92, 255), (184, 255, 0), (255, 214, 0), (25, 194, 194),
(102, 255, 0), (92, 0, 255)],
}
@DATASETS.register_module()
class ADE20KSegDataset(BaseSegDataset):
"""ADE20K dataset.
In segmentation map annotation for ADE20K, 0 stands for background, which
is not included in 150 categories. The ``img_suffix`` is fixed to '.jpg',
and ``seg_map_suffix`` is fixed to '.png'.
"""
METAINFO = dict(
classes=('wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road',
'bed ', 'windowpane', 'grass', 'cabinet', 'sidewalk',
'person', 'earth', 'door', 'table', 'mountain', 'plant',
'curtain', 'chair', 'car', 'water', 'painting', 'sofa',
'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair',
'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp',
'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
'skyscraper', 'fireplace', 'refrigerator', 'grandstand',
'path', 'stairs', 'runway', 'case', 'pool table', 'pillow',
'screen door', 'stairway', 'river', 'bridge', 'bookcase',
'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill',
'bench', 'countertop', 'stove', 'palm', 'kitchen island',
'computer', 'swivel chair', 'boat', 'bar', 'arcade machine',
'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
'chandelier', 'awning', 'streetlight', 'booth',
'television receiver', 'airplane', 'dirt track', 'apparel',
'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle',
'buffet', 'poster', 'stage', 'van', 'ship', 'fountain',
'conveyer belt', 'canopy', 'washer', 'plaything',
'swimming pool', 'stool', 'barrel', 'basket', 'waterfall',
'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food',
'step', 'tank', 'trade name', 'microwave', 'pot', 'animal',
'bicycle', 'lake', 'dishwasher', 'screen', 'blanket',
'sculpture', 'hood', 'sconce', 'vase', 'traffic light',
'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'plate',
'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
'clock', 'flag'),
palette=ADE_PALETTE)
def __init__(self,
img_suffix='.jpg',
seg_map_suffix='.png',
return_classes=False,
**kwargs) -> None:
self.return_classes = return_classes
super().__init__(
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
def load_data_list(self) -> List[dict]:
"""Load annotation from directory or annotation file.
Returns:
List[dict]: All data info of dataset.
"""
data_list = []
img_dir = self.data_prefix.get('img_path', None)
ann_dir = self.data_prefix.get('seg_map_path', None)
for img in fileio.list_dir_or_file(
dir_path=img_dir,
list_dir=False,
suffix=self.img_suffix,
recursive=True,
backend_args=self.backend_args):
data_info = dict(img_path=osp.join(img_dir, img))
if ann_dir is not None:
seg_map = img.replace(self.img_suffix, self.seg_map_suffix)
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
data_info['label_map'] = self.label_map
if self.return_classes:
data_info['text'] = list(self._metainfo['classes'])
data_list.append(data_info)
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
from .coco_api import COCO, COCOeval, COCOPanoptic
from .cocoeval_mp import COCOevalMP
__all__ = ['COCO', 'COCOeval', 'COCOPanoptic', 'COCOevalMP']
# Copyright (c) OpenMMLab. All rights reserved.
# This file add snake case alias for coco api
import warnings
from collections import defaultdict
from typing import List, Optional, Union
import pycocotools
from pycocotools.coco import COCO as _COCO
from pycocotools.cocoeval import COCOeval as _COCOeval
class COCO(_COCO):
"""This class is almost the same as official pycocotools package.
It implements some snake case function aliases. So that the COCO class has
the same interface as LVIS class.
"""
def __init__(self, annotation_file=None):
if getattr(pycocotools, '__version__', '0') >= '12.0.2':
warnings.warn(
'mmpycocotools is deprecated. Please install official pycocotools by "pip install pycocotools"', # noqa: E501
UserWarning)
super().__init__(annotation_file=annotation_file)
self.img_ann_map = self.imgToAnns
self.cat_img_map = self.catToImgs
def get_ann_ids(self, img_ids=[], cat_ids=[], area_rng=[], iscrowd=None):
return self.getAnnIds(img_ids, cat_ids, area_rng, iscrowd)
def get_cat_ids(self, cat_names=[], sup_names=[], cat_ids=[]):
return self.getCatIds(cat_names, sup_names, cat_ids)
def get_img_ids(self, img_ids=[], cat_ids=[]):
return self.getImgIds(img_ids, cat_ids)
def load_anns(self, ids):
return self.loadAnns(ids)
def load_cats(self, ids):
return self.loadCats(ids)
def load_imgs(self, ids):
return self.loadImgs(ids)
# just for the ease of import
COCOeval = _COCOeval
class COCOPanoptic(COCO):
"""This wrapper is for loading the panoptic style annotation file.
The format is shown in the CocoPanopticDataset class.
Args:
annotation_file (str, optional): Path of annotation file.
Defaults to None.
"""
def __init__(self, annotation_file: Optional[str] = None) -> None:
super(COCOPanoptic, self).__init__(annotation_file)
def createIndex(self) -> None:
"""Create index."""
# create index
print('creating index...')
# anns stores 'segment_id -> annotation'
anns, cats, imgs = {}, {}, {}
img_to_anns, cat_to_imgs = defaultdict(list), defaultdict(list)
if 'annotations' in self.dataset:
for ann in self.dataset['annotations']:
for seg_ann in ann['segments_info']:
# to match with instance.json
seg_ann['image_id'] = ann['image_id']
img_to_anns[ann['image_id']].append(seg_ann)
# segment_id is not unique in coco dataset orz...
# annotations from different images but
# may have same segment_id
if seg_ann['id'] in anns.keys():
anns[seg_ann['id']].append(seg_ann)
else:
anns[seg_ann['id']] = [seg_ann]
# filter out annotations from other images
img_to_anns_ = defaultdict(list)
for k, v in img_to_anns.items():
img_to_anns_[k] = [x for x in v if x['image_id'] == k]
img_to_anns = img_to_anns_
if 'images' in self.dataset:
for img_info in self.dataset['images']:
img_info['segm_file'] = img_info['file_name'].replace(
'.jpg', '.png')
imgs[img_info['id']] = img_info
if 'categories' in self.dataset:
for cat in self.dataset['categories']:
cats[cat['id']] = cat
if 'annotations' in self.dataset and 'categories' in self.dataset:
for ann in self.dataset['annotations']:
for seg_ann in ann['segments_info']:
cat_to_imgs[seg_ann['category_id']].append(ann['image_id'])
print('index created!')
self.anns = anns
self.imgToAnns = img_to_anns
self.catToImgs = cat_to_imgs
self.imgs = imgs
self.cats = cats
def load_anns(self,
ids: Union[List[int], int] = []) -> Optional[List[dict]]:
"""Load anns with the specified ids.
``self.anns`` is a list of annotation lists instead of a
list of annotations.
Args:
ids (Union[List[int], int]): Integer ids specifying anns.
Returns:
anns (List[dict], optional): Loaded ann objects.
"""
anns = []
if hasattr(ids, '__iter__') and hasattr(ids, '__len__'):
# self.anns is a list of annotation lists instead of
# a list of annotations
for id in ids:
anns += self.anns[id]
return anns
elif type(ids) == int:
return self.anns[ids]
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import itertools
import time
from collections import defaultdict
import numpy as np
import torch.multiprocessing as mp
from mmengine.logging import MMLogger
from pycocotools.cocoeval import COCOeval
from tqdm import tqdm
class COCOevalMP(COCOeval):
def _prepare(self):
'''
Prepare ._gts and ._dts for evaluation based on params
:return: None
'''
def _toMask(anns, coco):
# modify ann['segmentation'] by reference
for ann in anns:
rle = coco.annToRLE(ann)
ann['segmentation'] = rle
p = self.params
if p.useCats:
gts = []
dts = []
img_ids = set(p.imgIds)
cat_ids = set(p.catIds)
for gt in self.cocoGt.dataset['annotations']:
if (gt['category_id'] in cat_ids) and (gt['image_id']
in img_ids):
gts.append(gt)
for dt in self.cocoDt.dataset['annotations']:
if (dt['category_id'] in cat_ids) and (dt['image_id']
in img_ids):
dts.append(dt)
# gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)) # noqa
# dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)) # noqa
# gts=self.cocoGt.dataset['annotations']
# dts=self.cocoDt.dataset['annotations']
else:
gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds))
dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds))
# convert ground truth to mask if iouType == 'segm'
if p.iouType == 'segm':
_toMask(gts, self.cocoGt)
_toMask(dts, self.cocoDt)
# set ignore flag
for gt in gts:
gt['ignore'] = gt['ignore'] if 'ignore' in gt else 0
gt['ignore'] = 'iscrowd' in gt and gt['iscrowd']
if p.iouType == 'keypoints':
gt['ignore'] = (gt['num_keypoints'] == 0) or gt['ignore']
self._gts = defaultdict(list) # gt for evaluation
self._dts = defaultdict(list) # dt for evaluation
for gt in gts:
self._gts[gt['image_id'], gt['category_id']].append(gt)
for dt in dts:
self._dts[dt['image_id'], dt['category_id']].append(dt)
self.evalImgs = defaultdict(
list) # per-image per-category evaluation results
self.eval = {} # accumulated evaluation results
def evaluate(self):
"""Run per image evaluation on given images and store results (a list
of dict) in self.evalImgs.
:return: None
"""
tic = time.time()
print('Running per image evaluation...')
p = self.params
# add backward compatibility if useSegm is specified in params
if p.useSegm is not None:
p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
print('useSegm (deprecated) is not None. Running {} evaluation'.
format(p.iouType))
print('Evaluate annotation type *{}*'.format(p.iouType))
p.imgIds = list(np.unique(p.imgIds))
if p.useCats:
p.catIds = list(np.unique(p.catIds))
p.maxDets = sorted(p.maxDets)
self.params = p
# loop through images, area range, max detection number
catIds = p.catIds if p.useCats else [-1]
nproc = 8
split_size = len(catIds) // nproc
mp_params = []
for i in range(nproc):
begin = i * split_size
end = (i + 1) * split_size
if i == nproc - 1:
end = len(catIds)
mp_params.append((catIds[begin:end], ))
MMLogger.get_current_instance().info(
'start multi processing evaluation ...')
with mp.Pool(nproc) as pool:
self.evalImgs = pool.starmap(self._evaluateImg, mp_params)
self.evalImgs = list(itertools.chain(*self.evalImgs))
self._paramsEval = copy.deepcopy(self.params)
toc = time.time()
print('DONE (t={:0.2f}s).'.format(toc - tic))
def _evaluateImg(self, catids_chunk):
self._prepare()
p = self.params
maxDet = max(p.maxDets)
all_params = []
for catId in catids_chunk:
for areaRng in p.areaRng:
for imgId in p.imgIds:
all_params.append((catId, areaRng, imgId))
evalImgs = [
self.evaluateImg(imgId, catId, areaRng, maxDet)
for catId, areaRng, imgId in tqdm(all_params)
]
return evalImgs
def evaluateImg(self, imgId, catId, aRng, maxDet):
p = self.params
if p.useCats:
gt = self._gts[imgId, catId]
dt = self._dts[imgId, catId]
else:
gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
if len(gt) == 0 and len(dt) == 0:
return None
for g in gt:
if g['ignore'] or (g['area'] < aRng[0] or g['area'] > aRng[1]):
g['_ignore'] = 1
else:
g['_ignore'] = 0
# sort dt highest score first, sort gt ignore last
gtind = np.argsort([g['_ignore'] for g in gt], kind='mergesort')
gt = [gt[i] for i in gtind]
dtind = np.argsort([-d['score'] for d in dt], kind='mergesort')
dt = [dt[i] for i in dtind[0:maxDet]]
iscrowd = [int(o['iscrowd']) for o in gt]
# load computed ious
# ious = self.ious[imgId, catId][:, gtind] if len(self.ious[imgId, catId]) > 0 else self.ious[imgId, catId] # noqa
ious = self.computeIoU(imgId, catId)
ious = ious[:, gtind] if len(ious) > 0 else ious
T = len(p.iouThrs)
G = len(gt)
D = len(dt)
gtm = np.zeros((T, G))
dtm = np.zeros((T, D))
gtIg = np.array([g['_ignore'] for g in gt])
dtIg = np.zeros((T, D))
if not len(ious) == 0:
for tind, t in enumerate(p.iouThrs):
for dind, d in enumerate(dt):
# information about best match so far (m=-1 -> unmatched)
iou = min([t, 1 - 1e-10])
m = -1
for gind, g in enumerate(gt):
# if this gt already matched, and not a crowd, continue
if gtm[tind, gind] > 0 and not iscrowd[gind]:
continue
# if dt matched to reg gt, and on ignore gt, stop
if m > -1 and gtIg[m] == 0 and gtIg[gind] == 1:
break
# continue to next gt unless better match made
if ious[dind, gind] < iou:
continue
# if match successful and best so far,
# store appropriately
iou = ious[dind, gind]
m = gind
# if match made store id of match for both dt and gt
if m == -1:
continue
dtIg[tind, dind] = gtIg[m]
dtm[tind, dind] = gt[m]['id']
gtm[tind, m] = d['id']
# set unmatched detections outside of area range to ignore
a = np.array([d['area'] < aRng[0] or d['area'] > aRng[1]
for d in dt]).reshape((1, len(dt)))
dtIg = np.logical_or(dtIg, np.logical_and(dtm == 0, np.repeat(a, T,
0)))
# store results for given image and category
return {
'image_id': imgId,
'category_id': catId,
'aRng': aRng,
'maxDet': maxDet,
'dtIds': [d['id'] for d in dt],
'gtIds': [g['id'] for g in gt],
'dtMatches': dtm,
'gtMatches': gtm,
'dtScores': [d['score'] for d in dt],
'gtIgnore': gtIg,
'dtIgnore': dtIg,
}
def summarize(self):
"""Compute and display summary metrics for evaluation results.
Note this function can *only* be applied on the default parameter
setting
"""
def _summarize(ap=1, iouThr=None, areaRng='all', maxDets=100):
p = self.params
iStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}' # noqa
titleStr = 'Average Precision' if ap == 1 else 'Average Recall'
typeStr = '(AP)' if ap == 1 else '(AR)'
iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \
if iouThr is None else '{:0.2f}'.format(iouThr)
aind = [
i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng
]
mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
if ap == 1:
# dimension of precision: [TxRxKxAxM]
s = self.eval['precision']
# IoU
if iouThr is not None:
t = np.where(iouThr == p.iouThrs)[0]
s = s[t]
s = s[:, :, :, aind, mind]
else:
# dimension of recall: [TxKxAxM]
s = self.eval['recall']
if iouThr is not None:
t = np.where(iouThr == p.iouThrs)[0]
s = s[t]
s = s[:, :, aind, mind]
if len(s[s > -1]) == 0:
mean_s = -1
else:
mean_s = np.mean(s[s > -1])
print(
iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets,
mean_s))
return mean_s
def _summarizeDets():
stats = []
stats.append(_summarize(1, maxDets=self.params.maxDets[-1]))
stats.append(
_summarize(1, iouThr=.5, maxDets=self.params.maxDets[-1]))
stats.append(
_summarize(1, iouThr=.75, maxDets=self.params.maxDets[-1]))
for area_rng in ('small', 'medium', 'large'):
stats.append(
_summarize(
1, areaRng=area_rng, maxDets=self.params.maxDets[-1]))
for max_det in self.params.maxDets:
stats.append(_summarize(0, maxDets=max_det))
for area_rng in ('small', 'medium', 'large'):
stats.append(
_summarize(
0, areaRng=area_rng, maxDets=self.params.maxDets[-1]))
stats = np.array(stats)
return stats
def _summarizeKps():
stats = np.zeros((10, ))
stats[0] = _summarize(1, maxDets=20)
stats[1] = _summarize(1, maxDets=20, iouThr=.5)
stats[2] = _summarize(1, maxDets=20, iouThr=.75)
stats[3] = _summarize(1, maxDets=20, areaRng='medium')
stats[4] = _summarize(1, maxDets=20, areaRng='large')
stats[5] = _summarize(0, maxDets=20)
stats[6] = _summarize(0, maxDets=20, iouThr=.5)
stats[7] = _summarize(0, maxDets=20, iouThr=.75)
stats[8] = _summarize(0, maxDets=20, areaRng='medium')
stats[9] = _summarize(0, maxDets=20, areaRng='large')
return stats
if not self.eval:
raise Exception('Please run accumulate() first')
iouType = self.params.iouType
if iouType == 'segm' or iouType == 'bbox':
summarize = _summarizeDets
elif iouType == 'keypoints':
summarize = _summarizeKps
self.stats = summarize()
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List, Optional
from mmengine.dataset import BaseDataset
from mmengine.fileio import load
from mmengine.utils import is_abs
from ..registry import DATASETS
@DATASETS.register_module()
class BaseDetDataset(BaseDataset):
"""Base dataset for detection.
Args:
proposal_file (str, optional): Proposals file path. Defaults to None.
file_client_args (dict): Arguments to instantiate the
corresponding backend in mmdet <= 3.0.0rc6. Defaults to None.
backend_args (dict, optional): Arguments to instantiate the
corresponding backend. Defaults to None.
return_classes (bool): Whether to return class information
for open vocabulary-based algorithms. Defaults to False.
caption_prompt (dict, optional): Prompt for captioning.
Defaults to None.
"""
def __init__(self,
*args,
seg_map_suffix: str = '.png',
proposal_file: Optional[str] = None,
file_client_args: dict = None,
backend_args: dict = None,
return_classes: bool = False,
caption_prompt: Optional[dict] = None,
**kwargs) -> None:
self.seg_map_suffix = seg_map_suffix
self.proposal_file = proposal_file
self.backend_args = backend_args
self.return_classes = return_classes
self.caption_prompt = caption_prompt
if self.caption_prompt is not None:
assert self.return_classes, \
'return_classes must be True when using caption_prompt'
if file_client_args is not None:
raise RuntimeError(
'The `file_client_args` is deprecated, '
'please use `backend_args` instead, please refer to'
'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501
)
super().__init__(*args, **kwargs)
def full_init(self) -> None:
"""Load annotation file and set ``BaseDataset._fully_initialized`` to
True.
If ``lazy_init=False``, ``full_init`` will be called during the
instantiation and ``self._fully_initialized`` will be set to True. If
``obj._fully_initialized=False``, the class method decorated by
``force_full_init`` will call ``full_init`` automatically.
Several steps to initialize annotation:
- load_data_list: Load annotations from annotation file.
- load_proposals: Load proposals from proposal file, if
`self.proposal_file` is not None.
- filter data information: Filter annotations according to
filter_cfg.
- slice_data: Slice dataset according to ``self._indices``
- serialize_data: Serialize ``self.data_list`` if
``self.serialize_data`` is True.
"""
if self._fully_initialized:
return
# load data information
self.data_list = self.load_data_list()
# get proposals from file
if self.proposal_file is not None:
self.load_proposals()
# filter illegal data, such as data that has no annotations.
self.data_list = self.filter_data()
# Get subset data according to indices.
if self._indices is not None:
self.data_list = self._get_unserialized_subset(self._indices)
# serialize data_list
if self.serialize_data:
self.data_bytes, self.data_address = self._serialize_data()
self._fully_initialized = True
def load_proposals(self) -> None:
"""Load proposals from proposals file.
The `proposals_list` should be a dict[img_path: proposals]
with the same length as `data_list`. And the `proposals` should be
a `dict` or :obj:`InstanceData` usually contains following keys.
- bboxes (np.ndarry): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
- scores (np.ndarry): Classification scores, has a shape
(num_instance, ).
"""
# TODO: Add Unit Test after fully support Dump-Proposal Metric
if not is_abs(self.proposal_file):
self.proposal_file = osp.join(self.data_root, self.proposal_file)
proposals_list = load(
self.proposal_file, backend_args=self.backend_args)
assert len(self.data_list) == len(proposals_list)
for data_info in self.data_list:
img_path = data_info['img_path']
# `file_name` is the key to obtain the proposals from the
# `proposals_list`.
file_name = osp.join(
osp.split(osp.split(img_path)[0])[-1],
osp.split(img_path)[-1])
proposals = proposals_list[file_name]
data_info['proposals'] = proposals
def get_cat_ids(self, idx: int) -> List[int]:
"""Get COCO category ids by index.
Args:
idx (int): Index of data.
Returns:
List[int]: All categories in the image of specified index.
"""
instances = self.get_data_info(idx)['instances']
return [instance['bbox_label'] for instance in instances]
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from typing import Callable, Dict, List, Optional, Sequence, Union
import mmengine
import mmengine.fileio as fileio
import numpy as np
from mmengine.dataset import BaseDataset, Compose
from mmdet.registry import DATASETS
@DATASETS.register_module()
class BaseSegDataset(BaseDataset):
"""Custom dataset for semantic segmentation. An example of file structure
is as followed.
.. code-block:: none
├── data
│ ├── my_dataset
│ │ ├── img_dir
│ │ │ ├── train
│ │ │ │ ├── xxx{img_suffix}
│ │ │ │ ├── yyy{img_suffix}
│ │ │ │ ├── zzz{img_suffix}
│ │ │ ├── val
│ │ ├── ann_dir
│ │ │ ├── train
│ │ │ │ ├── xxx{seg_map_suffix}
│ │ │ │ ├── yyy{seg_map_suffix}
│ │ │ │ ├── zzz{seg_map_suffix}
│ │ │ ├── val
The img/gt_semantic_seg pair of BaseSegDataset should be of the same
except suffix. A valid img/gt_semantic_seg filename pair should be like
``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
in the suffix). If split is given, then ``xxx`` is specified in txt file.
Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
Please refer to ``docs/en/tutorials/new_dataset.md`` for more details.
Args:
ann_file (str): Annotation file path. Defaults to ''.
metainfo (dict, optional): Meta information for dataset, such as
specify classes to load. Defaults to None.
data_root (str, optional): The root directory for ``data_prefix`` and
``ann_file``. Defaults to None.
data_prefix (dict, optional): Prefix for training data. Defaults to
dict(img_path=None, seg_map_path=None).
img_suffix (str): Suffix of images. Default: '.jpg'
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
filter_cfg (dict, optional): Config for filter data. Defaults to None.
indices (int or Sequence[int], optional): Support using first few
data in annotation file to facilitate training/testing on a smaller
dataset. Defaults to None which means using all ``data_infos``.
serialize_data (bool, optional): Whether to hold memory using
serialized objects, when enabled, data loader workers can use
shared RAM from master process instead of making a copy. Defaults
to True.
pipeline (list, optional): Processing pipeline. Defaults to [].
test_mode (bool, optional): ``test_mode=True`` means in test phase.
Defaults to False.
lazy_init (bool, optional): Whether to load annotation during
instantiation. In some cases, such as visualization, only the meta
information of the dataset is needed, which is not necessary to
load annotation file. ``Basedataset`` can skip load annotations to
save time by set ``lazy_init=True``. Defaults to False.
use_label_map (bool, optional): Whether to use label map.
Defaults to False.
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
None img. The maximum extra number of cycles to get a valid
image. Defaults to 1000.
backend_args (dict, Optional): Arguments to instantiate a file backend.
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
for details. Defaults to None.
Notes: mmcv>=2.0.0rc4 required.
"""
METAINFO: dict = dict()
def __init__(self,
ann_file: str = '',
img_suffix='.jpg',
seg_map_suffix='.png',
metainfo: Optional[dict] = None,
data_root: Optional[str] = None,
data_prefix: dict = dict(img_path='', seg_map_path=''),
filter_cfg: Optional[dict] = None,
indices: Optional[Union[int, Sequence[int]]] = None,
serialize_data: bool = True,
pipeline: List[Union[dict, Callable]] = [],
test_mode: bool = False,
lazy_init: bool = False,
use_label_map: bool = False,
max_refetch: int = 1000,
backend_args: Optional[dict] = None) -> None:
self.img_suffix = img_suffix
self.seg_map_suffix = seg_map_suffix
self.backend_args = backend_args.copy() if backend_args else None
self.data_root = data_root
self.data_prefix = copy.copy(data_prefix)
self.ann_file = ann_file
self.filter_cfg = copy.deepcopy(filter_cfg)
self._indices = indices
self.serialize_data = serialize_data
self.test_mode = test_mode
self.max_refetch = max_refetch
self.data_list: List[dict] = []
self.data_bytes: np.ndarray
# Set meta information.
self._metainfo = self._load_metainfo(copy.deepcopy(metainfo))
# Get label map for custom classes
new_classes = self._metainfo.get('classes', None)
self.label_map = self.get_label_map(
new_classes) if use_label_map else None
self._metainfo.update(dict(label_map=self.label_map))
# Update palette based on label map or generate palette
# if it is not defined
updated_palette = self._update_palette()
self._metainfo.update(dict(palette=updated_palette))
# Join paths.
if self.data_root is not None:
self._join_prefix()
# Build pipeline.
self.pipeline = Compose(pipeline)
# Full initialize the dataset.
if not lazy_init:
self.full_init()
if test_mode:
assert self._metainfo.get('classes') is not None, \
'dataset metainfo `classes` should be specified when testing'
@classmethod
def get_label_map(cls,
new_classes: Optional[Sequence] = None
) -> Union[Dict, None]:
"""Require label mapping.
The ``label_map`` is a dictionary, its keys are the old label ids and
its values are the new label ids, and is used for changing pixel
labels in load_annotations. If and only if old classes in cls.METAINFO
is not equal to new classes in self._metainfo and nether of them is not
None, `label_map` is not None.
Args:
new_classes (list, tuple, optional): The new classes name from
metainfo. Default to None.
Returns:
dict, optional: The mapping from old classes in cls.METAINFO to
new classes in self._metainfo
"""
old_classes = cls.METAINFO.get('classes', None)
if (new_classes is not None and old_classes is not None
and list(new_classes) != list(old_classes)):
label_map = {}
if not set(new_classes).issubset(cls.METAINFO['classes']):
raise ValueError(
f'new classes {new_classes} is not a '
f'subset of classes {old_classes} in METAINFO.')
for i, c in enumerate(old_classes):
if c not in new_classes:
# 0 is background
label_map[i] = 0
else:
label_map[i] = new_classes.index(c)
return label_map
else:
return None
def _update_palette(self) -> list:
"""Update palette after loading metainfo.
If length of palette is equal to classes, just return the palette.
If palette is not defined, it will randomly generate a palette.
If classes is updated by customer, it will return the subset of
palette.
Returns:
Sequence: Palette for current dataset.
"""
palette = self._metainfo.get('palette', [])
classes = self._metainfo.get('classes', [])
# palette does match classes
if len(palette) == len(classes):
return palette
if len(palette) == 0:
# Get random state before set seed, and restore
# random state later.
# It will prevent loss of randomness, as the palette
# may be different in each iteration if not specified.
# See: https://github.com/open-mmlab/mmdetection/issues/5844
state = np.random.get_state()
np.random.seed(42)
# random palette
new_palette = np.random.randint(
0, 255, size=(len(classes), 3)).tolist()
np.random.set_state(state)
elif len(palette) >= len(classes) and self.label_map is not None:
new_palette = []
# return subset of palette
for old_id, new_id in sorted(
self.label_map.items(), key=lambda x: x[1]):
# 0 is background
if new_id != 0:
new_palette.append(palette[old_id])
new_palette = type(palette)(new_palette)
elif len(palette) >= len(classes):
# Allow palette length is greater than classes.
return palette
else:
raise ValueError('palette does not match classes '
f'as metainfo is {self._metainfo}.')
return new_palette
def load_data_list(self) -> List[dict]:
"""Load annotation from directory or annotation file.
Returns:
list[dict]: All data info of dataset.
"""
data_list = []
img_dir = self.data_prefix.get('img_path', None)
ann_dir = self.data_prefix.get('seg_map_path', None)
if not osp.isdir(self.ann_file) and self.ann_file:
assert osp.isfile(self.ann_file), \
f'Failed to load `ann_file` {self.ann_file}'
lines = mmengine.list_from_file(
self.ann_file, backend_args=self.backend_args)
for line in lines:
img_name = line.strip()
data_info = dict(
img_path=osp.join(img_dir, img_name + self.img_suffix))
if ann_dir is not None:
seg_map = img_name + self.seg_map_suffix
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
data_info['label_map'] = self.label_map
data_list.append(data_info)
else:
for img in fileio.list_dir_or_file(
dir_path=img_dir,
list_dir=False,
suffix=self.img_suffix,
recursive=True,
backend_args=self.backend_args):
data_info = dict(img_path=osp.join(img_dir, img))
if ann_dir is not None:
seg_map = img.replace(self.img_suffix, self.seg_map_suffix)
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
data_info['label_map'] = self.label_map
data_list.append(data_info)
data_list = sorted(data_list, key=lambda x: x['img_path'])
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from collections import defaultdict
from typing import Any, List, Tuple
import mmengine.fileio as fileio
from mmengine.dataset import BaseDataset
from mmengine.logging import print_log
from mmdet.datasets.api_wrappers import COCO
from mmdet.registry import DATASETS
@DATASETS.register_module()
class BaseVideoDataset(BaseDataset):
"""Base video dataset for VID, MOT and VIS tasks."""
META = dict(classes=None)
# ann_id is unique in coco dataset.
ANN_ID_UNIQUE = True
def __init__(self, *args, backend_args: dict = None, **kwargs):
self.backend_args = backend_args
super().__init__(*args, **kwargs)
def load_data_list(self) -> Tuple[List[dict], List]:
"""Load annotations from an annotation file named as ``self.ann_file``.
Returns:
tuple(list[dict], list): A list of annotation and a list of
valid data indices.
"""
with fileio.get_local_path(self.ann_file) as local_path:
self.coco = COCO(local_path)
# The order of returned `cat_ids` will not
# change with the order of the classes
self.cat_ids = self.coco.get_cat_ids(
cat_names=self.metainfo['classes'])
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.cat_img_map = copy.deepcopy(self.coco.cat_img_map)
# used in `filter_data`
self.img_ids_with_ann = set()
img_ids = self.coco.get_img_ids()
total_ann_ids = []
# if ``video_id`` is not in the annotation file, we will assign a big
# unique video_id for this video.
single_video_id = 100000
videos = {}
for img_id in img_ids:
raw_img_info = self.coco.load_imgs([img_id])[0]
raw_img_info['img_id'] = img_id
if 'video_id' not in raw_img_info:
single_video_id = single_video_id + 1
video_id = single_video_id
else:
video_id = raw_img_info['video_id']
if video_id not in videos:
videos[video_id] = {
'video_id': video_id,
'images': [],
'video_length': 0
}
videos[video_id]['video_length'] += 1
ann_ids = self.coco.get_ann_ids(
img_ids=[img_id], cat_ids=self.cat_ids)
raw_ann_info = self.coco.load_anns(ann_ids)
total_ann_ids.extend(ann_ids)
parsed_data_info = self.parse_data_info(
dict(raw_img_info=raw_img_info, raw_ann_info=raw_ann_info))
if len(parsed_data_info['instances']) > 0:
self.img_ids_with_ann.add(parsed_data_info['img_id'])
videos[video_id]['images'].append(parsed_data_info)
data_list = [v for v in videos.values()]
if self.ANN_ID_UNIQUE:
assert len(set(total_ann_ids)) == len(
total_ann_ids
), f"Annotation ids in '{self.ann_file}' are not unique!"
del self.coco
return data_list
def parse_data_info(self, raw_data_info: dict) -> dict:
"""Parse raw annotation to target format.
Args:
raw_data_info (dict): Raw data information loaded from
``ann_file``.
Returns:
dict: Parsed annotation.
"""
img_info = raw_data_info['raw_img_info']
ann_info = raw_data_info['raw_ann_info']
data_info = {}
data_info.update(img_info)
if self.data_prefix.get('img_path', None) is not None:
img_path = osp.join(self.data_prefix['img_path'],
img_info['file_name'])
else:
img_path = img_info['file_name']
data_info['img_path'] = img_path
instances = []
for i, ann in enumerate(ann_info):
instance = {}
if ann.get('ignore', False):
continue
x1, y1, w, h = ann['bbox']
inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
if inter_w * inter_h == 0:
continue
if ann['area'] <= 0 or w < 1 or h < 1:
continue
if ann['category_id'] not in self.cat_ids:
continue
bbox = [x1, y1, x1 + w, y1 + h]
if ann.get('iscrowd', False):
instance['ignore_flag'] = 1
else:
instance['ignore_flag'] = 0
instance['bbox'] = bbox
instance['bbox_label'] = self.cat2label[ann['category_id']]
if ann.get('segmentation', None):
instance['mask'] = ann['segmentation']
if ann.get('instance_id', None):
instance['instance_id'] = ann['instance_id']
else:
# image dataset usually has no `instance_id`.
# Therefore, we set it to `i`.
instance['instance_id'] = i
instances.append(instance)
data_info['instances'] = instances
return data_info
def filter_data(self) -> List[int]:
"""Filter image annotations according to filter_cfg.
Returns:
list[int]: Filtered results.
"""
if self.test_mode:
return self.data_list
num_imgs_before_filter = sum(
[len(info['images']) for info in self.data_list])
num_imgs_after_filter = 0
# obtain images that contain annotations of the required categories
ids_in_cat = set()
for i, class_id in enumerate(self.cat_ids):
ids_in_cat |= set(self.cat_img_map[class_id])
# merge the image id sets of the two conditions and use the merged set
# to filter out images if self.filter_empty_gt=True
ids_in_cat &= self.img_ids_with_ann
new_data_list = []
for video_data_info in self.data_list:
imgs_data_info = video_data_info['images']
valid_imgs_data_info = []
for data_info in imgs_data_info:
img_id = data_info['img_id']
width = data_info['width']
height = data_info['height']
# TODO: simplify these conditions
if self.filter_cfg is None:
if img_id not in ids_in_cat:
video_data_info['video_length'] -= 1
continue
if min(width, height) >= 32:
valid_imgs_data_info.append(data_info)
num_imgs_after_filter += 1
else:
video_data_info['video_length'] -= 1
else:
if self.filter_cfg.get('filter_empty_gt',
True) and img_id not in ids_in_cat:
video_data_info['video_length'] -= 1
continue
if min(width, height) >= self.filter_cfg.get(
'min_size', 32):
valid_imgs_data_info.append(data_info)
num_imgs_after_filter += 1
else:
video_data_info['video_length'] -= 1
video_data_info['images'] = valid_imgs_data_info
new_data_list.append(video_data_info)
print_log(
'The number of samples before and after filtering: '
f'{num_imgs_before_filter} / {num_imgs_after_filter}', 'current')
return new_data_list
def prepare_data(self, idx) -> Any:
"""Get date processed by ``self.pipeline``. Note that ``idx`` is a
video index in default since the base element of video dataset is a
video. However, in some cases, we need to specific both the video index
and frame index. For example, in traing mode, we may want to sample the
specific frames and all the frames must be sampled once in a epoch; in
test mode, we may want to output data of a single image rather than the
whole video for saving memory.
Args:
idx (int): The index of ``data_info``.
Returns:
Any: Depends on ``self.pipeline``.
"""
if isinstance(idx, tuple):
assert len(idx) == 2, 'The length of idx must be 2: '
'(video_index, frame_index)'
video_idx, frame_idx = idx[0], idx[1]
else:
video_idx, frame_idx = idx, None
data_info = self.get_data_info(video_idx)
if self.test_mode:
# Support two test_mode: frame-level and video-level
final_data_info = defaultdict(list)
if frame_idx is None:
frames_idx_list = list(range(data_info['video_length']))
else:
frames_idx_list = [frame_idx]
for index in frames_idx_list:
frame_ann = data_info['images'][index]
frame_ann['video_id'] = data_info['video_id']
# Collate data_list (list of dict to dict of list)
for key, value in frame_ann.items():
final_data_info[key].append(value)
# copy the info in video-level into img-level
# TODO: the value of this key is the same as that of
# `video_length` in test mode
final_data_info['ori_video_length'].append(
data_info['video_length'])
final_data_info['video_length'] = [len(frames_idx_list)
] * len(frames_idx_list)
return self.pipeline(final_data_info)
else:
# Specify `key_frame_id` for the frame sampling in the pipeline
if frame_idx is not None:
data_info['key_frame_id'] = frame_idx
return self.pipeline(data_info)
def get_cat_ids(self, index) -> List[int]:
"""Following image detection, we provide this interface function. Get
category ids by video index and frame index.
Args:
index: The index of the dataset. It support two kinds of inputs:
Tuple:
video_idx (int): Index of video.
frame_idx (int): Index of frame.
Int: Index of video.
Returns:
List[int]: All categories in the image of specified video index
and frame index.
"""
if isinstance(index, tuple):
assert len(
index
) == 2, f'Expect the length of index is 2, but got {len(index)}'
video_idx, frame_idx = index
instances = self.get_data_info(
video_idx)['images'][frame_idx]['instances']
return [instance['bbox_label'] for instance in instances]
else:
cat_ids = []
for img in self.get_data_info(index)['images']:
for instance in img['instances']:
cat_ids.append(instance['bbox_label'])
return cat_ids
@property
def num_all_imgs(self):
"""Get the number of all the images in this video dataset."""
return sum(
[len(self.get_data_info(i)['images']) for i in range(len(self))])
def get_len_per_video(self, idx):
"""Get length of one video.
Args:
idx (int): Index of video.
Returns:
int (int): The length of the video.
"""
return len(self.get_data_info(idx)['images'])
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/data/datasets/cityscapes.py # noqa
# and https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalInstanceLevelSemanticLabeling.py # noqa
from typing import List
from mmdet.registry import DATASETS
from .coco import CocoDataset
@DATASETS.register_module()
class CityscapesDataset(CocoDataset):
"""Dataset for Cityscapes."""
METAINFO = {
'classes': ('person', 'rider', 'car', 'truck', 'bus', 'train',
'motorcycle', 'bicycle'),
'palette': [(220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70),
(0, 60, 100), (0, 80, 100), (0, 0, 230), (119, 11, 32)]
}
def filter_data(self) -> List[dict]:
"""Filter annotations according to filter_cfg.
Returns:
List[dict]: Filtered results.
"""
if self.test_mode:
return self.data_list
if self.filter_cfg is None:
return self.data_list
filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False)
min_size = self.filter_cfg.get('min_size', 0)
# obtain images that contain annotation
ids_with_ann = set(data_info['img_id'] for data_info in self.data_list)
# obtain images that contain annotations of the required categories
ids_in_cat = set()
for i, class_id in enumerate(self.cat_ids):
ids_in_cat |= set(self.cat_img_map[class_id])
# merge the image id sets of the two conditions and use the merged set
# to filter out images if self.filter_empty_gt=True
ids_in_cat &= ids_with_ann
valid_data_infos = []
for i, data_info in enumerate(self.data_list):
img_id = data_info['img_id']
width = data_info['width']
height = data_info['height']
all_is_crowd = all([
instance['ignore_flag'] == 1
for instance in data_info['instances']
])
if filter_empty_gt and (img_id not in ids_in_cat or all_is_crowd):
continue
if min(width, height) >= min_size:
valid_data_infos.append(data_info)
return valid_data_infos
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from typing import List, Union
from mmengine.fileio import get_local_path
from mmdet.registry import DATASETS
from .api_wrappers import COCO
from .base_det_dataset import BaseDetDataset
@DATASETS.register_module()
class CocoDataset(BaseDetDataset):
"""Dataset for COCO."""
METAINFO = {
'classes':
('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
'scissors', 'teddy bear', 'hair drier', 'toothbrush'),
# palette is a list of color tuples, which is used for visualization.
'palette':
[(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228),
(0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30),
(100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30),
(165, 42, 42), (255, 77, 255), (0, 226, 252), (182, 182, 255),
(0, 82, 0), (120, 166, 157), (110, 76, 0), (174, 57, 255),
(199, 100, 0), (72, 0, 118), (255, 179, 240), (0, 125, 92),
(209, 0, 151), (188, 208, 182), (0, 220, 176), (255, 99, 164),
(92, 0, 73), (133, 129, 255), (78, 180, 255), (0, 228, 0),
(174, 255, 243), (45, 89, 255), (134, 134, 103), (145, 148, 174),
(255, 208, 186), (197, 226, 255), (171, 134, 1), (109, 63, 54),
(207, 138, 255), (151, 0, 95), (9, 80, 61), (84, 105, 51),
(74, 65, 105), (166, 196, 102), (208, 195, 210), (255, 109, 65),
(0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0),
(227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161),
(163, 255, 0), (119, 0, 170), (0, 182, 199), (0, 165, 120),
(183, 130, 88), (95, 32, 0), (130, 114, 135), (110, 129, 133),
(166, 74, 118), (219, 142, 185), (79, 210, 114), (178, 90, 62),
(65, 70, 15), (127, 167, 115), (59, 105, 106), (142, 108, 45),
(196, 172, 0), (95, 54, 80), (128, 76, 255), (201, 57, 1),
(246, 0, 122), (191, 162, 208)]
}
COCOAPI = COCO
# ann_id is unique in coco dataset.
ANN_ID_UNIQUE = True
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
""" # noqa: E501
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
self.coco = self.COCOAPI(local_path)
# The order of returned `cat_ids` will not
# change with the order of the `classes`
self.cat_ids = self.coco.get_cat_ids(
cat_names=self.metainfo['classes'])
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.cat_img_map = copy.deepcopy(self.coco.cat_img_map)
img_ids = self.coco.get_img_ids()
data_list = []
total_ann_ids = []
for img_id in img_ids:
raw_img_info = self.coco.load_imgs([img_id])[0]
raw_img_info['img_id'] = img_id
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
raw_ann_info = self.coco.load_anns(ann_ids)
total_ann_ids.extend(ann_ids)
parsed_data_info = self.parse_data_info({
'raw_ann_info':
raw_ann_info,
'raw_img_info':
raw_img_info
})
data_list.append(parsed_data_info)
if self.ANN_ID_UNIQUE:
assert len(set(total_ann_ids)) == len(
total_ann_ids
), f"Annotation ids in '{self.ann_file}' are not unique!"
del self.coco
return data_list
def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]:
"""Parse raw annotation to target format.
Args:
raw_data_info (dict): Raw data information load from ``ann_file``
Returns:
Union[dict, List[dict]]: Parsed annotation.
"""
img_info = raw_data_info['raw_img_info']
ann_info = raw_data_info['raw_ann_info']
data_info = {}
# TODO: need to change data_prefix['img'] to data_prefix['img_path']
img_path = osp.join(self.data_prefix['img'], img_info['file_name'])
if self.data_prefix.get('seg', None):
seg_map_path = osp.join(
self.data_prefix['seg'],
img_info['file_name'].rsplit('.', 1)[0] + self.seg_map_suffix)
else:
seg_map_path = None
data_info['img_path'] = img_path
data_info['img_id'] = img_info['img_id']
data_info['seg_map_path'] = seg_map_path
data_info['height'] = img_info['height']
data_info['width'] = img_info['width']
if self.return_classes:
data_info['text'] = self.metainfo['classes']
data_info['caption_prompt'] = self.caption_prompt
data_info['custom_entities'] = True
instances = []
for i, ann in enumerate(ann_info):
instance = {}
if ann.get('ignore', False):
continue
x1, y1, w, h = ann['bbox']
inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
if inter_w * inter_h == 0:
continue
if ann['area'] <= 0 or w < 1 or h < 1:
continue
if ann['category_id'] not in self.cat_ids:
continue
bbox = [x1, y1, x1 + w, y1 + h]
if ann.get('iscrowd', False):
instance['ignore_flag'] = 1
else:
instance['ignore_flag'] = 0
instance['bbox'] = bbox
instance['bbox_label'] = self.cat2label[ann['category_id']]
if ann.get('segmentation', None):
instance['mask'] = ann['segmentation']
instances.append(instance)
data_info['instances'] = instances
return data_info
def filter_data(self) -> List[dict]:
"""Filter annotations according to filter_cfg.
Returns:
List[dict]: Filtered results.
"""
if self.test_mode:
return self.data_list
if self.filter_cfg is None:
return self.data_list
filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False)
min_size = self.filter_cfg.get('min_size', 0)
# obtain images that contain annotation
ids_with_ann = set(data_info['img_id'] for data_info in self.data_list)
# obtain images that contain annotations of the required categories
ids_in_cat = set()
for i, class_id in enumerate(self.cat_ids):
ids_in_cat |= set(self.cat_img_map[class_id])
# merge the image id sets of the two conditions and use the merged set
# to filter out images if self.filter_empty_gt=True
ids_in_cat &= ids_with_ann
valid_data_infos = []
for i, data_info in enumerate(self.data_list):
img_id = data_info['img_id']
width = data_info['width']
height = data_info['height']
if filter_empty_gt and img_id not in ids_in_cat:
continue
if min(width, height) >= min_size:
valid_data_infos.append(data_info)
return valid_data_infos
# Copyright (c) OpenMMLab. All rights reserved.
from pathlib import Path
from typing import List
import mmengine
from mmengine.dataset import BaseDataset
from mmengine.fileio import get_file_backend
from mmdet.registry import DATASETS
@DATASETS.register_module()
class CocoCaptionDataset(BaseDataset):
"""COCO2014 Caption dataset."""
def load_data_list(self) -> List[dict]:
"""Load data list."""
img_prefix = self.data_prefix['img_path']
annotations = mmengine.load(self.ann_file)
file_backend = get_file_backend(img_prefix)
data_list = []
for ann in annotations:
data_info = {
'img_id': Path(ann['image']).stem.split('_')[-1],
'img_path': file_backend.join_path(img_prefix, ann['image']),
'gt_caption': ann['caption'],
}
data_list.append(data_info)
return data_list
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Callable, List, Optional, Sequence, Union
from mmdet.registry import DATASETS
from .api_wrappers import COCOPanoptic
from .coco import CocoDataset
@DATASETS.register_module()
class CocoPanopticDataset(CocoDataset):
"""Coco dataset for Panoptic segmentation.
The annotation format is shown as follows. The `ann` field is optional
for testing.
.. code-block:: none
[
{
'filename': f'{image_id:012}.png',
'image_id':9
'segments_info':
[
{
'id': 8345037, (segment_id in panoptic png,
convert from rgb)
'category_id': 51,
'iscrowd': 0,
'bbox': (x1, y1, w, h),
'area': 24315
},
...
]
},
...
]
Args:
ann_file (str): Annotation file path. Defaults to ''.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_root (str, optional): The root directory for ``data_prefix`` and
``ann_file``. Defaults to None.
data_prefix (dict, optional): Prefix for training data. Defaults to
``dict(img=None, ann=None, seg=None)``. The prefix ``seg`` which is
for panoptic segmentation map must be not None.
filter_cfg (dict, optional): Config for filter data. Defaults to None.
indices (int or Sequence[int], optional): Support using first few
data in annotation file to facilitate training/testing on a smaller
dataset. Defaults to None which means using all ``data_infos``.
serialize_data (bool, optional): Whether to hold memory using
serialized objects, when enabled, data loader workers can use
shared RAM from master process instead of making a copy. Defaults
to True.
pipeline (list, optional): Processing pipeline. Defaults to [].
test_mode (bool, optional): ``test_mode=True`` means in test phase.
Defaults to False.
lazy_init (bool, optional): Whether to load annotation during
instantiation. In some cases, such as visualization, only the meta
information of the dataset is needed, which is not necessary to
load annotation file. ``Basedataset`` can skip load annotations to
save time by set ``lazy_init=False``. Defaults to False.
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
None img. The maximum extra number of cycles to get a valid
image. Defaults to 1000.
"""
METAINFO = {
'classes':
('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',
'blanket', 'bridge', 'cardboard', 'counter', 'curtain', 'door-stuff',
'floor-wood', 'flower', 'fruit', 'gravel', 'house', 'light',
'mirror-stuff', 'net', 'pillow', 'platform', 'playingfield',
'railroad', 'river', 'road', 'roof', 'sand', 'sea', 'shelf', 'snow',
'stairs', 'tent', 'towel', 'wall-brick', 'wall-stone', 'wall-tile',
'wall-wood', 'water-other', 'window-blind', 'window-other',
'tree-merged', 'fence-merged', 'ceiling-merged', 'sky-other-merged',
'cabinet-merged', 'table-merged', 'floor-other-merged',
'pavement-merged', 'mountain-merged', 'grass-merged', 'dirt-merged',
'paper-merged', 'food-other-merged', 'building-other-merged',
'rock-merged', 'wall-other-merged', 'rug-merged'),
'thing_classes':
('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
'scissors', 'teddy bear', 'hair drier', 'toothbrush'),
'stuff_classes':
('banner', 'blanket', 'bridge', 'cardboard', 'counter', 'curtain',
'door-stuff', 'floor-wood', 'flower', 'fruit', 'gravel', 'house',
'light', 'mirror-stuff', 'net', 'pillow', 'platform', 'playingfield',
'railroad', 'river', 'road', 'roof', 'sand', 'sea', 'shelf', 'snow',
'stairs', 'tent', 'towel', 'wall-brick', 'wall-stone', 'wall-tile',
'wall-wood', 'water-other', 'window-blind', 'window-other',
'tree-merged', 'fence-merged', 'ceiling-merged', 'sky-other-merged',
'cabinet-merged', 'table-merged', 'floor-other-merged',
'pavement-merged', 'mountain-merged', 'grass-merged', 'dirt-merged',
'paper-merged', 'food-other-merged', 'building-other-merged',
'rock-merged', 'wall-other-merged', 'rug-merged'),
'palette':
[(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228),
(0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30),
(100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30),
(165, 42, 42), (255, 77, 255), (0, 226, 252), (182, 182, 255),
(0, 82, 0), (120, 166, 157), (110, 76, 0), (174, 57, 255),
(199, 100, 0), (72, 0, 118), (255, 179, 240), (0, 125, 92),
(209, 0, 151), (188, 208, 182), (0, 220, 176), (255, 99, 164),
(92, 0, 73), (133, 129, 255), (78, 180, 255), (0, 228, 0),
(174, 255, 243), (45, 89, 255), (134, 134, 103), (145, 148, 174),
(255, 208, 186), (197, 226, 255), (171, 134, 1), (109, 63, 54),
(207, 138, 255), (151, 0, 95), (9, 80, 61), (84, 105, 51),
(74, 65, 105), (166, 196, 102), (208, 195, 210), (255, 109, 65),
(0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0),
(227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161),
(163, 255, 0), (119, 0, 170), (0, 182, 199), (0, 165, 120),
(183, 130, 88), (95, 32, 0), (130, 114, 135), (110, 129, 133),
(166, 74, 118), (219, 142, 185), (79, 210, 114), (178, 90, 62),
(65, 70, 15), (127, 167, 115), (59, 105, 106), (142, 108, 45),
(196, 172, 0), (95, 54, 80), (128, 76, 255), (201, 57, 1),
(246, 0, 122), (191, 162, 208), (255, 255, 128), (147, 211, 203),
(150, 100, 100), (168, 171, 172), (146, 112, 198), (210, 170, 100),
(92, 136, 89), (218, 88, 184), (241, 129, 0), (217, 17, 255),
(124, 74, 181), (70, 70, 70), (255, 228, 255), (154, 208, 0),
(193, 0, 92), (76, 91, 113), (255, 180, 195), (106, 154, 176),
(230, 150, 140), (60, 143, 255), (128, 64, 128), (92, 82, 55),
(254, 212, 124), (73, 77, 174), (255, 160, 98), (255, 255, 255),
(104, 84, 109), (169, 164, 131), (225, 199, 255), (137, 54, 74),
(135, 158, 223), (7, 246, 231), (107, 255, 200), (58, 41, 149),
(183, 121, 142), (255, 73, 97), (107, 142, 35), (190, 153, 153),
(146, 139, 141), (70, 130, 180), (134, 199, 156), (209, 226, 140),
(96, 36, 108), (96, 96, 96), (64, 170, 64), (152, 251, 152),
(208, 229, 228), (206, 186, 171), (152, 161, 64), (116, 112, 0),
(0, 114, 143), (102, 102, 156), (250, 141, 255)]
}
COCOAPI = COCOPanoptic
# ann_id is not unique in coco panoptic dataset.
ANN_ID_UNIQUE = False
def __init__(self,
ann_file: str = '',
metainfo: Optional[dict] = None,
data_root: Optional[str] = None,
data_prefix: dict = dict(img=None, ann=None, seg=None),
filter_cfg: Optional[dict] = None,
indices: Optional[Union[int, Sequence[int]]] = None,
serialize_data: bool = True,
pipeline: List[Union[dict, Callable]] = [],
test_mode: bool = False,
lazy_init: bool = False,
max_refetch: int = 1000,
backend_args: dict = None,
**kwargs) -> None:
super().__init__(
ann_file=ann_file,
metainfo=metainfo,
data_root=data_root,
data_prefix=data_prefix,
filter_cfg=filter_cfg,
indices=indices,
serialize_data=serialize_data,
pipeline=pipeline,
test_mode=test_mode,
lazy_init=lazy_init,
max_refetch=max_refetch,
backend_args=backend_args,
**kwargs)
def parse_data_info(self, raw_data_info: dict) -> dict:
"""Parse raw annotation to target format.
Args:
raw_data_info (dict): Raw data information load from ``ann_file``.
Returns:
dict: Parsed annotation.
"""
img_info = raw_data_info['raw_img_info']
ann_info = raw_data_info['raw_ann_info']
# filter out unmatched annotations which have
# same segment_id but belong to other image
ann_info = [
ann for ann in ann_info if ann['image_id'] == img_info['img_id']
]
data_info = {}
img_path = osp.join(self.data_prefix['img'], img_info['file_name'])
if self.data_prefix.get('seg', None):
seg_map_path = osp.join(
self.data_prefix['seg'],
img_info['file_name'].replace('.jpg', '.png'))
else:
seg_map_path = None
data_info['img_path'] = img_path
data_info['img_id'] = img_info['img_id']
data_info['seg_map_path'] = seg_map_path
data_info['height'] = img_info['height']
data_info['width'] = img_info['width']
if self.return_classes:
data_info['text'] = self.metainfo['thing_classes']
data_info['stuff_text'] = self.metainfo['stuff_classes']
data_info['custom_entities'] = True # no important
instances = []
segments_info = []
for ann in ann_info:
instance = {}
x1, y1, w, h = ann['bbox']
if ann['area'] <= 0 or w < 1 or h < 1:
continue
bbox = [x1, y1, x1 + w, y1 + h]
category_id = ann['category_id']
contiguous_cat_id = self.cat2label[category_id]
is_thing = self.coco.load_cats(ids=category_id)[0]['isthing']
if is_thing:
is_crowd = ann.get('iscrowd', False)
instance['bbox'] = bbox
instance['bbox_label'] = contiguous_cat_id
if not is_crowd:
instance['ignore_flag'] = 0
else:
instance['ignore_flag'] = 1
is_thing = False
segment_info = {
'id': ann['id'],
'category': contiguous_cat_id,
'is_thing': is_thing
}
segments_info.append(segment_info)
if len(instance) > 0 and is_thing:
instances.append(instance)
data_info['instances'] = instances
data_info['segments_info'] = segments_info
return data_info
def filter_data(self) -> List[dict]:
"""Filter images too small or without ground truth.
Returns:
List[dict]: ``self.data_list`` after filtering.
"""
if self.test_mode:
return self.data_list
if self.filter_cfg is None:
return self.data_list
filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False)
min_size = self.filter_cfg.get('min_size', 0)
ids_with_ann = set()
# check whether images have legal thing annotations.
for data_info in self.data_list:
for segment_info in data_info['segments_info']:
if not segment_info['is_thing']:
continue
ids_with_ann.add(data_info['img_id'])
valid_data_list = []
for data_info in self.data_list:
img_id = data_info['img_id']
width = data_info['width']
height = data_info['height']
if filter_empty_gt and img_id not in ids_with_ann:
continue
if min(width, height) >= min_size:
valid_data_list.append(data_info)
return valid_data_list
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.registry import DATASETS
from .ade20k import ADE20KSegDataset
@DATASETS.register_module()
class CocoSegDataset(ADE20KSegDataset):
"""COCO dataset.
In segmentation map annotation for COCO. The ``img_suffix`` is fixed to
'.jpg', and ``seg_map_suffix`` is fixed to '.png'.
"""
METAINFO = dict(
classes=(
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',
'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet',
'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile',
'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain',
'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble',
'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower',
'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel',
'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal',
'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net',
'paper', 'pavement', 'pillow', 'plant-other', 'plastic',
'platform', 'playingfield', 'railing', 'railroad', 'river', 'road',
'rock', 'roof', 'rug', 'salad', 'sand', 'sea', 'shelf',
'sky-other', 'skyscraper', 'snow', 'solid-other', 'stairs',
'stone', 'straw', 'structural-other', 'table', 'tent',
'textile-other', 'towel', 'tree', 'vegetable', 'wall-brick',
'wall-concrete', 'wall-other', 'wall-panel', 'wall-stone',
'wall-tile', 'wall-wood', 'water-other', 'waterdrops',
'window-blind', 'window-other', 'wood'),
palette=[(120, 120, 120), (180, 120, 120), (6, 230, 230), (80, 50, 50),
(4, 200, 3), (120, 120, 80), (140, 140, 140), (204, 5, 255),
(230, 230, 230), (4, 250, 7), (224, 5, 255), (235, 255, 7),
(150, 5, 61), (120, 120, 70), (8, 255, 51), (255, 6, 82),
(143, 255, 140), (204, 255, 4), (255, 51, 7), (204, 70, 3),
(0, 102, 200), (61, 230, 250), (255, 6, 51), (11, 102, 255),
(255, 7, 71), (255, 9, 224), (9, 7, 230), (220, 220, 220),
(255, 9, 92), (112, 9, 255), (8, 255, 214), (7, 255, 224),
(255, 184, 6), (10, 255, 71), (255, 41, 10), (7, 255, 255),
(224, 255, 8), (102, 8, 255), (255, 61, 6), (255, 194, 7),
(255, 122, 8), (0, 255, 20), (255, 8, 41), (255, 5, 153),
(6, 51, 255), (235, 12, 255), (160, 150, 20), (0, 163, 255),
(140, 140, 140), (250, 10, 15), (20, 255, 0), (31, 255, 0),
(255, 31, 0), (255, 224, 0), (153, 255, 0), (0, 0, 255),
(255, 71, 0), (0, 235, 255), (0, 173, 255), (31, 0, 255),
(11, 200, 200), (255, 82, 0), (0, 255, 245), (0, 61, 255),
(0, 255, 112), (0, 255, 133), (255, 0, 0), (255, 163, 0),
(255, 102, 0), (194, 255, 0), (0, 143, 255), (51, 255, 0),
(0, 82, 255), (0, 255, 41), (0, 255, 173), (10, 0, 255),
(173, 255, 0), (0, 255, 153), (255, 92, 0), (255, 0, 255),
(255, 0, 245), (255, 0, 102), (255, 173, 0), (255, 0, 20),
(255, 184, 184), (0, 31, 255), (0, 255, 61), (0, 71, 255),
(255, 0, 204), (0, 255, 194), (0, 255, 82), (0, 10, 255),
(0, 112, 255), (51, 0, 255), (0, 194, 255), (0, 122, 255),
(0, 255, 163), (255, 153, 0), (0, 255, 10), (255, 112, 0),
(143, 255, 0), (82, 0, 255), (163, 255, 0), (255, 235, 0),
(8, 184, 170), (133, 0, 255), (0, 255, 92), (184, 0, 255),
(255, 0, 31), (0, 184, 255), (0, 214, 255), (255, 0, 112),
(92, 255, 0), (0, 224, 255), (112, 224, 255), (70, 184, 160),
(163, 0, 255), (153, 0, 255), (71, 255, 0), (255, 0, 163),
(255, 204, 0), (255, 0, 143), (0, 255, 235), (133, 255, 0),
(255, 0, 235), (245, 0, 255), (255, 0, 122), (255, 245, 0),
(10, 190, 212), (214, 255, 0), (0, 204, 255), (20, 0, 255),
(255, 255, 0), (0, 153, 255), (0, 41, 255), (0, 255, 204),
(41, 0, 255), (41, 255, 0), (173, 0, 255), (0, 245, 255),
(71, 0, 255), (122, 0, 255), (0, 255, 184), (0, 92, 255),
(184, 255, 0), (0, 133, 255), (255, 214, 0), (25, 194, 194),
(102, 255, 0), (92, 0, 255), (107, 255, 200), (58, 41, 149),
(183, 121, 142), (255, 73, 97), (107, 142, 35),
(190, 153, 153), (146, 139, 141), (70, 130, 180),
(134, 199, 156), (209, 226, 140), (96, 36, 108), (96, 96, 96),
(64, 170, 64), (152, 251, 152), (208, 229, 228),
(206, 186, 171), (152, 161, 64), (116, 112, 0), (0, 114, 143),
(102, 102, 156), (250, 141, 255)])
# Copyright (c) OpenMMLab. All rights reserved.
import json
import logging
import os.path as osp
import warnings
from typing import List, Union
import mmcv
from mmengine.dist import get_rank
from mmengine.fileio import dump, get, get_text, load
from mmengine.logging import print_log
from mmengine.utils import ProgressBar
from mmdet.registry import DATASETS
from .base_det_dataset import BaseDetDataset
@DATASETS.register_module()
class CrowdHumanDataset(BaseDetDataset):
r"""Dataset for CrowdHuman.
Args:
data_root (str): The root directory for
``data_prefix`` and ``ann_file``.
ann_file (str): Annotation file path.
extra_ann_file (str | optional):The path of extra image metas
for CrowdHuman. It can be created by CrowdHumanDataset
automatically or by tools/misc/get_crowdhuman_id_hw.py
manually. Defaults to None.
"""
METAINFO = {
'classes': ('person', ),
# palette is a list of color tuples, which is used for visualization.
'palette': [(220, 20, 60)]
}
def __init__(self, data_root, ann_file, extra_ann_file=None, **kwargs):
# extra_ann_file record the size of each image. This file is
# automatically created when you first load the CrowdHuman
# dataset by mmdet.
if extra_ann_file is not None:
self.extra_ann_exist = True
self.extra_anns = load(extra_ann_file)
else:
ann_file_name = osp.basename(ann_file)
if 'train' in ann_file_name:
self.extra_ann_file = osp.join(data_root, 'id_hw_train.json')
elif 'val' in ann_file_name:
self.extra_ann_file = osp.join(data_root, 'id_hw_val.json')
self.extra_ann_exist = False
if not osp.isfile(self.extra_ann_file):
print_log(
'extra_ann_file does not exist, prepare to collect '
'image height and width...',
level=logging.INFO)
self.extra_anns = {}
else:
self.extra_ann_exist = True
self.extra_anns = load(self.extra_ann_file)
super().__init__(data_root=data_root, ann_file=ann_file, **kwargs)
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
""" # noqa: E501
anno_strs = get_text(
self.ann_file, backend_args=self.backend_args).strip().split('\n')
print_log('loading CrowdHuman annotation...', level=logging.INFO)
data_list = []
prog_bar = ProgressBar(len(anno_strs))
for i, anno_str in enumerate(anno_strs):
anno_dict = json.loads(anno_str)
parsed_data_info = self.parse_data_info(anno_dict)
data_list.append(parsed_data_info)
prog_bar.update()
if not self.extra_ann_exist and get_rank() == 0:
# TODO: support file client
try:
dump(self.extra_anns, self.extra_ann_file, file_format='json')
except: # noqa
warnings.warn(
'Cache files can not be saved automatically! To speed up'
'loading the dataset, please manually generate the cache'
' file by file tools/misc/get_crowdhuman_id_hw.py')
print_log(
f'\nsave extra_ann_file in {self.data_root}',
level=logging.INFO)
del self.extra_anns
print_log('\nDone', level=logging.INFO)
return data_list
def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]:
"""Parse raw annotation to target format.
Args:
raw_data_info (dict): Raw data information load from ``ann_file``
Returns:
Union[dict, List[dict]]: Parsed annotation.
"""
data_info = {}
img_path = osp.join(self.data_prefix['img'],
f"{raw_data_info['ID']}.jpg")
data_info['img_path'] = img_path
data_info['img_id'] = raw_data_info['ID']
if not self.extra_ann_exist:
img_bytes = get(img_path, backend_args=self.backend_args)
img = mmcv.imfrombytes(img_bytes, backend='cv2')
data_info['height'], data_info['width'] = img.shape[:2]
self.extra_anns[raw_data_info['ID']] = img.shape[:2]
del img, img_bytes
else:
data_info['height'], data_info['width'] = self.extra_anns[
raw_data_info['ID']]
instances = []
for i, ann in enumerate(raw_data_info['gtboxes']):
instance = {}
if ann['tag'] not in self.metainfo['classes']:
instance['bbox_label'] = -1
instance['ignore_flag'] = 1
else:
instance['bbox_label'] = self.metainfo['classes'].index(
ann['tag'])
instance['ignore_flag'] = 0
if 'extra' in ann:
if 'ignore' in ann['extra']:
if ann['extra']['ignore'] != 0:
instance['bbox_label'] = -1
instance['ignore_flag'] = 1
x1, y1, w, h = ann['fbox']
bbox = [x1, y1, x1 + w, y1 + h]
instance['bbox'] = bbox
# Record the full bbox(fbox), head bbox(hbox) and visible
# bbox(vbox) as additional information. If you need to use
# this information, you just need to design the pipeline
# instead of overriding the CrowdHumanDataset.
instance['fbox'] = bbox
hbox = ann['hbox']
instance['hbox'] = [
hbox[0], hbox[1], hbox[0] + hbox[2], hbox[1] + hbox[3]
]
vbox = ann['vbox']
instance['vbox'] = [
vbox[0], vbox[1], vbox[0] + vbox[2], vbox[1] + vbox[3]
]
instances.append(instance)
data_info['instances'] = instances
return data_info
# Copyright (c) OpenMMLab. All rights reserved.
import collections
import copy
from typing import List, Sequence, Union
from mmengine.dataset import BaseDataset
from mmengine.dataset import ConcatDataset as MMENGINE_ConcatDataset
from mmengine.dataset import force_full_init
from mmdet.registry import DATASETS, TRANSFORMS
@DATASETS.register_module()
class MultiImageMixDataset:
"""A wrapper of multiple images mixed dataset.
Suitable for training on multiple images mixed data augmentation like
mosaic and mixup. For the augmentation pipeline of mixed image data,
the `get_indexes` method needs to be provided to obtain the image
indexes, and you can set `skip_flags` to change the pipeline running
process. At the same time, we provide the `dynamic_scale` parameter
to dynamically change the output image size.
Args:
dataset (:obj:`CustomDataset`): The dataset to be mixed.
pipeline (Sequence[dict]): Sequence of transform object or
config dict to be composed.
dynamic_scale (tuple[int], optional): The image scale can be changed
dynamically. Default to None. It is deprecated.
skip_type_keys (list[str], optional): Sequence of type string to
be skip pipeline. Default to None.
max_refetch (int): The maximum number of retry iterations for getting
valid results from the pipeline. If the number of iterations is
greater than `max_refetch`, but results is still None, then the
iteration is terminated and raise the error. Default: 15.
"""
def __init__(self,
dataset: Union[BaseDataset, dict],
pipeline: Sequence[str],
skip_type_keys: Union[Sequence[str], None] = None,
max_refetch: int = 15,
lazy_init: bool = False) -> None:
assert isinstance(pipeline, collections.abc.Sequence)
if skip_type_keys is not None:
assert all([
isinstance(skip_type_key, str)
for skip_type_key in skip_type_keys
])
self._skip_type_keys = skip_type_keys
self.pipeline = []
self.pipeline_types = []
for transform in pipeline:
if isinstance(transform, dict):
self.pipeline_types.append(transform['type'])
transform = TRANSFORMS.build(transform)
self.pipeline.append(transform)
else:
raise TypeError('pipeline must be a dict')
self.dataset: BaseDataset
if isinstance(dataset, dict):
self.dataset = DATASETS.build(dataset)
elif isinstance(dataset, BaseDataset):
self.dataset = dataset
else:
raise TypeError(
'elements in datasets sequence should be config or '
f'`BaseDataset` instance, but got {type(dataset)}')
self._metainfo = self.dataset.metainfo
if hasattr(self.dataset, 'flag'):
self.flag = self.dataset.flag
self.num_samples = len(self.dataset)
self.max_refetch = max_refetch
self._fully_initialized = False
if not lazy_init:
self.full_init()
@property
def metainfo(self) -> dict:
"""Get the meta information of the multi-image-mixed dataset.
Returns:
dict: The meta information of multi-image-mixed dataset.
"""
return copy.deepcopy(self._metainfo)
def full_init(self):
"""Loop to ``full_init`` each dataset."""
if self._fully_initialized:
return
self.dataset.full_init()
self._ori_len = len(self.dataset)
self._fully_initialized = True
@force_full_init
def get_data_info(self, idx: int) -> dict:
"""Get annotation by index.
Args:
idx (int): Global index of ``ConcatDataset``.
Returns:
dict: The idx-th annotation of the datasets.
"""
return self.dataset.get_data_info(idx)
@force_full_init
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
results = copy.deepcopy(self.dataset[idx])
for (transform, transform_type) in zip(self.pipeline,
self.pipeline_types):
if self._skip_type_keys is not None and \
transform_type in self._skip_type_keys:
continue
if hasattr(transform, 'get_indexes'):
for i in range(self.max_refetch):
# Make sure the results passed the loading pipeline
# of the original dataset is not None.
indexes = transform.get_indexes(self.dataset)
if not isinstance(indexes, collections.abc.Sequence):
indexes = [indexes]
mix_results = [
copy.deepcopy(self.dataset[index]) for index in indexes
]
if None not in mix_results:
results['mix_results'] = mix_results
break
else:
raise RuntimeError(
'The loading pipeline of the original dataset'
' always return None. Please check the correctness '
'of the dataset and its pipeline.')
for i in range(self.max_refetch):
# To confirm the results passed the training pipeline
# of the wrapper is not None.
updated_results = transform(copy.deepcopy(results))
if updated_results is not None:
results = updated_results
break
else:
raise RuntimeError(
'The training pipeline of the dataset wrapper'
' always return None.Please check the correctness '
'of the dataset and its pipeline.')
if 'mix_results' in results:
results.pop('mix_results')
return results
def update_skip_type_keys(self, skip_type_keys):
"""Update skip_type_keys. It is called by an external hook.
Args:
skip_type_keys (list[str], optional): Sequence of type
string to be skip pipeline.
"""
assert all([
isinstance(skip_type_key, str) for skip_type_key in skip_type_keys
])
self._skip_type_keys = skip_type_keys
@DATASETS.register_module()
class ConcatDataset(MMENGINE_ConcatDataset):
"""A wrapper of concatenated dataset.
Same as ``torch.utils.data.dataset.ConcatDataset``, support
lazy_init and get_dataset_source.
Note:
``ConcatDataset`` should not inherit from ``BaseDataset`` since
``get_subset`` and ``get_subset_`` could produce ambiguous meaning
sub-dataset which conflicts with original dataset. If you want to use
a sub-dataset of ``ConcatDataset``, you should set ``indices``
arguments for wrapped dataset which inherit from ``BaseDataset``.
Args:
datasets (Sequence[BaseDataset] or Sequence[dict]): A list of datasets
which will be concatenated.
lazy_init (bool, optional): Whether to load annotation during
instantiation. Defaults to False.
ignore_keys (List[str] or str): Ignore the keys that can be
unequal in `dataset.metainfo`. Defaults to None.
`New in version 0.3.0.`
"""
def __init__(self,
datasets: Sequence[Union[BaseDataset, dict]],
lazy_init: bool = False,
ignore_keys: Union[str, List[str], None] = None):
self.datasets: List[BaseDataset] = []
for i, dataset in enumerate(datasets):
if isinstance(dataset, dict):
self.datasets.append(DATASETS.build(dataset))
elif isinstance(dataset, BaseDataset):
self.datasets.append(dataset)
else:
raise TypeError(
'elements in datasets sequence should be config or '
f'`BaseDataset` instance, but got {type(dataset)}')
if ignore_keys is None:
self.ignore_keys = []
elif isinstance(ignore_keys, str):
self.ignore_keys = [ignore_keys]
elif isinstance(ignore_keys, list):
self.ignore_keys = ignore_keys
else:
raise TypeError('ignore_keys should be a list or str, '
f'but got {type(ignore_keys)}')
meta_keys: set = set()
for dataset in self.datasets:
meta_keys |= dataset.metainfo.keys()
# if the metainfo of multiple datasets are the same, use metainfo
# of the first dataset, else the metainfo is a list with metainfo
# of all the datasets
is_all_same = True
self._metainfo_first = self.datasets[0].metainfo
for i, dataset in enumerate(self.datasets, 1):
for key in meta_keys:
if key in self.ignore_keys:
continue
if key not in dataset.metainfo:
is_all_same = False
break
if self._metainfo_first[key] != dataset.metainfo[key]:
is_all_same = False
break
if is_all_same:
self._metainfo = self.datasets[0].metainfo
else:
self._metainfo = [dataset.metainfo for dataset in self.datasets]
self._fully_initialized = False
if not lazy_init:
self.full_init()
if is_all_same:
self._metainfo.update(
dict(cumulative_sizes=self.cumulative_sizes))
else:
for i, dataset in enumerate(self.datasets):
self._metainfo[i].update(
dict(cumulative_sizes=self.cumulative_sizes))
def get_dataset_source(self, idx: int) -> int:
dataset_idx, _ = self._get_ori_dataset_idx(idx)
return dataset_idx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment