Commit 46260e34 authored by suily's avatar suily
Browse files

Initial commit

parents
Pipeline #2006 failed with stages
in 0 seconds
task: detection
num_classes: 80
remap_mscoco_category: True
train_dataloader:
type: DataLoader
dataset:
type: CocoDetection
img_folder: /home/RT-DETR/datasets/train2017/
ann_file: /home/RT-DETR/datasets/annotations/instances_train2017.json
transforms:
type: Compose
ops: ~
shuffle: True
batch_size: 8
num_workers: 4
drop_last: True
val_dataloader:
type: DataLoader
dataset:
type: CocoDetection
img_folder: /home/RT-DETR/datasets/val2017/
ann_file: /home/RT-DETR/datasets/annotations/instances_val2017.json
transforms:
type: Compose
ops: ~
shuffle: False
batch_size: 8
num_workers: 4
drop_last: False
\ No newline at end of file
# num_classes: 91
# remap_mscoco_category: True
train_dataloader:
dataset:
return_masks: False
transforms:
ops:
- {type: RandomPhotometricDistort, p: 0.5}
- {type: RandomZoomOut, fill: 0}
- {type: RandomIoUCrop, p: 0.8}
# - {type: SanitizeBoundingBox, min_size: 1}
- {type: SanitizeBoundingBoxes, min_size: 1} # TODO:修改
- {type: RandomHorizontalFlip}
- {type: Resize, size: [640, 640], }
# - {type: Resize, size: 639, max_size: 640}
# - {type: PadToSize, spatial_size: 640}
# - {type: ToImageTensor}
- {type: ToImage} # TODO:修改
# - {type: ConvertDtype}
- {type: ConvertImageDtype} # TODO:修改
# - {type: SanitizeBoundingBox, min_size: 1}
- {type: SanitizeBoundingBoxes, min_size: 1} # TODO:修改
- {type: ConvertBox, out_fmt: 'cxcywh', normalize: True}
shuffle: True
batch_size: 4
num_workers: 4
collate_fn: default_collate_fn
val_dataloader:
dataset:
transforms:
ops:
# - {type: Resize, size: 639, max_size: 640}
# - {type: PadToSize, spatial_size: 640}
- {type: Resize, size: [640, 640]}
# - {type: ToImageTensor}
- {type: ToImage} # TODO:修改
# - {type: ConvertDtype}
- {type: ConvertImageDtype} # TODO:修改
shuffle: False
batch_size: 8
num_workers: 4
collate_fn: default_collate_fn
use_ema: True
ema:
type: ModelEMA
decay: 0.9999
warmups: 2000
find_unused_parameters: True
epoches: 72
clip_max_norm: 0.1
optimizer:
type: AdamW
params:
-
params: 'backbone'
lr: 0.00001
-
params: '^(?=.*encoder(?=.*bias|.*norm.*weight)).*$'
weight_decay: 0.
-
params: '^(?=.*decoder(?=.*bias|.*norm.*weight)).*$'
weight_decay: 0.
lr: 0.0001
betas: [0.9, 0.999]
weight_decay: 0.0001
lr_scheduler:
type: MultiStepLR
milestones: [1000]
gamma: 0.1
task: detection
model: RTDETR
criterion: SetCriterion
postprocessor: RTDETRPostProcessor
RTDETR:
backbone: PResNet
encoder: HybridEncoder
decoder: RTDETRTransformer
multi_scale: [480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800]
PResNet:
depth: 50
variant: d
freeze_at: 0
return_idx: [1, 2, 3]
num_stages: 4
freeze_norm: True
pretrained: True
HybridEncoder:
in_channels: [512, 1024, 2048]
feat_strides: [8, 16, 32]
# intra
hidden_dim: 256
use_encoder_idx: [2]
num_encoder_layers: 1
nhead: 8
dim_feedforward: 1024
dropout: 0.
enc_act: 'gelu'
pe_temperature: 10000
# cross
expansion: 1.0
depth_mult: 1
act: 'silu'
# eval
eval_spatial_size: [640, 640]
RTDETRTransformer:
feat_channels: [256, 256, 256]
feat_strides: [8, 16, 32]
hidden_dim: 256
num_levels: 3
num_queries: 300
num_decoder_layers: 6
num_denoising: 100
eval_idx: -1
eval_spatial_size: [640, 640]
use_focal_loss: True
RTDETRPostProcessor:
num_top_queries: 300
SetCriterion:
weight_dict: {loss_vfl: 1, loss_bbox: 5, loss_giou: 2,}
losses: ['vfl', 'boxes', ]
alpha: 0.75
gamma: 2.0
matcher:
type: HungarianMatcher
weight_dict: {cost_class: 2, cost_bbox: 5, cost_giou: 2}
# use_focal_loss: True
alpha: 0.25
gamma: 2.0
__include__: [
'../dataset/coco_detection.yml',
'../runtime.yml',
'./include/dataloader.yml',
'./include/optimizer.yml',
'./include/rtdetr_r50vd.yml',
]
PResNet:
depth: 101
HybridEncoder:
# intra
hidden_dim: 384
dim_feedforward: 2048
RTDETRTransformer:
feat_channels: [384, 384, 384]
optimizer:
type: AdamW
params:
-
params: 'backbone'
lr: 0.000001
\ No newline at end of file
__include__: [
'../dataset/coco_detection.yml',
'../runtime.yml',
'./include/dataloader.yml',
'./include/optimizer.yml',
'./include/rtdetr_r50vd.yml',
]
output_dir: /home/RT-DETR/output/rtdetr_r18vd_6x_coco
PResNet:
depth: 18
freeze_at: -1
freeze_norm: False
pretrained: True
HybridEncoder:
in_channels: [128, 256, 512]
hidden_dim: 256
expansion: 0.5
RTDETRTransformer:
eval_idx: -1
num_decoder_layers: 3
num_denoising: 100
optimizer:
type: AdamW
params:
-
params: '^(?=.*backbone)(?=.*norm).*$'
lr: 0.00001
weight_decay: 0.
-
params: '^(?=.*backbone)(?!.*norm).*$'
lr: 0.00001
-
params: '^(?=.*(?:encoder|decoder))(?=.*(?:norm|bias)).*$'
weight_decay: 0.
lr: 0.0001
betas: [0.9, 0.999]
weight_decay: 0.0001
__include__: [
'../dataset/coco_detection.yml',
'../runtime.yml',
'./include/dataloader.yml',
'./include/optimizer.yml',
'./include/rtdetr_r50vd.yml',
]
output_dir: ./output/rtdetr_r34vd_6x_coco
PResNet:
depth: 34
freeze_at: -1
freeze_norm: False
pretrained: True
HybridEncoder:
in_channels: [128, 256, 512]
hidden_dim: 256
expansion: 0.5
RTDETRTransformer:
num_decoder_layers: 4
optimizer:
type: AdamW
params:
-
params: '^(?=.*backbone)(?=.*norm|bn).*$'
weight_decay: 0.
lr: 0.00001
-
params: '^(?=.*backbone)(?!.*norm|bn).*$'
lr: 0.00001
-
params: '^(?=.*(?:encoder|decoder))(?=.*(?:norm|bn|bias)).*$'
weight_decay: 0.
lr: 0.0001
betas: [0.9, 0.999]
weight_decay: 0.0001
__include__: [
'../dataset/coco_detection.yml',
'../runtime.yml',
'./include/dataloader.yml',
'./include/optimizer.yml',
'./include/rtdetr_r50vd.yml',
]
output_dir: ./output/rtdetr_r50vd_6x_coco
__include__: [
'../dataset/coco_detection.yml',
'../runtime.yml',
'./include/dataloader.yml',
'./include/optimizer.yml',
'./include/rtdetr_r50vd.yml',
]
output_dir: ./output/rtdetr_r50vd_m_6x_coco
HybridEncoder:
expansion: 0.5
RTDETRTransformer:
eval_idx: 2 # use 3th decoder layer to eval
\ No newline at end of file
sync_bn: True
find_unused_parameters: False
use_amp: False
scaler:
type: GradScaler
enabled: True
use_ema: False
ema:
type: ModelEMA
decay: 0.9999
warmups: 2000
# torch==2.0.1
# torchvision==0.15.2
onnx==1.14.0
onnxruntime==1.15.1
pycocotools
PyYAML
scipy
from . import data
from . import nn
from . import optim
from . import zoo
"""by lyuwenyu
"""
# from .yaml_utils import register, create, load_config, merge_config, merge_dict
from .yaml_utils import *
from .config import BaseConfig
from .yaml_config import YAMLConfig
"""by lyuwenyu
"""
from pprint import pprint
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.cuda.amp.grad_scaler import GradScaler
from typing import Callable, List, Dict
__all__ = ['BaseConfig', ]
class BaseConfig(object):
# TODO property
def __init__(self) -> None:
super().__init__()
self.task :str = None
self._model :nn.Module = None
self._postprocessor :nn.Module = None
self._criterion :nn.Module = None
self._optimizer :Optimizer = None
self._lr_scheduler :LRScheduler = None
self._train_dataloader :DataLoader = None
self._val_dataloader :DataLoader = None
self._ema :nn.Module = None
self._scaler :GradScaler = None
self.train_dataset :Dataset = None
self.val_dataset :Dataset = None
self.num_workers :int = 0
self.collate_fn :Callable = None
self.batch_size :int = None
self._train_batch_size :int = None
self._val_batch_size :int = None
self._train_shuffle: bool = None
self._val_shuffle: bool = None
self.evaluator :Callable[[nn.Module, DataLoader, str], ] = None
# runtime
self.resume :str = None
self.tuning :str = None
self.epoches :int = None
self.last_epoch :int = -1
self.end_epoch :int = None
self.use_amp :bool = False
self.use_ema :bool = False
self.sync_bn :bool = False
self.clip_max_norm : float = None
self.find_unused_parameters :bool = None
# self.ema_decay: float = 0.9999
# self.grad_clip_: Callable = None
self.log_dir :str = './logs/'
self.log_step :int = 10
self._output_dir :str = None
self._print_freq :int = None
self.checkpoint_step :int = 1
# self.device :str = torch.device('cpu')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = torch.device(device)
@property
def model(self, ) -> nn.Module:
return self._model
@model.setter
def model(self, m):
assert isinstance(m, nn.Module), f'{type(m)} != nn.Module, please check your model class'
self._model = m
@property
def postprocessor(self, ) -> nn.Module:
return self._postprocessor
@postprocessor.setter
def postprocessor(self, m):
assert isinstance(m, nn.Module), f'{type(m)} != nn.Module, please check your model class'
self._postprocessor = m
@property
def criterion(self, ) -> nn.Module:
return self._criterion
@criterion.setter
def criterion(self, m):
assert isinstance(m, nn.Module), f'{type(m)} != nn.Module, please check your model class'
self._criterion = m
@property
def optimizer(self, ) -> Optimizer:
return self._optimizer
@optimizer.setter
def optimizer(self, m):
assert isinstance(m, Optimizer), f'{type(m)} != optim.Optimizer, please check your model class'
self._optimizer = m
@property
def lr_scheduler(self, ) -> LRScheduler:
return self._lr_scheduler
@lr_scheduler.setter
def lr_scheduler(self, m):
assert isinstance(m, LRScheduler), f'{type(m)} != LRScheduler, please check your model class'
self._lr_scheduler = m
@property
def train_dataloader(self):
if self._train_dataloader is None and self.train_dataset is not None:
loader = DataLoader(self.train_dataset,
batch_size=self.train_batch_size,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
shuffle=self.train_shuffle, )
loader.shuffle = self.train_shuffle
self._train_dataloader = loader
return self._train_dataloader
@train_dataloader.setter
def train_dataloader(self, loader):
self._train_dataloader = loader
@property
def val_dataloader(self):
if self._val_dataloader is None and self.val_dataset is not None:
loader = DataLoader(self.val_dataset,
batch_size=self.val_batch_size,
num_workers=self.num_workers,
drop_last=False,
collate_fn=self.collate_fn,
shuffle=self.val_shuffle)
loader.shuffle = self.val_shuffle
self._val_dataloader = loader
return self._val_dataloader
@val_dataloader.setter
def val_dataloader(self, loader):
self._val_dataloader = loader
# TODO method
# @property
# def ema(self, ) -> nn.Module:
# if self._ema is None and self.use_ema and self.model is not None:
# self._ema = ModelEMA(self.model, self.ema_decay)
# return self._ema
@property
def ema(self, ) -> nn.Module:
return self._ema
@ema.setter
def ema(self, obj):
self._ema = obj
@property
def scaler(self) -> GradScaler:
if self._scaler is None and self.use_amp and torch.cuda.is_available():
self._scaler = GradScaler()
return self._scaler
@scaler.setter
def scaler(self, obj: GradScaler):
self._scaler = obj
@property
def val_shuffle(self):
if self._val_shuffle is None:
print('warning: set default val_shuffle=False')
return False
return self._val_shuffle
@val_shuffle.setter
def val_shuffle(self, shuffle):
assert isinstance(shuffle, bool), 'shuffle must be bool'
self._val_shuffle = shuffle
@property
def train_shuffle(self):
if self._train_shuffle is None:
print('warning: set default train_shuffle=True')
return True
return self._train_shuffle
@train_shuffle.setter
def train_shuffle(self, shuffle):
assert isinstance(shuffle, bool), 'shuffle must be bool'
self._train_shuffle = shuffle
@property
def train_batch_size(self):
if self._train_batch_size is None and isinstance(self.batch_size, int):
print(f'warning: set train_batch_size=batch_size={self.batch_size}')
return self.batch_size
return self._train_batch_size
@train_batch_size.setter
def train_batch_size(self, batch_size):
assert isinstance(batch_size, int), 'batch_size must be int'
self._train_batch_size = batch_size
@property
def val_batch_size(self):
if self._val_batch_size is None:
print(f'warning: set val_batch_size=batch_size={self.batch_size}')
return self.batch_size
return self._val_batch_size
@val_batch_size.setter
def val_batch_size(self, batch_size):
assert isinstance(batch_size, int), 'batch_size must be int'
self._val_batch_size = batch_size
@property
def output_dir(self):
if self._output_dir is None:
return self.log_dir
return self._output_dir
@output_dir.setter
def output_dir(self, root):
self._output_dir = root
@property
def print_freq(self):
if self._print_freq is None:
# self._print_freq = self.log_step
return self.log_step
return self._print_freq
@print_freq.setter
def print_freq(self, n):
assert isinstance(n, int), 'print_freq must be int'
self._print_freq = n
# def __repr__(self) -> str:
# pass
"""by lyuwenyu
"""
import torch
import torch.nn as nn
import re
import copy
from .config import BaseConfig
from .yaml_utils import load_config, merge_config, create, merge_dict
class YAMLConfig(BaseConfig):
def __init__(self, cfg_path: str, **kwargs) -> None:
super().__init__()
cfg = load_config(cfg_path)
merge_dict(cfg, kwargs)
# pprint(cfg)
self.yaml_cfg = cfg
self.log_step = cfg.get('log_step', 100)
self.checkpoint_step = cfg.get('checkpoint_step', 1)
self.epoches = cfg.get('epoches', -1)
self.resume = cfg.get('resume', '')
self.tuning = cfg.get('tuning', '')
self.sync_bn = cfg.get('sync_bn', False)
self.output_dir = cfg.get('output_dir', None)
self.use_ema = cfg.get('use_ema', False)
self.use_amp = cfg.get('use_amp', False)
self.autocast = cfg.get('autocast', dict())
self.find_unused_parameters = cfg.get('find_unused_parameters', None)
self.clip_max_norm = cfg.get('clip_max_norm', 0.)
@property
def model(self, ) -> torch.nn.Module:
if self._model is None and 'model' in self.yaml_cfg:
merge_config(self.yaml_cfg)
self._model = create(self.yaml_cfg['model'])
return self._model
@property
def postprocessor(self, ) -> torch.nn.Module:
if self._postprocessor is None and 'postprocessor' in self.yaml_cfg:
merge_config(self.yaml_cfg)
self._postprocessor = create(self.yaml_cfg['postprocessor'])
return self._postprocessor
@property
def criterion(self, ):
if self._criterion is None and 'criterion' in self.yaml_cfg:
merge_config(self.yaml_cfg)
self._criterion = create(self.yaml_cfg['criterion'])
return self._criterion
@property
def optimizer(self, ):
if self._optimizer is None and 'optimizer' in self.yaml_cfg:
merge_config(self.yaml_cfg)
params = self.get_optim_params(self.yaml_cfg['optimizer'], self.model)
self._optimizer = create('optimizer', params=params)
return self._optimizer
@property
def lr_scheduler(self, ):
if self._lr_scheduler is None and 'lr_scheduler' in self.yaml_cfg:
merge_config(self.yaml_cfg)
self._lr_scheduler = create('lr_scheduler', optimizer=self.optimizer)
print('Initial lr: ', self._lr_scheduler.get_last_lr())
return self._lr_scheduler
@property
def train_dataloader(self, ):
if self._train_dataloader is None and 'train_dataloader' in self.yaml_cfg:
merge_config(self.yaml_cfg)
self._train_dataloader = create('train_dataloader')
self._train_dataloader.shuffle = self.yaml_cfg['train_dataloader'].get('shuffle', False)
return self._train_dataloader
@property
def val_dataloader(self, ):
if self._val_dataloader is None and 'val_dataloader' in self.yaml_cfg:
merge_config(self.yaml_cfg)
self._val_dataloader = create('val_dataloader')
self._val_dataloader.shuffle = self.yaml_cfg['val_dataloader'].get('shuffle', False)
return self._val_dataloader
@property
def ema(self, ):
if self._ema is None and self.yaml_cfg.get('use_ema', False):
merge_config(self.yaml_cfg)
self._ema = create('ema', model=self.model)
return self._ema
@property
def scaler(self, ):
if self._scaler is None and self.yaml_cfg.get('use_amp', False):
merge_config(self.yaml_cfg)
self._scaler = create('scaler')
return self._scaler
@staticmethod
def get_optim_params(cfg: dict, model: nn.Module):
'''
E.g.:
^(?=.*a)(?=.*b).*$ means including a and b
^((?!b.)*a((?!b).)*$ means including a but not b
^((?!b|c).)*a((?!b|c).)*$ means including a but not (b | c)
'''
assert 'type' in cfg, ''
cfg = copy.deepcopy(cfg)
if 'params' not in cfg:
return model.parameters()
assert isinstance(cfg['params'], list), ''
param_groups = []
visited = []
for pg in cfg['params']:
pattern = pg['params']
params = {k: v for k, v in model.named_parameters() if v.requires_grad and len(re.findall(pattern, k)) > 0}
pg['params'] = params.values()
param_groups.append(pg)
visited.extend(list(params.keys()))
names = [k for k, v in model.named_parameters() if v.requires_grad]
if len(visited) < len(names):
unseen = set(names) - set(visited)
params = {k: v for k, v in model.named_parameters() if v.requires_grad and k in unseen}
param_groups.append({'params': params.values()})
visited.extend(list(params.keys()))
assert len(visited) == len(names), ''
return param_groups
""""by lyuwenyu
"""
import os
import yaml
import inspect
import importlib
__all__ = ['GLOBAL_CONFIG', 'register', 'create', 'load_config', 'merge_config', 'merge_dict']
GLOBAL_CONFIG = dict()
INCLUDE_KEY = '__include__'
def register(cls: type):
'''
Args:
cls (type): Module class to be registered.
'''
if cls.__name__ in GLOBAL_CONFIG:
raise ValueError('{} already registered'.format(cls.__name__))
if inspect.isfunction(cls):
GLOBAL_CONFIG[cls.__name__] = cls
elif inspect.isclass(cls):
GLOBAL_CONFIG[cls.__name__] = extract_schema(cls)
else:
raise ValueError(f'register {cls}')
return cls
def extract_schema(cls: type):
'''
Args:
cls (type),
Return:
Dict,
'''
argspec = inspect.getfullargspec(cls.__init__)
arg_names = [arg for arg in argspec.args if arg != 'self']
num_defualts = len(argspec.defaults) if argspec.defaults is not None else 0
num_requires = len(arg_names) - num_defualts
schame = dict()
schame['_name'] = cls.__name__
schame['_pymodule'] = importlib.import_module(cls.__module__)
schame['_inject'] = getattr(cls, '__inject__', [])
schame['_share'] = getattr(cls, '__share__', [])
for i, name in enumerate(arg_names):
if name in schame['_share']:
assert i >= num_requires, 'share config must have default value.'
value = argspec.defaults[i - num_requires]
elif i >= num_requires:
value = argspec.defaults[i - num_requires]
else:
value = None
schame[name] = value
return schame
def create(type_or_name, **kwargs):
'''
'''
assert type(type_or_name) in (type, str), 'create should be class or name.'
name = type_or_name if isinstance(type_or_name, str) else type_or_name.__name__
if name in GLOBAL_CONFIG:
if hasattr(GLOBAL_CONFIG[name], '__dict__'):
return GLOBAL_CONFIG[name]
else:
raise ValueError('The module {} is not registered'.format(name))
cfg = GLOBAL_CONFIG[name]
if isinstance(cfg, dict) and 'type' in cfg:
_cfg: dict = GLOBAL_CONFIG[cfg['type']]
_cfg.update(cfg) # update global cls default args
_cfg.update(kwargs) # TODO
name = _cfg.pop('type')
return create(name)
cls = getattr(cfg['_pymodule'], name)
argspec = inspect.getfullargspec(cls.__init__)
arg_names = [arg for arg in argspec.args if arg != 'self']
cls_kwargs = {}
cls_kwargs.update(cfg)
# shared var
for k in cfg['_share']:
if k in GLOBAL_CONFIG:
cls_kwargs[k] = GLOBAL_CONFIG[k]
else:
cls_kwargs[k] = cfg[k]
# inject
for k in cfg['_inject']:
_k = cfg[k]
if _k is None:
continue
if isinstance(_k, str):
if _k not in GLOBAL_CONFIG:
raise ValueError(f'Missing inject config of {_k}.')
_cfg = GLOBAL_CONFIG[_k]
if isinstance(_cfg, dict):
cls_kwargs[k] = create(_cfg['_name'])
else:
cls_kwargs[k] = _cfg
elif isinstance(_k, dict):
if 'type' not in _k.keys():
raise ValueError(f'Missing inject for `type` style.')
_type = str(_k['type'])
if _type not in GLOBAL_CONFIG:
raise ValueError(f'Missing {_type} in inspect stage.')
# TODO modified inspace, maybe get wrong result for using `> 1`
_cfg: dict = GLOBAL_CONFIG[_type]
# _cfg_copy = copy.deepcopy(_cfg)
_cfg.update(_k) # update
cls_kwargs[k] = create(_type)
# _cfg.update(_cfg_copy) # resume
else:
raise ValueError(f'Inject does not support {_k}')
cls_kwargs = {n: cls_kwargs[n] for n in arg_names}
return cls(**cls_kwargs)
def load_config(file_path, cfg=dict()):
'''load config
'''
_, ext = os.path.splitext(file_path)
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
with open(file_path) as f:
file_cfg = yaml.load(f, Loader=yaml.Loader)
if file_cfg is None:
return {}
if INCLUDE_KEY in file_cfg:
base_yamls = list(file_cfg[INCLUDE_KEY])
for base_yaml in base_yamls:
if base_yaml.startswith('~'):
base_yaml = os.path.expanduser(base_yaml)
if not base_yaml.startswith('/'):
base_yaml = os.path.join(os.path.dirname(file_path), base_yaml)
with open(base_yaml) as f:
base_cfg = load_config(base_yaml, cfg)
merge_config(base_cfg, cfg)
return merge_config(file_cfg, cfg)
def merge_dict(dct, another_dct):
'''merge another_dct into dct
'''
for k in another_dct:
if (k in dct and isinstance(dct[k], dict) and isinstance(another_dct[k], dict)):
merge_dict(dct[k], another_dct[k])
else:
dct[k] = another_dct[k]
return dct
def merge_config(config, another_cfg=None):
"""
Merge config into global config or another_cfg.
Args:
config (dict): Config to be merged.
Returns: global config
"""
global GLOBAL_CONFIG
dct = GLOBAL_CONFIG if another_cfg is None else another_cfg
return merge_dict(dct, config)
from .coco import *
from .cifar10 import CIFAR10
from .dataloader import *
from .transforms import *
import torchvision
from typing import Optional, Callable
from src.core import register
@register
class CIFAR10(torchvision.datasets.CIFAR10):
__inject__ = ['transform', 'target_transform']
def __init__(self, root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False) -> None:
super().__init__(root, train, transform, target_transform, download)
from .coco_dataset import (
CocoDetection,
mscoco_category2label,
mscoco_label2category,
mscoco_category2name,
)
from .coco_eval import *
from .coco_utils import get_coco_api_from_dataset
\ No newline at end of file
"""
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
COCO dataset which returns image_id for evaluation.
Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
"""
import torch
import torch.utils.data
import torchvision
torchvision.disable_beta_transforms_warning()
# TODO:修改库:
# from torchvision import datapoints
from torchvision import tv_tensors
from pycocotools import mask as coco_mask
from src.core import register
__all__ = ['CocoDetection']
@register
class CocoDetection(torchvision.datasets.CocoDetection):
__inject__ = ['transforms']
__share__ = ['remap_mscoco_category']
def __init__(self, img_folder, ann_file, transforms, return_masks, remap_mscoco_category=False):
super(CocoDetection, self).__init__(img_folder, ann_file)
self._transforms = transforms
self.prepare = ConvertCocoPolysToMask(return_masks, remap_mscoco_category)
self.img_folder = img_folder
self.ann_file = ann_file
self.return_masks = return_masks
self.remap_mscoco_category = remap_mscoco_category
def __getitem__(self, idx):
img, target = super(CocoDetection, self).__getitem__(idx)
image_id = self.ids[idx]
target = {'image_id': image_id, 'annotations': target}
img, target = self.prepare(img, target)
# ['boxes', 'masks', 'labels']:
if 'boxes' in target:
# target['boxes'] = datapoints.BoundingBox(
# target['boxes'],
# format=datapoints.BoundingBoxFormat.XYXY,
# spatial_size=img.size[::-1]) # h w
target['boxes'] = tv_tensors.BoundingBoxes( # TODO:修改
target['boxes'],
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=img.size[::-1]) # h w
if 'masks' in target:
# target['masks'] = datapoints.Mask(target['masks'])
target['masks'] = tv_tensors.Mask(target['masks']) # TODO:修改
if self._transforms is not None:
img, target = self._transforms(img, target)
return img, target
def extra_repr(self) -> str:
s = f' img_folder: {self.img_folder}\n ann_file: {self.ann_file}\n'
s += f' return_masks: {self.return_masks}\n'
if hasattr(self, '_transforms') and self._transforms is not None:
s += f' transforms:\n {repr(self._transforms)}'
return s
def convert_coco_poly_to_mask(segmentations, height, width):
masks = []
for polygons in segmentations:
rles = coco_mask.frPyObjects(polygons, height, width)
mask = coco_mask.decode(rles)
if len(mask.shape) < 3:
mask = mask[..., None]
mask = torch.as_tensor(mask, dtype=torch.uint8)
mask = mask.any(dim=2)
masks.append(mask)
if masks:
masks = torch.stack(masks, dim=0)
else:
masks = torch.zeros((0, height, width), dtype=torch.uint8)
return masks
class ConvertCocoPolysToMask(object):
def __init__(self, return_masks=False, remap_mscoco_category=False):
self.return_masks = return_masks
self.remap_mscoco_category = remap_mscoco_category
def __call__(self, image, target):
w, h = image.size
image_id = target["image_id"]
image_id = torch.tensor([image_id])
anno = target["annotations"]
anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]
boxes = [obj["bbox"] for obj in anno]
# guard against no boxes via resizing
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
boxes[:, 2:] += boxes[:, :2]
boxes[:, 0::2].clamp_(min=0, max=w)
boxes[:, 1::2].clamp_(min=0, max=h)
if self.remap_mscoco_category:
classes = [mscoco_category2label[obj["category_id"]] for obj in anno]
else:
classes = [obj["category_id"] for obj in anno]
classes = torch.tensor(classes, dtype=torch.int64)
if self.return_masks:
segmentations = [obj["segmentation"] for obj in anno]
masks = convert_coco_poly_to_mask(segmentations, h, w)
keypoints = None
if anno and "keypoints" in anno[0]:
keypoints = [obj["keypoints"] for obj in anno]
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
num_keypoints = keypoints.shape[0]
if num_keypoints:
keypoints = keypoints.view(num_keypoints, -1, 3)
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
boxes = boxes[keep]
classes = classes[keep]
if self.return_masks:
masks = masks[keep]
if keypoints is not None:
keypoints = keypoints[keep]
target = {}
target["boxes"] = boxes
target["labels"] = classes
if self.return_masks:
target["masks"] = masks
target["image_id"] = image_id
if keypoints is not None:
target["keypoints"] = keypoints
# for conversion to coco api
area = torch.tensor([obj["area"] for obj in anno])
iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
target["area"] = area[keep]
target["iscrowd"] = iscrowd[keep]
target["orig_size"] = torch.as_tensor([int(w), int(h)])
target["size"] = torch.as_tensor([int(w), int(h)])
return image, target
mscoco_category2name = {
1: 'person',
2: 'bicycle',
3: 'car',
4: 'motorcycle',
5: 'airplane',
6: 'bus',
7: 'train',
8: 'truck',
9: 'boat',
10: 'traffic light',
11: 'fire hydrant',
13: 'stop sign',
14: 'parking meter',
15: 'bench',
16: 'bird',
17: 'cat',
18: 'dog',
19: 'horse',
20: 'sheep',
21: 'cow',
22: 'elephant',
23: 'bear',
24: 'zebra',
25: 'giraffe',
27: 'backpack',
28: 'umbrella',
31: 'handbag',
32: 'tie',
33: 'suitcase',
34: 'frisbee',
35: 'skis',
36: 'snowboard',
37: 'sports ball',
38: 'kite',
39: 'baseball bat',
40: 'baseball glove',
41: 'skateboard',
42: 'surfboard',
43: 'tennis racket',
44: 'bottle',
46: 'wine glass',
47: 'cup',
48: 'fork',
49: 'knife',
50: 'spoon',
51: 'bowl',
52: 'banana',
53: 'apple',
54: 'sandwich',
55: 'orange',
56: 'broccoli',
57: 'carrot',
58: 'hot dog',
59: 'pizza',
60: 'donut',
61: 'cake',
62: 'chair',
63: 'couch',
64: 'potted plant',
65: 'bed',
67: 'dining table',
70: 'toilet',
72: 'tv',
73: 'laptop',
74: 'mouse',
75: 'remote',
76: 'keyboard',
77: 'cell phone',
78: 'microwave',
79: 'oven',
80: 'toaster',
81: 'sink',
82: 'refrigerator',
84: 'book',
85: 'clock',
86: 'vase',
87: 'scissors',
88: 'teddy bear',
89: 'hair drier',
90: 'toothbrush'
}
mscoco_category2label = {k: i for i, k in enumerate(mscoco_category2name.keys())}
mscoco_label2category = {v: k for k, v in mscoco_category2label.items()}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment