Commit fdeee889 authored by limm's avatar limm
Browse files

release v1.6.1 of mmcv

parent df465820
...@@ -24,7 +24,9 @@ ...@@ -24,7 +24,9 @@
# SOFTWARE. # SOFTWARE.
import sys import sys
import warnings
from functools import partial from functools import partial
from typing import Any, Callable, Dict, Optional, TextIO, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -33,13 +35,13 @@ import torch.nn as nn ...@@ -33,13 +35,13 @@ import torch.nn as nn
import mmcv import mmcv
def get_model_complexity_info(model, def get_model_complexity_info(model: nn.Module,
input_shape, input_shape: tuple,
print_per_layer_stat=True, print_per_layer_stat: bool = True,
as_strings=True, as_strings: bool = True,
input_constructor=None, input_constructor: Optional[Callable] = None,
flush=False, flush: bool = False,
ost=sys.stdout): ost: TextIO = sys.stdout) -> tuple:
"""Get complexity information of a model. """Get complexity information of a model.
This method can calculate FLOPs and parameter counts of a model with This method can calculate FLOPs and parameter counts of a model with
...@@ -48,16 +50,16 @@ def get_model_complexity_info(model, ...@@ -48,16 +50,16 @@ def get_model_complexity_info(model,
Supported layers are listed as below: Supported layers are listed as below:
- Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``. - Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``.
- Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``, - Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``,
``nn.ReLU6``. ``nn.LeakyReLU``, ``nn.ReLU6``.
- Poolings: ``nn.MaxPool1d``, ``nn.MaxPool2d``, ``nn.MaxPool3d``, - Poolings: ``nn.MaxPool1d``, ``nn.MaxPool2d``, ``nn.MaxPool3d``,
``nn.AvgPool1d``, ``nn.AvgPool2d``, ``nn.AvgPool3d``, ``nn.AvgPool1d``, ``nn.AvgPool2d``, ``nn.AvgPool3d``,
``nn.AdaptiveMaxPool1d``, ``nn.AdaptiveMaxPool2d``, ``nn.AdaptiveMaxPool1d``, ``nn.AdaptiveMaxPool2d``,
``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``, ``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``,
``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``. ``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``.
- BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``, - BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``,
``nn.BatchNorm3d``, ``nn.GroupNorm``, ``nn.InstanceNorm1d``, ``nn.BatchNorm3d``, ``nn.GroupNorm``, ``nn.InstanceNorm1d``,
``InstanceNorm2d``, ``InstanceNorm3d``, ``nn.LayerNorm``. ``InstanceNorm2d``, ``InstanceNorm3d``, ``nn.LayerNorm``.
- Linear: ``nn.Linear``. - Linear: ``nn.Linear``.
- Deconvolution: ``nn.ConvTranspose2d``. - Deconvolution: ``nn.ConvTranspose2d``.
- Upsample: ``nn.Upsample``. - Upsample: ``nn.Upsample``.
...@@ -78,8 +80,8 @@ def get_model_complexity_info(model, ...@@ -78,8 +80,8 @@ def get_model_complexity_info(model,
Returns: Returns:
tuple[float | str]: If ``as_strings`` is set to True, it will return tuple[float | str]: If ``as_strings`` is set to True, it will return
FLOPs and parameter counts in a string format. otherwise, it will FLOPs and parameter counts in a string format. otherwise, it will
return those in a float number format. return those in a float number format.
""" """
assert type(input_shape) is tuple assert type(input_shape) is tuple
assert len(input_shape) >= 1 assert len(input_shape) >= 1
...@@ -115,7 +117,9 @@ def get_model_complexity_info(model, ...@@ -115,7 +117,9 @@ def get_model_complexity_info(model,
return flops_count, params_count return flops_count, params_count
def flops_to_string(flops, units='GFLOPs', precision=2): def flops_to_string(flops: float,
units: Optional[str] = 'GFLOPs',
precision: int = 2) -> str:
"""Convert FLOPs number into a string. """Convert FLOPs number into a string.
Note that Here we take a multiply-add counts as one FLOP. Note that Here we take a multiply-add counts as one FLOP.
...@@ -158,7 +162,9 @@ def flops_to_string(flops, units='GFLOPs', precision=2): ...@@ -158,7 +162,9 @@ def flops_to_string(flops, units='GFLOPs', precision=2):
return str(flops) + ' FLOPs' return str(flops) + ' FLOPs'
def params_to_string(num_params, units=None, precision=2): def params_to_string(num_params: float,
units: Optional[str] = None,
precision: int = 2) -> str:
"""Convert parameter number into a string. """Convert parameter number into a string.
Args: Args:
...@@ -195,13 +201,13 @@ def params_to_string(num_params, units=None, precision=2): ...@@ -195,13 +201,13 @@ def params_to_string(num_params, units=None, precision=2):
return str(num_params) return str(num_params)
def print_model_with_flops(model, def print_model_with_flops(model: nn.Module,
total_flops, total_flops: float,
total_params, total_params: float,
units='GFLOPs', units: Optional[str] = 'GFLOPs',
precision=3, precision: int = 3,
ost=sys.stdout, ost: TextIO = sys.stdout,
flush=False): flush: bool = False) -> None:
"""Print a model with FLOPs for each layer. """Print a model with FLOPs for each layer.
Args: Args:
...@@ -276,10 +282,10 @@ def print_model_with_flops(model, ...@@ -276,10 +282,10 @@ def print_model_with_flops(model,
return ', '.join([ return ', '.join([
params_to_string( params_to_string(
accumulated_num_params, units='M', precision=precision), accumulated_num_params, units='M', precision=precision),
'{:.3%} Params'.format(accumulated_num_params / total_params), f'{accumulated_num_params / total_params:.3%} Params',
flops_to_string( flops_to_string(
accumulated_flops_cost, units=units, precision=precision), accumulated_flops_cost, units=units, precision=precision),
'{:.3%} FLOPs'.format(accumulated_flops_cost / total_flops), f'{accumulated_flops_cost / total_flops:.3%} FLOPs',
self.original_extra_repr() self.original_extra_repr()
]) ])
...@@ -304,7 +310,7 @@ def print_model_with_flops(model, ...@@ -304,7 +310,7 @@ def print_model_with_flops(model,
model.apply(del_extra_repr) model.apply(del_extra_repr)
def get_model_parameters_number(model): def get_model_parameters_number(model: nn.Module) -> float:
"""Calculate parameter number of a model. """Calculate parameter number of a model.
Args: Args:
...@@ -317,16 +323,16 @@ def get_model_parameters_number(model): ...@@ -317,16 +323,16 @@ def get_model_parameters_number(model):
return num_params return num_params
def add_flops_counting_methods(net_main_module): def add_flops_counting_methods(net_main_module: nn.Module) -> nn.Module:
# adding additional methods to the existing module object, # adding additional methods to the existing module object,
# this is done this way so that each function has access to self object # this is done this way so that each function has access to self object
net_main_module.start_flops_count = start_flops_count.__get__( net_main_module.start_flops_count = start_flops_count.__get__( # type: ignore # noqa E501
net_main_module) net_main_module)
net_main_module.stop_flops_count = stop_flops_count.__get__( net_main_module.stop_flops_count = stop_flops_count.__get__( # type: ignore # noqa E501
net_main_module) net_main_module)
net_main_module.reset_flops_count = reset_flops_count.__get__( net_main_module.reset_flops_count = reset_flops_count.__get__( # type: ignore # noqa E501
net_main_module) net_main_module)
net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( # noqa: E501 net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( # type: ignore # noqa E501
net_main_module) net_main_module)
net_main_module.reset_flops_count() net_main_module.reset_flops_count()
...@@ -334,7 +340,7 @@ def add_flops_counting_methods(net_main_module): ...@@ -334,7 +340,7 @@ def add_flops_counting_methods(net_main_module):
return net_main_module return net_main_module
def compute_average_flops_cost(self): def compute_average_flops_cost(self) -> Tuple[float, float]:
"""Compute average FLOPs cost. """Compute average FLOPs cost.
A method to compute average FLOPs cost, which will be available after A method to compute average FLOPs cost, which will be available after
...@@ -352,7 +358,7 @@ def compute_average_flops_cost(self): ...@@ -352,7 +358,7 @@ def compute_average_flops_cost(self):
return flops_sum / batches_count, params_sum return flops_sum / batches_count, params_sum
def start_flops_count(self): def start_flops_count(self) -> None:
"""Activate the computation of mean flops consumption per image. """Activate the computation of mean flops consumption per image.
A method to activate the computation of mean flops consumption per image. A method to activate the computation of mean flops consumption per image.
...@@ -361,7 +367,7 @@ def start_flops_count(self): ...@@ -361,7 +367,7 @@ def start_flops_count(self):
""" """
add_batch_counter_hook_function(self) add_batch_counter_hook_function(self)
def add_flops_counter_hook_function(module): def add_flops_counter_hook_function(module: nn.Module) -> None:
if is_supported_instance(module): if is_supported_instance(module):
if hasattr(module, '__flops_handle__'): if hasattr(module, '__flops_handle__'):
return return
...@@ -375,7 +381,7 @@ def start_flops_count(self): ...@@ -375,7 +381,7 @@ def start_flops_count(self):
self.apply(partial(add_flops_counter_hook_function)) self.apply(partial(add_flops_counter_hook_function))
def stop_flops_count(self): def stop_flops_count(self) -> None:
"""Stop computing the mean flops consumption per image. """Stop computing the mean flops consumption per image.
A method to stop computing the mean flops consumption per image, which will A method to stop computing the mean flops consumption per image, which will
...@@ -386,7 +392,7 @@ def stop_flops_count(self): ...@@ -386,7 +392,7 @@ def stop_flops_count(self):
self.apply(remove_flops_counter_hook_function) self.apply(remove_flops_counter_hook_function)
def reset_flops_count(self): def reset_flops_count(self) -> None:
"""Reset statistics computed so far. """Reset statistics computed so far.
A method to Reset computed statistics, which will be available after A method to Reset computed statistics, which will be available after
...@@ -397,11 +403,13 @@ def reset_flops_count(self): ...@@ -397,11 +403,13 @@ def reset_flops_count(self):
# ---- Internal functions # ---- Internal functions
def empty_flops_counter_hook(module, input, output): def empty_flops_counter_hook(module: nn.Module, input: tuple,
output: Any) -> None:
module.__flops__ += 0 module.__flops__ += 0
def upsample_flops_counter_hook(module, input, output): def upsample_flops_counter_hook(module: nn.Module, input: tuple,
output: torch.Tensor) -> None:
output_size = output[0] output_size = output[0]
batch_size = output_size.shape[0] batch_size = output_size.shape[0]
output_elements_count = batch_size output_elements_count = batch_size
...@@ -410,39 +418,38 @@ def upsample_flops_counter_hook(module, input, output): ...@@ -410,39 +418,38 @@ def upsample_flops_counter_hook(module, input, output):
module.__flops__ += int(output_elements_count) module.__flops__ += int(output_elements_count)
def relu_flops_counter_hook(module, input, output): def relu_flops_counter_hook(module: nn.Module, input: tuple,
output: torch.Tensor) -> None:
active_elements_count = output.numel() active_elements_count = output.numel()
module.__flops__ += int(active_elements_count) module.__flops__ += int(active_elements_count)
def linear_flops_counter_hook(module, input, output): def linear_flops_counter_hook(module: nn.Module, input: tuple,
input = input[0] output: torch.Tensor) -> None:
output_last_dim = output.shape[ output_last_dim = output.shape[
-1] # pytorch checks dimensions, so here we don't care much -1] # pytorch checks dimensions, so here we don't care much
module.__flops__ += int(np.prod(input.shape) * output_last_dim) module.__flops__ += int(np.prod(input[0].shape) * output_last_dim)
def pool_flops_counter_hook(module, input, output): def pool_flops_counter_hook(module: nn.Module, input: tuple,
input = input[0] output: torch.Tensor) -> None:
module.__flops__ += int(np.prod(input.shape)) module.__flops__ += int(np.prod(input[0].shape))
def norm_flops_counter_hook(module, input, output): def norm_flops_counter_hook(module: nn.Module, input: tuple,
input = input[0] output: torch.Tensor) -> None:
batch_flops = np.prod(input[0].shape)
batch_flops = np.prod(input.shape)
if (getattr(module, 'affine', False) if (getattr(module, 'affine', False)
or getattr(module, 'elementwise_affine', False)): or getattr(module, 'elementwise_affine', False)):
batch_flops *= 2 batch_flops *= 2
module.__flops__ += int(batch_flops) module.__flops__ += int(batch_flops)
def deconv_flops_counter_hook(conv_module, input, output): def deconv_flops_counter_hook(conv_module: nn.Module, input: tuple,
output: torch.Tensor) -> None:
# Can have multiple inputs, getting the first one # Can have multiple inputs, getting the first one
input = input[0] batch_size = input[0].shape[0]
input_height, input_width = input[0].shape[2:]
batch_size = input.shape[0]
input_height, input_width = input.shape[2:]
kernel_height, kernel_width = conv_module.kernel_size kernel_height, kernel_width = conv_module.kernel_size
in_channels = conv_module.in_channels in_channels = conv_module.in_channels
...@@ -458,17 +465,16 @@ def deconv_flops_counter_hook(conv_module, input, output): ...@@ -458,17 +465,16 @@ def deconv_flops_counter_hook(conv_module, input, output):
bias_flops = 0 bias_flops = 0
if conv_module.bias is not None: if conv_module.bias is not None:
output_height, output_width = output.shape[2:] output_height, output_width = output.shape[2:]
bias_flops = out_channels * batch_size * output_height * output_height bias_flops = out_channels * batch_size * output_height * output_width
overall_flops = overall_conv_flops + bias_flops overall_flops = overall_conv_flops + bias_flops
conv_module.__flops__ += int(overall_flops) conv_module.__flops__ += int(overall_flops)
def conv_flops_counter_hook(conv_module, input, output): def conv_flops_counter_hook(conv_module: nn.Module, input: tuple,
output: torch.Tensor) -> None:
# Can have multiple inputs, getting the first one # Can have multiple inputs, getting the first one
input = input[0] batch_size = input[0].shape[0]
batch_size = input.shape[0]
output_dims = list(output.shape[2:]) output_dims = list(output.shape[2:])
kernel_dims = list(conv_module.kernel_size) kernel_dims = list(conv_module.kernel_size)
...@@ -495,25 +501,23 @@ def conv_flops_counter_hook(conv_module, input, output): ...@@ -495,25 +501,23 @@ def conv_flops_counter_hook(conv_module, input, output):
conv_module.__flops__ += int(overall_flops) conv_module.__flops__ += int(overall_flops)
def batch_counter_hook(module, input, output): def batch_counter_hook(module: nn.Module, input: tuple, output: Any) -> None:
batch_size = 1 batch_size = 1
if len(input) > 0: if len(input) > 0:
# Can have multiple inputs, getting the first one # Can have multiple inputs, getting the first one
input = input[0] batch_size = len(input[0])
batch_size = len(input)
else: else:
pass warnings.warn('No positional inputs found for a module, '
print('Warning! No positional inputs found for a module, ' 'assuming batch size is 1.')
'assuming batch size is 1.')
module.__batch_counter__ += batch_size module.__batch_counter__ += batch_size
def add_batch_counter_variables_or_reset(module): def add_batch_counter_variables_or_reset(module: nn.Module) -> None:
module.__batch_counter__ = 0 module.__batch_counter__ = 0
def add_batch_counter_hook_function(module): def add_batch_counter_hook_function(module: nn.Module) -> None:
if hasattr(module, '__batch_counter_handle__'): if hasattr(module, '__batch_counter_handle__'):
return return
...@@ -521,36 +525,36 @@ def add_batch_counter_hook_function(module): ...@@ -521,36 +525,36 @@ def add_batch_counter_hook_function(module):
module.__batch_counter_handle__ = handle module.__batch_counter_handle__ = handle
def remove_batch_counter_hook_function(module): def remove_batch_counter_hook_function(module: nn.Module) -> None:
if hasattr(module, '__batch_counter_handle__'): if hasattr(module, '__batch_counter_handle__'):
module.__batch_counter_handle__.remove() module.__batch_counter_handle__.remove()
del module.__batch_counter_handle__ del module.__batch_counter_handle__
def add_flops_counter_variable_or_reset(module): def add_flops_counter_variable_or_reset(module: nn.Module) -> None:
if is_supported_instance(module): if is_supported_instance(module):
if hasattr(module, '__flops__') or hasattr(module, '__params__'): if hasattr(module, '__flops__') or hasattr(module, '__params__'):
print('Warning: variables __flops__ or __params__ are already ' warnings.warn('variables __flops__ or __params__ are already '
'defined for the module' + type(module).__name__ + 'defined for the module' + type(module).__name__ +
' ptflops can affect your code!') ' ptflops can affect your code!')
module.__flops__ = 0 module.__flops__ = 0
module.__params__ = get_model_parameters_number(module) module.__params__ = get_model_parameters_number(module)
def is_supported_instance(module): def is_supported_instance(module: nn.Module) -> bool:
if type(module) in get_modules_mapping(): if type(module) in get_modules_mapping():
return True return True
return False return False
def remove_flops_counter_hook_function(module): def remove_flops_counter_hook_function(module: nn.Module) -> None:
if is_supported_instance(module): if is_supported_instance(module):
if hasattr(module, '__flops_handle__'): if hasattr(module, '__flops_handle__'):
module.__flops_handle__.remove() module.__flops_handle__.remove()
del module.__flops_handle__ del module.__flops_handle__
def get_modules_mapping(): def get_modules_mapping() -> Dict:
return { return {
# convolutions # convolutions
nn.Conv1d: conv_flops_counter_hook, nn.Conv1d: conv_flops_counter_hook,
......
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
import torch.nn as nn import torch.nn as nn
def _fuse_conv_bn(conv, bn): def _fuse_conv_bn(conv: nn.Module, bn: nn.Module) -> nn.Module:
"""Fuse conv and bn into one module. """Fuse conv and bn into one module.
Args: Args:
...@@ -24,7 +24,7 @@ def _fuse_conv_bn(conv, bn): ...@@ -24,7 +24,7 @@ def _fuse_conv_bn(conv, bn):
return conv return conv
def fuse_conv_bn(module): def fuse_conv_bn(module: nn.Module) -> nn.Module:
"""Recursively fuse conv and bn in a module. """Recursively fuse conv and bn in a module.
During inference, the functionary of batch norm layers is turned off During inference, the functionary of batch norm layers is turned off
......
# Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
import torch.nn as nn
import mmcv import mmcv
class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm): class _BatchNormXd(nn.modules.batchnorm._BatchNorm):
"""A general BatchNorm layer without input dimension check. """A general BatchNorm layer without input dimension check.
Reproduced from @kapily's work: Reproduced from @kapily's work:
...@@ -14,11 +16,11 @@ class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm): ...@@ -14,11 +16,11 @@ class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
SyncBatchNorm. SyncBatchNorm.
""" """
def _check_input_dim(self, input): def _check_input_dim(self, input: torch.Tensor):
return return
def revert_sync_batchnorm(module): def revert_sync_batchnorm(module: nn.Module) -> nn.Module:
"""Helper function to convert all `SyncBatchNorm` (SyncBN) and """Helper function to convert all `SyncBatchNorm` (SyncBN) and
`mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to `mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to
`BatchNormXd` layers. `BatchNormXd` layers.
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import copy import copy
import math import math
import warnings import warnings
from typing import Dict, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -13,7 +14,7 @@ from mmcv.utils import Registry, build_from_cfg, get_logger, print_log ...@@ -13,7 +14,7 @@ from mmcv.utils import Registry, build_from_cfg, get_logger, print_log
INITIALIZERS = Registry('initializer') INITIALIZERS = Registry('initializer')
def update_init_info(module, init_info): def update_init_info(module: nn.Module, init_info: str) -> None:
"""Update the `_params_init_info` in the module if the value of parameters """Update the `_params_init_info` in the module if the value of parameters
are changed. are changed.
...@@ -45,14 +46,17 @@ def update_init_info(module, init_info): ...@@ -45,14 +46,17 @@ def update_init_info(module, init_info):
module._params_init_info[param]['tmp_mean_value'] = mean_value module._params_init_info[param]['tmp_mean_value'] = mean_value
def constant_init(module, val, bias=0): def constant_init(module: nn.Module, val: float, bias: float = 0) -> None:
if hasattr(module, 'weight') and module.weight is not None: if hasattr(module, 'weight') and module.weight is not None:
nn.init.constant_(module.weight, val) nn.init.constant_(module.weight, val)
if hasattr(module, 'bias') and module.bias is not None: if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias) nn.init.constant_(module.bias, bias)
def xavier_init(module, gain=1, bias=0, distribution='normal'): def xavier_init(module: nn.Module,
gain: float = 1,
bias: float = 0,
distribution: str = 'normal') -> None:
assert distribution in ['uniform', 'normal'] assert distribution in ['uniform', 'normal']
if hasattr(module, 'weight') and module.weight is not None: if hasattr(module, 'weight') and module.weight is not None:
if distribution == 'uniform': if distribution == 'uniform':
...@@ -63,7 +67,10 @@ def xavier_init(module, gain=1, bias=0, distribution='normal'): ...@@ -63,7 +67,10 @@ def xavier_init(module, gain=1, bias=0, distribution='normal'):
nn.init.constant_(module.bias, bias) nn.init.constant_(module.bias, bias)
def normal_init(module, mean=0, std=1, bias=0): def normal_init(module: nn.Module,
mean: float = 0,
std: float = 1,
bias: float = 0) -> None:
if hasattr(module, 'weight') and module.weight is not None: if hasattr(module, 'weight') and module.weight is not None:
nn.init.normal_(module.weight, mean, std) nn.init.normal_(module.weight, mean, std)
if hasattr(module, 'bias') and module.bias is not None: if hasattr(module, 'bias') and module.bias is not None:
...@@ -82,19 +89,22 @@ def trunc_normal_init(module: nn.Module, ...@@ -82,19 +89,22 @@ def trunc_normal_init(module: nn.Module,
nn.init.constant_(module.bias, bias) # type: ignore nn.init.constant_(module.bias, bias) # type: ignore
def uniform_init(module, a=0, b=1, bias=0): def uniform_init(module: nn.Module,
a: float = 0,
b: float = 1,
bias: float = 0) -> None:
if hasattr(module, 'weight') and module.weight is not None: if hasattr(module, 'weight') and module.weight is not None:
nn.init.uniform_(module.weight, a, b) nn.init.uniform_(module.weight, a, b)
if hasattr(module, 'bias') and module.bias is not None: if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias) nn.init.constant_(module.bias, bias)
def kaiming_init(module, def kaiming_init(module: nn.Module,
a=0, a: float = 0,
mode='fan_out', mode: str = 'fan_out',
nonlinearity='relu', nonlinearity: str = 'relu',
bias=0, bias: float = 0,
distribution='normal'): distribution: str = 'normal') -> None:
assert distribution in ['uniform', 'normal'] assert distribution in ['uniform', 'normal']
if hasattr(module, 'weight') and module.weight is not None: if hasattr(module, 'weight') and module.weight is not None:
if distribution == 'uniform': if distribution == 'uniform':
...@@ -107,7 +117,7 @@ def kaiming_init(module, ...@@ -107,7 +117,7 @@ def kaiming_init(module,
nn.init.constant_(module.bias, bias) nn.init.constant_(module.bias, bias)
def caffe2_xavier_init(module, bias=0): def caffe2_xavier_init(module: nn.Module, bias: float = 0) -> None:
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
# Acknowledgment to FAIR's internal code # Acknowledgment to FAIR's internal code
kaiming_init( kaiming_init(
...@@ -119,19 +129,23 @@ def caffe2_xavier_init(module, bias=0): ...@@ -119,19 +129,23 @@ def caffe2_xavier_init(module, bias=0):
distribution='uniform') distribution='uniform')
def bias_init_with_prob(prior_prob): def bias_init_with_prob(prior_prob: float) -> float:
"""initialize conv/fc bias value according to a given probability value.""" """initialize conv/fc bias value according to a given probability value."""
bias_init = float(-np.log((1 - prior_prob) / prior_prob)) bias_init = float(-np.log((1 - prior_prob) / prior_prob))
return bias_init return bias_init
def _get_bases_name(m): def _get_bases_name(m: nn.Module) -> List[str]:
return [b.__name__ for b in m.__class__.__bases__] return [b.__name__ for b in m.__class__.__bases__]
class BaseInit(object): class BaseInit:
def __init__(self, *, bias=0, bias_prob=None, layer=None): def __init__(self,
*,
bias: float = 0,
bias_prob: Optional[float] = None,
layer: Union[str, List, None] = None):
self.wholemodule = False self.wholemodule = False
if not isinstance(bias, (int, float)): if not isinstance(bias, (int, float)):
raise TypeError(f'bias must be a number, but got a {type(bias)}') raise TypeError(f'bias must be a number, but got a {type(bias)}')
...@@ -154,7 +168,7 @@ class BaseInit(object): ...@@ -154,7 +168,7 @@ class BaseInit(object):
self.bias = bias self.bias = bias
self.layer = [layer] if isinstance(layer, str) else layer self.layer = [layer] if isinstance(layer, str) else layer
def _get_init_info(self): def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}, bias={self.bias}' info = f'{self.__class__.__name__}, bias={self.bias}'
return info return info
...@@ -172,11 +186,11 @@ class ConstantInit(BaseInit): ...@@ -172,11 +186,11 @@ class ConstantInit(BaseInit):
Defaults to None. Defaults to None.
""" """
def __init__(self, val, **kwargs): def __init__(self, val: Union[int, float], **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.val = val self.val = val
def __call__(self, module): def __call__(self, module: nn.Module) -> None:
def init(m): def init(m):
if self.wholemodule: if self.wholemodule:
...@@ -191,7 +205,7 @@ class ConstantInit(BaseInit): ...@@ -191,7 +205,7 @@ class ConstantInit(BaseInit):
if hasattr(module, '_params_init_info'): if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info()) update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self): def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}' info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}'
return info return info
...@@ -214,12 +228,15 @@ class XavierInit(BaseInit): ...@@ -214,12 +228,15 @@ class XavierInit(BaseInit):
Defaults to None. Defaults to None.
""" """
def __init__(self, gain=1, distribution='normal', **kwargs): def __init__(self,
gain: float = 1,
distribution: str = 'normal',
**kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.gain = gain self.gain = gain
self.distribution = distribution self.distribution = distribution
def __call__(self, module): def __call__(self, module: nn.Module) -> None:
def init(m): def init(m):
if self.wholemodule: if self.wholemodule:
...@@ -234,7 +251,7 @@ class XavierInit(BaseInit): ...@@ -234,7 +251,7 @@ class XavierInit(BaseInit):
if hasattr(module, '_params_init_info'): if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info()) update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self): def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}: gain={self.gain}, ' \ info = f'{self.__class__.__name__}: gain={self.gain}, ' \
f'distribution={self.distribution}, bias={self.bias}' f'distribution={self.distribution}, bias={self.bias}'
return info return info
...@@ -257,12 +274,12 @@ class NormalInit(BaseInit): ...@@ -257,12 +274,12 @@ class NormalInit(BaseInit):
""" """
def __init__(self, mean=0, std=1, **kwargs): def __init__(self, mean: float = 0, std: float = 1, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.mean = mean self.mean = mean
self.std = std self.std = std
def __call__(self, module): def __call__(self, module: nn.Module) -> None:
def init(m): def init(m):
if self.wholemodule: if self.wholemodule:
...@@ -277,7 +294,7 @@ class NormalInit(BaseInit): ...@@ -277,7 +294,7 @@ class NormalInit(BaseInit):
if hasattr(module, '_params_init_info'): if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info()) update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self): def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}: mean={self.mean},' \ info = f'{self.__class__.__name__}: mean={self.mean},' \
f' std={self.std}, bias={self.bias}' f' std={self.std}, bias={self.bias}'
return info return info
...@@ -355,12 +372,12 @@ class UniformInit(BaseInit): ...@@ -355,12 +372,12 @@ class UniformInit(BaseInit):
Defaults to None. Defaults to None.
""" """
def __init__(self, a=0, b=1, **kwargs): def __init__(self, a: float = 0., b: float = 1., **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.a = a self.a = a
self.b = b self.b = b
def __call__(self, module): def __call__(self, module: nn.Module) -> None:
def init(m): def init(m):
if self.wholemodule: if self.wholemodule:
...@@ -375,7 +392,7 @@ class UniformInit(BaseInit): ...@@ -375,7 +392,7 @@ class UniformInit(BaseInit):
if hasattr(module, '_params_init_info'): if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info()) update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self): def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}: a={self.a},' \ info = f'{self.__class__.__name__}: a={self.a},' \
f' b={self.b}, bias={self.bias}' f' b={self.b}, bias={self.bias}'
return info return info
...@@ -409,10 +426,10 @@ class KaimingInit(BaseInit): ...@@ -409,10 +426,10 @@ class KaimingInit(BaseInit):
""" """
def __init__(self, def __init__(self,
a=0, a: float = 0,
mode='fan_out', mode: str = 'fan_out',
nonlinearity='relu', nonlinearity: str = 'relu',
distribution='normal', distribution: str = 'normal',
**kwargs): **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.a = a self.a = a
...@@ -420,7 +437,7 @@ class KaimingInit(BaseInit): ...@@ -420,7 +437,7 @@ class KaimingInit(BaseInit):
self.nonlinearity = nonlinearity self.nonlinearity = nonlinearity
self.distribution = distribution self.distribution = distribution
def __call__(self, module): def __call__(self, module: nn.Module) -> None:
def init(m): def init(m):
if self.wholemodule: if self.wholemodule:
...@@ -437,7 +454,7 @@ class KaimingInit(BaseInit): ...@@ -437,7 +454,7 @@ class KaimingInit(BaseInit):
if hasattr(module, '_params_init_info'): if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info()) update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self): def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \ info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \
f'nonlinearity={self.nonlinearity}, ' \ f'nonlinearity={self.nonlinearity}, ' \
f'distribution ={self.distribution}, bias={self.bias}' f'distribution ={self.distribution}, bias={self.bias}'
...@@ -456,12 +473,12 @@ class Caffe2XavierInit(KaimingInit): ...@@ -456,12 +473,12 @@ class Caffe2XavierInit(KaimingInit):
distribution='uniform', distribution='uniform',
**kwargs) **kwargs)
def __call__(self, module): def __call__(self, module: nn.Module) -> None:
super().__call__(module) super().__call__(module)
@INITIALIZERS.register_module(name='Pretrained') @INITIALIZERS.register_module(name='Pretrained')
class PretrainedInit(object): class PretrainedInit:
"""Initialize module by loading a pretrained model. """Initialize module by loading a pretrained model.
Args: Args:
...@@ -475,12 +492,15 @@ class PretrainedInit(object): ...@@ -475,12 +492,15 @@ class PretrainedInit(object):
map_location (str): map tensors into proper locations. map_location (str): map tensors into proper locations.
""" """
def __init__(self, checkpoint, prefix=None, map_location=None): def __init__(self,
checkpoint: str,
prefix: Optional[str] = None,
map_location: Optional[str] = None):
self.checkpoint = checkpoint self.checkpoint = checkpoint
self.prefix = prefix self.prefix = prefix
self.map_location = map_location self.map_location = map_location
def __call__(self, module): def __call__(self, module: nn.Module) -> None:
from mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint, from mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint,
load_state_dict) load_state_dict)
logger = get_logger('mmcv') logger = get_logger('mmcv')
...@@ -503,12 +523,14 @@ class PretrainedInit(object): ...@@ -503,12 +523,14 @@ class PretrainedInit(object):
if hasattr(module, '_params_init_info'): if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info()) update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self): def _get_init_info(self) -> str:
info = f'{self.__class__.__name__}: load from {self.checkpoint}' info = f'{self.__class__.__name__}: load from {self.checkpoint}'
return info return info
def _initialize(module, cfg, wholemodule=False): def _initialize(module: nn.Module,
cfg: Dict,
wholemodule: bool = False) -> None:
func = build_from_cfg(cfg, INITIALIZERS) func = build_from_cfg(cfg, INITIALIZERS)
# wholemodule flag is for override mode, there is no layer key in override # wholemodule flag is for override mode, there is no layer key in override
# and initializer will give init values for the whole module with the name # and initializer will give init values for the whole module with the name
...@@ -517,7 +539,8 @@ def _initialize(module, cfg, wholemodule=False): ...@@ -517,7 +539,8 @@ def _initialize(module, cfg, wholemodule=False):
func(module) func(module)
def _initialize_override(module, override, cfg): def _initialize_override(module: nn.Module, override: Union[Dict, List],
cfg: Dict) -> None:
if not isinstance(override, (dict, list)): if not isinstance(override, (dict, list)):
raise TypeError(f'override must be a dict or a list of dict, \ raise TypeError(f'override must be a dict or a list of dict, \
but got {type(override)}') but got {type(override)}')
...@@ -547,8 +570,8 @@ def _initialize_override(module, override, cfg): ...@@ -547,8 +570,8 @@ def _initialize_override(module, override, cfg):
f'but init_cfg is {cp_override}.') f'but init_cfg is {cp_override}.')
def initialize(module, init_cfg): def initialize(module: nn.Module, init_cfg: Union[Dict, List[dict]]) -> None:
"""Initialize a module. r"""Initialize a module.
Args: Args:
module (``torch.nn.Module``): the module will be initialized. module (``torch.nn.Module``): the module will be initialized.
...@@ -556,6 +579,7 @@ def initialize(module, init_cfg): ...@@ -556,6 +579,7 @@ def initialize(module, init_cfg):
define initializer. OpenMMLab has implemented 6 initializers define initializer. OpenMMLab has implemented 6 initializers
including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``, including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
``Kaiming``, and ``Pretrained``. ``Kaiming``, and ``Pretrained``.
Example: Example:
>>> module = nn.Linear(2, 3, bias=True) >>> module = nn.Linear(2, 3, bias=True)
>>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2) >>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import logging import logging
from typing import List, Optional, Sequence, Tuple, Union
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from .utils import constant_init, kaiming_init, normal_init from .utils import constant_init, kaiming_init, normal_init
def conv3x3(in_planes, out_planes, dilation=1): def conv3x3(in_planes: int, out_planes: int, dilation: int = 1) -> nn.Module:
"""3x3 convolution with padding.""" """3x3 convolution with padding."""
return nn.Conv2d( return nn.Conv2d(
in_planes, in_planes,
...@@ -16,12 +18,12 @@ def conv3x3(in_planes, out_planes, dilation=1): ...@@ -16,12 +18,12 @@ def conv3x3(in_planes, out_planes, dilation=1):
dilation=dilation) dilation=dilation)
def make_vgg_layer(inplanes, def make_vgg_layer(inplanes: int,
planes, planes: int,
num_blocks, num_blocks: int,
dilation=1, dilation: int = 1,
with_bn=False, with_bn: bool = False,
ceil_mode=False): ceil_mode: bool = False) -> List[nn.Module]:
layers = [] layers = []
for _ in range(num_blocks): for _ in range(num_blocks):
layers.append(conv3x3(inplanes, planes, dilation)) layers.append(conv3x3(inplanes, planes, dilation))
...@@ -59,18 +61,18 @@ class VGG(nn.Module): ...@@ -59,18 +61,18 @@ class VGG(nn.Module):
} }
def __init__(self, def __init__(self,
depth, depth: int,
with_bn=False, with_bn: bool = False,
num_classes=-1, num_classes: int = -1,
num_stages=5, num_stages: int = 5,
dilations=(1, 1, 1, 1, 1), dilations: Sequence[int] = (1, 1, 1, 1, 1),
out_indices=(0, 1, 2, 3, 4), out_indices: Sequence[int] = (0, 1, 2, 3, 4),
frozen_stages=-1, frozen_stages: int = -1,
bn_eval=True, bn_eval: bool = True,
bn_frozen=False, bn_frozen: bool = False,
ceil_mode=False, ceil_mode: bool = False,
with_last_pool=True): with_last_pool: bool = True):
super(VGG, self).__init__() super().__init__()
if depth not in self.arch_settings: if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for vgg') raise KeyError(f'invalid depth {depth} for vgg')
assert num_stages >= 1 and num_stages <= 5 assert num_stages >= 1 and num_stages <= 5
...@@ -122,7 +124,7 @@ class VGG(nn.Module): ...@@ -122,7 +124,7 @@ class VGG(nn.Module):
nn.Linear(4096, num_classes), nn.Linear(4096, num_classes),
) )
def init_weights(self, pretrained=None): def init_weights(self, pretrained: Optional[str] = None) -> None:
if isinstance(pretrained, str): if isinstance(pretrained, str):
logger = logging.getLogger() logger = logging.getLogger()
from ..runner import load_checkpoint from ..runner import load_checkpoint
...@@ -138,7 +140,7 @@ class VGG(nn.Module): ...@@ -138,7 +140,7 @@ class VGG(nn.Module):
else: else:
raise TypeError('pretrained must be a str or None') raise TypeError('pretrained must be a str or None')
def forward(self, x): def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor, ...]]:
outs = [] outs = []
vgg_layers = getattr(self, self.module_name) vgg_layers = getattr(self, self.module_name)
for i in range(len(self.stage_blocks)): for i in range(len(self.stage_blocks)):
...@@ -156,8 +158,8 @@ class VGG(nn.Module): ...@@ -156,8 +158,8 @@ class VGG(nn.Module):
else: else:
return tuple(outs) return tuple(outs)
def train(self, mode=True): def train(self, mode: bool = True) -> None:
super(VGG, self).train(mode) super().train(mode)
if self.bn_eval: if self.bn_eval:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.BatchNorm2d): if isinstance(m, nn.BatchNorm2d):
......
# Copyright (c) OpenMMLab. All rights reserved.
from . import ipu, mlu, mps
from .scatter_gather import scatter, scatter_kwargs
from .utils import get_device
__all__ = ['mlu', 'ipu', 'mps', 'get_device', 'scatter', 'scatter_kwargs']
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Union
import torch
from mmcv.utils import deprecated_api_warning
from .utils import get_device
def scatter(input: Union[List, torch.Tensor], devices: List) -> List:
"""scatter copies tensor to devices directly."""
current_device = get_device()
if isinstance(input, list):
outputs = [scatter(_input, devices) for _input in input]
return outputs
elif isinstance(input, torch.Tensor):
output = input.contiguous()
return output.to(current_device) if devices != [-1] else output
else:
raise Exception(f'Unknown type {type(input)}.')
class Scatter:
@staticmethod
@deprecated_api_warning({'target_mlus': 'target_devices'},
cls_name='Scatter')
def forward(target_devices, input):
outputs = scatter(input, target_devices)
return tuple(outputs) if isinstance(outputs, list) else (outputs, )
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import IS_IPU_AVAILABLE
if IS_IPU_AVAILABLE:
from .dataloader import IPUDataLoader
from .hook_wrapper import IPUFp16OptimizerHook
from .model_wrapper import ipu_model_wrapper
from .runner import IPUBaseRunner, IPUEpochBasedRunner, IPUIterBasedRunner
from .utils import cfg2options
__all__ = [
'cfg2options', 'ipu_model_wrapper', 'IPUFp16OptimizerHook',
'IPUDataLoader', 'IPUBaseRunner', 'IPUEpochBasedRunner',
'IPUIterBasedRunner'
]
# Copyright (c) OpenMMLab. All rights reserved.
from collections.abc import Mapping, Sequence
from functools import partial
import poptorch
from torch.utils.data.dataloader import default_collate
from mmcv.parallel import DataContainer
def collate(batch, samples_per_gpu=1):
"""Put each data field into a tensor/DataContainer with outer dimension
batch size.
TODO support for
:type:`~mmcv.parallel.DataContainer`. Currently, it will be ignored.
There are 3 cases.
1. cpu_only = True, e.g., meta data.
2. cpu_only = False, stack = True, e.g., images tensors.
3. cpu_only = False, stack = False, e.g., gt bboxes.
"""
if not isinstance(batch, Sequence):
raise TypeError(
f'`batch` should be a sequence, but got {type(batch)}.')
if isinstance(batch[0], DataContainer):
# TODO `DataContainer` will be supported in the future.
raise TypeError('DataContainer is not supported in ipu data loader.')
elif isinstance(batch[0], Sequence):
transposed = zip(*batch)
collated_batch = []
for samples in transposed:
if not isinstance(samples[0], DataContainer):
# At present, we will skip the processing of datacontainer,
# which will reduce the performance of IPU DataLoder
collated_batch.append(collate(samples, samples_per_gpu))
return collated_batch
elif isinstance(batch[0], Mapping):
collated_batch = {}
for key in batch[0]:
if not isinstance(batch[0][key], DataContainer):
# At present, we will skip the processing of datacontainer,
# which will reduce the performance of IPU DataLoder
collated_batch[key] = collate([d[key] for d in batch])
return collated_batch
else:
return default_collate(batch)
class IPUDataLoader(poptorch.DataLoader):
"""Thin wrapper of `torch.utils.data.DataLoader`.
Compared with the pytorch DataLoder, this DataLoder changes the way of
calculation of batch size and adds the AsynchronousDataAccessor to
load and release data faster in cpu mode.
If this data loader is used in a distributed execution environment, it will
ensure that each process uses a different subset of the dataset, providing
you first call ``options.randomSeed(N)`` with an integer N which is the
same across all hosts.
Args:
dataset (torch.utils.data.Dataset): The dataset to get the data from.
options (poptorch.Options): Options that will be used to compile
and run the model.
batch_size (int, optional): This is the batch size in the conventional
sense of being the size that runs through an operation in the model
at any given time.
shuffle (bool, optional): set to ``True`` to have the data reshuffled
at every epoch (default: ``False``).
num_workers (int, optional): how many subprocesses to use for data
loading. ``0`` means that the data will be loaded in the main
process. (default: ``0``)
drop_last (bool, optional): If True and the number of elements in the
dataset is not a multiple of the combined batch size then the
incomplete batch at the end will be dropped.
persistent_workers (bool, optional): Re-use workers between
iterations if True.
auto_distributed_partitioning (bool, optional): If True, partitions the
dataset for distributed execution automatically. Otherwise, it is
assumed that partitioning has been handled manually.
mode (poptorch.DataLoaderMode, optional): If `DataLoaderMode.Async`,
uses an :py:class:`~poptorch.AsynchronousDataAccessor` to access
the dataset. If `DataLoaderMode.Sync`, accesses the dataset
synchronously.
async_options (Dict[str, Any], optional): Options to pass to
:py:class:`~poptorch.AsynchronousDataAccessor`.
rebatched_worker_size (int, optional): When using AsyncRebatched: batch
size of the tensors loaded by the workers.
Default to the combined batch size.
If specified the ``rebatched_worker_size`` must be less than
or equal to the combined batch size.
kwargs (Dict[str, Any], optional): Other options to pass to PyTorch's
``DataLoader`` constructor.
"""
def __init__(self,
dataset,
options,
batch_size=1,
shuffle=False,
num_workers=0,
drop_last=True,
persistent_workers=True,
auto_distributed_partitioning=True,
mode='sync',
async_options=None,
rebatched_worker_size=None,
**kwargs):
"""Lazy init:
In many frameworks, the dataloader will be constructed before the
initialization of the ipu options, so the lazy init method is used
here, and the real initialization will not be done until the dataloader
needs to be used and the options are input.
"""
# lazy init: sometimes, we cannot get IPU options when build data
# loader
self.kwargs = {
'dataset': dataset,
'batch_size': batch_size,
'shuffle': shuffle,
'num_workers': num_workers,
'drop_last': drop_last,
'persistent_workers': persistent_workers,
'auto_distributed_partitioning': auto_distributed_partitioning,
'mode': mode,
'collate_fn': partial(collate, samples_per_gpu=batch_size),
'async_options': async_options,
'rebatched_worker_size': rebatched_worker_size,
**kwargs
}
self.dataset = dataset
self.initialized = False
if options:
self.init(options=options)
def init(self, options, **kwargs):
if not self.initialized:
kwargs = {**self.kwargs, **kwargs, 'options': options}
if kwargs['mode'] == 'sync':
kwargs['mode'] = poptorch.DataLoaderMode.Sync
elif kwargs['mode'] == 'async':
kwargs['mode'] = poptorch.DataLoaderMode.AsyncRebatched
if kwargs['async_options'] is None:
kwargs['async_options'] = {
'load_indefinitely': True,
'buffer_size': 8
}
if kwargs['rebatched_worker_size'] is None:
kwargs['rebatched_worker_size'] = 128
super().__init__(**kwargs)
self.initialized = True
return self
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import numpy as np
import torch
from mmcv.parallel import DataContainer
# A customized None type for HierarchicalDataManager
HierarchicalDataNone = object()
class HierarchicalDataManager:
"""A class manage all the tensors in the hierarchical data.
At present, the input data structure accepted by IPU is limited,
when the input data structure of mmcv varies.
Here, an intermediate class is needed to get and update tensors
from the original data.
HierarchicalDataManager will record a hierarchical input/output data in
self._hierarchical_data. For example, we have an input data:
{'img': tensorA, 'label': tensorB, 'img_metas': [tensorC, tensorD]}
To enable IPU to use the input, HierarchicalDataManager will collect
the torch tensors from self._hierarchical_data into a tuple like:
(tensorA, tensorB, tensorC, tensorD).
Meanwhile, the return of IPU is a tuple of tensors, HierarchicalDataManager
also have a function named update_all_tensors to update tensors in
self._hierarchical_data which is the output for upper calls.
Args:
logger (:obj:`logging.Logger`): Logger used during running.
Defaults to None.
"""
def __init__(self, logger=None):
self.atomic_types = (int, str, float, np.ndarray, type(None))
self.warning = warnings.warn if logger is None else logger.warning
# enable or disable input data's shape and value check
self.quick_mode = False
self._hierarchical_data = None
def quick(self):
self.quick_mode = True
def compare_atomic_type(self, a, b):
"""Compare data, supported datatypes are numpy array and python basic
types."""
if isinstance(a, np.ndarray):
return np.all(a == b)
else:
return a == b
def record_hierarchical_data(self, data):
"""Record a hierarchical data."""
if self._hierarchical_data is not None:
if isinstance(data, torch.Tensor):
assert isinstance(self._hierarchical_data, torch.Tensor), \
'original hierarchical data is not torch.tensor'
self._hierarchical_data = data
else:
self.update_hierarchical_data(data)
else:
self._hierarchical_data = data
@property
def hierarchical_data(self):
return self._hierarchical_data
def update_hierarchical_data(self,
dataA,
dataB=HierarchicalDataNone,
strict=True,
address='data'):
"""Update dataB with dataA in-place.
Args:
dataA (list or dict or tuple): New hierarchical data.
dataB (list or dict or tuple): hierarchical data to update.
if not specified, self.hierarchical_data will be updated then.
strict (bool, optional): If true, an error will be reported
when the following conditions occur:
1. Non-torch.Tensor data changed.
2. Torch.Tensor data shape changed.
address (str): Record the address of current data to be updated.
Default: 'data'.
"""
if dataB is HierarchicalDataNone:
dataB = self.hierarchical_data
# Update with a da ta with the same structure
# but different values(tensors and basic python data types)
if isinstance(dataA, (tuple, list)):
for idx, node in enumerate(dataA):
new_address = ''
if not self.quick_mode:
new_address = address + f'[{str(idx)}]'
assert isinstance(node, type(dataB[idx])),\
f'data structure changed: {new_address}'
if isinstance(node, torch.Tensor):
dataB[idx] = node
else:
self.update_hierarchical_data(
node, dataB[idx], strict, address=new_address)
elif isinstance(dataA, dict):
for k, v in dataA.items():
new_address = ''
if not self.quick_mode:
new_address = address + f'[{str(k)}]'
assert isinstance(v, type(dataB[k])),\
f'data structure changed: {new_address}'
if isinstance(v, torch.Tensor):
dataB[k] = v
else:
self.update_hierarchical_data(
v, dataB[k], strict, address=new_address)
elif isinstance(dataA, self.atomic_types):
if not self.quick_mode:
is_equal = self.compare_atomic_type(dataA, dataB)
if not is_equal:
if strict:
raise ValueError(
'all data except torch.Tensor should be same, '
f'but data({address}) is changed.')
else:
self.warning(
f'find a non-torch.Tensor data({type(dataA)}) '
f'changed, and the address is {address}')
elif isinstance(dataA, DataContainer):
if not self.quick_mode:
assert isinstance(dataB, DataContainer)
new_address = address + '.data'
self.update_hierarchical_data(
dataA.data, dataB.data, False, address=new_address)
else:
raise NotImplementedError(
f'not supported datatype:{type(dataA)}, address is {address}')
def collect_all_tensors(self, hierarchical_data=None):
"""Collect torch.Tensor data from self.hierarchical_data to a list and
return."""
# get a list of tensor from self._hierarchical_data
if hierarchical_data is None:
hierarchical_data = self._hierarchical_data
tensors = []
if isinstance(hierarchical_data, torch.Tensor):
tensors = [hierarchical_data]
else:
self._collect_tensors(hierarchical_data, tensors)
return tensors
def _collect_tensors(self, data, tensors):
if isinstance(data, (tuple, list)):
for node in data:
if isinstance(node, torch.Tensor):
tensors.append(node)
else:
self._collect_tensors(node, tensors)
elif isinstance(data, dict):
for v in data.values():
if isinstance(v, torch.Tensor):
tensors.append(v)
else:
self._collect_tensors(v, tensors)
elif isinstance(data, self.atomic_types):
pass
elif isinstance(data, DataContainer):
self._collect_tensors(data.data, tensors)
else:
raise NotImplementedError(f'not supported datatype:{type(data)}')
def update_all_tensors(self, tensors):
"""Put tensors from tuple back to self.hierarchical_data."""
if isinstance(self._hierarchical_data, torch.Tensor):
print(tensors, len(tensors))
assert len(tensors) == 1
assert isinstance(tensors[0], torch.Tensor)
self._hierarchical_data = tensors[0]
else:
# convert to list if tensors is tuple
tensors = list(tensors)
self._set_tensors(self._hierarchical_data, tensors)
return self.hierarchical_data
def _set_tensors(self, data, tensors):
if isinstance(data, tuple):
data = list(data)
for idx in range(len(data)):
if isinstance(data[idx], torch.Tensor):
data[idx] = tensors.pop(0)
else:
self._set_tensors(data[idx], tensors)
data = tuple(data)
elif isinstance(data, list):
for idx in range(len(data)):
if isinstance(data[idx], torch.Tensor):
data[idx] = tensors.pop(0)
else:
self._set_tensors(data[idx], tensors)
elif isinstance(data, dict):
for k, v in data.items():
if isinstance(v, torch.Tensor):
data[k] = tensors.pop(0)
else:
self._set_tensors(v, tensors)
elif isinstance(data, self.atomic_types):
pass
elif isinstance(data, DataContainer):
self._set_tensors(data.data, tensors)
else:
raise NotImplementedError(f'not supported datatype:{type(data)}')
def clean_all_tensors(self):
"""Delete tensors from self.hierarchical_data."""
self._clean_tensors(self._hierarchical_data)
def _clean_tensors(self, data):
if isinstance(data, tuple):
data = list(data)
for idx in range(len(data)):
if isinstance(data[idx], torch.Tensor):
data[idx] = None
else:
self._clean_tensors(data[idx])
data = tuple(data)
elif isinstance(data, list):
for idx in range(len(data)):
if isinstance(data[idx], torch.Tensor):
data[idx] = None
else:
self._clean_tensors(data[idx])
elif isinstance(data, dict):
for k, v in data.items():
if isinstance(v, torch.Tensor):
data[k] = None
else:
self._clean_tensors(v)
elif isinstance(data, self.atomic_types):
pass
elif isinstance(data, DataContainer):
self._clean_tensors(data.data)
else:
raise NotImplementedError(f'not supported datatype:{type(data)}')
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import HOOKS, LrUpdaterHook, OptimizerHook
from mmcv.utils import TORCH_VERSION, digit_version
def wrap_lr_updater_hook(lr_hook_class):
"""A wrapper function to wrap any subclass of LrUpdaterHook.
IPU needs extra operations to upload optimizer settings. This wrapper will
override function(_set_lr) of a subclass of LrUpdaterHook.
"""
assert issubclass(lr_hook_class, LrUpdaterHook)
class ipu_lr_hook_class(lr_hook_class):
def _set_lr(self, runner, *args, **kwargs):
super()._set_lr(runner, *args, **kwargs)
# convert torch optimizer to poptorch optimizer
runner.model.setOptimizer(runner.optimizer)
return ipu_lr_hook_class
def wrap_optimizer_hook(optimizer_hook_class):
"""A wrapper function to wrap OptimizerHook.
This is an non-intrusive implementation of wrapping optimizer hook (or you
need to change every config file to use IPU optimizer hook) IPU's clip-norm
implementation is different from pytorch, so there should be an error
raised when using clip-norm.
"""
class ipu_optimizer_hook_class(OptimizerHook):
def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.grad_clip is not None:
raise NotImplementedError('IPU does not support gradient clip')
return ipu_optimizer_hook_class
if (TORCH_VERSION != 'parrots'
and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
@HOOKS.register_module()
class IPUFp16OptimizerHook(OptimizerHook):
"""FP16 optimizer hook (using PyTorch's implementation).
If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend,
to take care of the optimization procedure.
Args:
loss_scale (float | str | dict): Scale factor configuration.
If loss_scale is a float, static loss scaling will be used with
the specified scale. If loss_scale is a string, it must be
'dynamic', then dynamic loss scaling will be used.
It can also be a dict containing arguments of GradScalar.
Defaults to 512. For Pytorch >= 1.6, mmcv uses official
implementation of GradScaler. If you use a dict version of
loss_scale to create GradScaler, please refer to:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
for the parameters.
Examples:
>>> loss_scale = dict(
... init_scale=65536.0,
... growth_factor=2.0,
... backoff_factor=0.5,
... growth_interval=2000
... )
>>> optimizer_hook = Fp16OptimizerHook(loss_scale=loss_scale)
"""
def __init__(self,
grad_clip=None,
coalesce=True,
bucket_size_mb=-1,
loss_scale=512.,
distributed=True):
assert grad_clip is None,\
'IPU mode does not support `grad_clip` currently'
assert coalesce,\
'implemented all reduce in distributed training currently'
assert bucket_size_mb == -1,\
'`bucket_size_mb` should not be set in IPU mode'
self.distributed = distributed
self._scale_update_param = None
if loss_scale == 'dynamic':
raise NotImplementedError(
'IPU mode does not support dynamic loss scale currently')
elif isinstance(loss_scale, float):
self.loss_scale = loss_scale
elif isinstance(loss_scale, dict):
raise NotImplementedError(
'IPU mode supports single scale currently')
else:
raise ValueError(
f'loss_scale should be float, but got {loss_scale} ')
def after_train_iter(self, runner):
pass
else:
raise RuntimeError('The IPU mode only supports torch 1.6 and above')
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import inspect
from collections import OrderedDict
from typing import Optional, Union
import poptorch
import torch
import torch.nn as nn
from poptorch import PoplarExecutor, __version__, identity_loss
from poptorch._args_parser import ArgsParser
from mmcv.runner import auto_fp16
from .hierarchical_data_manager import HierarchicalDataManager
from .utils import compare_ndarray, model_sharding, recomputation_checkpoint
class DictArgsParser(ArgsParser):
"""A helper class for handling model input.
Args:
inputs (list): Inputs of model.
"""
def __init__(self, inputs):
# Combine args and kwargs:
self._has_variadic_arguments = True
self._varnames = list(inputs.keys())
self._defaults = [inspect.Parameter.empty for _ in self._varnames]
self._warned_not_contiguous_input = False
class WrappedNet(nn.Module):
"""A net wrapper for model conversion.
This wrapper will make some changes and add some extra functions to
training/inference model.
Args:
model (:obj:`nn.Module`): The model to run.
inputs_manager (:obj:`HierarchicalDataManager`): A parser
converting inputs from tuple to dictionary.
outputs_manager (:obj:`HierarchicalDataManager`): A parser
converting outputs from dictionary to tuple.
inter_outputs_in_cpu (dict): Specify the features to be
recorded.
modules_to_record (mmcv.Config, list): Index or name of modules which
will be recorded for output. It is necessary to specify output for
static graph of model training or inference.
"""
def __init__(self,
model,
inputs_manager,
outputs_manager,
inter_outputs_in_cpu,
modules_to_record=None):
super().__init__()
self.model = model
self.inputs_manager = inputs_manager
self.outputs_manager = outputs_manager
self.training = model.training
# Register a hook function to capture the intermediate features
# generated by the network to align the outputs between ipu and cpu
# Used to confirm whether the implementation of CPU is consistent
# with the implementation of IPU
self.inter_outputs_in_cpu = inter_outputs_in_cpu
if modules_to_record is None:
modules_to_record = []
for idx, (name, module) in enumerate(model.named_modules()):
if name in modules_to_record or idx in modules_to_record:
features_hook = self.get_input_output_hook(
name, idx, self.inter_outputs_in_cpu)
module.register_forward_hook(hook=features_hook)
def get_input_output_hook(self, name, idx, save_dict):
def input_output_hook(module, fea_in, fea_out):
if isinstance(fea_in, tuple):
fea_in = list(fea_in)
if isinstance(fea_out, tuple):
fea_out = list(fea_out)
save_dict[name] = {
'fea_in': fea_in,
'fea_out': fea_out,
'idx': idx
}
return None
return input_output_hook
def forward(self, inputs_tuple):
"""This function is used to be compiled to ipu, the inputs and outputs
need to be tuples, so here we need to restore the input back to a
dictionary and convert the output to a tuple."""
self.inputs_manager.update_all_tensors(inputs_tuple)
kwargs = {**(self.inputs_manager.hierarchical_data)}
if self.training:
outputs = self.forward_train(kwargs)
# tell poptorch which loss will be used finally
identity_loss(outputs['loss'], reduction='none')
else:
outputs = self.forward_eval(kwargs)
if isinstance(outputs, torch.Tensor):
# currently not support single tensor output,
# need to wrap it with a dictionary,
# use a keyword to identify this case
outputs = {'output of WrappedNet: single tensor': outputs}
# if there are some features need to be record, add extra outputs
for name in self.inter_outputs_in_cpu:
outputs[name] = self.inter_outputs_in_cpu[name]
# record all the places of return tensors in the converting stage
# while in the real run stage, all the tensor are changed in-place
# that means the output can be obtained directly outside this function
self.outputs_manager.record_hierarchical_data(outputs)
plain_outputs = self.outputs_manager.collect_all_tensors()
return plain_outputs
def forward_train(self, kwargs):
optimizer = kwargs.pop('optimizer')
outputs = self.train_step(kwargs, optimizer)
return outputs
def train_step(self, data, optimizer=None, **kwargs):
"""The iteration step during training.
This method defines an iteration step during training, except for the
back propagation and optimizer updating, which are done in an optimizer
hook. Note that in some complicated cases or models, the whole process
including back propagation and optimizer updating are also defined in
this method, such as GAN.
Args:
data (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer`, optional): The
optimizer of runner is passed to ``train_step()``. This
argument is unused and reserved.
Returns:
dict: Dict of outputs. The following fields are contained.
- loss (torch.Tensor): A tensor for back propagation, which \
can be a weighted sum of multiple losses.
- log_vars (dict): Dict contains all the variables to be sent \
to the logger.
- num_samples (int): Indicates the batch size (when the model \
is DDP, it means the batch size on each GPU), which is \
used for averaging the logs.
"""
losses = self.model(**data)
loss, log_vars = self._parse_losses(losses)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
return outputs
def _parse_losses(self, losses):
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(loss.mean() for loss in loss_value)
elif isinstance(loss_value, dict):
for name, value in loss_value.items():
log_vars[name] = value
else:
raise TypeError(
f'{loss_name} is not a tensor or list of tensors')
loss = sum(value for key, value in log_vars.items() if 'loss' in key)
log_vars['loss'] = loss
return loss, log_vars
def forward_eval(self, kwargs):
img = kwargs.pop('img')
img_metas = kwargs.pop('img_metas', None)
return_loss = kwargs.pop('return_loss')
assert not return_loss
# TODO Temporarily hard-code to close post_process,
# otherwise, in the third trace(_check_trace),
# post_process will convert output tensor to numpy array automatically,
# resulting in _check_trace failure
outputs = self.model(
img,
img_metas=img_metas,
return_loss=return_loss,
post_process=False)
return outputs
class MMPoplarExecutor(PoplarExecutor):
"""An executor for inputs/outputs parsing, model compilation, data
alignment and IPU upload/download.
Args:
model (:obj:`nn.Module`): The model to be compiled.
logger (:obj:`logging.Logger`): Logger used during running.
Defaults to None.
training (bool): Model in training mode or eval mode.
modules_to_record (mmcv.Config, list): Index or name of modules which
will be recorded for output. It is necessary to specify output for
static graph of model training or inference.
args (argument list): Arguments passed to the `__init__`
method of PoplarExecutor.
kwargs (keyword arguments): Keyword arguments passed to the `__init__`
method of PoplarExecutor.
"""
def __init__(self,
model,
logger=None,
training=True,
modules_to_record=None,
*args,
**kwargs):
# self.model == self._user_model: input pytorch model
# self._model: wrapped model which is used to compile
# and update weights, these two models use same weights
# wrapped model only accept and output tuple, so
# HierarchicalDataManager will convert dictionary
# to tuple and convert them back
self.inputs_manager = HierarchicalDataManager(logger=logger)
self.outputs_manager = HierarchicalDataManager(logger=logger)
self.logger = logger
# the features calculated by CPU
self.inter_outputs_in_cpu = {}
# the features calculated by IPU
self.inter_outputs_in_ipu = {}
if modules_to_record is None:
# It is possible that the IPU implementation of some operators
# is inconsistent with the expected (CPU), here you can use
# this method to confirm whether there is a problem
self.compare_with_cpu = False
else:
self.compare_with_cpu = True
# move model.fp16_enabled to self.fp16_enabled,
# modify the position where the input is automatically casted to half
if getattr(model, 'fp16_enabled', False):
model.fp16_enabled = False
self.fp16_enabled = True
# make torch.jit.trace convert self._model
model = WrappedNet(
model,
self.inputs_manager,
self.outputs_manager,
self.inter_outputs_in_cpu,
modules_to_record=modules_to_record)
super().__init__(model, training=training, *args, **kwargs)
# overwrite self._args_parser in train_step or val_step
self._args_parser = None
if training:
assert self.training
else:
assert not self.training
@property
def training(self):
# If trying to get the attribute(training) of self,
# since the class has no training attribute,
# it will automatically look for the training attribute of self.model.
# However, the real attribute we want to check is self._training,
# self.model.training and self._training are often inconsistent.
# It is not clear whether it is a Poptorch bug or a special design,
# temporarily use this function to fix the problem
return self._training # comes from self.model._training
@auto_fp16(supported_types=(PoplarExecutor, ))
def run_model(self, data_dict):
# this function is used to parse input_dict
# and convert to output_dict
if self.isCompiled():
self.inputs_manager.record_hierarchical_data(data_dict)
inputs_tuple = tuple(self.inputs_manager.collect_all_tensors())
else:
# get tensors out of data and put them in a tuple
self.inputs_manager.record_hierarchical_data(data_dict)
inputs_tuple = tuple(self.inputs_manager.collect_all_tensors())
# turn logger in data manager off after compilation
self.inputs_manager.quick()
self.outputs_manager.quick()
# parser args in the first iter
if self._args_parser is None:
self._args_parser = DictArgsParser({'args': inputs_tuple})
# run or convert model
# the plain_outputs will be used in converting stage
plain_outputs = self(inputs_tuple)
self.inputs_manager.clean_all_tensors()
# put list of tensors back to the output dict
# according to the same order
self.outputs_manager.update_all_tensors(plain_outputs)
# get the real output dictionary from self.outputs_manager
output_dict = self.outputs_manager.hierarchical_data
# split output_dict into inter_outputs_in_ipu
# and output of the torch model
torch_model_output = {}
for name in output_dict:
if name in self.inter_outputs_in_cpu:
self.inter_outputs_in_ipu[name] = output_dict[name]
else:
torch_model_output[name] = output_dict[name]
if 'output of WrappedNet: single tensor' in output_dict:
assert len(torch_model_output) == 1
assert isinstance(
torch_model_output['output of WrappedNet: single tensor'],
torch.Tensor)
torch_model_output = \
torch_model_output['output of WrappedNet: single tensor']
return torch_model_output
def train_step(self, data, optimizer=None, **kwargs):
# arguments from mmcls/models/classifiers/base.py:
# BaseClassifier.train_step
assert self.training
assert len(kwargs) == 0 # TODO, support later if necessary
# TODO support datacontainer as input
# currently, auto_fp16 and HierarchicalDataManager take too much
# time on traversing datacontainer
data['img_metas'] = None
num_samples = len(data['img'].data)
# TODO we will ignore optimizer because it will not be used in model,
# support later if necessary
data['optimizer'] = None
output_dict = self.run_model(data)
# outputs contained loss, log_vars, num_samples,
# only loss(torch.tensor) has been updated
# remove all unchanged vars, left torch.tensor
neat_output_dict = {'loss': output_dict['loss']}
# re-parse outputs, get back log_vars and num_samples
loss, log_vars = self.model._parse_losses(neat_output_dict)
final_output_dict = dict(
loss=loss, log_vars=log_vars, num_samples=num_samples)
return final_output_dict
def eval_call(self, img, img_metas=None, return_loss=True, **kwargs):
# arguments from mmdet/models/detectors/base.py:BaseDetector.forward
# tmp usssage for eval mode
assert not self.training
assert len(kwargs) == 0 # TODO, support later if necessary
assert not return_loss
data = {'img': img, 'img_metas': img_metas, 'return_loss': return_loss}
output_dict = self.run_model(data)
return output_dict
def detachFromDevice(self):
if self.isCompiled() and self._is_attached:
super().detachFromDevice()
def attachToDevice(self):
if self.isCompiled() and not self._is_attached:
super().attachToDevice()
class TrainEvalModel:
"""A class maintaining training MMPoplarExecutor and inference
MMPoplarExecutor.
Args:
train_model (:obj:`nn.Module`): The training model to be compiled.
``train_model`` can be None if only executing validation.
eval_model (:obj:`nn.Module`): The inference model to be compiled.
options (mmcv.Config, dict): Options that will be used to compile
and run the model.
optimizer (:obj:`torch.optim.Optimizer`, optional): torch
optimizer, necessary if in training mode
logger (:obj:`logging.Logger`): Logger used during running.
Defaults to None.
modules_to_record (mmcv.Config, list): Index or name of modules which
will be recorded for output. It is necessary to specify output for
static graph of model training or inference.
"""
def __init__(self,
train_model,
eval_model,
options,
optimizer,
modules_to_record=None,
logger=None):
if train_model is None:
self._train_executor = None
self.training = False
else:
self._train_executor = get_training_model(
train_model,
options=options['training'],
optimizer=optimizer,
logger=logger,
modules_to_record=modules_to_record)
self.training = True
self._eval_executor = get_inference_model(
eval_model, options=options['inference'], logger=logger)
@property
def executor(self):
if self.training:
return self._train_executor
else:
return self._eval_executor
def train(self, mode: bool = True):
"""Sets the module in training mode.
This has any effect only on certain modules. See documentations of
particular modules for details of their behaviors in
training/evaluation mode, if they are affected,
e.g. :class:`Dropout`, :class:`BatchNorm`, etc.
Args:
mode (bool): whether to set training mode (``True``) or evaluation
mode (``False``). Default: ``True``.
Returns:
Module: self
"""
if not isinstance(mode, bool):
raise ValueError('training mode is expected to be boolean, '
f'but got {type(mode)}')
if self._train_executor is None and mode:
raise RuntimeError(
'The train_executor is not initialized.'
'If you want to initialize train_executor,'
'you need to input optimizer when converting pytorch model')
if mode == self.training:
self.model.train(mode)
return self
else:
if self.isCompiled():
# copy weights from IPU to cpu before off-load current session
self.copyWeightsToHost()
# detach the current session before change the mode,
# if is training mode and weights are updated,
# poptorch will copy weights from IPU to host
self.detachFromDevice()
self.training = mode # session will changed with mode changing
self.model.train(mode)
# after changing mode, attach the current new session,
# and this function will copy weights of model to device
self.attachToDevice()
return self
def eval(self):
"""Sets the module in evaluation mode.
This has any effect only on certain modules.
See documentations of particular modules
for details of their behaviors in training/evaluation mode,
if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc.
This is equivalent with :meth:`self.train(False)
<nn.Module.train>`.
See :ref:`locally-disable-grad-doc` for a comparison between
`.eval()` and several similar mechanisms that may be confused with it.
Returns:
Module: self
"""
return self.train(False)
def compare_data_between_ipu_and_cpu(self, inter_outputs_in_cpu,
inter_outputs_in_ipu):
for key, val in inter_outputs_in_cpu.items():
is_tensor = isinstance(val['fea_in'], torch.Tensor)
fea_in_cpu = val['fea_in']
fea_in_cpu_list = [fea_in_cpu] if is_tensor else fea_in_cpu
fea_in_ipu = inter_outputs_in_ipu[key]['fea_in']
fea_in_ipu_list = [fea_in_ipu] if is_tensor else fea_in_ipu
is_tensor = isinstance(val['fea_out'], torch.Tensor)
fea_out_cpu = val['fea_out']
fea_out_cpu_list = [fea_out_cpu] if is_tensor else fea_out_cpu
fea_out_ipu = inter_outputs_in_ipu[key]['fea_out']
fea_out_ipu_list = [fea_out_ipu] if is_tensor else fea_out_ipu
print('comparing layer:', key)
for idx, (featA, featB) in \
enumerate(zip(fea_in_cpu_list, fea_in_ipu_list)):
print('fea_in, tensor ', idx)
compare_ndarray(featA.detach().numpy(), featB.detach().numpy())
for idx, (featA, featB) in \
enumerate(zip(fea_out_cpu_list, fea_out_ipu_list)):
print('fea_out, tensor', idx)
compare_ndarray(featA.detach().numpy(), featB.detach().numpy())
# TODO Unified training and eval interface,
# merge train_step(train) and __call__(eval) together
def train_step(self, data, optimizer=None, **kwargs):
assert self.training, 'not supported train_step on eval mode'
inter_outputs_in_cpu = {}
if (self._train_executor.isCompiled()
and self._train_executor.compare_with_cpu):
self.copyWeightsToHost()
# run in CPU mode
self._train_executor.model.train_step(data, optimizer, **kwargs)
inter_outputs_in_cpu = {
**(self._train_executor.inter_outputs_in_cpu)
}
# run in IPU mode
result = self._train_executor.train_step(data, optimizer, **kwargs)
if (self._train_executor.isCompiled()
and self._train_executor.compare_with_cpu
and len(inter_outputs_in_cpu) > 0):
self.compare_data_between_ipu_and_cpu(
inter_outputs_in_cpu,
self._train_executor.inter_outputs_in_ipu)
return result
# TODO Unified training and eval interface,
# merge train_step(train) and __call__(eval) together
def __call__(self, *args, **kwargs):
if self.training:
raise NotImplementedError('use train_step rather than __call__')
else:
return self._eval_executor.eval_call(*args, **kwargs)
def __getattr__(self, attr):
return getattr(self.executor, attr)
def get_training_model(model: nn.Module,
options: Optional[poptorch.Options] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
logger=None,
modules_to_record=None) -> poptorch.PoplarExecutor:
"""Create a PopTorch training model from a PyTorch model, running on IPU
hardware in training mode.
Note:
PopTorch makes a shallow copy of the model. Changes to the
parameters in the returned training model affect the original model
and vice versa. However, primitive variable types are not synced: for
example calling ``model.train()`` on the original model, which
changes the ``training`` bool of the model instance, will not alter the
model returned by this function. You may need to call ``model.train()``
on your model before you call this function for correct behavior.
Args:
model (:obj:`nn.Module`): The model to run.
options (poptorch.Options): Options that will be used to compile
and run the model.
optimizer (:obj:`torch.optim.Optimizer`, optional): The optimizers
to apply during training.
logger (:obj:`logging.Logger`): Logger used during running.
Defaults to None.
modules_to_record (mmcv.Config, list): Index or name of modules which
will be recorded for output. It is necessary to specify output for
static graph of model training or inference.
Returns:
The :class:`poptorch.PoplarExecutor` wrapper to use in place
of ``model``.
"""
# Create a copy of the original model in case it needs to be wrapped
maybe_wrapped_model = copy.copy(model)
return MMPoplarExecutor(
model=maybe_wrapped_model,
logger=logger,
options=options,
training=True,
optimizer=optimizer,
user_model=model,
modules_to_record=modules_to_record,
poptorch_version=__version__)
def get_inference_model(model: Union[nn.Module, poptorch.PoplarExecutor],
options: Optional[poptorch.Options] = None,
logger=None) -> poptorch.PoplarExecutor:
"""Create a PopTorch inference model from a PyTorch model, running on IPU
hardware in inference mode.
Note:
PopTorch makes a shallow copy of the model. Changes to the
parameters in the returned inference model affect the original model
and vice versa. However, primitive variable types are not synced: for
example calling ``model.eval()`` on the original model will not alter
the model returned by this function. You may need to call
``model.eval()`` on your model before you call this function for
correct behavior.
Args:
model (:obj:`nn.Module`): The model to run.
options (poptorch.Options): Options that will be used to compile
and run the model.
logger (:obj:`logging.Logger`): Logger used during running.
Defaults to None.
Returns:
The :class:`poptorch.PoplarExecutor` wrapper to use in place of
``model``.
"""
return MMPoplarExecutor(
model=copy.copy(model),
logger=logger,
options=options,
training=False,
poptorch_version=__version__)
def ipu_model_wrapper(model,
options,
optimizer=None,
logger=None,
modules_to_record=None,
ipu_model_cfg=None,
fp16_cfg=None):
"""Convert torch model to IPU model.
Args:
model (nn.Module): The target model to be converted.
options (dict[str, poptorch.Options]): IPU options, generated
by :func:`cfg2options`.
optimizer (:obj:`torch.optim.Optimizer`, optional): torch
optimizer, necessary if in training mode
logger (:obj:`logging.Logger`): Logger used during training.
modules_to_record (mmcv.Config, list): Index or name of modules which
will be recorded for output. It is necessary to specify output for
static graph of model training or inference.
ipu_model_cfg (dict): A dictionary contains train_split_edges and
train_ckpt_nodes, See details in :func:`model_sharding` and
:func:`recomputation_checkpoint` functions.
fp16_cfg (dict): Config for IPU fp16 training. Currently supports
configs: `loss_scale`, `velocity_accum_type` and `accum_type`.
See details in
https://docs.graphcore.ai/projects/poptorch-user-guide/en/latest/index.html
Returns:
TrainEvalModel: IPU wrapped model.
"""
if ipu_model_cfg is None:
ipu_model_cfg = {}
training = model.training if optimizer is not None else False
# set mixed-precision
if fp16_cfg is not None:
from mmcv.runner import wrap_fp16_model
loss_scale = fp16_cfg['loss_scale']
wrap_fp16_model(model)
model.half()
# TODO tmp ussage to set loss scaling for torch original optimizer
if optimizer is not None:
optimizer.loss_scaling = loss_scale
if fp16_cfg.get('velocity_accum_type', False):
if fp16_cfg['velocity_accum_type'] == 'half':
optimizer.velocity_accum_type = torch.half
else:
optimizer.velocity_accum_type = torch.float32
if fp16_cfg.get('accum_type', False):
if fp16_cfg['accum_type'] == 'half':
optimizer.accum_type = torch.half
else:
optimizer.accum_type = torch.float32
# TODO support feature alignment for fp16
if modules_to_record is not None:
raise NotImplementedError(
'Feature alignment for fp16 is not implemented')
# set model partition
if optimizer is None:
train_model = None
else:
# split model into multi-IPUs if specified
train_model = model_sharding(
copy.copy(model).train(),
ipu_model_cfg.get('train_split_edges', []))
recomputation_checkpoint(train_model,
ipu_model_cfg.get('train_ckpt_nodes', []))
# TODO support feature alignment for gradient accumulation mode
gradient_accumulation = \
getattr(options['training'].Training, 'gradient_accumulation', 1)
if gradient_accumulation > 1:
assert modules_to_record is None, \
'Feature alignment for grad-accumulation mode not implemented'
# TODO support feature alignment for multi-replica mode
replication_factor = \
getattr(options['training'], 'replication_factor', 1)
if replication_factor > 1:
assert modules_to_record is None, \
'Feature alignment for multi-replica mode not implemented'
# TODO supports different model partitions between train and eval mode
assert len(ipu_model_cfg.get('eval_split_edges', [])) == 0,\
'Currently, BeginBlock can only be used once on the same model'
eval_model = copy.copy(model).eval()
# wrap model for compilation
model = TrainEvalModel(
train_model,
eval_model,
options=options,
optimizer=optimizer,
logger=logger,
modules_to_record=modules_to_record)
model.train(training)
return model
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import (HOOKS, RUNNERS, BaseRunner, EpochBasedRunner,
IterBasedRunner)
from mmcv.utils import IS_IPU_AVAILABLE
if IS_IPU_AVAILABLE:
from .dataloader import IPUDataLoader
from .hook_wrapper import (IPUFp16OptimizerHook, wrap_lr_updater_hook,
wrap_optimizer_hook)
from .model_wrapper import ipu_model_wrapper
from .utils import build_from_cfg_with_wrapper, cfg2options
class IPUBaseRunner(BaseRunner):
"""A base runner for IPU.
This runner has some extra processes for IPU which are shown below:
1. Parse options for IPU
2. wrap pytorch model for IPU
3. Raise errors while encountering illegal usage
4. Input IPU options and initialize dataloader if finding an instance
of IPUDataLoader
Args:
model (:obj:`nn.Module`): The model to run.
options_cfg (mmcv.Config, dict): Options that will be used to compile
and run the model.
modules_to_record (mmcv.Config, list): Index or name of modules which
will be recorded for output. It is necessary to specify output for
static graph of model training or inference.
ipu_model_cfg (mmcv.Config, dict): Config of model partition and
recomputing checkpoint
fp16_cfg (mmcv.Config): Config for fp16 training.
batch_processor (callable): A callable method that process a data
batch. Should be None for IPU runner
kwargs (Dict[str, Any], optional): Keyword arguments will be passed to
``base_runner.BaseRunner``.
"""
def __init__(self,
model,
options_cfg=None,
modules_to_record=None,
ipu_model_cfg=None,
fp16_cfg=None,
batch_processor=None,
**kwargs):
assert hasattr(model, 'train_step') and batch_processor is None,\
'only support model with train_step'
if options_cfg is None:
options_cfg = {}
# call BaseRunner.__init__() here
super().__init__(model, **kwargs)
# process options of ipu
if IS_IPU_AVAILABLE:
self.options = cfg2options(options_cfg)
self.model = ipu_model_wrapper(
self.model,
self.options,
self.optimizer,
self.logger,
modules_to_record=modules_to_record,
ipu_model_cfg=ipu_model_cfg,
fp16_cfg=fp16_cfg)
else:
raise NotImplementedError('cpu mode on IPURunner is not supported')
def register_lr_hook(self, lr_config):
if lr_config is None:
return
assert isinstance(lr_config, dict)
assert 'policy' in lr_config
policy_type = lr_config.pop('policy')
# If the type of policy is all in lower case,
# e.g., 'cyclic', then its first letter will be capitalized,
# e.g., to be 'Cyclic'.
# This is for the convenient usage of Lr updater.
# Since this is not applicable for `
# CosineAnnealingLrUpdater`, the string will not be changed
# if it contains capital letters.
if policy_type == policy_type.lower():
policy_type = policy_type.title()
hook_type = policy_type + 'LrUpdaterHook'
lr_config['type'] = hook_type
hook = build_from_cfg_with_wrapper(lr_config, HOOKS,
wrap_lr_updater_hook)
self.register_hook(hook, priority='VERY_HIGH')
def register_optimizer_hook(self, optimizer_config):
if optimizer_config is None:
return
assert isinstance(optimizer_config, (dict, IPUFp16OptimizerHook))
if isinstance(optimizer_config, dict):
optimizer_config.setdefault('type', 'OptimizerHook')
hook = build_from_cfg_with_wrapper(optimizer_config, HOOKS,
wrap_optimizer_hook)
else:
hook = optimizer_config
self.register_hook(hook, priority='ABOVE_NORMAL')
def run(self, data_loaders, workflow, *args, **kwargs):
for i, flow in enumerate(workflow):
mode, _ = flow
# initialize IPU dataloader if not initialized
assert isinstance(data_loaders[i], IPUDataLoader),\
'IPU runner can only work with `IPUDataLoader`'
data_loaders[i].init(options=self.get_options(mode))
super().run(data_loaders, workflow, *args, **kwargs)
def get_options(self, mode):
if mode == 'train':
return self.options['training']
elif mode == 'val':
return self.options['inference']
else:
raise ValueError(f'mode should be train or val but got {mode}')
@RUNNERS.register_module()
class IPUEpochBasedRunner(IPUBaseRunner, EpochBasedRunner):
"""Epoch-based Runner for IPU.
The Inheritance order(MRO) is: IPUEpochBasedRunner -> IPUBaseRunner ->
EpochBasedRunner -> BaseRunner This runner train models epoch by epoch.
"""
pass
@RUNNERS.register_module()
class IPUIterBasedRunner(IPUBaseRunner, IterBasedRunner):
"""Iteration-based Runner for IPU.
The Inheritance order(MRO) is: IPUIterBasedRunner -> IPUBaseRunner ->
IterBasedRunner -> BaseRunner This runner train models iteration by
iteration.
"""
pass
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import numpy as np
import popart
import poptorch
import torch
import torch.nn as nn
from mmcv.utils import Registry
def _options_assigner(cfg, options_node):
# set popart.options by config
# cfg: dict, python data type
# options_node: python module or function
if isinstance(cfg, dict):
for key in cfg:
_options_assigner(cfg[key], getattr(options_node, key))
elif isinstance(cfg, (int, float, str, list)):
if callable(options_node):
options_node(cfg)
else:
error_msg = f'options_node type {type(options_node)} not supported'
raise NotImplementedError(error_msg)
else:
error_msg = f'cfg type {type(cfg)} not supported'
raise NotImplementedError(error_msg)
def cfg2options(cfg):
"""Parse dictionary to ipu options.
Args:
cfg (dict): A dictionary of ipu settings.
Returns:
dict[str, poptorch.Options]: Training options and inference options
of IPU.
"""
# set ipu options for inference and training by config
train_cfg = cfg.pop('train_cfg', {})
eval_cfg = cfg.pop('eval_cfg', {})
eval_cfg['replicationFactor'] = 1 # eval mode only use one replica
eval_cfg['executionStrategy'] = 'ShardedExecution'
# overwrite default ipu cfg with specified train cfgs
training_ipu_cfg = {**cfg, **train_cfg}
# overwrite default ipu cfg with specified eval cfgs
inference_ipu_cfg = {**cfg, **eval_cfg}
ipu_options = {
'training': _cast_to_options(training_ipu_cfg),
'inference': _cast_to_options(inference_ipu_cfg)
}
# TODO configure these codes
ipu_options['training']._Popart.set('disableGradAccumulationTensorStreams',
True)
ipu_options['training']._Popart.set(
'accumulateOuterFragmentSettings.schedule',
int(popart.AccumulateOuterFragmentSchedule.OverlapMemoryOptimized))
ipu_options['training'].Precision.enableStochasticRounding(True)
return ipu_options
def _cast_to_options(cfg):
# If it cannot be directly assigned, use if statement to parse it,
# and if it can be directly assigned, use _options_assigner to assign
options = poptorch.Options()
if 'availableMemoryProportion' in cfg:
available_memory_proportion = cfg.pop('availableMemoryProportion')
mem_props = {}
for i, mem_prop in enumerate(available_memory_proportion):
mem_props[f'IPU{i}'] = mem_prop
options.setAvailableMemoryProportion(mem_props)
if 'executionStrategy' in cfg:
execution_strategy = cfg.pop('executionStrategy')
if execution_strategy == 'SameAsIpu':
options.setExecutionStrategy(
poptorch.PipelinedExecution(
getattr(poptorch.AutoStage, execution_strategy)))
elif execution_strategy == 'ShardedExecution':
options.setExecutionStrategy(poptorch.ShardedExecution())
else:
raise NotImplementedError(
'executionStrategy should be "SameAsIpu" or "ShardedExecution"'
f', but got {execution_strategy}')
if 'partialsType' in cfg:
partials_type = cfg.pop('partialsType')
options.Precision.setPartialsType(getattr(
torch, partials_type)) # half or float
_options_assigner(cfg, options)
return options
def model_sharding(model, split_edges):
"""split models in-place into multi-IPUs.
Args:
model (nn.Module): The target model to be split.
split_edges (list of dict): Model layer names or layer numbers
of split edge. Each item of ``split_edges`` is a dictionary,
which may contain the following key-pairs:
- layer_to_call: PyTorch module to assign to the block
- user_id (optional): A user defined identifier for the block.
- ipu_id: The id of the IPU to run on.
Examples:
>>> split_edges = [
... dict(layer_to_call='model.conv1', ipu_id=0),
... dict(layer_to_call='model.conv3', ipu_id=1)]
>>> sharding_model = model_sharding(torch_model, split_edges)
Returns:
nn.Module: Split model.
"""
if len(split_edges) == 0:
return model
assert isinstance(split_edges, list)
spilt_edges_dict = {edge['layer_to_call']: edge for edge in split_edges}
for idx, (name, module) in enumerate(model.named_modules()):
if idx in spilt_edges_dict and name in spilt_edges_dict:
raise ValueError(
'The same layer is referenced twice while doing model'
f' partition: idx is {idx} and name is {name}')
edge = spilt_edges_dict.pop(name, None)
edge = spilt_edges_dict.pop(idx, edge)
if edge is not None:
poptorch.BeginBlock(module, edge.get('user_id', name),
edge['ipu_id'])
# ensure all split_edges are used
if len(spilt_edges_dict) > 0:
split_edge_names = list(spilt_edges_dict.keys())
raise RuntimeError(
f'split_edges: {split_edge_names} are not contained in the model')
return model
def recomputation_checkpoint(model: nn.Module, module_names: list):
"""Annotates the output of a module to be checkpointed instead of
recomputed.
If recomputation mode is enabled, ipu will release the activations of
the middle layers to save memory. During the backward of gradient,
the activation of the middle layer will be recalculated again.
This function is used to declare the activations of some intermediate
layers that need to be saved in order to skip the recomputation of
some layers.
Args:
model (nn.Module): The target model to apply recomputation
checkpoint.
module_names (list): Layer names of module.
"""
def recompute_outputs(module, inputs, outputs):
if isinstance(outputs, tuple):
return tuple(poptorch.recomputationCheckpoint(y) for y in outputs)
else:
return poptorch.recomputationCheckpoint(outputs)
for name, module in model.named_modules():
if name in module_names:
module.register_forward_hook(recompute_outputs)
module_names.remove(name)
# check all module_names are used
assert len(module_names) == 0,\
f'recomputed nodes: {module_names} are not contained in the model'
def compare_ndarray(featA, featB, rtol=1e-3, atol=1e-5):
"""Align data between two activations or weights."""
try:
np.testing.assert_allclose(featA, featB, rtol=rtol, atol=atol)
except AssertionError as e:
print(e)
def build_from_cfg_with_wrapper(cfg,
registry,
wrapper_func=None,
default_args=None):
"""Build a module from config dict and wrap module with "wrapper_func".
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
wrapper_func (function): Used to wrap class
Returns:
object: The constructed object.
"""
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'type' not in cfg:
if default_args is None or 'type' not in default_args:
raise KeyError(
'`cfg` or `default_args` must contain the key "type", '
f'but got {cfg}\n{default_args}')
if not isinstance(registry, Registry):
raise TypeError('registry must be an mmcv.Registry object, '
f'but got {type(registry)}')
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')
args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
obj_type = args.pop('type')
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
if wrapper_func is None:
wrapped_obj_cls = obj_cls
else:
wrapped_obj_cls = wrapper_func(obj_cls)
try:
return wrapped_obj_cls(**args)
except Exception as e:
# Normal TypeError does not print class name.
raise type(e)(f'{wrapped_obj_cls.__name__}: {e}')
# Copyright (c) OpenMMLab. All rights reserved.
from .data_parallel import MLUDataParallel
from .distributed import MLUDistributedDataParallel
__all__ = ['MLUDataParallel', 'MLUDistributedDataParallel']
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Union
import torch
def scatter(input: Union[List, torch.Tensor], devices: List) -> List:
"""scatter copies tensor to MLU directly."""
if isinstance(input, list):
outputs = [scatter(_input, devices) for _input in input]
return outputs
elif isinstance(input, torch.Tensor):
output = input.contiguous()
return output.to('mlu') if devices != [-1] else output
else:
raise Exception(f'Unknown type {type(input)}.')
class Scatter:
@staticmethod
def forward(target_mlus, input):
outputs = scatter(input, target_mlus)
return tuple(outputs) if isinstance(outputs, list) else (outputs, )
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.parallel import MMDataParallel
from .scatter_gather import scatter_kwargs
class MLUDataParallel(MMDataParallel):
"""The MLUDataParallel module that supports DataContainer.
MLUDataParallel is a class inherited from MMDataParall, which supports
MLU training and inference only.
The main differences with MMDataParallel:
- It only supports single-card of MLU, and only use first card to
run training and inference.
- It uses direct host-to-device copy instead of stream-background
scatter.
.. warning::
MLUDataParallel only supports single MLU training, if you need to
train with multiple MLUs, please use MLUDistributedDataParallel
instead. If you have multiple MLUs, you can set the environment
variable ``MLU_VISIBLE_DEVICES=0`` (or any other card number(s))
to specify the running device.
Args:
module (:class:`nn.Module`): Module to be encapsulated.
dim (int): Dimension used to scatter the data. Defaults to 0.
"""
def __init__(self, *args, dim=0, **kwargs):
super().__init__(*args, dim=dim, **kwargs)
self.device_ids = [0]
self.src_device_obj = torch.device('mlu:0')
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.parallel import MMDistributedDataParallel
from .scatter_gather import scatter_kwargs
class MLUDistributedDataParallel(MMDistributedDataParallel):
"""The DDP module supports DataContainer.
MLUDDP has one difference from MMDDP which moves data to MLU with coping
instead of scattering.
"""
def to_kwargs(self, inputs, kwargs, device_id):
# Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
# to move all tensors to device_id
return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.parallel.data_container import DataContainer
from ._functions import Scatter
def scatter(inputs, target_mlus, dim=0):
"""Scatter inputs to target mlu.
The only difference from original :func:`scatter` is to add support for
:type:`~mmcv.parallel.DataContainer`.
"""
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
if target_mlus != [-1]:
obj = obj.to('mlu')
return [obj]
else:
# for CPU inference we use self-implemented scatter
return Scatter.forward(target_mlus, obj)
if isinstance(obj, DataContainer):
if obj.cpu_only:
return obj.data
else:
return Scatter.forward(target_mlus, obj.data)
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
out = list(map(list, zip(*map(scatter_map, obj))))
return out
if isinstance(obj, dict) and len(obj) > 0:
out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return out
return [obj for targets in target_mlus]
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
return scatter_map(inputs)
finally:
scatter_map = None
def scatter_kwargs(inputs, kwargs, target_mlus, dim=0):
"""Scatter with support for kwargs dictionary."""
inputs = scatter(inputs, target_mlus, dim) if inputs else []
kwargs = scatter(kwargs, target_mlus, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs
# Copyright (c) OpenMMLab. All rights reserved.
from .data_parallel import MPSDataParallel
__all__ = ['MPSDataParallel']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment