# Copyright (c) OpenMMLab. All rights reserved. import torch from mmcv.utils import Registry, build_from_cfg MODULE_HOOKS = Registry('module_hooks') def register_module_hooks(Module, module_hooks_list): handles = [] for module_hook_cfg in module_hooks_list: hooked_module_name = module_hook_cfg.pop('hooked_module', 'backbone') if not hasattr(Module, hooked_module_name): raise ValueError( f'{Module.__class__} has no {hooked_module_name}!') hooked_module = getattr(Module, hooked_module_name) hook_pos = module_hook_cfg.pop('hook_pos', 'forward_pre') if hook_pos == 'forward_pre': handle = hooked_module.register_forward_pre_hook( build_from_cfg(module_hook_cfg, MODULE_HOOKS).hook_func()) elif hook_pos == 'forward': handle = hooked_module.register_forward_hook( build_from_cfg(module_hook_cfg, MODULE_HOOKS).hook_func()) elif hook_pos == 'backward': handle = hooked_module.register_backward_hook( build_from_cfg(module_hook_cfg, MODULE_HOOKS).hook_func()) else: raise ValueError( f'hook_pos must be `forward_pre`, `forward` or `backward`, ' f'but get {hook_pos}') handles.append(handle) return handles @MODULE_HOOKS.register_module() class GPUNormalize: """Normalize images with the given mean and std value on GPUs. Call the member function ``hook_func`` will return the forward pre-hook function for module registration. GPU normalization, rather than CPU normalization, is more recommended in the case of a model running on GPUs with strong compute capacity such as Tesla V100. Args: mean (Sequence[float]): Mean values of different channels. std (Sequence[float]): Std values of different channels. """ def __init__(self, input_format, mean, std): if input_format not in ['NCTHW', 'NCHW', 'NCHW_Flow', 'NPTCHW']: raise ValueError(f'The input format {input_format} is invalid.') self.input_format = input_format _mean = torch.tensor(mean) _std = torch.tensor(std) if input_format == 'NCTHW': self._mean = _mean[None, :, None, None, None] self._std = _std[None, :, None, None, None] elif input_format == 'NCHW': self._mean = _mean[None, :, None, None] self._std = _std[None, :, None, None] elif input_format == 'NCHW_Flow': self._mean = _mean[None, :, None, None] self._std = _std[None, :, None, None] elif input_format == 'NPTCHW': self._mean = _mean[None, None, None, :, None, None] self._std = _std[None, None, None, :, None, None] else: raise ValueError(f'The input format {input_format} is invalid.') def hook_func(self): def normalize_hook(Module, input): x = input[0] assert x.dtype == torch.uint8, ( f'The previous augmentation should use uint8 data type to ' f'speed up computation, but get {x.dtype}') mean = self._mean.to(x.device) std = self._std.to(x.device) with torch.no_grad(): x = x.float().sub_(mean).div_(std) return (x, *input[1:]) return normalize_hook