Commit c04f261a authored by dongchy920's avatar dongchy920
Browse files

InstruceBLIP

parents
Pipeline #1594 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import torch.nn as nn
import torch.utils.checkpoint as cp
from .utils import constant_init, kaiming_init
def conv3x3(in_planes, out_planes, stride=1, dilation=1):
"""3x3 convolution with padding."""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
dilation=dilation,
bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
style='pytorch',
with_cp=False):
super(BasicBlock, self).__init__()
assert style in ['pytorch', 'caffe']
self.conv1 = conv3x3(inplanes, planes, stride, dilation)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
assert not with_cp
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self,
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
style='pytorch',
with_cp=False):
"""Bottleneck block.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
super(Bottleneck, self).__init__()
assert style in ['pytorch', 'caffe']
if style == 'pytorch':
conv1_stride = 1
conv2_stride = stride
else:
conv1_stride = stride
conv2_stride = 1
self.conv1 = nn.Conv2d(
inplanes, planes, kernel_size=1, stride=conv1_stride, bias=False)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=conv2_stride,
padding=dilation,
dilation=dilation,
bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(
planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.with_cp = with_cp
def forward(self, x):
def _inner_forward(x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
def make_res_layer(block,
inplanes,
planes,
blocks,
stride=1,
dilation=1,
style='pytorch',
with_cp=False):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(
block(
inplanes,
planes,
stride,
dilation,
downsample,
style=style,
with_cp=with_cp))
inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(inplanes, planes, 1, dilation, style=style, with_cp=with_cp))
return nn.Sequential(*layers)
class ResNet(nn.Module):
"""ResNet backbone.
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
num_stages (int): Resnet stages, normally 4.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
running stats (mean and var).
bn_frozen (bool): Whether to freeze weight and bias of BN layers.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
"""
arch_settings = {
18: (BasicBlock, (2, 2, 2, 2)),
34: (BasicBlock, (3, 4, 6, 3)),
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self,
depth,
num_stages=4,
strides=(1, 2, 2, 2),
dilations=(1, 1, 1, 1),
out_indices=(0, 1, 2, 3),
style='pytorch',
frozen_stages=-1,
bn_eval=True,
bn_frozen=False,
with_cp=False):
super(ResNet, self).__init__()
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet')
assert num_stages >= 1 and num_stages <= 4
block, stage_blocks = self.arch_settings[depth]
stage_blocks = stage_blocks[:num_stages]
assert len(strides) == len(dilations) == num_stages
assert max(out_indices) < num_stages
self.out_indices = out_indices
self.style = style
self.frozen_stages = frozen_stages
self.bn_eval = bn_eval
self.bn_frozen = bn_frozen
self.with_cp = with_cp
self.inplanes = 64
self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.res_layers = []
for i, num_blocks in enumerate(stage_blocks):
stride = strides[i]
dilation = dilations[i]
planes = 64 * 2**i
res_layer = make_res_layer(
block,
self.inplanes,
planes,
num_blocks,
stride=stride,
dilation=dilation,
style=self.style,
with_cp=with_cp)
self.inplanes = planes * block.expansion
layer_name = f'layer{i + 1}'
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1)
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
from ..runner import load_checkpoint
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
outs = []
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x = res_layer(x)
if i in self.out_indices:
outs.append(x)
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
def train(self, mode=True):
super(ResNet, self).train(mode)
if self.bn_eval:
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
if self.bn_frozen:
for params in m.parameters():
params.requires_grad = False
if mode and self.frozen_stages >= 0:
for param in self.conv1.parameters():
param.requires_grad = False
for param in self.bn1.parameters():
param.requires_grad = False
self.bn1.eval()
self.bn1.weight.requires_grad = False
self.bn1.bias.requires_grad = False
for i in range(1, self.frozen_stages + 1):
mod = getattr(self, f'layer{i}')
mod.eval()
for param in mod.parameters():
param.requires_grad = False
# Copyright (c) OpenMMLab. All rights reserved.
from .flops_counter import get_model_complexity_info
from .fuse_conv_bn import fuse_conv_bn
from .sync_bn import revert_sync_batchnorm
from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
KaimingInit, NormalInit, PretrainedInit,
TruncNormalInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init,
constant_init, initialize, kaiming_init, normal_init,
trunc_normal_init, uniform_init, xavier_init)
__all__ = [
'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init',
'constant_init', 'kaiming_init', 'normal_init', 'trunc_normal_init',
'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize',
'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
'Caffe2XavierInit', 'revert_sync_batchnorm'
]
# Modified from flops-counter.pytorch by Vladislav Sovrasov
# original repo: https://github.com/sovrasov/flops-counter.pytorch
# MIT License
# Copyright (c) 2018 Vladislav Sovrasov
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import sys
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import annotator.uniformer.mmcv as mmcv
def get_model_complexity_info(model,
input_shape,
print_per_layer_stat=True,
as_strings=True,
input_constructor=None,
flush=False,
ost=sys.stdout):
"""Get complexity information of a model.
This method can calculate FLOPs and parameter counts of a model with
corresponding input shape. It can also print complexity information for
each layer in a model.
Supported layers are listed as below:
- Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``.
- Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``,
``nn.ReLU6``.
- Poolings: ``nn.MaxPool1d``, ``nn.MaxPool2d``, ``nn.MaxPool3d``,
``nn.AvgPool1d``, ``nn.AvgPool2d``, ``nn.AvgPool3d``,
``nn.AdaptiveMaxPool1d``, ``nn.AdaptiveMaxPool2d``,
``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``,
``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``.
- BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``,
``nn.BatchNorm3d``, ``nn.GroupNorm``, ``nn.InstanceNorm1d``,
``InstanceNorm2d``, ``InstanceNorm3d``, ``nn.LayerNorm``.
- Linear: ``nn.Linear``.
- Deconvolution: ``nn.ConvTranspose2d``.
- Upsample: ``nn.Upsample``.
Args:
model (nn.Module): The model for complexity calculation.
input_shape (tuple): Input shape used for calculation.
print_per_layer_stat (bool): Whether to print complexity information
for each layer in a model. Default: True.
as_strings (bool): Output FLOPs and params counts in a string form.
Default: True.
input_constructor (None | callable): If specified, it takes a callable
method that generates input. otherwise, it will generate a random
tensor with input shape to calculate FLOPs. Default: None.
flush (bool): same as that in :func:`print`. Default: False.
ost (stream): same as ``file`` param in :func:`print`.
Default: sys.stdout.
Returns:
tuple[float | str]: If ``as_strings`` is set to True, it will return
FLOPs and parameter counts in a string format. otherwise, it will
return those in a float number format.
"""
assert type(input_shape) is tuple
assert len(input_shape) >= 1
assert isinstance(model, nn.Module)
flops_model = add_flops_counting_methods(model)
flops_model.eval()
flops_model.start_flops_count()
if input_constructor:
input = input_constructor(input_shape)
_ = flops_model(**input)
else:
try:
batch = torch.ones(()).new_empty(
(1, *input_shape),
dtype=next(flops_model.parameters()).dtype,
device=next(flops_model.parameters()).device)
except StopIteration:
# Avoid StopIteration for models which have no parameters,
# like `nn.Relu()`, `nn.AvgPool2d`, etc.
batch = torch.ones(()).new_empty((1, *input_shape))
_ = flops_model(batch)
flops_count, params_count = flops_model.compute_average_flops_cost()
if print_per_layer_stat:
print_model_with_flops(
flops_model, flops_count, params_count, ost=ost, flush=flush)
flops_model.stop_flops_count()
if as_strings:
return flops_to_string(flops_count), params_to_string(params_count)
return flops_count, params_count
def flops_to_string(flops, units='GFLOPs', precision=2):
"""Convert FLOPs number into a string.
Note that Here we take a multiply-add counts as one FLOP.
Args:
flops (float): FLOPs number to be converted.
units (str | None): Converted FLOPs units. Options are None, 'GFLOPs',
'MFLOPs', 'KFLOPs', 'FLOPs'. If set to None, it will automatically
choose the most suitable unit for FLOPs. Default: 'GFLOPs'.
precision (int): Digit number after the decimal point. Default: 2.
Returns:
str: The converted FLOPs number with units.
Examples:
>>> flops_to_string(1e9)
'1.0 GFLOPs'
>>> flops_to_string(2e5, 'MFLOPs')
'0.2 MFLOPs'
>>> flops_to_string(3e-9, None)
'3e-09 FLOPs'
"""
if units is None:
if flops // 10**9 > 0:
return str(round(flops / 10.**9, precision)) + ' GFLOPs'
elif flops // 10**6 > 0:
return str(round(flops / 10.**6, precision)) + ' MFLOPs'
elif flops // 10**3 > 0:
return str(round(flops / 10.**3, precision)) + ' KFLOPs'
else:
return str(flops) + ' FLOPs'
else:
if units == 'GFLOPs':
return str(round(flops / 10.**9, precision)) + ' ' + units
elif units == 'MFLOPs':
return str(round(flops / 10.**6, precision)) + ' ' + units
elif units == 'KFLOPs':
return str(round(flops / 10.**3, precision)) + ' ' + units
else:
return str(flops) + ' FLOPs'
def params_to_string(num_params, units=None, precision=2):
"""Convert parameter number into a string.
Args:
num_params (float): Parameter number to be converted.
units (str | None): Converted FLOPs units. Options are None, 'M',
'K' and ''. If set to None, it will automatically choose the most
suitable unit for Parameter number. Default: None.
precision (int): Digit number after the decimal point. Default: 2.
Returns:
str: The converted parameter number with units.
Examples:
>>> params_to_string(1e9)
'1000.0 M'
>>> params_to_string(2e5)
'200.0 k'
>>> params_to_string(3e-9)
'3e-09'
"""
if units is None:
if num_params // 10**6 > 0:
return str(round(num_params / 10**6, precision)) + ' M'
elif num_params // 10**3:
return str(round(num_params / 10**3, precision)) + ' k'
else:
return str(num_params)
else:
if units == 'M':
return str(round(num_params / 10.**6, precision)) + ' ' + units
elif units == 'K':
return str(round(num_params / 10.**3, precision)) + ' ' + units
else:
return str(num_params)
def print_model_with_flops(model,
total_flops,
total_params,
units='GFLOPs',
precision=3,
ost=sys.stdout,
flush=False):
"""Print a model with FLOPs for each layer.
Args:
model (nn.Module): The model to be printed.
total_flops (float): Total FLOPs of the model.
total_params (float): Total parameter counts of the model.
units (str | None): Converted FLOPs units. Default: 'GFLOPs'.
precision (int): Digit number after the decimal point. Default: 3.
ost (stream): same as `file` param in :func:`print`.
Default: sys.stdout.
flush (bool): same as that in :func:`print`. Default: False.
Example:
>>> class ExampleModel(nn.Module):
>>> def __init__(self):
>>> super().__init__()
>>> self.conv1 = nn.Conv2d(3, 8, 3)
>>> self.conv2 = nn.Conv2d(8, 256, 3)
>>> self.conv3 = nn.Conv2d(256, 8, 3)
>>> self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
>>> self.flatten = nn.Flatten()
>>> self.fc = nn.Linear(8, 1)
>>> def forward(self, x):
>>> x = self.conv1(x)
>>> x = self.conv2(x)
>>> x = self.conv3(x)
>>> x = self.avg_pool(x)
>>> x = self.flatten(x)
>>> x = self.fc(x)
>>> return x
>>> model = ExampleModel()
>>> x = (3, 16, 16)
to print the complexity information state for each layer, you can use
>>> get_model_complexity_info(model, x)
or directly use
>>> print_model_with_flops(model, 4579784.0, 37361)
ExampleModel(
0.037 M, 100.000% Params, 0.005 GFLOPs, 100.000% FLOPs,
(conv1): Conv2d(0.0 M, 0.600% Params, 0.0 GFLOPs, 0.959% FLOPs, 3, 8, kernel_size=(3, 3), stride=(1, 1)) # noqa: E501
(conv2): Conv2d(0.019 M, 50.020% Params, 0.003 GFLOPs, 58.760% FLOPs, 8, 256, kernel_size=(3, 3), stride=(1, 1))
(conv3): Conv2d(0.018 M, 49.356% Params, 0.002 GFLOPs, 40.264% FLOPs, 256, 8, kernel_size=(3, 3), stride=(1, 1))
(avg_pool): AdaptiveAvgPool2d(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.017% FLOPs, output_size=(1, 1))
(flatten): Flatten(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.000% FLOPs, )
(fc): Linear(0.0 M, 0.024% Params, 0.0 GFLOPs, 0.000% FLOPs, in_features=8, out_features=1, bias=True)
)
"""
def accumulate_params(self):
if is_supported_instance(self):
return self.__params__
else:
sum = 0
for m in self.children():
sum += m.accumulate_params()
return sum
def accumulate_flops(self):
if is_supported_instance(self):
return self.__flops__ / model.__batch_counter__
else:
sum = 0
for m in self.children():
sum += m.accumulate_flops()
return sum
def flops_repr(self):
accumulated_num_params = self.accumulate_params()
accumulated_flops_cost = self.accumulate_flops()
return ', '.join([
params_to_string(
accumulated_num_params, units='M', precision=precision),
'{:.3%} Params'.format(accumulated_num_params / total_params),
flops_to_string(
accumulated_flops_cost, units=units, precision=precision),
'{:.3%} FLOPs'.format(accumulated_flops_cost / total_flops),
self.original_extra_repr()
])
def add_extra_repr(m):
m.accumulate_flops = accumulate_flops.__get__(m)
m.accumulate_params = accumulate_params.__get__(m)
flops_extra_repr = flops_repr.__get__(m)
if m.extra_repr != flops_extra_repr:
m.original_extra_repr = m.extra_repr
m.extra_repr = flops_extra_repr
assert m.extra_repr != m.original_extra_repr
def del_extra_repr(m):
if hasattr(m, 'original_extra_repr'):
m.extra_repr = m.original_extra_repr
del m.original_extra_repr
if hasattr(m, 'accumulate_flops'):
del m.accumulate_flops
model.apply(add_extra_repr)
print(model, file=ost, flush=flush)
model.apply(del_extra_repr)
def get_model_parameters_number(model):
"""Calculate parameter number of a model.
Args:
model (nn.module): The model for parameter number calculation.
Returns:
float: Parameter number of the model.
"""
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return num_params
def add_flops_counting_methods(net_main_module):
# adding additional methods to the existing module object,
# this is done this way so that each function has access to self object
net_main_module.start_flops_count = start_flops_count.__get__(
net_main_module)
net_main_module.stop_flops_count = stop_flops_count.__get__(
net_main_module)
net_main_module.reset_flops_count = reset_flops_count.__get__(
net_main_module)
net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( # noqa: E501
net_main_module)
net_main_module.reset_flops_count()
return net_main_module
def compute_average_flops_cost(self):
"""Compute average FLOPs cost.
A method to compute average FLOPs cost, which will be available after
`add_flops_counting_methods()` is called on a desired net object.
Returns:
float: Current mean flops consumption per image.
"""
batches_count = self.__batch_counter__
flops_sum = 0
for module in self.modules():
if is_supported_instance(module):
flops_sum += module.__flops__
params_sum = get_model_parameters_number(self)
return flops_sum / batches_count, params_sum
def start_flops_count(self):
"""Activate the computation of mean flops consumption per image.
A method to activate the computation of mean flops consumption per image.
which will be available after ``add_flops_counting_methods()`` is called on
a desired net object. It should be called before running the network.
"""
add_batch_counter_hook_function(self)
def add_flops_counter_hook_function(module):
if is_supported_instance(module):
if hasattr(module, '__flops_handle__'):
return
else:
handle = module.register_forward_hook(
get_modules_mapping()[type(module)])
module.__flops_handle__ = handle
self.apply(partial(add_flops_counter_hook_function))
def stop_flops_count(self):
"""Stop computing the mean flops consumption per image.
A method to stop computing the mean flops consumption per image, which will
be available after ``add_flops_counting_methods()`` is called on a desired
net object. It can be called to pause the computation whenever.
"""
remove_batch_counter_hook_function(self)
self.apply(remove_flops_counter_hook_function)
def reset_flops_count(self):
"""Reset statistics computed so far.
A method to Reset computed statistics, which will be available after
`add_flops_counting_methods()` is called on a desired net object.
"""
add_batch_counter_variables_or_reset(self)
self.apply(add_flops_counter_variable_or_reset)
# ---- Internal functions
def empty_flops_counter_hook(module, input, output):
module.__flops__ += 0
def upsample_flops_counter_hook(module, input, output):
output_size = output[0]
batch_size = output_size.shape[0]
output_elements_count = batch_size
for val in output_size.shape[1:]:
output_elements_count *= val
module.__flops__ += int(output_elements_count)
def relu_flops_counter_hook(module, input, output):
active_elements_count = output.numel()
module.__flops__ += int(active_elements_count)
def linear_flops_counter_hook(module, input, output):
input = input[0]
output_last_dim = output.shape[
-1] # pytorch checks dimensions, so here we don't care much
module.__flops__ += int(np.prod(input.shape) * output_last_dim)
def pool_flops_counter_hook(module, input, output):
input = input[0]
module.__flops__ += int(np.prod(input.shape))
def norm_flops_counter_hook(module, input, output):
input = input[0]
batch_flops = np.prod(input.shape)
if (getattr(module, 'affine', False)
or getattr(module, 'elementwise_affine', False)):
batch_flops *= 2
module.__flops__ += int(batch_flops)
def deconv_flops_counter_hook(conv_module, input, output):
# Can have multiple inputs, getting the first one
input = input[0]
batch_size = input.shape[0]
input_height, input_width = input.shape[2:]
kernel_height, kernel_width = conv_module.kernel_size
in_channels = conv_module.in_channels
out_channels = conv_module.out_channels
groups = conv_module.groups
filters_per_channel = out_channels // groups
conv_per_position_flops = (
kernel_height * kernel_width * in_channels * filters_per_channel)
active_elements_count = batch_size * input_height * input_width
overall_conv_flops = conv_per_position_flops * active_elements_count
bias_flops = 0
if conv_module.bias is not None:
output_height, output_width = output.shape[2:]
bias_flops = out_channels * batch_size * output_height * output_height
overall_flops = overall_conv_flops + bias_flops
conv_module.__flops__ += int(overall_flops)
def conv_flops_counter_hook(conv_module, input, output):
# Can have multiple inputs, getting the first one
input = input[0]
batch_size = input.shape[0]
output_dims = list(output.shape[2:])
kernel_dims = list(conv_module.kernel_size)
in_channels = conv_module.in_channels
out_channels = conv_module.out_channels
groups = conv_module.groups
filters_per_channel = out_channels // groups
conv_per_position_flops = int(
np.prod(kernel_dims)) * in_channels * filters_per_channel
active_elements_count = batch_size * int(np.prod(output_dims))
overall_conv_flops = conv_per_position_flops * active_elements_count
bias_flops = 0
if conv_module.bias is not None:
bias_flops = out_channels * active_elements_count
overall_flops = overall_conv_flops + bias_flops
conv_module.__flops__ += int(overall_flops)
def batch_counter_hook(module, input, output):
batch_size = 1
if len(input) > 0:
# Can have multiple inputs, getting the first one
input = input[0]
batch_size = len(input)
else:
pass
print('Warning! No positional inputs found for a module, '
'assuming batch size is 1.')
module.__batch_counter__ += batch_size
def add_batch_counter_variables_or_reset(module):
module.__batch_counter__ = 0
def add_batch_counter_hook_function(module):
if hasattr(module, '__batch_counter_handle__'):
return
handle = module.register_forward_hook(batch_counter_hook)
module.__batch_counter_handle__ = handle
def remove_batch_counter_hook_function(module):
if hasattr(module, '__batch_counter_handle__'):
module.__batch_counter_handle__.remove()
del module.__batch_counter_handle__
def add_flops_counter_variable_or_reset(module):
if is_supported_instance(module):
if hasattr(module, '__flops__') or hasattr(module, '__params__'):
print('Warning: variables __flops__ or __params__ are already '
'defined for the module' + type(module).__name__ +
' ptflops can affect your code!')
module.__flops__ = 0
module.__params__ = get_model_parameters_number(module)
def is_supported_instance(module):
if type(module) in get_modules_mapping():
return True
return False
def remove_flops_counter_hook_function(module):
if is_supported_instance(module):
if hasattr(module, '__flops_handle__'):
module.__flops_handle__.remove()
del module.__flops_handle__
def get_modules_mapping():
return {
# convolutions
nn.Conv1d: conv_flops_counter_hook,
nn.Conv2d: conv_flops_counter_hook,
mmcv.cnn.bricks.Conv2d: conv_flops_counter_hook,
nn.Conv3d: conv_flops_counter_hook,
mmcv.cnn.bricks.Conv3d: conv_flops_counter_hook,
# activations
nn.ReLU: relu_flops_counter_hook,
nn.PReLU: relu_flops_counter_hook,
nn.ELU: relu_flops_counter_hook,
nn.LeakyReLU: relu_flops_counter_hook,
nn.ReLU6: relu_flops_counter_hook,
# poolings
nn.MaxPool1d: pool_flops_counter_hook,
nn.AvgPool1d: pool_flops_counter_hook,
nn.AvgPool2d: pool_flops_counter_hook,
nn.MaxPool2d: pool_flops_counter_hook,
mmcv.cnn.bricks.MaxPool2d: pool_flops_counter_hook,
nn.MaxPool3d: pool_flops_counter_hook,
mmcv.cnn.bricks.MaxPool3d: pool_flops_counter_hook,
nn.AvgPool3d: pool_flops_counter_hook,
nn.AdaptiveMaxPool1d: pool_flops_counter_hook,
nn.AdaptiveAvgPool1d: pool_flops_counter_hook,
nn.AdaptiveMaxPool2d: pool_flops_counter_hook,
nn.AdaptiveAvgPool2d: pool_flops_counter_hook,
nn.AdaptiveMaxPool3d: pool_flops_counter_hook,
nn.AdaptiveAvgPool3d: pool_flops_counter_hook,
# normalizations
nn.BatchNorm1d: norm_flops_counter_hook,
nn.BatchNorm2d: norm_flops_counter_hook,
nn.BatchNorm3d: norm_flops_counter_hook,
nn.GroupNorm: norm_flops_counter_hook,
nn.InstanceNorm1d: norm_flops_counter_hook,
nn.InstanceNorm2d: norm_flops_counter_hook,
nn.InstanceNorm3d: norm_flops_counter_hook,
nn.LayerNorm: norm_flops_counter_hook,
# FC
nn.Linear: linear_flops_counter_hook,
mmcv.cnn.bricks.Linear: linear_flops_counter_hook,
# Upscale
nn.Upsample: upsample_flops_counter_hook,
# Deconvolution
nn.ConvTranspose2d: deconv_flops_counter_hook,
mmcv.cnn.bricks.ConvTranspose2d: deconv_flops_counter_hook,
}
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
def _fuse_conv_bn(conv, bn):
"""Fuse conv and bn into one module.
Args:
conv (nn.Module): Conv to be fused.
bn (nn.Module): BN to be fused.
Returns:
nn.Module: Fused module.
"""
conv_w = conv.weight
conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
bn.running_mean)
factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
conv.weight = nn.Parameter(conv_w *
factor.reshape([conv.out_channels, 1, 1, 1]))
conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
return conv
def fuse_conv_bn(module):
"""Recursively fuse conv and bn in a module.
During inference, the functionary of batch norm layers is turned off
but only the mean and var alone channels are used, which exposes the
chance to fuse it with the preceding conv layers to save computations and
simplify network structures.
Args:
module (nn.Module): Module to be fused.
Returns:
nn.Module: Fused module.
"""
last_conv = None
last_conv_name = None
for name, child in module.named_children():
if isinstance(child,
(nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
if last_conv is None: # only fuse BN that is after Conv
continue
fused_conv = _fuse_conv_bn(last_conv, child)
module._modules[last_conv_name] = fused_conv
# To reduce changes, set BN as Identity instead of deleting it.
module._modules[name] = nn.Identity()
last_conv = None
elif isinstance(child, nn.Conv2d):
last_conv = child
last_conv_name = name
else:
fuse_conv_bn(child)
return module
import torch
import annotator.uniformer.mmcv as mmcv
class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
"""A general BatchNorm layer without input dimension check.
Reproduced from @kapily's work:
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
is `_check_input_dim` that is designed for tensor sanity checks.
The check has been bypassed in this class for the convenience of converting
SyncBatchNorm.
"""
def _check_input_dim(self, input):
return
def revert_sync_batchnorm(module):
"""Helper function to convert all `SyncBatchNorm` (SyncBN) and
`mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to
`BatchNormXd` layers.
Adapted from @kapily's work:
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
Args:
module (nn.Module): The module containing `SyncBatchNorm` layers.
Returns:
module_output: The converted module with `BatchNormXd` layers.
"""
module_output = module
module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
if hasattr(mmcv, 'ops'):
module_checklist.append(mmcv.ops.SyncBatchNorm)
if isinstance(module, tuple(module_checklist)):
module_output = _BatchNormXd(module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats)
if module.affine:
# no_grad() may not be needed here but
# just to be consistent with `convert_sync_batchnorm()`
with torch.no_grad():
module_output.weight = module.weight
module_output.bias = module.bias
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
module_output.training = module.training
# qconfig exists in quantized models
if hasattr(module, 'qconfig'):
module_output.qconfig = module.qconfig
for name, child in module.named_children():
module_output.add_module(name, revert_sync_batchnorm(child))
del module
return module_output
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import math
import warnings
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from annotator.uniformer.mmcv.utils import Registry, build_from_cfg, get_logger, print_log
INITIALIZERS = Registry('initializer')
def update_init_info(module, init_info):
"""Update the `_params_init_info` in the module if the value of parameters
are changed.
Args:
module (obj:`nn.Module`): The module of PyTorch with a user-defined
attribute `_params_init_info` which records the initialization
information.
init_info (str): The string that describes the initialization.
"""
assert hasattr(
module,
'_params_init_info'), f'Can not find `_params_init_info` in {module}'
for name, param in module.named_parameters():
assert param in module._params_init_info, (
f'Find a new :obj:`Parameter` '
f'named `{name}` during executing the '
f'`init_weights` of '
f'`{module.__class__.__name__}`. '
f'Please do not add or '
f'replace parameters during executing '
f'the `init_weights`. ')
# The parameter has been changed during executing the
# `init_weights` of module
mean_value = param.data.mean()
if module._params_init_info[param]['tmp_mean_value'] != mean_value:
module._params_init_info[param]['init_info'] = init_info
module._params_init_info[param]['tmp_mean_value'] = mean_value
def constant_init(module, val, bias=0):
if hasattr(module, 'weight') and module.weight is not None:
nn.init.constant_(module.weight, val)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def xavier_init(module, gain=1, bias=0, distribution='normal'):
assert distribution in ['uniform', 'normal']
if hasattr(module, 'weight') and module.weight is not None:
if distribution == 'uniform':
nn.init.xavier_uniform_(module.weight, gain=gain)
else:
nn.init.xavier_normal_(module.weight, gain=gain)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def normal_init(module, mean=0, std=1, bias=0):
if hasattr(module, 'weight') and module.weight is not None:
nn.init.normal_(module.weight, mean, std)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def trunc_normal_init(module: nn.Module,
mean: float = 0,
std: float = 1,
a: float = -2,
b: float = 2,
bias: float = 0) -> None:
if hasattr(module, 'weight') and module.weight is not None:
trunc_normal_(module.weight, mean, std, a, b) # type: ignore
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias) # type: ignore
def uniform_init(module, a=0, b=1, bias=0):
if hasattr(module, 'weight') and module.weight is not None:
nn.init.uniform_(module.weight, a, b)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def kaiming_init(module,
a=0,
mode='fan_out',
nonlinearity='relu',
bias=0,
distribution='normal'):
assert distribution in ['uniform', 'normal']
if hasattr(module, 'weight') and module.weight is not None:
if distribution == 'uniform':
nn.init.kaiming_uniform_(
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
else:
nn.init.kaiming_normal_(
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def caffe2_xavier_init(module, bias=0):
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
# Acknowledgment to FAIR's internal code
kaiming_init(
module,
a=1,
mode='fan_in',
nonlinearity='leaky_relu',
bias=bias,
distribution='uniform')
def bias_init_with_prob(prior_prob):
"""initialize conv/fc bias value according to a given probability value."""
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
return bias_init
def _get_bases_name(m):
return [b.__name__ for b in m.__class__.__bases__]
class BaseInit(object):
def __init__(self, *, bias=0, bias_prob=None, layer=None):
self.wholemodule = False
if not isinstance(bias, (int, float)):
raise TypeError(f'bias must be a number, but got a {type(bias)}')
if bias_prob is not None:
if not isinstance(bias_prob, float):
raise TypeError(f'bias_prob type must be float, \
but got {type(bias_prob)}')
if layer is not None:
if not isinstance(layer, (str, list)):
raise TypeError(f'layer must be a str or a list of str, \
but got a {type(layer)}')
else:
layer = []
if bias_prob is not None:
self.bias = bias_init_with_prob(bias_prob)
else:
self.bias = bias
self.layer = [layer] if isinstance(layer, str) else layer
def _get_init_info(self):
info = f'{self.__class__.__name__}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='Constant')
class ConstantInit(BaseInit):
"""Initialize module parameters with constant values.
Args:
val (int | float): the value to fill the weights in the module with
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self, val, **kwargs):
super().__init__(**kwargs)
self.val = val
def __call__(self, module):
def init(m):
if self.wholemodule:
constant_init(m, self.val, self.bias)
else:
layername = m.__class__.__name__
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
constant_init(m, self.val, self.bias)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='Xavier')
class XavierInit(BaseInit):
r"""Initialize module parameters with values according to the method
described in `Understanding the difficulty of training deep feedforward
neural networks - Glorot, X. & Bengio, Y. (2010).
<http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_
Args:
gain (int | float): an optional scaling factor. Defaults to 1.
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
distribution (str): distribution either be ``'normal'``
or ``'uniform'``. Defaults to ``'normal'``.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self, gain=1, distribution='normal', **kwargs):
super().__init__(**kwargs)
self.gain = gain
self.distribution = distribution
def __call__(self, module):
def init(m):
if self.wholemodule:
xavier_init(m, self.gain, self.bias, self.distribution)
else:
layername = m.__class__.__name__
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
xavier_init(m, self.gain, self.bias, self.distribution)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: gain={self.gain}, ' \
f'distribution={self.distribution}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='Normal')
class NormalInit(BaseInit):
r"""Initialize module parameters with the values drawn from the normal
distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
Args:
mean (int | float):the mean of the normal distribution. Defaults to 0.
std (int | float): the standard deviation of the normal distribution.
Defaults to 1.
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self, mean=0, std=1, **kwargs):
super().__init__(**kwargs)
self.mean = mean
self.std = std
def __call__(self, module):
def init(m):
if self.wholemodule:
normal_init(m, self.mean, self.std, self.bias)
else:
layername = m.__class__.__name__
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
normal_init(m, self.mean, self.std, self.bias)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: mean={self.mean},' \
f' std={self.std}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='TruncNormal')
class TruncNormalInit(BaseInit):
r"""Initialize module parameters with the values drawn from the normal
distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values
outside :math:`[a, b]`.
Args:
mean (float): the mean of the normal distribution. Defaults to 0.
std (float): the standard deviation of the normal distribution.
Defaults to 1.
a (float): The minimum cutoff value.
b ( float): The maximum cutoff value.
bias (float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self,
mean: float = 0,
std: float = 1,
a: float = -2,
b: float = 2,
**kwargs) -> None:
super().__init__(**kwargs)
self.mean = mean
self.std = std
self.a = a
self.b = b
def __call__(self, module: nn.Module) -> None:
def init(m):
if self.wholemodule:
trunc_normal_init(m, self.mean, self.std, self.a, self.b,
self.bias)
else:
layername = m.__class__.__name__
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
trunc_normal_init(m, self.mean, self.std, self.a, self.b,
self.bias)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: a={self.a}, b={self.b},' \
f' mean={self.mean}, std={self.std}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='Uniform')
class UniformInit(BaseInit):
r"""Initialize module parameters with values drawn from the uniform
distribution :math:`\mathcal{U}(a, b)`.
Args:
a (int | float): the lower bound of the uniform distribution.
Defaults to 0.
b (int | float): the upper bound of the uniform distribution.
Defaults to 1.
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self, a=0, b=1, **kwargs):
super().__init__(**kwargs)
self.a = a
self.b = b
def __call__(self, module):
def init(m):
if self.wholemodule:
uniform_init(m, self.a, self.b, self.bias)
else:
layername = m.__class__.__name__
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
uniform_init(m, self.a, self.b, self.bias)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: a={self.a},' \
f' b={self.b}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='Kaiming')
class KaimingInit(BaseInit):
r"""Initialize module parameters with the values according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification - He, K. et al. (2015).
<https://www.cv-foundation.org/openaccess/content_iccv_2015/
papers/He_Delving_Deep_into_ICCV_2015_paper.pdf>`_
Args:
a (int | float): the negative slope of the rectifier used after this
layer (only used with ``'leaky_relu'``). Defaults to 0.
mode (str): either ``'fan_in'`` or ``'fan_out'``. Choosing
``'fan_in'`` preserves the magnitude of the variance of the weights
in the forward pass. Choosing ``'fan_out'`` preserves the
magnitudes in the backwards pass. Defaults to ``'fan_out'``.
nonlinearity (str): the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` .
Defaults to 'relu'.
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
distribution (str): distribution either be ``'normal'`` or
``'uniform'``. Defaults to ``'normal'``.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self,
a=0,
mode='fan_out',
nonlinearity='relu',
distribution='normal',
**kwargs):
super().__init__(**kwargs)
self.a = a
self.mode = mode
self.nonlinearity = nonlinearity
self.distribution = distribution
def __call__(self, module):
def init(m):
if self.wholemodule:
kaiming_init(m, self.a, self.mode, self.nonlinearity,
self.bias, self.distribution)
else:
layername = m.__class__.__name__
basesname = _get_bases_name(m)
if len(set(self.layer) & set([layername] + basesname)):
kaiming_init(m, self.a, self.mode, self.nonlinearity,
self.bias, self.distribution)
module.apply(init)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \
f'nonlinearity={self.nonlinearity}, ' \
f'distribution ={self.distribution}, bias={self.bias}'
return info
@INITIALIZERS.register_module(name='Caffe2Xavier')
class Caffe2XavierInit(KaimingInit):
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
# Acknowledgment to FAIR's internal code
def __init__(self, **kwargs):
super().__init__(
a=1,
mode='fan_in',
nonlinearity='leaky_relu',
distribution='uniform',
**kwargs)
def __call__(self, module):
super().__call__(module)
@INITIALIZERS.register_module(name='Pretrained')
class PretrainedInit(object):
"""Initialize module by loading a pretrained model.
Args:
checkpoint (str): the checkpoint file of the pretrained model should
be load.
prefix (str, optional): the prefix of a sub-module in the pretrained
model. it is for loading a part of the pretrained model to
initialize. For example, if we would like to only load the
backbone of a detector model, we can set ``prefix='backbone.'``.
Defaults to None.
map_location (str): map tensors into proper locations.
"""
def __init__(self, checkpoint, prefix=None, map_location=None):
self.checkpoint = checkpoint
self.prefix = prefix
self.map_location = map_location
def __call__(self, module):
from annotator.uniformer.mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint,
load_state_dict)
logger = get_logger('mmcv')
if self.prefix is None:
print_log(f'load model from: {self.checkpoint}', logger=logger)
load_checkpoint(
module,
self.checkpoint,
map_location=self.map_location,
strict=False,
logger=logger)
else:
print_log(
f'load {self.prefix} in model from: {self.checkpoint}',
logger=logger)
state_dict = _load_checkpoint_with_prefix(
self.prefix, self.checkpoint, map_location=self.map_location)
load_state_dict(module, state_dict, strict=False, logger=logger)
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
def _get_init_info(self):
info = f'{self.__class__.__name__}: load from {self.checkpoint}'
return info
def _initialize(module, cfg, wholemodule=False):
func = build_from_cfg(cfg, INITIALIZERS)
# wholemodule flag is for override mode, there is no layer key in override
# and initializer will give init values for the whole module with the name
# in override.
func.wholemodule = wholemodule
func(module)
def _initialize_override(module, override, cfg):
if not isinstance(override, (dict, list)):
raise TypeError(f'override must be a dict or a list of dict, \
but got {type(override)}')
override = [override] if isinstance(override, dict) else override
for override_ in override:
cp_override = copy.deepcopy(override_)
name = cp_override.pop('name', None)
if name is None:
raise ValueError('`override` must contain the key "name",'
f'but got {cp_override}')
# if override only has name key, it means use args in init_cfg
if not cp_override:
cp_override.update(cfg)
# if override has name key and other args except type key, it will
# raise error
elif 'type' not in cp_override.keys():
raise ValueError(
f'`override` need "type" key, but got {cp_override}')
if hasattr(module, name):
_initialize(getattr(module, name), cp_override, wholemodule=True)
else:
raise RuntimeError(f'module did not have attribute {name}, '
f'but init_cfg is {cp_override}.')
def initialize(module, init_cfg):
"""Initialize a module.
Args:
module (``torch.nn.Module``): the module will be initialized.
init_cfg (dict | list[dict]): initialization configuration dict to
define initializer. OpenMMLab has implemented 6 initializers
including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
``Kaiming``, and ``Pretrained``.
Example:
>>> module = nn.Linear(2, 3, bias=True)
>>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2)
>>> initialize(module, init_cfg)
>>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2))
>>> # define key ``'layer'`` for initializing layer with different
>>> # configuration
>>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1),
dict(type='Constant', layer='Linear', val=2)]
>>> initialize(module, init_cfg)
>>> # define key``'override'`` to initialize some specific part in
>>> # module
>>> class FooNet(nn.Module):
>>> def __init__(self):
>>> super().__init__()
>>> self.feat = nn.Conv2d(3, 16, 3)
>>> self.reg = nn.Conv2d(16, 10, 3)
>>> self.cls = nn.Conv2d(16, 5, 3)
>>> model = FooNet()
>>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d',
>>> override=dict(type='Constant', name='reg', val=3, bias=4))
>>> initialize(model, init_cfg)
>>> model = ResNet(depth=50)
>>> # Initialize weights with the pretrained model.
>>> init_cfg = dict(type='Pretrained',
checkpoint='torchvision://resnet50')
>>> initialize(model, init_cfg)
>>> # Initialize weights of a sub-module with the specific part of
>>> # a pretrained model by using "prefix".
>>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\
>>> 'retinanet_r50_fpn_1x_coco/'\
>>> 'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth'
>>> init_cfg = dict(type='Pretrained',
checkpoint=url, prefix='backbone.')
"""
if not isinstance(init_cfg, (dict, list)):
raise TypeError(f'init_cfg must be a dict or a list of dict, \
but got {type(init_cfg)}')
if isinstance(init_cfg, dict):
init_cfg = [init_cfg]
for cfg in init_cfg:
# should deeply copy the original config because cfg may be used by
# other modules, e.g., one init_cfg shared by multiple bottleneck
# blocks, the expected cfg will be changed after pop and will change
# the initialization behavior of other modules
cp_cfg = copy.deepcopy(cfg)
override = cp_cfg.pop('override', None)
_initialize(module, cp_cfg)
if override is not None:
cp_cfg.pop('layer', None)
_initialize_override(module, override, cp_cfg)
else:
# All attributes in module have same initialization.
pass
def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float,
b: float) -> Tensor:
# Method based on
# https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
# Modified from
# https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
'The distribution of values may be incorrect.',
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
lower = norm_cdf((a - mean) / std)
upper = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [lower, upper], then translate
# to [2lower-1, 2upper-1].
tensor.uniform_(2 * lower - 1, 2 * upper - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor: Tensor,
mean: float = 0.,
std: float = 1.,
a: float = -2.,
b: float = 2.) -> Tensor:
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Modified from
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
Args:
tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
mean (float): the mean of the normal distribution.
std (float): the standard deviation of the normal distribution.
a (float): the minimum cutoff value.
b (float): the maximum cutoff value.
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import torch.nn as nn
from .utils import constant_init, kaiming_init, normal_init
def conv3x3(in_planes, out_planes, dilation=1):
"""3x3 convolution with padding."""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
padding=dilation,
dilation=dilation)
def make_vgg_layer(inplanes,
planes,
num_blocks,
dilation=1,
with_bn=False,
ceil_mode=False):
layers = []
for _ in range(num_blocks):
layers.append(conv3x3(inplanes, planes, dilation))
if with_bn:
layers.append(nn.BatchNorm2d(planes))
layers.append(nn.ReLU(inplace=True))
inplanes = planes
layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode))
return layers
class VGG(nn.Module):
"""VGG backbone.
Args:
depth (int): Depth of vgg, from {11, 13, 16, 19}.
with_bn (bool): Use BatchNorm or not.
num_classes (int): number of classes for classification.
num_stages (int): VGG stages, normally 5.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
running stats (mean and var).
bn_frozen (bool): Whether to freeze weight and bias of BN layers.
"""
arch_settings = {
11: (1, 1, 2, 2, 2),
13: (2, 2, 2, 2, 2),
16: (2, 2, 3, 3, 3),
19: (2, 2, 4, 4, 4)
}
def __init__(self,
depth,
with_bn=False,
num_classes=-1,
num_stages=5,
dilations=(1, 1, 1, 1, 1),
out_indices=(0, 1, 2, 3, 4),
frozen_stages=-1,
bn_eval=True,
bn_frozen=False,
ceil_mode=False,
with_last_pool=True):
super(VGG, self).__init__()
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for vgg')
assert num_stages >= 1 and num_stages <= 5
stage_blocks = self.arch_settings[depth]
self.stage_blocks = stage_blocks[:num_stages]
assert len(dilations) == num_stages
assert max(out_indices) <= num_stages
self.num_classes = num_classes
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.bn_eval = bn_eval
self.bn_frozen = bn_frozen
self.inplanes = 3
start_idx = 0
vgg_layers = []
self.range_sub_modules = []
for i, num_blocks in enumerate(self.stage_blocks):
num_modules = num_blocks * (2 + with_bn) + 1
end_idx = start_idx + num_modules
dilation = dilations[i]
planes = 64 * 2**i if i < 4 else 512
vgg_layer = make_vgg_layer(
self.inplanes,
planes,
num_blocks,
dilation=dilation,
with_bn=with_bn,
ceil_mode=ceil_mode)
vgg_layers.extend(vgg_layer)
self.inplanes = planes
self.range_sub_modules.append([start_idx, end_idx])
start_idx = end_idx
if not with_last_pool:
vgg_layers.pop(-1)
self.range_sub_modules[-1][1] -= 1
self.module_name = 'features'
self.add_module(self.module_name, nn.Sequential(*vgg_layers))
if self.num_classes > 0:
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
from ..runner import load_checkpoint
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
elif isinstance(m, nn.Linear):
normal_init(m, std=0.01)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
outs = []
vgg_layers = getattr(self, self.module_name)
for i in range(len(self.stage_blocks)):
for j in range(*self.range_sub_modules[i]):
vgg_layer = vgg_layers[j]
x = vgg_layer(x)
if i in self.out_indices:
outs.append(x)
if self.num_classes > 0:
x = x.view(x.size(0), -1)
x = self.classifier(x)
outs.append(x)
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
def train(self, mode=True):
super(VGG, self).train(mode)
if self.bn_eval:
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
if self.bn_frozen:
for params in m.parameters():
params.requires_grad = False
vgg_layers = getattr(self, self.module_name)
if mode and self.frozen_stages >= 0:
for i in range(self.frozen_stages):
for j in range(*self.range_sub_modules[i]):
mod = vgg_layers[j]
mod.eval()
for param in mod.parameters():
param.requires_grad = False
# Copyright (c) OpenMMLab. All rights reserved.
from .test import (collect_results_cpu, collect_results_gpu, multi_gpu_test,
single_gpu_test)
__all__ = [
'collect_results_cpu', 'collect_results_gpu', 'multi_gpu_test',
'single_gpu_test'
]
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import pickle
import shutil
import tempfile
import time
import torch
import torch.distributed as dist
import annotator.uniformer.mmcv as mmcv
from annotator.uniformer.mmcv.runner import get_dist_info
def single_gpu_test(model, data_loader):
"""Test model with a single gpu.
This method tests model with a single gpu and displays test progress bar.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): Pytorch data loader.
Returns:
list: The prediction results.
"""
model.eval()
results = []
dataset = data_loader.dataset
prog_bar = mmcv.ProgressBar(len(dataset))
for data in data_loader:
with torch.no_grad():
result = model(return_loss=False, **data)
results.extend(result)
# Assume result has the same length of batch_size
# refer to https://github.com/open-mmlab/mmcv/issues/985
batch_size = len(result)
for _ in range(batch_size):
prog_bar.update()
return results
def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
"""Test model with multiple gpus.
This method tests model with multiple gpus and collects the results
under two different modes: gpu and cpu modes. By setting
``gpu_collect=True``, it encodes results to gpu tensors and use gpu
communication for results collection. On cpu mode it saves the results on
different gpus to ``tmpdir`` and collects them by the rank 0 worker.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): Pytorch data loader.
tmpdir (str): Path of directory to save the temporary results from
different gpus under cpu mode.
gpu_collect (bool): Option to use either gpu or cpu to collect results.
Returns:
list: The prediction results.
"""
model.eval()
results = []
dataset = data_loader.dataset
rank, world_size = get_dist_info()
if rank == 0:
prog_bar = mmcv.ProgressBar(len(dataset))
time.sleep(2) # This line can prevent deadlock problem in some cases.
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, **data)
results.extend(result)
if rank == 0:
batch_size = len(result)
batch_size_all = batch_size * world_size
if batch_size_all + prog_bar.completed > len(dataset):
batch_size_all = len(dataset) - prog_bar.completed
for _ in range(batch_size_all):
prog_bar.update()
# collect results from all ranks
if gpu_collect:
results = collect_results_gpu(results, len(dataset))
else:
results = collect_results_cpu(results, len(dataset), tmpdir)
return results
def collect_results_cpu(result_part, size, tmpdir=None):
"""Collect results under cpu mode.
On cpu mode, this function will save the results on different gpus to
``tmpdir`` and collect them by the rank 0 worker.
Args:
result_part (list): Result list containing result parts
to be collected.
size (int): Size of the results, commonly equal to length of
the results.
tmpdir (str | None): temporal directory for collected results to
store. If set to None, it will create a random temporal directory
for it.
Returns:
list: The collected results.
"""
rank, world_size = get_dist_info()
# create a tmp dir if it is not specified
if tmpdir is None:
MAX_LEN = 512
# 32 is whitespace
dir_tensor = torch.full((MAX_LEN, ),
32,
dtype=torch.uint8,
device='cuda')
if rank == 0:
mmcv.mkdir_or_exist('.dist_test')
tmpdir = tempfile.mkdtemp(dir='.dist_test')
tmpdir = torch.tensor(
bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
dir_tensor[:len(tmpdir)] = tmpdir
dist.broadcast(dir_tensor, 0)
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
else:
mmcv.mkdir_or_exist(tmpdir)
# dump the part result to the dir
mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
dist.barrier()
# collect all parts
if rank != 0:
return None
else:
# load results of all parts from tmp dir
part_list = []
for i in range(world_size):
part_file = osp.join(tmpdir, f'part_{i}.pkl')
part_result = mmcv.load(part_file)
# When data is severely insufficient, an empty part_result
# on a certain gpu could makes the overall outputs empty.
if part_result:
part_list.append(part_result)
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
# remove tmp dir
shutil.rmtree(tmpdir)
return ordered_results
def collect_results_gpu(result_part, size):
"""Collect results under gpu mode.
On gpu mode, this function will encode results to gpu tensors and use gpu
communication for results collection.
Args:
result_part (list): Result list containing result parts
to be collected.
size (int): Size of the results, commonly equal to length of
the results.
Returns:
list: The collected results.
"""
rank, world_size = get_dist_info()
# dump result part to tensor with pickle
part_tensor = torch.tensor(
bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
# gather all result part tensor shape
shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
shape_list = [shape_tensor.clone() for _ in range(world_size)]
dist.all_gather(shape_list, shape_tensor)
# padding result part tensor to max length
shape_max = torch.tensor(shape_list).max()
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
part_send[:shape_tensor[0]] = part_tensor
part_recv_list = [
part_tensor.new_zeros(shape_max) for _ in range(world_size)
]
# gather all result part
dist.all_gather(part_recv_list, part_send)
if rank == 0:
part_list = []
for recv, shape in zip(part_recv_list, shape_list):
part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())
# When data is severely insufficient, an empty part_result
# on a certain gpu could makes the overall outputs empty.
if part_result:
part_list.append(part_result)
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
return ordered_results
# Copyright (c) OpenMMLab. All rights reserved.
from .file_client import BaseStorageBackend, FileClient
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
from .io import dump, load, register_handler
from .parse import dict_from_file, list_from_file
__all__ = [
'BaseStorageBackend', 'FileClient', 'load', 'dump', 'register_handler',
'BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler',
'list_from_file', 'dict_from_file'
]
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import os
import os.path as osp
import re
import tempfile
import warnings
from abc import ABCMeta, abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import Iterable, Iterator, Optional, Tuple, Union
from urllib.request import urlopen
import annotator.uniformer.mmcv as mmcv
from annotator.uniformer.mmcv.utils.misc import has_method
from annotator.uniformer.mmcv.utils.path import is_filepath
class BaseStorageBackend(metaclass=ABCMeta):
"""Abstract class of storage backends.
All backends need to implement two apis: ``get()`` and ``get_text()``.
``get()`` reads the file as a byte stream and ``get_text()`` reads the file
as texts.
"""
# a flag to indicate whether the backend can create a symlink for a file
_allow_symlink = False
@property
def name(self):
return self.__class__.__name__
@property
def allow_symlink(self):
return self._allow_symlink
@abstractmethod
def get(self, filepath):
pass
@abstractmethod
def get_text(self, filepath):
pass
class CephBackend(BaseStorageBackend):
"""Ceph storage backend (for internal use).
Args:
path_mapping (dict|None): path mapping dict from local path to Petrel
path. When ``path_mapping={'src': 'dst'}``, ``src`` in ``filepath``
will be replaced by ``dst``. Default: None.
.. warning::
:class:`mmcv.fileio.file_client.CephBackend` will be deprecated,
please use :class:`mmcv.fileio.file_client.PetrelBackend` instead.
"""
def __init__(self, path_mapping=None):
try:
import ceph
except ImportError:
raise ImportError('Please install ceph to enable CephBackend.')
warnings.warn(
'CephBackend will be deprecated, please use PetrelBackend instead')
self._client = ceph.S3Client()
assert isinstance(path_mapping, dict) or path_mapping is None
self.path_mapping = path_mapping
def get(self, filepath):
filepath = str(filepath)
if self.path_mapping is not None:
for k, v in self.path_mapping.items():
filepath = filepath.replace(k, v)
value = self._client.Get(filepath)
value_buf = memoryview(value)
return value_buf
def get_text(self, filepath, encoding=None):
raise NotImplementedError
class PetrelBackend(BaseStorageBackend):
"""Petrel storage backend (for internal use).
PetrelBackend supports reading and writing data to multiple clusters.
If the file path contains the cluster name, PetrelBackend will read data
from specified cluster or write data to it. Otherwise, PetrelBackend will
access the default cluster.
Args:
path_mapping (dict, optional): Path mapping dict from local path to
Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in
``filepath`` will be replaced by ``dst``. Default: None.
enable_mc (bool, optional): Whether to enable memcached support.
Default: True.
Examples:
>>> filepath1 = 's3://path/of/file'
>>> filepath2 = 'cluster-name:s3://path/of/file'
>>> client = PetrelBackend()
>>> client.get(filepath1) # get data from default cluster
>>> client.get(filepath2) # get data from 'cluster-name' cluster
"""
def __init__(self,
path_mapping: Optional[dict] = None,
enable_mc: bool = True):
try:
from petrel_client import client
except ImportError:
raise ImportError('Please install petrel_client to enable '
'PetrelBackend.')
self._client = client.Client(enable_mc=enable_mc)
assert isinstance(path_mapping, dict) or path_mapping is None
self.path_mapping = path_mapping
def _map_path(self, filepath: Union[str, Path]) -> str:
"""Map ``filepath`` to a string path whose prefix will be replaced by
:attr:`self.path_mapping`.
Args:
filepath (str): Path to be mapped.
"""
filepath = str(filepath)
if self.path_mapping is not None:
for k, v in self.path_mapping.items():
filepath = filepath.replace(k, v)
return filepath
def _format_path(self, filepath: str) -> str:
"""Convert a ``filepath`` to standard format of petrel oss.
If the ``filepath`` is concatenated by ``os.path.join``, in a Windows
environment, the ``filepath`` will be the format of
's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the
above ``filepath`` will be converted to 's3://bucket_name/image.jpg'.
Args:
filepath (str): Path to be formatted.
"""
return re.sub(r'\\+', '/', filepath)
def get(self, filepath: Union[str, Path]) -> memoryview:
"""Read data from a given ``filepath`` with 'rb' mode.
Args:
filepath (str or Path): Path to read data.
Returns:
memoryview: A memory view of expected bytes object to avoid
copying. The memoryview object can be converted to bytes by
``value_buf.tobytes()``.
"""
filepath = self._map_path(filepath)
filepath = self._format_path(filepath)
value = self._client.Get(filepath)
value_buf = memoryview(value)
return value_buf
def get_text(self,
filepath: Union[str, Path],
encoding: str = 'utf-8') -> str:
"""Read data from a given ``filepath`` with 'r' mode.
Args:
filepath (str or Path): Path to read data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
Returns:
str: Expected text reading from ``filepath``.
"""
return str(self.get(filepath), encoding=encoding)
def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
"""Save data to a given ``filepath``.
Args:
obj (bytes): Data to be saved.
filepath (str or Path): Path to write data.
"""
filepath = self._map_path(filepath)
filepath = self._format_path(filepath)
self._client.put(filepath, obj)
def put_text(self,
obj: str,
filepath: Union[str, Path],
encoding: str = 'utf-8') -> None:
"""Save data to a given ``filepath``.
Args:
obj (str): Data to be written.
filepath (str or Path): Path to write data.
encoding (str): The encoding format used to encode the ``obj``.
Default: 'utf-8'.
"""
self.put(bytes(obj, encoding=encoding), filepath)
def remove(self, filepath: Union[str, Path]) -> None:
"""Remove a file.
Args:
filepath (str or Path): Path to be removed.
"""
if not has_method(self._client, 'delete'):
raise NotImplementedError(
('Current version of Petrel Python SDK has not supported '
'the `delete` method, please use a higher version or dev'
' branch instead.'))
filepath = self._map_path(filepath)
filepath = self._format_path(filepath)
self._client.delete(filepath)
def exists(self, filepath: Union[str, Path]) -> bool:
"""Check whether a file path exists.
Args:
filepath (str or Path): Path to be checked whether exists.
Returns:
bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
"""
if not (has_method(self._client, 'contains')
and has_method(self._client, 'isdir')):
raise NotImplementedError(
('Current version of Petrel Python SDK has not supported '
'the `contains` and `isdir` methods, please use a higher'
'version or dev branch instead.'))
filepath = self._map_path(filepath)
filepath = self._format_path(filepath)
return self._client.contains(filepath) or self._client.isdir(filepath)
def isdir(self, filepath: Union[str, Path]) -> bool:
"""Check whether a file path is a directory.
Args:
filepath (str or Path): Path to be checked whether it is a
directory.
Returns:
bool: Return ``True`` if ``filepath`` points to a directory,
``False`` otherwise.
"""
if not has_method(self._client, 'isdir'):
raise NotImplementedError(
('Current version of Petrel Python SDK has not supported '
'the `isdir` method, please use a higher version or dev'
' branch instead.'))
filepath = self._map_path(filepath)
filepath = self._format_path(filepath)
return self._client.isdir(filepath)
def isfile(self, filepath: Union[str, Path]) -> bool:
"""Check whether a file path is a file.
Args:
filepath (str or Path): Path to be checked whether it is a file.
Returns:
bool: Return ``True`` if ``filepath`` points to a file, ``False``
otherwise.
"""
if not has_method(self._client, 'contains'):
raise NotImplementedError(
('Current version of Petrel Python SDK has not supported '
'the `contains` method, please use a higher version or '
'dev branch instead.'))
filepath = self._map_path(filepath)
filepath = self._format_path(filepath)
return self._client.contains(filepath)
def join_path(self, filepath: Union[str, Path],
*filepaths: Union[str, Path]) -> str:
"""Concatenate all file paths.
Args:
filepath (str or Path): Path to be concatenated.
Returns:
str: The result after concatenation.
"""
filepath = self._format_path(self._map_path(filepath))
if filepath.endswith('/'):
filepath = filepath[:-1]
formatted_paths = [filepath]
for path in filepaths:
formatted_paths.append(self._format_path(self._map_path(path)))
return '/'.join(formatted_paths)
@contextmanager
def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]:
"""Download a file from ``filepath`` and return a temporary path.
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
can be called with ``with`` statement, and when exists from the
``with`` statement, the temporary path will be released.
Args:
filepath (str | Path): Download a file from ``filepath``.
Examples:
>>> client = PetrelBackend()
>>> # After existing from the ``with`` clause,
>>> # the path will be removed
>>> with client.get_local_path('s3://path/of/your/file') as path:
... # do something here
Yields:
Iterable[str]: Only yield one temporary path.
"""
filepath = self._map_path(filepath)
filepath = self._format_path(filepath)
assert self.isfile(filepath)
try:
f = tempfile.NamedTemporaryFile(delete=False)
f.write(self.get(filepath))
f.close()
yield f.name
finally:
os.remove(f.name)
def list_dir_or_file(self,
dir_path: Union[str, Path],
list_dir: bool = True,
list_file: bool = True,
suffix: Optional[Union[str, Tuple[str]]] = None,
recursive: bool = False) -> Iterator[str]:
"""Scan a directory to find the interested directories or files in
arbitrary order.
Note:
Petrel has no concept of directories but it simulates the directory
hierarchy in the filesystem through public prefixes. In addition,
if the returned path ends with '/', it means the path is a public
prefix which is a logical directory.
Note:
:meth:`list_dir_or_file` returns the path relative to ``dir_path``.
In addition, the returned path of directory will not contains the
suffix '/' which is consistent with other backends.
Args:
dir_path (str | Path): Path of the directory.
list_dir (bool): List the directories. Default: True.
list_file (bool): List the path of files. Default: True.
suffix (str or tuple[str], optional): File suffix
that we are interested in. Default: None.
recursive (bool): If set to True, recursively scan the
directory. Default: False.
Yields:
Iterable[str]: A relative path to ``dir_path``.
"""
if not has_method(self._client, 'list'):
raise NotImplementedError(
('Current version of Petrel Python SDK has not supported '
'the `list` method, please use a higher version or dev'
' branch instead.'))
dir_path = self._map_path(dir_path)
dir_path = self._format_path(dir_path)
if list_dir and suffix is not None:
raise TypeError(
'`list_dir` should be False when `suffix` is not None')
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
raise TypeError('`suffix` must be a string or tuple of strings')
# Petrel's simulated directory hierarchy assumes that directory paths
# should end with `/`
if not dir_path.endswith('/'):
dir_path += '/'
root = dir_path
def _list_dir_or_file(dir_path, list_dir, list_file, suffix,
recursive):
for path in self._client.list(dir_path):
# the `self.isdir` is not used here to determine whether path
# is a directory, because `self.isdir` relies on
# `self._client.list`
if path.endswith('/'): # a directory path
next_dir_path = self.join_path(dir_path, path)
if list_dir:
# get the relative path and exclude the last
# character '/'
rel_dir = next_dir_path[len(root):-1]
yield rel_dir
if recursive:
yield from _list_dir_or_file(next_dir_path, list_dir,
list_file, suffix,
recursive)
else: # a file path
absolute_path = self.join_path(dir_path, path)
rel_path = absolute_path[len(root):]
if (suffix is None
or rel_path.endswith(suffix)) and list_file:
yield rel_path
return _list_dir_or_file(dir_path, list_dir, list_file, suffix,
recursive)
class MemcachedBackend(BaseStorageBackend):
"""Memcached storage backend.
Attributes:
server_list_cfg (str): Config file for memcached server list.
client_cfg (str): Config file for memcached client.
sys_path (str | None): Additional path to be appended to `sys.path`.
Default: None.
"""
def __init__(self, server_list_cfg, client_cfg, sys_path=None):
if sys_path is not None:
import sys
sys.path.append(sys_path)
try:
import mc
except ImportError:
raise ImportError(
'Please install memcached to enable MemcachedBackend.')
self.server_list_cfg = server_list_cfg
self.client_cfg = client_cfg
self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg,
self.client_cfg)
# mc.pyvector servers as a point which points to a memory cache
self._mc_buffer = mc.pyvector()
def get(self, filepath):
filepath = str(filepath)
import mc
self._client.Get(filepath, self._mc_buffer)
value_buf = mc.ConvertBuffer(self._mc_buffer)
return value_buf
def get_text(self, filepath, encoding=None):
raise NotImplementedError
class LmdbBackend(BaseStorageBackend):
"""Lmdb storage backend.
Args:
db_path (str): Lmdb database path.
readonly (bool, optional): Lmdb environment parameter. If True,
disallow any write operations. Default: True.
lock (bool, optional): Lmdb environment parameter. If False, when
concurrent access occurs, do not lock the database. Default: False.
readahead (bool, optional): Lmdb environment parameter. If False,
disable the OS filesystem readahead mechanism, which may improve
random read performance when a database is larger than RAM.
Default: False.
Attributes:
db_path (str): Lmdb database path.
"""
def __init__(self,
db_path,
readonly=True,
lock=False,
readahead=False,
**kwargs):
try:
import lmdb
except ImportError:
raise ImportError('Please install lmdb to enable LmdbBackend.')
self.db_path = str(db_path)
self._client = lmdb.open(
self.db_path,
readonly=readonly,
lock=lock,
readahead=readahead,
**kwargs)
def get(self, filepath):
"""Get values according to the filepath.
Args:
filepath (str | obj:`Path`): Here, filepath is the lmdb key.
"""
filepath = str(filepath)
with self._client.begin(write=False) as txn:
value_buf = txn.get(filepath.encode('ascii'))
return value_buf
def get_text(self, filepath, encoding=None):
raise NotImplementedError
class HardDiskBackend(BaseStorageBackend):
"""Raw hard disks storage backend."""
_allow_symlink = True
def get(self, filepath: Union[str, Path]) -> bytes:
"""Read data from a given ``filepath`` with 'rb' mode.
Args:
filepath (str or Path): Path to read data.
Returns:
bytes: Expected bytes object.
"""
with open(filepath, 'rb') as f:
value_buf = f.read()
return value_buf
def get_text(self,
filepath: Union[str, Path],
encoding: str = 'utf-8') -> str:
"""Read data from a given ``filepath`` with 'r' mode.
Args:
filepath (str or Path): Path to read data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
Returns:
str: Expected text reading from ``filepath``.
"""
with open(filepath, 'r', encoding=encoding) as f:
value_buf = f.read()
return value_buf
def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
"""Write data to a given ``filepath`` with 'wb' mode.
Note:
``put`` will create a directory if the directory of ``filepath``
does not exist.
Args:
obj (bytes): Data to be written.
filepath (str or Path): Path to write data.
"""
mmcv.mkdir_or_exist(osp.dirname(filepath))
with open(filepath, 'wb') as f:
f.write(obj)
def put_text(self,
obj: str,
filepath: Union[str, Path],
encoding: str = 'utf-8') -> None:
"""Write data to a given ``filepath`` with 'w' mode.
Note:
``put_text`` will create a directory if the directory of
``filepath`` does not exist.
Args:
obj (str): Data to be written.
filepath (str or Path): Path to write data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
"""
mmcv.mkdir_or_exist(osp.dirname(filepath))
with open(filepath, 'w', encoding=encoding) as f:
f.write(obj)
def remove(self, filepath: Union[str, Path]) -> None:
"""Remove a file.
Args:
filepath (str or Path): Path to be removed.
"""
os.remove(filepath)
def exists(self, filepath: Union[str, Path]) -> bool:
"""Check whether a file path exists.
Args:
filepath (str or Path): Path to be checked whether exists.
Returns:
bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
"""
return osp.exists(filepath)
def isdir(self, filepath: Union[str, Path]) -> bool:
"""Check whether a file path is a directory.
Args:
filepath (str or Path): Path to be checked whether it is a
directory.
Returns:
bool: Return ``True`` if ``filepath`` points to a directory,
``False`` otherwise.
"""
return osp.isdir(filepath)
def isfile(self, filepath: Union[str, Path]) -> bool:
"""Check whether a file path is a file.
Args:
filepath (str or Path): Path to be checked whether it is a file.
Returns:
bool: Return ``True`` if ``filepath`` points to a file, ``False``
otherwise.
"""
return osp.isfile(filepath)
def join_path(self, filepath: Union[str, Path],
*filepaths: Union[str, Path]) -> str:
"""Concatenate all file paths.
Join one or more filepath components intelligently. The return value
is the concatenation of filepath and any members of *filepaths.
Args:
filepath (str or Path): Path to be concatenated.
Returns:
str: The result of concatenation.
"""
return osp.join(filepath, *filepaths)
@contextmanager
def get_local_path(
self, filepath: Union[str, Path]) -> Iterable[Union[str, Path]]:
"""Only for unified API and do nothing."""
yield filepath
def list_dir_or_file(self,
dir_path: Union[str, Path],
list_dir: bool = True,
list_file: bool = True,
suffix: Optional[Union[str, Tuple[str]]] = None,
recursive: bool = False) -> Iterator[str]:
"""Scan a directory to find the interested directories or files in
arbitrary order.
Note:
:meth:`list_dir_or_file` returns the path relative to ``dir_path``.
Args:
dir_path (str | Path): Path of the directory.
list_dir (bool): List the directories. Default: True.
list_file (bool): List the path of files. Default: True.
suffix (str or tuple[str], optional): File suffix
that we are interested in. Default: None.
recursive (bool): If set to True, recursively scan the
directory. Default: False.
Yields:
Iterable[str]: A relative path to ``dir_path``.
"""
if list_dir and suffix is not None:
raise TypeError('`suffix` should be None when `list_dir` is True')
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
raise TypeError('`suffix` must be a string or tuple of strings')
root = dir_path
def _list_dir_or_file(dir_path, list_dir, list_file, suffix,
recursive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
rel_path = osp.relpath(entry.path, root)
if (suffix is None
or rel_path.endswith(suffix)) and list_file:
yield rel_path
elif osp.isdir(entry.path):
if list_dir:
rel_dir = osp.relpath(entry.path, root)
yield rel_dir
if recursive:
yield from _list_dir_or_file(entry.path, list_dir,
list_file, suffix,
recursive)
return _list_dir_or_file(dir_path, list_dir, list_file, suffix,
recursive)
class HTTPBackend(BaseStorageBackend):
"""HTTP and HTTPS storage bachend."""
def get(self, filepath):
value_buf = urlopen(filepath).read()
return value_buf
def get_text(self, filepath, encoding='utf-8'):
value_buf = urlopen(filepath).read()
return value_buf.decode(encoding)
@contextmanager
def get_local_path(self, filepath: str) -> Iterable[str]:
"""Download a file from ``filepath``.
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
can be called with ``with`` statement, and when exists from the
``with`` statement, the temporary path will be released.
Args:
filepath (str): Download a file from ``filepath``.
Examples:
>>> client = HTTPBackend()
>>> # After existing from the ``with`` clause,
>>> # the path will be removed
>>> with client.get_local_path('http://path/of/your/file') as path:
... # do something here
"""
try:
f = tempfile.NamedTemporaryFile(delete=False)
f.write(self.get(filepath))
f.close()
yield f.name
finally:
os.remove(f.name)
class FileClient:
"""A general file client to access files in different backends.
The client loads a file or text in a specified backend from its path
and returns it as a binary or text file. There are two ways to choose a
backend, the name of backend and the prefix of path. Although both of them
can be used to choose a storage backend, ``backend`` has a higher priority
that is if they are all set, the storage backend will be chosen by the
backend argument. If they are all `None`, the disk backend will be chosen.
Note that It can also register other backend accessor with a given name,
prefixes, and backend class. In addition, We use the singleton pattern to
avoid repeated object creation. If the arguments are the same, the same
object will be returned.
Args:
backend (str, optional): The storage backend type. Options are "disk",
"ceph", "memcached", "lmdb", "http" and "petrel". Default: None.
prefix (str, optional): The prefix of the registered storage backend.
Options are "s3", "http", "https". Default: None.
Examples:
>>> # only set backend
>>> file_client = FileClient(backend='petrel')
>>> # only set prefix
>>> file_client = FileClient(prefix='s3')
>>> # set both backend and prefix but use backend to choose client
>>> file_client = FileClient(backend='petrel', prefix='s3')
>>> # if the arguments are the same, the same object is returned
>>> file_client1 = FileClient(backend='petrel')
>>> file_client1 is file_client
True
Attributes:
client (:obj:`BaseStorageBackend`): The backend object.
"""
_backends = {
'disk': HardDiskBackend,
'ceph': CephBackend,
'memcached': MemcachedBackend,
'lmdb': LmdbBackend,
'petrel': PetrelBackend,
'http': HTTPBackend,
}
# This collection is used to record the overridden backends, and when a
# backend appears in the collection, the singleton pattern is disabled for
# that backend, because if the singleton pattern is used, then the object
# returned will be the backend before overwriting
_overridden_backends = set()
_prefix_to_backends = {
's3': PetrelBackend,
'http': HTTPBackend,
'https': HTTPBackend,
}
_overridden_prefixes = set()
_instances = {}
def __new__(cls, backend=None, prefix=None, **kwargs):
if backend is None and prefix is None:
backend = 'disk'
if backend is not None and backend not in cls._backends:
raise ValueError(
f'Backend {backend} is not supported. Currently supported ones'
f' are {list(cls._backends.keys())}')
if prefix is not None and prefix not in cls._prefix_to_backends:
raise ValueError(
f'prefix {prefix} is not supported. Currently supported ones '
f'are {list(cls._prefix_to_backends.keys())}')
# concatenate the arguments to a unique key for determining whether
# objects with the same arguments were created
arg_key = f'{backend}:{prefix}'
for key, value in kwargs.items():
arg_key += f':{key}:{value}'
# if a backend was overridden, it will create a new object
if (arg_key in cls._instances
and backend not in cls._overridden_backends
and prefix not in cls._overridden_prefixes):
_instance = cls._instances[arg_key]
else:
# create a new object and put it to _instance
_instance = super().__new__(cls)
if backend is not None:
_instance.client = cls._backends[backend](**kwargs)
else:
_instance.client = cls._prefix_to_backends[prefix](**kwargs)
cls._instances[arg_key] = _instance
return _instance
@property
def name(self):
return self.client.name
@property
def allow_symlink(self):
return self.client.allow_symlink
@staticmethod
def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]:
"""Parse the prefix of a uri.
Args:
uri (str | Path): Uri to be parsed that contains the file prefix.
Examples:
>>> FileClient.parse_uri_prefix('s3://path/of/your/file')
's3'
Returns:
str | None: Return the prefix of uri if the uri contains '://'
else ``None``.
"""
assert is_filepath(uri)
uri = str(uri)
if '://' not in uri:
return None
else:
prefix, _ = uri.split('://')
# In the case of PetrelBackend, the prefix may contains the cluster
# name like clusterName:s3
if ':' in prefix:
_, prefix = prefix.split(':')
return prefix
@classmethod
def infer_client(cls,
file_client_args: Optional[dict] = None,
uri: Optional[Union[str, Path]] = None) -> 'FileClient':
"""Infer a suitable file client based on the URI and arguments.
Args:
file_client_args (dict, optional): Arguments to instantiate a
FileClient. Default: None.
uri (str | Path, optional): Uri to be parsed that contains the file
prefix. Default: None.
Examples:
>>> uri = 's3://path/of/your/file'
>>> file_client = FileClient.infer_client(uri=uri)
>>> file_client_args = {'backend': 'petrel'}
>>> file_client = FileClient.infer_client(file_client_args)
Returns:
FileClient: Instantiated FileClient object.
"""
assert file_client_args is not None or uri is not None
if file_client_args is None:
file_prefix = cls.parse_uri_prefix(uri) # type: ignore
return cls(prefix=file_prefix)
else:
return cls(**file_client_args)
@classmethod
def _register_backend(cls, name, backend, force=False, prefixes=None):
if not isinstance(name, str):
raise TypeError('the backend name should be a string, '
f'but got {type(name)}')
if not inspect.isclass(backend):
raise TypeError(
f'backend should be a class but got {type(backend)}')
if not issubclass(backend, BaseStorageBackend):
raise TypeError(
f'backend {backend} is not a subclass of BaseStorageBackend')
if not force and name in cls._backends:
raise KeyError(
f'{name} is already registered as a storage backend, '
'add "force=True" if you want to override it')
if name in cls._backends and force:
cls._overridden_backends.add(name)
cls._backends[name] = backend
if prefixes is not None:
if isinstance(prefixes, str):
prefixes = [prefixes]
else:
assert isinstance(prefixes, (list, tuple))
for prefix in prefixes:
if prefix not in cls._prefix_to_backends:
cls._prefix_to_backends[prefix] = backend
elif (prefix in cls._prefix_to_backends) and force:
cls._overridden_prefixes.add(prefix)
cls._prefix_to_backends[prefix] = backend
else:
raise KeyError(
f'{prefix} is already registered as a storage backend,'
' add "force=True" if you want to override it')
@classmethod
def register_backend(cls, name, backend=None, force=False, prefixes=None):
"""Register a backend to FileClient.
This method can be used as a normal class method or a decorator.
.. code-block:: python
class NewBackend(BaseStorageBackend):
def get(self, filepath):
return filepath
def get_text(self, filepath):
return filepath
FileClient.register_backend('new', NewBackend)
or
.. code-block:: python
@FileClient.register_backend('new')
class NewBackend(BaseStorageBackend):
def get(self, filepath):
return filepath
def get_text(self, filepath):
return filepath
Args:
name (str): The name of the registered backend.
backend (class, optional): The backend class to be registered,
which must be a subclass of :class:`BaseStorageBackend`.
When this method is used as a decorator, backend is None.
Defaults to None.
force (bool, optional): Whether to override the backend if the name
has already been registered. Defaults to False.
prefixes (str or list[str] or tuple[str], optional): The prefixes
of the registered storage backend. Default: None.
`New in version 1.3.15.`
"""
if backend is not None:
cls._register_backend(
name, backend, force=force, prefixes=prefixes)
return
def _register(backend_cls):
cls._register_backend(
name, backend_cls, force=force, prefixes=prefixes)
return backend_cls
return _register
def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]:
"""Read data from a given ``filepath`` with 'rb' mode.
Note:
There are two types of return values for ``get``, one is ``bytes``
and the other is ``memoryview``. The advantage of using memoryview
is that you can avoid copying, and if you want to convert it to
``bytes``, you can use ``.tobytes()``.
Args:
filepath (str or Path): Path to read data.
Returns:
bytes | memoryview: Expected bytes object or a memory view of the
bytes object.
"""
return self.client.get(filepath)
def get_text(self, filepath: Union[str, Path], encoding='utf-8') -> str:
"""Read data from a given ``filepath`` with 'r' mode.
Args:
filepath (str or Path): Path to read data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
Returns:
str: Expected text reading from ``filepath``.
"""
return self.client.get_text(filepath, encoding)
def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
"""Write data to a given ``filepath`` with 'wb' mode.
Note:
``put`` should create a directory if the directory of ``filepath``
does not exist.
Args:
obj (bytes): Data to be written.
filepath (str or Path): Path to write data.
"""
self.client.put(obj, filepath)
def put_text(self, obj: str, filepath: Union[str, Path]) -> None:
"""Write data to a given ``filepath`` with 'w' mode.
Note:
``put_text`` should create a directory if the directory of
``filepath`` does not exist.
Args:
obj (str): Data to be written.
filepath (str or Path): Path to write data.
encoding (str, optional): The encoding format used to open the
`filepath`. Default: 'utf-8'.
"""
self.client.put_text(obj, filepath)
def remove(self, filepath: Union[str, Path]) -> None:
"""Remove a file.
Args:
filepath (str, Path): Path to be removed.
"""
self.client.remove(filepath)
def exists(self, filepath: Union[str, Path]) -> bool:
"""Check whether a file path exists.
Args:
filepath (str or Path): Path to be checked whether exists.
Returns:
bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
"""
return self.client.exists(filepath)
def isdir(self, filepath: Union[str, Path]) -> bool:
"""Check whether a file path is a directory.
Args:
filepath (str or Path): Path to be checked whether it is a
directory.
Returns:
bool: Return ``True`` if ``filepath`` points to a directory,
``False`` otherwise.
"""
return self.client.isdir(filepath)
def isfile(self, filepath: Union[str, Path]) -> bool:
"""Check whether a file path is a file.
Args:
filepath (str or Path): Path to be checked whether it is a file.
Returns:
bool: Return ``True`` if ``filepath`` points to a file, ``False``
otherwise.
"""
return self.client.isfile(filepath)
def join_path(self, filepath: Union[str, Path],
*filepaths: Union[str, Path]) -> str:
"""Concatenate all file paths.
Join one or more filepath components intelligently. The return value
is the concatenation of filepath and any members of *filepaths.
Args:
filepath (str or Path): Path to be concatenated.
Returns:
str: The result of concatenation.
"""
return self.client.join_path(filepath, *filepaths)
@contextmanager
def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]:
"""Download data from ``filepath`` and write the data to local path.
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
can be called with ``with`` statement, and when exists from the
``with`` statement, the temporary path will be released.
Note:
If the ``filepath`` is a local path, just return itself.
.. warning::
``get_local_path`` is an experimental interface that may change in
the future.
Args:
filepath (str or Path): Path to be read data.
Examples:
>>> file_client = FileClient(prefix='s3')
>>> with file_client.get_local_path('s3://bucket/abc.jpg') as path:
... # do something here
Yields:
Iterable[str]: Only yield one path.
"""
with self.client.get_local_path(str(filepath)) as local_path:
yield local_path
def list_dir_or_file(self,
dir_path: Union[str, Path],
list_dir: bool = True,
list_file: bool = True,
suffix: Optional[Union[str, Tuple[str]]] = None,
recursive: bool = False) -> Iterator[str]:
"""Scan a directory to find the interested directories or files in
arbitrary order.
Note:
:meth:`list_dir_or_file` returns the path relative to ``dir_path``.
Args:
dir_path (str | Path): Path of the directory.
list_dir (bool): List the directories. Default: True.
list_file (bool): List the path of files. Default: True.
suffix (str or tuple[str], optional): File suffix
that we are interested in. Default: None.
recursive (bool): If set to True, recursively scan the
directory. Default: False.
Yields:
Iterable[str]: A relative path to ``dir_path``.
"""
yield from self.client.list_dir_or_file(dir_path, list_dir, list_file,
suffix, recursive)
# Copyright (c) OpenMMLab. All rights reserved.
from .base import BaseFileHandler
from .json_handler import JsonHandler
from .pickle_handler import PickleHandler
from .yaml_handler import YamlHandler
__all__ = ['BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler']
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
class BaseFileHandler(metaclass=ABCMeta):
# `str_like` is a flag to indicate whether the type of file object is
# str-like object or bytes-like object. Pickle only processes bytes-like
# objects but json only processes str-like object. If it is str-like
# object, `StringIO` will be used to process the buffer.
str_like = True
@abstractmethod
def load_from_fileobj(self, file, **kwargs):
pass
@abstractmethod
def dump_to_fileobj(self, obj, file, **kwargs):
pass
@abstractmethod
def dump_to_str(self, obj, **kwargs):
pass
def load_from_path(self, filepath, mode='r', **kwargs):
with open(filepath, mode) as f:
return self.load_from_fileobj(f, **kwargs)
def dump_to_path(self, obj, filepath, mode='w', **kwargs):
with open(filepath, mode) as f:
self.dump_to_fileobj(obj, f, **kwargs)
# Copyright (c) OpenMMLab. All rights reserved.
import json
import numpy as np
from .base import BaseFileHandler
def set_default(obj):
"""Set default json values for non-serializable values.
It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
etc.) into plain numbers of plain python built-in types.
"""
if isinstance(obj, (set, range)):
return list(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, np.generic):
return obj.item()
raise TypeError(f'{type(obj)} is unsupported for json dump')
class JsonHandler(BaseFileHandler):
def load_from_fileobj(self, file):
return json.load(file)
def dump_to_fileobj(self, obj, file, **kwargs):
kwargs.setdefault('default', set_default)
json.dump(obj, file, **kwargs)
def dump_to_str(self, obj, **kwargs):
kwargs.setdefault('default', set_default)
return json.dumps(obj, **kwargs)
# Copyright (c) OpenMMLab. All rights reserved.
import pickle
from .base import BaseFileHandler
class PickleHandler(BaseFileHandler):
str_like = False
def load_from_fileobj(self, file, **kwargs):
return pickle.load(file, **kwargs)
def load_from_path(self, filepath, **kwargs):
return super(PickleHandler, self).load_from_path(
filepath, mode='rb', **kwargs)
def dump_to_str(self, obj, **kwargs):
kwargs.setdefault('protocol', 2)
return pickle.dumps(obj, **kwargs)
def dump_to_fileobj(self, obj, file, **kwargs):
kwargs.setdefault('protocol', 2)
pickle.dump(obj, file, **kwargs)
def dump_to_path(self, obj, filepath, **kwargs):
super(PickleHandler, self).dump_to_path(
obj, filepath, mode='wb', **kwargs)
# Copyright (c) OpenMMLab. All rights reserved.
import yaml
try:
from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
from yaml import Loader, Dumper
from .base import BaseFileHandler # isort:skip
class YamlHandler(BaseFileHandler):
def load_from_fileobj(self, file, **kwargs):
kwargs.setdefault('Loader', Loader)
return yaml.load(file, **kwargs)
def dump_to_fileobj(self, obj, file, **kwargs):
kwargs.setdefault('Dumper', Dumper)
yaml.dump(obj, file, **kwargs)
def dump_to_str(self, obj, **kwargs):
kwargs.setdefault('Dumper', Dumper)
return yaml.dump(obj, **kwargs)
# Copyright (c) OpenMMLab. All rights reserved.
from io import BytesIO, StringIO
from pathlib import Path
from ..utils import is_list_of, is_str
from .file_client import FileClient
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
file_handlers = {
'json': JsonHandler(),
'yaml': YamlHandler(),
'yml': YamlHandler(),
'pickle': PickleHandler(),
'pkl': PickleHandler()
}
def load(file, file_format=None, file_client_args=None, **kwargs):
"""Load data from json/yaml/pickle files.
This method provides a unified api for loading data from serialized files.
Note:
In v1.3.16 and later, ``load`` supports loading data from serialized
files those can be storaged in different backends.
Args:
file (str or :obj:`Path` or file-like object): Filename or a file-like
object.
file_format (str, optional): If not specified, the file format will be
inferred from the file extension, otherwise use the specified one.
Currently supported formats include "json", "yaml/yml" and
"pickle/pkl".
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
Examples:
>>> load('/path/of/your/file') # file is storaged in disk
>>> load('https://path/of/your/file') # file is storaged in Internet
>>> load('s3://path/of/your/file') # file is storaged in petrel
Returns:
The content from the file.
"""
if isinstance(file, Path):
file = str(file)
if file_format is None and is_str(file):
file_format = file.split('.')[-1]
if file_format not in file_handlers:
raise TypeError(f'Unsupported format: {file_format}')
handler = file_handlers[file_format]
if is_str(file):
file_client = FileClient.infer_client(file_client_args, file)
if handler.str_like:
with StringIO(file_client.get_text(file)) as f:
obj = handler.load_from_fileobj(f, **kwargs)
else:
with BytesIO(file_client.get(file)) as f:
obj = handler.load_from_fileobj(f, **kwargs)
elif hasattr(file, 'read'):
obj = handler.load_from_fileobj(file, **kwargs)
else:
raise TypeError('"file" must be a filepath str or a file-object')
return obj
def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
"""Dump data to json/yaml/pickle strings or files.
This method provides a unified api for dumping data as strings or to files,
and also supports custom arguments for each file format.
Note:
In v1.3.16 and later, ``dump`` supports dumping data as strings or to
files which is saved to different backends.
Args:
obj (any): The python object to be dumped.
file (str or :obj:`Path` or file-like object, optional): If not
specified, then the object is dumped to a str, otherwise to a file
specified by the filename or file-like object.
file_format (str, optional): Same as :func:`load`.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
Examples:
>>> dump('hello world', '/path/of/your/file') # disk
>>> dump('hello world', 's3://path/of/your/file') # ceph or petrel
Returns:
bool: True for success, False otherwise.
"""
if isinstance(file, Path):
file = str(file)
if file_format is None:
if is_str(file):
file_format = file.split('.')[-1]
elif file is None:
raise ValueError(
'file_format must be specified since file is None')
if file_format not in file_handlers:
raise TypeError(f'Unsupported format: {file_format}')
handler = file_handlers[file_format]
if file is None:
return handler.dump_to_str(obj, **kwargs)
elif is_str(file):
file_client = FileClient.infer_client(file_client_args, file)
if handler.str_like:
with StringIO() as f:
handler.dump_to_fileobj(obj, f, **kwargs)
file_client.put_text(f.getvalue(), file)
else:
with BytesIO() as f:
handler.dump_to_fileobj(obj, f, **kwargs)
file_client.put(f.getvalue(), file)
elif hasattr(file, 'write'):
handler.dump_to_fileobj(obj, file, **kwargs)
else:
raise TypeError('"file" must be a filename str or a file-object')
def _register_handler(handler, file_formats):
"""Register a handler for some file extensions.
Args:
handler (:obj:`BaseFileHandler`): Handler to be registered.
file_formats (str or list[str]): File formats to be handled by this
handler.
"""
if not isinstance(handler, BaseFileHandler):
raise TypeError(
f'handler must be a child of BaseFileHandler, not {type(handler)}')
if isinstance(file_formats, str):
file_formats = [file_formats]
if not is_list_of(file_formats, str):
raise TypeError('file_formats must be a str or a list of str')
for ext in file_formats:
file_handlers[ext] = handler
def register_handler(file_formats, **kwargs):
def wrap(cls):
_register_handler(cls(**kwargs), file_formats)
return cls
return wrap
# Copyright (c) OpenMMLab. All rights reserved.
from io import StringIO
from .file_client import FileClient
def list_from_file(filename,
prefix='',
offset=0,
max_num=0,
encoding='utf-8',
file_client_args=None):
"""Load a text file and parse the content as a list of strings.
Note:
In v1.3.16 and later, ``list_from_file`` supports loading a text file
which can be storaged in different backends and parsing the content as
a list for strings.
Args:
filename (str): Filename.
prefix (str): The prefix to be inserted to the beginning of each item.
offset (int): The offset of lines.
max_num (int): The maximum number of lines to be read,
zeros and negatives mean no limitation.
encoding (str): Encoding used to open the file. Default utf-8.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
Examples:
>>> list_from_file('/path/of/your/file') # disk
['hello', 'world']
>>> list_from_file('s3://path/of/your/file') # ceph or petrel
['hello', 'world']
Returns:
list[str]: A list of strings.
"""
cnt = 0
item_list = []
file_client = FileClient.infer_client(file_client_args, filename)
with StringIO(file_client.get_text(filename, encoding)) as f:
for _ in range(offset):
f.readline()
for line in f:
if 0 < max_num <= cnt:
break
item_list.append(prefix + line.rstrip('\n\r'))
cnt += 1
return item_list
def dict_from_file(filename,
key_type=str,
encoding='utf-8',
file_client_args=None):
"""Load a text file and parse the content as a dict.
Each line of the text file will be two or more columns split by
whitespaces or tabs. The first column will be parsed as dict keys, and
the following columns will be parsed as dict values.
Note:
In v1.3.16 and later, ``dict_from_file`` supports loading a text file
which can be storaged in different backends and parsing the content as
a dict.
Args:
filename(str): Filename.
key_type(type): Type of the dict keys. str is user by default and
type conversion will be performed if specified.
encoding (str): Encoding used to open the file. Default utf-8.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
Examples:
>>> dict_from_file('/path/of/your/file') # disk
{'key1': 'value1', 'key2': 'value2'}
>>> dict_from_file('s3://path/of/your/file') # ceph or petrel
{'key1': 'value1', 'key2': 'value2'}
Returns:
dict: The parsed contents.
"""
mapping = {}
file_client = FileClient.infer_client(file_client_args, filename)
with StringIO(file_client.get_text(filename, encoding)) as f:
for line in f:
items = line.rstrip('\n').split()
assert len(items) >= 2
key = key_type(items[0])
val = items[1:] if len(items) > 2 else items[1]
mapping[key] = val
return mapping
# Copyright (c) OpenMMLab. All rights reserved.
from .colorspace import (bgr2gray, bgr2hls, bgr2hsv, bgr2rgb, bgr2ycbcr,
gray2bgr, gray2rgb, hls2bgr, hsv2bgr, imconvert,
rgb2bgr, rgb2gray, rgb2ycbcr, ycbcr2bgr, ycbcr2rgb)
from .geometric import (cutout, imcrop, imflip, imflip_, impad,
impad_to_multiple, imrescale, imresize, imresize_like,
imresize_to_multiple, imrotate, imshear, imtranslate,
rescale_size)
from .io import imfrombytes, imread, imwrite, supported_backends, use_backend
from .misc import tensor2imgs
from .photometric import (adjust_brightness, adjust_color, adjust_contrast,
adjust_lighting, adjust_sharpness, auto_contrast,
clahe, imdenormalize, imequalize, iminvert,
imnormalize, imnormalize_, lut_transform, posterize,
solarize)
__all__ = [
'bgr2gray', 'bgr2hls', 'bgr2hsv', 'bgr2rgb', 'gray2bgr', 'gray2rgb',
'hls2bgr', 'hsv2bgr', 'imconvert', 'rgb2bgr', 'rgb2gray', 'imrescale',
'imresize', 'imresize_like', 'imresize_to_multiple', 'rescale_size',
'imcrop', 'imflip', 'imflip_', 'impad', 'impad_to_multiple', 'imrotate',
'imfrombytes', 'imread', 'imwrite', 'supported_backends', 'use_backend',
'imdenormalize', 'imnormalize', 'imnormalize_', 'iminvert', 'posterize',
'solarize', 'rgb2ycbcr', 'bgr2ycbcr', 'ycbcr2rgb', 'ycbcr2bgr',
'tensor2imgs', 'imshear', 'imtranslate', 'adjust_color', 'imequalize',
'adjust_brightness', 'adjust_contrast', 'lut_transform', 'clahe',
'adjust_sharpness', 'auto_contrast', 'cutout', 'adjust_lighting'
]
# Copyright (c) OpenMMLab. All rights reserved.
import cv2
import numpy as np
def imconvert(img, src, dst):
"""Convert an image from the src colorspace to dst colorspace.
Args:
img (ndarray): The input image.
src (str): The source colorspace, e.g., 'rgb', 'hsv'.
dst (str): The destination colorspace, e.g., 'rgb', 'hsv'.
Returns:
ndarray: The converted image.
"""
code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
out_img = cv2.cvtColor(img, code)
return out_img
def bgr2gray(img, keepdim=False):
"""Convert a BGR image to grayscale image.
Args:
img (ndarray): The input image.
keepdim (bool): If False (by default), then return the grayscale image
with 2 dims, otherwise 3 dims.
Returns:
ndarray: The converted grayscale image.
"""
out_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
if keepdim:
out_img = out_img[..., None]
return out_img
def rgb2gray(img, keepdim=False):
"""Convert a RGB image to grayscale image.
Args:
img (ndarray): The input image.
keepdim (bool): If False (by default), then return the grayscale image
with 2 dims, otherwise 3 dims.
Returns:
ndarray: The converted grayscale image.
"""
out_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
if keepdim:
out_img = out_img[..., None]
return out_img
def gray2bgr(img):
"""Convert a grayscale image to BGR image.
Args:
img (ndarray): The input image.
Returns:
ndarray: The converted BGR image.
"""
img = img[..., None] if img.ndim == 2 else img
out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
return out_img
def gray2rgb(img):
"""Convert a grayscale image to RGB image.
Args:
img (ndarray): The input image.
Returns:
ndarray: The converted RGB image.
"""
img = img[..., None] if img.ndim == 2 else img
out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
return out_img
def _convert_input_type_range(img):
"""Convert the type and range of the input image.
It converts the input image to np.float32 type and range of [0, 1].
It is mainly used for pre-processing the input image in colorspace
conversion functions such as rgb2ycbcr and ycbcr2rgb.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
(ndarray): The converted image with type of np.float32 and range of
[0, 1].
"""
img_type = img.dtype
img = img.astype(np.float32)
if img_type == np.float32:
pass
elif img_type == np.uint8:
img /= 255.
else:
raise TypeError('The img type should be np.float32 or np.uint8, '
f'but got {img_type}')
return img
def _convert_output_type_range(img, dst_type):
"""Convert the type and range of the image according to dst_type.
It converts the image to desired type and range. If `dst_type` is np.uint8,
images will be converted to np.uint8 type with range [0, 255]. If
`dst_type` is np.float32, it converts the image to np.float32 type with
range [0, 1].
It is mainly used for post-processing images in colorspace conversion
functions such as rgb2ycbcr and ycbcr2rgb.
Args:
img (ndarray): The image to be converted with np.float32 type and
range [0, 255].
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
converts the image to np.uint8 type with range [0, 255]. If
dst_type is np.float32, it converts the image to np.float32 type
with range [0, 1].
Returns:
(ndarray): The converted image with desired type and range.
"""
if dst_type not in (np.uint8, np.float32):
raise TypeError('The dst_type should be np.float32 or np.uint8, '
f'but got {dst_type}')
if dst_type == np.uint8:
img = img.round()
else:
img /= 255.
return img.astype(dst_type)
def rgb2ycbcr(img, y_only=False):
"""Convert a RGB image to YCbCr image.
This function produces the same results as Matlab's `rgb2ycbcr` function.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img)
if y_only:
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
else:
out_img = np.matmul(
img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
[24.966, 112.0, -18.214]]) + [16, 128, 128]
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def bgr2ycbcr(img, y_only=False):
"""Convert a BGR image to YCbCr image.
The bgr version of rgb2ycbcr.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img)
if y_only:
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
else:
out_img = np.matmul(
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) + [16, 128, 128]
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def ycbcr2rgb(img):
"""Convert a YCbCr image to RGB image.
This function produces the same results as Matlab's ycbcr2rgb function.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
ndarray: The converted RGB image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img) * 255
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
[0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0]]) * 255.0 + [
-222.921, 135.576, -276.836
]
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def ycbcr2bgr(img):
"""Convert a YCbCr image to BGR image.
The bgr version of ycbcr2rgb.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
ndarray: The converted BGR image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img) * 255
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
[0.00791071, -0.00153632, 0],
[0, -0.00318811, 0.00625893]]) * 255.0 + [
-276.836, 135.576, -222.921
]
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def convert_color_factory(src, dst):
code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
def convert_color(img):
out_img = cv2.cvtColor(img, code)
return out_img
convert_color.__doc__ = f"""Convert a {src.upper()} image to {dst.upper()}
image.
Args:
img (ndarray or str): The input image.
Returns:
ndarray: The converted {dst.upper()} image.
"""
return convert_color
bgr2rgb = convert_color_factory('bgr', 'rgb')
rgb2bgr = convert_color_factory('rgb', 'bgr')
bgr2hsv = convert_color_factory('bgr', 'hsv')
hsv2bgr = convert_color_factory('hsv', 'bgr')
bgr2hls = convert_color_factory('bgr', 'hls')
hls2bgr = convert_color_factory('hls', 'bgr')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment