Commit baf20b93 authored by dlyrm's avatar dlyrm
Browse files

update yolox

parent ec3f5448
Pipeline #679 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.transforms import LoadImageFromFile
from mmdet.datasets.transforms import LoadAnnotations, LoadPanopticAnnotations
from mmdet.registry import TRANSFORMS
def get_loading_pipeline(pipeline):
"""Only keep loading image and annotations related configuration.
Args:
pipeline (list[dict]): Data pipeline configs.
Returns:
list[dict]: The new pipeline list with only keep
loading image and annotations related configuration.
Examples:
>>> pipelines = [
... dict(type='LoadImageFromFile'),
... dict(type='LoadAnnotations', with_bbox=True),
... dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
... dict(type='RandomFlip', flip_ratio=0.5),
... dict(type='Normalize', **img_norm_cfg),
... dict(type='Pad', size_divisor=32),
... dict(type='DefaultFormatBundle'),
... dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
... ]
>>> expected_pipelines = [
... dict(type='LoadImageFromFile'),
... dict(type='LoadAnnotations', with_bbox=True)
... ]
>>> assert expected_pipelines ==\
... get_loading_pipeline(pipelines)
"""
loading_pipeline_cfg = []
for cfg in pipeline:
obj_cls = TRANSFORMS.get(cfg['type'])
# TODO:use more elegant way to distinguish loading modules
if obj_cls is not None and obj_cls in (LoadImageFromFile,
LoadAnnotations,
LoadPanopticAnnotations):
loading_pipeline_cfg.append(cfg)
assert len(loading_pipeline_cfg) == 2, \
'The data pipeline in your config file must include ' \
'loading image and annotations related pipeline.'
return loading_pipeline_cfg
# Copyright (c) OpenMMLab. All rights reserved.
import os.path
from typing import Optional
import mmengine
from mmdet.registry import DATASETS
from .coco import CocoDataset
@DATASETS.register_module()
class V3DetDataset(CocoDataset):
"""Dataset for V3Det."""
METAINFO = {
'classes': None,
'palette': None,
}
def __init__(
self,
*args,
metainfo: Optional[dict] = None,
data_root: str = '',
label_file='annotations/category_name_13204_v3det_2023_v1.txt', # noqa
**kwargs) -> None:
class_names = tuple(
mmengine.list_from_file(os.path.join(data_root, label_file)))
if metainfo is None:
metainfo = {'classes': class_names}
super().__init__(
*args, data_root=data_root, metainfo=metainfo, **kwargs)
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.registry import DATASETS
from .xml_style import XMLDataset
@DATASETS.register_module()
class VOCDataset(XMLDataset):
"""Dataset for PASCAL VOC."""
METAINFO = {
'classes':
('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',
'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',
'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'),
# palette is a list of color tuples, which is used for visualization.
'palette': [(106, 0, 228), (119, 11, 32), (165, 42, 42), (0, 0, 192),
(197, 226, 255), (0, 60, 100), (0, 0, 142), (255, 77, 255),
(153, 69, 1), (120, 166, 157), (0, 182, 199),
(0, 226, 252), (182, 182, 255), (0, 0, 230), (220, 20, 60),
(163, 255, 0), (0, 82, 0), (3, 95, 161), (0, 80, 100),
(183, 130, 88)]
}
def __init__(self, **kwargs):
super().__init__(**kwargs)
if 'VOC2007' in self.sub_data_root:
self._metainfo['dataset_type'] = 'VOC2007'
elif 'VOC2012' in self.sub_data_root:
self._metainfo['dataset_type'] = 'VOC2012'
else:
self._metainfo['dataset_type'] = None
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import xml.etree.ElementTree as ET
from mmengine.dist import is_main_process
from mmengine.fileio import get_local_path, list_from_file
from mmengine.utils import ProgressBar
from mmdet.registry import DATASETS
from mmdet.utils.typing_utils import List, Union
from .xml_style import XMLDataset
@DATASETS.register_module()
class WIDERFaceDataset(XMLDataset):
"""Reader for the WIDER Face dataset in PASCAL VOC format.
Conversion scripts can be found in
https://github.com/sovrasov/wider-face-pascal-voc-annotations
"""
METAINFO = {'classes': ('face', ), 'palette': [(0, 255, 0)]}
def load_data_list(self) -> List[dict]:
"""Load annotation from XML style ann_file.
Returns:
list[dict]: Annotation info from XML file.
"""
assert self._metainfo.get('classes', None) is not None, \
'classes in `XMLDataset` can not be None.'
self.cat2label = {
cat: i
for i, cat in enumerate(self._metainfo['classes'])
}
data_list = []
img_ids = list_from_file(self.ann_file, backend_args=self.backend_args)
# loading process takes around 10 mins
if is_main_process():
prog_bar = ProgressBar(len(img_ids))
for img_id in img_ids:
raw_img_info = {}
raw_img_info['img_id'] = img_id
raw_img_info['file_name'] = f'{img_id}.jpg'
parsed_data_info = self.parse_data_info(raw_img_info)
data_list.append(parsed_data_info)
if is_main_process():
prog_bar.update()
return data_list
def parse_data_info(self, img_info: dict) -> Union[dict, List[dict]]:
"""Parse raw annotation to target format.
Args:
img_info (dict): Raw image information, usually it includes
`img_id`, `file_name`, and `xml_path`.
Returns:
Union[dict, List[dict]]: Parsed annotation.
"""
data_info = {}
img_id = img_info['img_id']
xml_path = osp.join(self.data_prefix['img'], 'Annotations',
f'{img_id}.xml')
data_info['img_id'] = img_id
data_info['xml_path'] = xml_path
# deal with xml file
with get_local_path(
xml_path, backend_args=self.backend_args) as local_path:
raw_ann_info = ET.parse(local_path)
root = raw_ann_info.getroot()
size = root.find('size')
width = int(size.find('width').text)
height = int(size.find('height').text)
folder = root.find('folder').text
img_path = osp.join(self.data_prefix['img'], folder,
img_info['file_name'])
data_info['img_path'] = img_path
data_info['height'] = height
data_info['width'] = width
# Coordinates are in range [0, width - 1 or height - 1]
data_info['instances'] = self._parse_instance_info(
raw_ann_info, minus_one=False)
return data_info
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import xml.etree.ElementTree as ET
from typing import List, Optional, Union
import mmcv
from mmengine.fileio import get, get_local_path, list_from_file
from mmdet.registry import DATASETS
from .base_det_dataset import BaseDetDataset
@DATASETS.register_module()
class XMLDataset(BaseDetDataset):
"""XML dataset for detection.
Args:
img_subdir (str): Subdir where images are stored. Default: JPEGImages.
ann_subdir (str): Subdir where annotations are. Default: Annotations.
backend_args (dict, optional): Arguments to instantiate the
corresponding backend. Defaults to None.
"""
def __init__(self,
img_subdir: str = 'JPEGImages',
ann_subdir: str = 'Annotations',
**kwargs) -> None:
self.img_subdir = img_subdir
self.ann_subdir = ann_subdir
super().__init__(**kwargs)
@property
def sub_data_root(self) -> str:
"""Return the sub data root."""
return self.data_prefix.get('sub_data_root', '')
def load_data_list(self) -> List[dict]:
"""Load annotation from XML style ann_file.
Returns:
list[dict]: Annotation info from XML file.
"""
assert self._metainfo.get('classes', None) is not None, \
'`classes` in `XMLDataset` can not be None.'
self.cat2label = {
cat: i
for i, cat in enumerate(self._metainfo['classes'])
}
data_list = []
img_ids = list_from_file(self.ann_file, backend_args=self.backend_args)
for img_id in img_ids:
file_name = osp.join(self.img_subdir, f'{img_id}.jpg')
xml_path = osp.join(self.sub_data_root, self.ann_subdir,
f'{img_id}.xml')
raw_img_info = {}
raw_img_info['img_id'] = img_id
raw_img_info['file_name'] = file_name
raw_img_info['xml_path'] = xml_path
parsed_data_info = self.parse_data_info(raw_img_info)
data_list.append(parsed_data_info)
return data_list
@property
def bbox_min_size(self) -> Optional[int]:
"""Return the minimum size of bounding boxes in the images."""
if self.filter_cfg is not None:
return self.filter_cfg.get('bbox_min_size', None)
else:
return None
def parse_data_info(self, img_info: dict) -> Union[dict, List[dict]]:
"""Parse raw annotation to target format.
Args:
img_info (dict): Raw image information, usually it includes
`img_id`, `file_name`, and `xml_path`.
Returns:
Union[dict, List[dict]]: Parsed annotation.
"""
data_info = {}
img_path = osp.join(self.sub_data_root, img_info['file_name'])
data_info['img_path'] = img_path
data_info['img_id'] = img_info['img_id']
data_info['xml_path'] = img_info['xml_path']
# deal with xml file
with get_local_path(
img_info['xml_path'],
backend_args=self.backend_args) as local_path:
raw_ann_info = ET.parse(local_path)
root = raw_ann_info.getroot()
size = root.find('size')
if size is not None:
width = int(size.find('width').text)
height = int(size.find('height').text)
else:
img_bytes = get(img_path, backend_args=self.backend_args)
img = mmcv.imfrombytes(img_bytes, backend='cv2')
height, width = img.shape[:2]
del img, img_bytes
data_info['height'] = height
data_info['width'] = width
data_info['instances'] = self._parse_instance_info(
raw_ann_info, minus_one=True)
return data_info
def _parse_instance_info(self,
raw_ann_info: ET,
minus_one: bool = True) -> List[dict]:
"""parse instance information.
Args:
raw_ann_info (ElementTree): ElementTree object.
minus_one (bool): Whether to subtract 1 from the coordinates.
Defaults to True.
Returns:
List[dict]: List of instances.
"""
instances = []
for obj in raw_ann_info.findall('object'):
instance = {}
name = obj.find('name').text
if name not in self._metainfo['classes']:
continue
difficult = obj.find('difficult')
difficult = 0 if difficult is None else int(difficult.text)
bnd_box = obj.find('bndbox')
bbox = [
int(float(bnd_box.find('xmin').text)),
int(float(bnd_box.find('ymin').text)),
int(float(bnd_box.find('xmax').text)),
int(float(bnd_box.find('ymax').text))
]
# VOC needs to subtract 1 from the coordinates
if minus_one:
bbox = [x - 1 for x in bbox]
ignore = False
if self.bbox_min_size is not None:
assert not self.test_mode
w = bbox[2] - bbox[0]
h = bbox[3] - bbox[1]
if w < self.bbox_min_size or h < self.bbox_min_size:
ignore = True
if difficult or ignore:
instance['ignore_flag'] = 1
else:
instance['ignore_flag'] = 0
instance['bbox'] = bbox
instance['bbox_label'] = self.cat2label[name]
instances.append(instance)
return instances
def filter_data(self) -> List[dict]:
"""Filter annotations according to filter_cfg.
Returns:
List[dict]: Filtered results.
"""
if self.test_mode:
return self.data_list
filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) \
if self.filter_cfg is not None else False
min_size = self.filter_cfg.get('min_size', 0) \
if self.filter_cfg is not None else 0
valid_data_infos = []
for i, data_info in enumerate(self.data_list):
width = data_info['width']
height = data_info['height']
if filter_empty_gt and len(data_info['instances']) == 0:
continue
if min(width, height) >= min_size:
valid_data_infos.append(data_info)
return valid_data_infos
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.registry import DATASETS
from .base_video_dataset import BaseVideoDataset
@DATASETS.register_module()
class YouTubeVISDataset(BaseVideoDataset):
"""YouTube VIS dataset for video instance segmentation.
Args:
dataset_version (str): Select dataset year version.
"""
def __init__(self, dataset_version: str, *args, **kwargs):
self.set_dataset_classes(dataset_version)
super().__init__(*args, **kwargs)
@classmethod
def set_dataset_classes(cls, dataset_version: str) -> None:
"""Pass the category of the corresponding year to metainfo.
Args:
dataset_version (str): Select dataset year version.
"""
classes_2019_version = ('person', 'giant_panda', 'lizard', 'parrot',
'skateboard', 'sedan', 'ape', 'dog', 'snake',
'monkey', 'hand', 'rabbit', 'duck', 'cat',
'cow', 'fish', 'train', 'horse', 'turtle',
'bear', 'motorbike', 'giraffe', 'leopard',
'fox', 'deer', 'owl', 'surfboard', 'airplane',
'truck', 'zebra', 'tiger', 'elephant',
'snowboard', 'boat', 'shark', 'mouse', 'frog',
'eagle', 'earless_seal', 'tennis_racket')
classes_2021_version = ('airplane', 'bear', 'bird', 'boat', 'car',
'cat', 'cow', 'deer', 'dog', 'duck',
'earless_seal', 'elephant', 'fish',
'flying_disc', 'fox', 'frog', 'giant_panda',
'giraffe', 'horse', 'leopard', 'lizard',
'monkey', 'motorbike', 'mouse', 'parrot',
'person', 'rabbit', 'shark', 'skateboard',
'snake', 'snowboard', 'squirrel', 'surfboard',
'tennis_racket', 'tiger', 'train', 'truck',
'turtle', 'whale', 'zebra')
if dataset_version == '2019':
cls.METAINFO = dict(classes=classes_2019_version)
elif dataset_version == '2021':
cls.METAINFO = dict(classes=classes_2021_version)
else:
raise NotImplementedError('Not supported YouTubeVIS dataset'
f'version: {dataset_version}')
# Copyright (c) OpenMMLab. All rights reserved.
from .hooks import * # noqa: F401, F403
from .optimizers import * # noqa: F401, F403
from .runner import * # noqa: F401, F403
from .schedulers import * # noqa: F401, F403
# Copyright (c) OpenMMLab. All rights reserved.
from .checkloss_hook import CheckInvalidLossHook
from .mean_teacher_hook import MeanTeacherHook
from .memory_profiler_hook import MemoryProfilerHook
from .num_class_check_hook import NumClassCheckHook
from .pipeline_switch_hook import PipelineSwitchHook
from .set_epoch_info_hook import SetEpochInfoHook
from .sync_norm_hook import SyncNormHook
from .utils import trigger_visualization_hook
from .visualization_hook import DetVisualizationHook, TrackVisualizationHook
from .yolox_mode_switch_hook import YOLOXModeSwitchHook
__all__ = [
'YOLOXModeSwitchHook', 'SyncNormHook', 'CheckInvalidLossHook',
'SetEpochInfoHook', 'MemoryProfilerHook', 'DetVisualizationHook',
'NumClassCheckHook', 'MeanTeacherHook', 'trigger_visualization_hook',
'PipelineSwitchHook', 'TrackVisualizationHook'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch
from mmengine.hooks import Hook
from mmengine.runner import Runner
from mmdet.registry import HOOKS
@HOOKS.register_module()
class CheckInvalidLossHook(Hook):
"""Check invalid loss hook.
This hook will regularly check whether the loss is valid
during training.
Args:
interval (int): Checking interval (every k iterations).
Default: 50.
"""
def __init__(self, interval: int = 50) -> None:
self.interval = interval
def after_train_iter(self,
runner: Runner,
batch_idx: int,
data_batch: Optional[dict] = None,
outputs: Optional[dict] = None) -> None:
"""Regularly check whether the loss is valid every n iterations.
Args:
runner (:obj:`Runner`): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (dict, Optional): Data from dataloader.
Defaults to None.
outputs (dict, Optional): Outputs from model. Defaults to None.
"""
if self.every_n_train_iters(runner, self.interval):
assert torch.isfinite(outputs['loss']), \
runner.logger.info('loss become infinite or NaN!')
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch.nn as nn
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmengine.runner import Runner
from mmdet.registry import HOOKS
@HOOKS.register_module()
class MeanTeacherHook(Hook):
"""Mean Teacher Hook.
Mean Teacher is an efficient semi-supervised learning method in
`Mean Teacher <https://arxiv.org/abs/1703.01780>`_.
This method requires two models with exactly the same structure,
as the student model and the teacher model, respectively.
The student model updates the parameters through gradient descent,
and the teacher model updates the parameters through
exponential moving average of the student model.
Compared with the student model, the teacher model
is smoother and accumulates more knowledge.
Args:
momentum (float): The momentum used for updating teacher's parameter.
Teacher's parameter are updated with the formula:
`teacher = (1-momentum) * teacher + momentum * student`.
Defaults to 0.001.
interval (int): Update teacher's parameter every interval iteration.
Defaults to 1.
skip_buffers (bool): Whether to skip the model buffers, such as
batchnorm running stats (running_mean, running_var), it does not
perform the ema operation. Default to True.
"""
def __init__(self,
momentum: float = 0.001,
interval: int = 1,
skip_buffer=True) -> None:
assert 0 < momentum < 1
self.momentum = momentum
self.interval = interval
self.skip_buffers = skip_buffer
def before_train(self, runner: Runner) -> None:
"""To check that teacher model and student model exist."""
model = runner.model
if is_model_wrapper(model):
model = model.module
assert hasattr(model, 'teacher')
assert hasattr(model, 'student')
# only do it at initial stage
if runner.iter == 0:
self.momentum_update(model, 1)
def after_train_iter(self,
runner: Runner,
batch_idx: int,
data_batch: Optional[dict] = None,
outputs: Optional[dict] = None) -> None:
"""Update teacher's parameter every self.interval iterations."""
if (runner.iter + 1) % self.interval != 0:
return
model = runner.model
if is_model_wrapper(model):
model = model.module
self.momentum_update(model, self.momentum)
def momentum_update(self, model: nn.Module, momentum: float) -> None:
"""Compute the moving average of the parameters using exponential
moving average."""
if self.skip_buffers:
for (src_name, src_parm), (dst_name, dst_parm) in zip(
model.student.named_parameters(),
model.teacher.named_parameters()):
dst_parm.data.mul_(1 - momentum).add_(
src_parm.data, alpha=momentum)
else:
for (src_parm,
dst_parm) in zip(model.student.state_dict().values(),
model.teacher.state_dict().values()):
# exclude num_tracking
if dst_parm.dtype.is_floating_point:
dst_parm.data.mul_(1 - momentum).add_(
src_parm.data, alpha=momentum)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence
from mmengine.hooks import Hook
from mmengine.runner import Runner
from mmdet.registry import HOOKS
from mmdet.structures import DetDataSample
@HOOKS.register_module()
class MemoryProfilerHook(Hook):
"""Memory profiler hook recording memory information including virtual
memory, swap memory, and the memory of the current process.
Args:
interval (int): Checking interval (every k iterations).
Default: 50.
"""
def __init__(self, interval: int = 50) -> None:
try:
from psutil import swap_memory, virtual_memory
self._swap_memory = swap_memory
self._virtual_memory = virtual_memory
except ImportError:
raise ImportError('psutil is not installed, please install it by: '
'pip install psutil')
try:
from memory_profiler import memory_usage
self._memory_usage = memory_usage
except ImportError:
raise ImportError(
'memory_profiler is not installed, please install it by: '
'pip install memory_profiler')
self.interval = interval
def _record_memory_information(self, runner: Runner) -> None:
"""Regularly record memory information.
Args:
runner (:obj:`Runner`): The runner of the training or evaluation
process.
"""
# in Byte
virtual_memory = self._virtual_memory()
swap_memory = self._swap_memory()
# in MB
process_memory = self._memory_usage()[0]
factor = 1024 * 1024
runner.logger.info(
'Memory information '
'available_memory: '
f'{round(virtual_memory.available / factor)} MB, '
'used_memory: '
f'{round(virtual_memory.used / factor)} MB, '
f'memory_utilization: {virtual_memory.percent} %, '
'available_swap_memory: '
f'{round((swap_memory.total - swap_memory.used) / factor)}'
' MB, '
f'used_swap_memory: {round(swap_memory.used / factor)} MB, '
f'swap_memory_utilization: {swap_memory.percent} %, '
'current_process_memory: '
f'{round(process_memory)} MB')
def after_train_iter(self,
runner: Runner,
batch_idx: int,
data_batch: Optional[dict] = None,
outputs: Optional[dict] = None) -> None:
"""Regularly record memory information.
Args:
runner (:obj:`Runner`): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (dict, optional): Data from dataloader.
Defaults to None.
outputs (dict, optional): Outputs from model. Defaults to None.
"""
if self.every_n_inner_iters(batch_idx, self.interval):
self._record_memory_information(runner)
def after_val_iter(
self,
runner: Runner,
batch_idx: int,
data_batch: Optional[dict] = None,
outputs: Optional[Sequence[DetDataSample]] = None) -> None:
"""Regularly record memory information.
Args:
runner (:obj:`Runner`): The runner of the validation process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (dict, optional): Data from dataloader.
Defaults to None.
outputs (Sequence[:obj:`DetDataSample`], optional):
Outputs from model. Defaults to None.
"""
if self.every_n_inner_iters(batch_idx, self.interval):
self._record_memory_information(runner)
def after_test_iter(
self,
runner: Runner,
batch_idx: int,
data_batch: Optional[dict] = None,
outputs: Optional[Sequence[DetDataSample]] = None) -> None:
"""Regularly record memory information.
Args:
runner (:obj:`Runner`): The runner of the testing process.
batch_idx (int): The index of the current batch in the test loop.
data_batch (dict, optional): Data from dataloader.
Defaults to None.
outputs (Sequence[:obj:`DetDataSample`], optional):
Outputs from model. Defaults to None.
"""
if self.every_n_inner_iters(batch_idx, self.interval):
self._record_memory_information(runner)
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import VGG
from mmengine.hooks import Hook
from mmengine.runner import Runner
from mmdet.registry import HOOKS
@HOOKS.register_module()
class NumClassCheckHook(Hook):
"""Check whether the `num_classes` in head matches the length of `classes`
in `dataset.metainfo`."""
def _check_head(self, runner: Runner, mode: str) -> None:
"""Check whether the `num_classes` in head matches the length of
`classes` in `dataset.metainfo`.
Args:
runner (:obj:`Runner`): The runner of the training or evaluation
process.
"""
assert mode in ['train', 'val']
model = runner.model
dataset = runner.train_dataloader.dataset if mode == 'train' else \
runner.val_dataloader.dataset
if dataset.metainfo.get('classes', None) is None:
runner.logger.warning(
f'Please set `classes` '
f'in the {dataset.__class__.__name__} `metainfo` and'
f'check if it is consistent with the `num_classes` '
f'of head')
else:
classes = dataset.metainfo['classes']
assert type(classes) is not str, \
(f'`classes` in {dataset.__class__.__name__}'
f'should be a tuple of str.'
f'Add comma if number of classes is 1 as '
f'classes = ({classes},)')
from mmdet.models.roi_heads.mask_heads import FusedSemanticHead
for name, module in model.named_modules():
if hasattr(module, 'num_classes') and not name.endswith(
'rpn_head') and not isinstance(
module, (VGG, FusedSemanticHead)):
assert module.num_classes == len(classes), \
(f'The `num_classes` ({module.num_classes}) in '
f'{module.__class__.__name__} of '
f'{model.__class__.__name__} does not matches '
f'the length of `classes` '
f'{len(classes)}) in '
f'{dataset.__class__.__name__}')
def before_train_epoch(self, runner: Runner) -> None:
"""Check whether the training dataset is compatible with head.
Args:
runner (:obj:`Runner`): The runner of the training or evaluation
process.
"""
self._check_head(runner, 'train')
def before_val_epoch(self, runner: Runner) -> None:
"""Check whether the dataset in val epoch is compatible with head.
Args:
runner (:obj:`Runner`): The runner of the training or evaluation
process.
"""
self._check_head(runner, 'val')
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.transforms import Compose
from mmengine.hooks import Hook
from mmdet.registry import HOOKS
@HOOKS.register_module()
class PipelineSwitchHook(Hook):
"""Switch data pipeline at switch_epoch.
Args:
switch_epoch (int): switch pipeline at this epoch.
switch_pipeline (list[dict]): the pipeline to switch to.
"""
def __init__(self, switch_epoch, switch_pipeline):
self.switch_epoch = switch_epoch
self.switch_pipeline = switch_pipeline
self._restart_dataloader = False
self._has_switched = False
def before_train_epoch(self, runner):
"""switch pipeline."""
epoch = runner.epoch
train_loader = runner.train_dataloader
if epoch >= self.switch_epoch and not self._has_switched:
runner.logger.info('Switch pipeline now!')
# The dataset pipeline cannot be updated when persistent_workers
# is True, so we need to force the dataloader's multi-process
# restart. This is a very hacky approach.
train_loader.dataset.pipeline = Compose(self.switch_pipeline)
if hasattr(train_loader, 'persistent_workers'
) and train_loader.persistent_workers is True:
train_loader._DataLoader__initialized = False
train_loader._iterator = None
self._restart_dataloader = True
self._has_switched = True
else:
# Once the restart is complete, we need to restore
# the initialization flag.
if self._restart_dataloader:
train_loader._DataLoader__initialized = True
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.hooks import Hook
from mmengine.model.wrappers import is_model_wrapper
from mmdet.registry import HOOKS
@HOOKS.register_module()
class SetEpochInfoHook(Hook):
"""Set runner's epoch information to the model."""
def before_train_epoch(self, runner):
epoch = runner.epoch
model = runner.model
if is_model_wrapper(model):
model = model.module
model.set_epoch(epoch)
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from mmengine.dist import get_dist_info
from mmengine.hooks import Hook
from torch import nn
from mmdet.registry import HOOKS
from mmdet.utils import all_reduce_dict
def get_norm_states(module: nn.Module) -> OrderedDict:
"""Get the state_dict of batch norms in the module."""
async_norm_states = OrderedDict()
for name, child in module.named_modules():
if isinstance(child, nn.modules.batchnorm._NormBase):
for k, v in child.state_dict().items():
async_norm_states['.'.join([name, k])] = v
return async_norm_states
@HOOKS.register_module()
class SyncNormHook(Hook):
"""Synchronize Norm states before validation, currently used in YOLOX."""
def before_val_epoch(self, runner):
"""Synchronizing norm."""
module = runner.model
_, world_size = get_dist_info()
if world_size == 1:
return
norm_states = get_norm_states(module)
if len(norm_states) == 0:
return
# TODO: use `all_reduce_dict` in mmengine
norm_states = all_reduce_dict(norm_states, op='mean')
module.load_state_dict(norm_states, strict=False)
# Copyright (c) OpenMMLab. All rights reserved.
def trigger_visualization_hook(cfg, args):
default_hooks = cfg.default_hooks
if 'visualization' in default_hooks:
visualization_hook = default_hooks['visualization']
# Turn on visualization
visualization_hook['draw'] = True
if args.show:
visualization_hook['show'] = True
visualization_hook['wait_time'] = args.wait_time
if args.show_dir:
visualization_hook['test_out_dir'] = args.show_dir
else:
raise RuntimeError(
'VisualizationHook must be included in default_hooks.'
'refer to usage '
'"visualization=dict(type=\'VisualizationHook\')"')
return cfg
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment