Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
InstructBLIP_pytorch
Commits
c04f261a
Commit
c04f261a
authored
Aug 22, 2024
by
dongchy920
Browse files
InstruceBLIP
parents
Pipeline
#1594
canceled with stages
Changes
421
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3987 additions
and
0 deletions
+3987
-0
lavis/common/annotator/uniformer/mmcv/cnn/resnet.py
lavis/common/annotator/uniformer/mmcv/cnn/resnet.py
+316
-0
lavis/common/annotator/uniformer/mmcv/cnn/utils/__init__.py
lavis/common/annotator/uniformer/mmcv/cnn/utils/__init__.py
+19
-0
lavis/common/annotator/uniformer/mmcv/cnn/utils/flops_counter.py
...ommon/annotator/uniformer/mmcv/cnn/utils/flops_counter.py
+599
-0
lavis/common/annotator/uniformer/mmcv/cnn/utils/fuse_conv_bn.py
...common/annotator/uniformer/mmcv/cnn/utils/fuse_conv_bn.py
+59
-0
lavis/common/annotator/uniformer/mmcv/cnn/utils/sync_bn.py
lavis/common/annotator/uniformer/mmcv/cnn/utils/sync_bn.py
+59
-0
lavis/common/annotator/uniformer/mmcv/cnn/utils/weight_init.py
.../common/annotator/uniformer/mmcv/cnn/utils/weight_init.py
+684
-0
lavis/common/annotator/uniformer/mmcv/cnn/vgg.py
lavis/common/annotator/uniformer/mmcv/cnn/vgg.py
+175
-0
lavis/common/annotator/uniformer/mmcv/engine/__init__.py
lavis/common/annotator/uniformer/mmcv/engine/__init__.py
+8
-0
lavis/common/annotator/uniformer/mmcv/engine/test.py
lavis/common/annotator/uniformer/mmcv/engine/test.py
+202
-0
lavis/common/annotator/uniformer/mmcv/fileio/__init__.py
lavis/common/annotator/uniformer/mmcv/fileio/__init__.py
+11
-0
lavis/common/annotator/uniformer/mmcv/fileio/file_client.py
lavis/common/annotator/uniformer/mmcv/fileio/file_client.py
+1148
-0
lavis/common/annotator/uniformer/mmcv/fileio/handlers/__init__.py
...mmon/annotator/uniformer/mmcv/fileio/handlers/__init__.py
+7
-0
lavis/common/annotator/uniformer/mmcv/fileio/handlers/base.py
...s/common/annotator/uniformer/mmcv/fileio/handlers/base.py
+30
-0
lavis/common/annotator/uniformer/mmcv/fileio/handlers/json_handler.py
.../annotator/uniformer/mmcv/fileio/handlers/json_handler.py
+36
-0
lavis/common/annotator/uniformer/mmcv/fileio/handlers/pickle_handler.py
...nnotator/uniformer/mmcv/fileio/handlers/pickle_handler.py
+28
-0
lavis/common/annotator/uniformer/mmcv/fileio/handlers/yaml_handler.py
.../annotator/uniformer/mmcv/fileio/handlers/yaml_handler.py
+24
-0
lavis/common/annotator/uniformer/mmcv/fileio/io.py
lavis/common/annotator/uniformer/mmcv/fileio/io.py
+151
-0
lavis/common/annotator/uniformer/mmcv/fileio/parse.py
lavis/common/annotator/uniformer/mmcv/fileio/parse.py
+97
-0
lavis/common/annotator/uniformer/mmcv/image/__init__.py
lavis/common/annotator/uniformer/mmcv/image/__init__.py
+28
-0
lavis/common/annotator/uniformer/mmcv/image/colorspace.py
lavis/common/annotator/uniformer/mmcv/image/colorspace.py
+306
-0
No files found.
Too many changes to show.
To preserve performance only
421 of 421+
files are displayed.
Plain diff
Email patch
lavis/common/annotator/uniformer/mmcv/cnn/resnet.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
import
logging
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
cp
from
.utils
import
constant_init
,
kaiming_init
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
dilation
=
1
):
"""3x3 convolution with padding."""
return
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
dilation
,
dilation
=
dilation
,
bias
=
False
)
class
BasicBlock
(
nn
.
Module
):
expansion
=
1
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
,
dilation
=
1
,
downsample
=
None
,
style
=
'pytorch'
,
with_cp
=
False
):
super
(
BasicBlock
,
self
).
__init__
()
assert
style
in
[
'pytorch'
,
'caffe'
]
self
.
conv1
=
conv3x3
(
inplanes
,
planes
,
stride
,
dilation
)
self
.
bn1
=
nn
.
BatchNorm2d
(
planes
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
conv2
=
conv3x3
(
planes
,
planes
)
self
.
bn2
=
nn
.
BatchNorm2d
(
planes
)
self
.
downsample
=
downsample
self
.
stride
=
stride
self
.
dilation
=
dilation
assert
not
with_cp
def
forward
(
self
,
x
):
residual
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
if
self
.
downsample
is
not
None
:
residual
=
self
.
downsample
(
x
)
out
+=
residual
out
=
self
.
relu
(
out
)
return
out
class
Bottleneck
(
nn
.
Module
):
expansion
=
4
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
,
dilation
=
1
,
downsample
=
None
,
style
=
'pytorch'
,
with_cp
=
False
):
"""Bottleneck block.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
super
(
Bottleneck
,
self
).
__init__
()
assert
style
in
[
'pytorch'
,
'caffe'
]
if
style
==
'pytorch'
:
conv1_stride
=
1
conv2_stride
=
stride
else
:
conv1_stride
=
stride
conv2_stride
=
1
self
.
conv1
=
nn
.
Conv2d
(
inplanes
,
planes
,
kernel_size
=
1
,
stride
=
conv1_stride
,
bias
=
False
)
self
.
conv2
=
nn
.
Conv2d
(
planes
,
planes
,
kernel_size
=
3
,
stride
=
conv2_stride
,
padding
=
dilation
,
dilation
=
dilation
,
bias
=
False
)
self
.
bn1
=
nn
.
BatchNorm2d
(
planes
)
self
.
bn2
=
nn
.
BatchNorm2d
(
planes
)
self
.
conv3
=
nn
.
Conv2d
(
planes
,
planes
*
self
.
expansion
,
kernel_size
=
1
,
bias
=
False
)
self
.
bn3
=
nn
.
BatchNorm2d
(
planes
*
self
.
expansion
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
downsample
=
downsample
self
.
stride
=
stride
self
.
dilation
=
dilation
self
.
with_cp
=
with_cp
def
forward
(
self
,
x
):
def
_inner_forward
(
x
):
residual
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv3
(
out
)
out
=
self
.
bn3
(
out
)
if
self
.
downsample
is
not
None
:
residual
=
self
.
downsample
(
x
)
out
+=
residual
return
out
if
self
.
with_cp
and
x
.
requires_grad
:
out
=
cp
.
checkpoint
(
_inner_forward
,
x
)
else
:
out
=
_inner_forward
(
x
)
out
=
self
.
relu
(
out
)
return
out
def
make_res_layer
(
block
,
inplanes
,
planes
,
blocks
,
stride
=
1
,
dilation
=
1
,
style
=
'pytorch'
,
with_cp
=
False
):
downsample
=
None
if
stride
!=
1
or
inplanes
!=
planes
*
block
.
expansion
:
downsample
=
nn
.
Sequential
(
nn
.
Conv2d
(
inplanes
,
planes
*
block
.
expansion
,
kernel_size
=
1
,
stride
=
stride
,
bias
=
False
),
nn
.
BatchNorm2d
(
planes
*
block
.
expansion
),
)
layers
=
[]
layers
.
append
(
block
(
inplanes
,
planes
,
stride
,
dilation
,
downsample
,
style
=
style
,
with_cp
=
with_cp
))
inplanes
=
planes
*
block
.
expansion
for
_
in
range
(
1
,
blocks
):
layers
.
append
(
block
(
inplanes
,
planes
,
1
,
dilation
,
style
=
style
,
with_cp
=
with_cp
))
return
nn
.
Sequential
(
*
layers
)
class
ResNet
(
nn
.
Module
):
"""ResNet backbone.
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
num_stages (int): Resnet stages, normally 4.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
running stats (mean and var).
bn_frozen (bool): Whether to freeze weight and bias of BN layers.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
"""
arch_settings
=
{
18
:
(
BasicBlock
,
(
2
,
2
,
2
,
2
)),
34
:
(
BasicBlock
,
(
3
,
4
,
6
,
3
)),
50
:
(
Bottleneck
,
(
3
,
4
,
6
,
3
)),
101
:
(
Bottleneck
,
(
3
,
4
,
23
,
3
)),
152
:
(
Bottleneck
,
(
3
,
8
,
36
,
3
))
}
def
__init__
(
self
,
depth
,
num_stages
=
4
,
strides
=
(
1
,
2
,
2
,
2
),
dilations
=
(
1
,
1
,
1
,
1
),
out_indices
=
(
0
,
1
,
2
,
3
),
style
=
'pytorch'
,
frozen_stages
=-
1
,
bn_eval
=
True
,
bn_frozen
=
False
,
with_cp
=
False
):
super
(
ResNet
,
self
).
__init__
()
if
depth
not
in
self
.
arch_settings
:
raise
KeyError
(
f
'invalid depth
{
depth
}
for resnet'
)
assert
num_stages
>=
1
and
num_stages
<=
4
block
,
stage_blocks
=
self
.
arch_settings
[
depth
]
stage_blocks
=
stage_blocks
[:
num_stages
]
assert
len
(
strides
)
==
len
(
dilations
)
==
num_stages
assert
max
(
out_indices
)
<
num_stages
self
.
out_indices
=
out_indices
self
.
style
=
style
self
.
frozen_stages
=
frozen_stages
self
.
bn_eval
=
bn_eval
self
.
bn_frozen
=
bn_frozen
self
.
with_cp
=
with_cp
self
.
inplanes
=
64
self
.
conv1
=
nn
.
Conv2d
(
3
,
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
self
.
bn1
=
nn
.
BatchNorm2d
(
64
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
res_layers
=
[]
for
i
,
num_blocks
in
enumerate
(
stage_blocks
):
stride
=
strides
[
i
]
dilation
=
dilations
[
i
]
planes
=
64
*
2
**
i
res_layer
=
make_res_layer
(
block
,
self
.
inplanes
,
planes
,
num_blocks
,
stride
=
stride
,
dilation
=
dilation
,
style
=
self
.
style
,
with_cp
=
with_cp
)
self
.
inplanes
=
planes
*
block
.
expansion
layer_name
=
f
'layer
{
i
+
1
}
'
self
.
add_module
(
layer_name
,
res_layer
)
self
.
res_layers
.
append
(
layer_name
)
self
.
feat_dim
=
block
.
expansion
*
64
*
2
**
(
len
(
stage_blocks
)
-
1
)
def
init_weights
(
self
,
pretrained
=
None
):
if
isinstance
(
pretrained
,
str
):
logger
=
logging
.
getLogger
()
from
..runner
import
load_checkpoint
load_checkpoint
(
self
,
pretrained
,
strict
=
False
,
logger
=
logger
)
elif
pretrained
is
None
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
kaiming_init
(
m
)
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
constant_init
(
m
,
1
)
else
:
raise
TypeError
(
'pretrained must be a str or None'
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
maxpool
(
x
)
outs
=
[]
for
i
,
layer_name
in
enumerate
(
self
.
res_layers
):
res_layer
=
getattr
(
self
,
layer_name
)
x
=
res_layer
(
x
)
if
i
in
self
.
out_indices
:
outs
.
append
(
x
)
if
len
(
outs
)
==
1
:
return
outs
[
0
]
else
:
return
tuple
(
outs
)
def
train
(
self
,
mode
=
True
):
super
(
ResNet
,
self
).
train
(
mode
)
if
self
.
bn_eval
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
BatchNorm2d
):
m
.
eval
()
if
self
.
bn_frozen
:
for
params
in
m
.
parameters
():
params
.
requires_grad
=
False
if
mode
and
self
.
frozen_stages
>=
0
:
for
param
in
self
.
conv1
.
parameters
():
param
.
requires_grad
=
False
for
param
in
self
.
bn1
.
parameters
():
param
.
requires_grad
=
False
self
.
bn1
.
eval
()
self
.
bn1
.
weight
.
requires_grad
=
False
self
.
bn1
.
bias
.
requires_grad
=
False
for
i
in
range
(
1
,
self
.
frozen_stages
+
1
):
mod
=
getattr
(
self
,
f
'layer
{
i
}
'
)
mod
.
eval
()
for
param
in
mod
.
parameters
():
param
.
requires_grad
=
False
lavis/common/annotator/uniformer/mmcv/cnn/utils/__init__.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
from
.flops_counter
import
get_model_complexity_info
from
.fuse_conv_bn
import
fuse_conv_bn
from
.sync_bn
import
revert_sync_batchnorm
from
.weight_init
import
(
INITIALIZERS
,
Caffe2XavierInit
,
ConstantInit
,
KaimingInit
,
NormalInit
,
PretrainedInit
,
TruncNormalInit
,
UniformInit
,
XavierInit
,
bias_init_with_prob
,
caffe2_xavier_init
,
constant_init
,
initialize
,
kaiming_init
,
normal_init
,
trunc_normal_init
,
uniform_init
,
xavier_init
)
__all__
=
[
'get_model_complexity_info'
,
'bias_init_with_prob'
,
'caffe2_xavier_init'
,
'constant_init'
,
'kaiming_init'
,
'normal_init'
,
'trunc_normal_init'
,
'uniform_init'
,
'xavier_init'
,
'fuse_conv_bn'
,
'initialize'
,
'INITIALIZERS'
,
'ConstantInit'
,
'XavierInit'
,
'NormalInit'
,
'TruncNormalInit'
,
'UniformInit'
,
'KaimingInit'
,
'PretrainedInit'
,
'Caffe2XavierInit'
,
'revert_sync_batchnorm'
]
lavis/common/annotator/uniformer/mmcv/cnn/utils/flops_counter.py
0 → 100644
View file @
c04f261a
# Modified from flops-counter.pytorch by Vladislav Sovrasov
# original repo: https://github.com/sovrasov/flops-counter.pytorch
# MIT License
# Copyright (c) 2018 Vladislav Sovrasov
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import
sys
from
functools
import
partial
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
annotator.uniformer.mmcv
as
mmcv
def
get_model_complexity_info
(
model
,
input_shape
,
print_per_layer_stat
=
True
,
as_strings
=
True
,
input_constructor
=
None
,
flush
=
False
,
ost
=
sys
.
stdout
):
"""Get complexity information of a model.
This method can calculate FLOPs and parameter counts of a model with
corresponding input shape. It can also print complexity information for
each layer in a model.
Supported layers are listed as below:
- Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``.
- Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``,
``nn.ReLU6``.
- Poolings: ``nn.MaxPool1d``, ``nn.MaxPool2d``, ``nn.MaxPool3d``,
``nn.AvgPool1d``, ``nn.AvgPool2d``, ``nn.AvgPool3d``,
``nn.AdaptiveMaxPool1d``, ``nn.AdaptiveMaxPool2d``,
``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``,
``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``.
- BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``,
``nn.BatchNorm3d``, ``nn.GroupNorm``, ``nn.InstanceNorm1d``,
``InstanceNorm2d``, ``InstanceNorm3d``, ``nn.LayerNorm``.
- Linear: ``nn.Linear``.
- Deconvolution: ``nn.ConvTranspose2d``.
- Upsample: ``nn.Upsample``.
Args:
model (nn.Module): The model for complexity calculation.
input_shape (tuple): Input shape used for calculation.
print_per_layer_stat (bool): Whether to print complexity information
for each layer in a model. Default: True.
as_strings (bool): Output FLOPs and params counts in a string form.
Default: True.
input_constructor (None | callable): If specified, it takes a callable
method that generates input. otherwise, it will generate a random
tensor with input shape to calculate FLOPs. Default: None.
flush (bool): same as that in :func:`print`. Default: False.
ost (stream): same as ``file`` param in :func:`print`.
Default: sys.stdout.
Returns:
tuple[float | str]: If ``as_strings`` is set to True, it will return
FLOPs and parameter counts in a string format. otherwise, it will
return those in a float number format.
"""
assert
type
(
input_shape
)
is
tuple
assert
len
(
input_shape
)
>=
1
assert
isinstance
(
model
,
nn
.
Module
)
flops_model
=
add_flops_counting_methods
(
model
)
flops_model
.
eval
()
flops_model
.
start_flops_count
()
if
input_constructor
:
input
=
input_constructor
(
input_shape
)
_
=
flops_model
(
**
input
)
else
:
try
:
batch
=
torch
.
ones
(()).
new_empty
(
(
1
,
*
input_shape
),
dtype
=
next
(
flops_model
.
parameters
()).
dtype
,
device
=
next
(
flops_model
.
parameters
()).
device
)
except
StopIteration
:
# Avoid StopIteration for models which have no parameters,
# like `nn.Relu()`, `nn.AvgPool2d`, etc.
batch
=
torch
.
ones
(()).
new_empty
((
1
,
*
input_shape
))
_
=
flops_model
(
batch
)
flops_count
,
params_count
=
flops_model
.
compute_average_flops_cost
()
if
print_per_layer_stat
:
print_model_with_flops
(
flops_model
,
flops_count
,
params_count
,
ost
=
ost
,
flush
=
flush
)
flops_model
.
stop_flops_count
()
if
as_strings
:
return
flops_to_string
(
flops_count
),
params_to_string
(
params_count
)
return
flops_count
,
params_count
def
flops_to_string
(
flops
,
units
=
'GFLOPs'
,
precision
=
2
):
"""Convert FLOPs number into a string.
Note that Here we take a multiply-add counts as one FLOP.
Args:
flops (float): FLOPs number to be converted.
units (str | None): Converted FLOPs units. Options are None, 'GFLOPs',
'MFLOPs', 'KFLOPs', 'FLOPs'. If set to None, it will automatically
choose the most suitable unit for FLOPs. Default: 'GFLOPs'.
precision (int): Digit number after the decimal point. Default: 2.
Returns:
str: The converted FLOPs number with units.
Examples:
>>> flops_to_string(1e9)
'1.0 GFLOPs'
>>> flops_to_string(2e5, 'MFLOPs')
'0.2 MFLOPs'
>>> flops_to_string(3e-9, None)
'3e-09 FLOPs'
"""
if
units
is
None
:
if
flops
//
10
**
9
>
0
:
return
str
(
round
(
flops
/
10.
**
9
,
precision
))
+
' GFLOPs'
elif
flops
//
10
**
6
>
0
:
return
str
(
round
(
flops
/
10.
**
6
,
precision
))
+
' MFLOPs'
elif
flops
//
10
**
3
>
0
:
return
str
(
round
(
flops
/
10.
**
3
,
precision
))
+
' KFLOPs'
else
:
return
str
(
flops
)
+
' FLOPs'
else
:
if
units
==
'GFLOPs'
:
return
str
(
round
(
flops
/
10.
**
9
,
precision
))
+
' '
+
units
elif
units
==
'MFLOPs'
:
return
str
(
round
(
flops
/
10.
**
6
,
precision
))
+
' '
+
units
elif
units
==
'KFLOPs'
:
return
str
(
round
(
flops
/
10.
**
3
,
precision
))
+
' '
+
units
else
:
return
str
(
flops
)
+
' FLOPs'
def
params_to_string
(
num_params
,
units
=
None
,
precision
=
2
):
"""Convert parameter number into a string.
Args:
num_params (float): Parameter number to be converted.
units (str | None): Converted FLOPs units. Options are None, 'M',
'K' and ''. If set to None, it will automatically choose the most
suitable unit for Parameter number. Default: None.
precision (int): Digit number after the decimal point. Default: 2.
Returns:
str: The converted parameter number with units.
Examples:
>>> params_to_string(1e9)
'1000.0 M'
>>> params_to_string(2e5)
'200.0 k'
>>> params_to_string(3e-9)
'3e-09'
"""
if
units
is
None
:
if
num_params
//
10
**
6
>
0
:
return
str
(
round
(
num_params
/
10
**
6
,
precision
))
+
' M'
elif
num_params
//
10
**
3
:
return
str
(
round
(
num_params
/
10
**
3
,
precision
))
+
' k'
else
:
return
str
(
num_params
)
else
:
if
units
==
'M'
:
return
str
(
round
(
num_params
/
10.
**
6
,
precision
))
+
' '
+
units
elif
units
==
'K'
:
return
str
(
round
(
num_params
/
10.
**
3
,
precision
))
+
' '
+
units
else
:
return
str
(
num_params
)
def
print_model_with_flops
(
model
,
total_flops
,
total_params
,
units
=
'GFLOPs'
,
precision
=
3
,
ost
=
sys
.
stdout
,
flush
=
False
):
"""Print a model with FLOPs for each layer.
Args:
model (nn.Module): The model to be printed.
total_flops (float): Total FLOPs of the model.
total_params (float): Total parameter counts of the model.
units (str | None): Converted FLOPs units. Default: 'GFLOPs'.
precision (int): Digit number after the decimal point. Default: 3.
ost (stream): same as `file` param in :func:`print`.
Default: sys.stdout.
flush (bool): same as that in :func:`print`. Default: False.
Example:
>>> class ExampleModel(nn.Module):
>>> def __init__(self):
>>> super().__init__()
>>> self.conv1 = nn.Conv2d(3, 8, 3)
>>> self.conv2 = nn.Conv2d(8, 256, 3)
>>> self.conv3 = nn.Conv2d(256, 8, 3)
>>> self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
>>> self.flatten = nn.Flatten()
>>> self.fc = nn.Linear(8, 1)
>>> def forward(self, x):
>>> x = self.conv1(x)
>>> x = self.conv2(x)
>>> x = self.conv3(x)
>>> x = self.avg_pool(x)
>>> x = self.flatten(x)
>>> x = self.fc(x)
>>> return x
>>> model = ExampleModel()
>>> x = (3, 16, 16)
to print the complexity information state for each layer, you can use
>>> get_model_complexity_info(model, x)
or directly use
>>> print_model_with_flops(model, 4579784.0, 37361)
ExampleModel(
0.037 M, 100.000% Params, 0.005 GFLOPs, 100.000% FLOPs,
(conv1): Conv2d(0.0 M, 0.600% Params, 0.0 GFLOPs, 0.959% FLOPs, 3, 8, kernel_size=(3, 3), stride=(1, 1)) # noqa: E501
(conv2): Conv2d(0.019 M, 50.020% Params, 0.003 GFLOPs, 58.760% FLOPs, 8, 256, kernel_size=(3, 3), stride=(1, 1))
(conv3): Conv2d(0.018 M, 49.356% Params, 0.002 GFLOPs, 40.264% FLOPs, 256, 8, kernel_size=(3, 3), stride=(1, 1))
(avg_pool): AdaptiveAvgPool2d(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.017% FLOPs, output_size=(1, 1))
(flatten): Flatten(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.000% FLOPs, )
(fc): Linear(0.0 M, 0.024% Params, 0.0 GFLOPs, 0.000% FLOPs, in_features=8, out_features=1, bias=True)
)
"""
def
accumulate_params
(
self
):
if
is_supported_instance
(
self
):
return
self
.
__params__
else
:
sum
=
0
for
m
in
self
.
children
():
sum
+=
m
.
accumulate_params
()
return
sum
def
accumulate_flops
(
self
):
if
is_supported_instance
(
self
):
return
self
.
__flops__
/
model
.
__batch_counter__
else
:
sum
=
0
for
m
in
self
.
children
():
sum
+=
m
.
accumulate_flops
()
return
sum
def
flops_repr
(
self
):
accumulated_num_params
=
self
.
accumulate_params
()
accumulated_flops_cost
=
self
.
accumulate_flops
()
return
', '
.
join
([
params_to_string
(
accumulated_num_params
,
units
=
'M'
,
precision
=
precision
),
'{:.3%} Params'
.
format
(
accumulated_num_params
/
total_params
),
flops_to_string
(
accumulated_flops_cost
,
units
=
units
,
precision
=
precision
),
'{:.3%} FLOPs'
.
format
(
accumulated_flops_cost
/
total_flops
),
self
.
original_extra_repr
()
])
def
add_extra_repr
(
m
):
m
.
accumulate_flops
=
accumulate_flops
.
__get__
(
m
)
m
.
accumulate_params
=
accumulate_params
.
__get__
(
m
)
flops_extra_repr
=
flops_repr
.
__get__
(
m
)
if
m
.
extra_repr
!=
flops_extra_repr
:
m
.
original_extra_repr
=
m
.
extra_repr
m
.
extra_repr
=
flops_extra_repr
assert
m
.
extra_repr
!=
m
.
original_extra_repr
def
del_extra_repr
(
m
):
if
hasattr
(
m
,
'original_extra_repr'
):
m
.
extra_repr
=
m
.
original_extra_repr
del
m
.
original_extra_repr
if
hasattr
(
m
,
'accumulate_flops'
):
del
m
.
accumulate_flops
model
.
apply
(
add_extra_repr
)
print
(
model
,
file
=
ost
,
flush
=
flush
)
model
.
apply
(
del_extra_repr
)
def
get_model_parameters_number
(
model
):
"""Calculate parameter number of a model.
Args:
model (nn.module): The model for parameter number calculation.
Returns:
float: Parameter number of the model.
"""
num_params
=
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
)
return
num_params
def
add_flops_counting_methods
(
net_main_module
):
# adding additional methods to the existing module object,
# this is done this way so that each function has access to self object
net_main_module
.
start_flops_count
=
start_flops_count
.
__get__
(
net_main_module
)
net_main_module
.
stop_flops_count
=
stop_flops_count
.
__get__
(
net_main_module
)
net_main_module
.
reset_flops_count
=
reset_flops_count
.
__get__
(
net_main_module
)
net_main_module
.
compute_average_flops_cost
=
compute_average_flops_cost
.
__get__
(
# noqa: E501
net_main_module
)
net_main_module
.
reset_flops_count
()
return
net_main_module
def
compute_average_flops_cost
(
self
):
"""Compute average FLOPs cost.
A method to compute average FLOPs cost, which will be available after
`add_flops_counting_methods()` is called on a desired net object.
Returns:
float: Current mean flops consumption per image.
"""
batches_count
=
self
.
__batch_counter__
flops_sum
=
0
for
module
in
self
.
modules
():
if
is_supported_instance
(
module
):
flops_sum
+=
module
.
__flops__
params_sum
=
get_model_parameters_number
(
self
)
return
flops_sum
/
batches_count
,
params_sum
def
start_flops_count
(
self
):
"""Activate the computation of mean flops consumption per image.
A method to activate the computation of mean flops consumption per image.
which will be available after ``add_flops_counting_methods()`` is called on
a desired net object. It should be called before running the network.
"""
add_batch_counter_hook_function
(
self
)
def
add_flops_counter_hook_function
(
module
):
if
is_supported_instance
(
module
):
if
hasattr
(
module
,
'__flops_handle__'
):
return
else
:
handle
=
module
.
register_forward_hook
(
get_modules_mapping
()[
type
(
module
)])
module
.
__flops_handle__
=
handle
self
.
apply
(
partial
(
add_flops_counter_hook_function
))
def
stop_flops_count
(
self
):
"""Stop computing the mean flops consumption per image.
A method to stop computing the mean flops consumption per image, which will
be available after ``add_flops_counting_methods()`` is called on a desired
net object. It can be called to pause the computation whenever.
"""
remove_batch_counter_hook_function
(
self
)
self
.
apply
(
remove_flops_counter_hook_function
)
def
reset_flops_count
(
self
):
"""Reset statistics computed so far.
A method to Reset computed statistics, which will be available after
`add_flops_counting_methods()` is called on a desired net object.
"""
add_batch_counter_variables_or_reset
(
self
)
self
.
apply
(
add_flops_counter_variable_or_reset
)
# ---- Internal functions
def
empty_flops_counter_hook
(
module
,
input
,
output
):
module
.
__flops__
+=
0
def
upsample_flops_counter_hook
(
module
,
input
,
output
):
output_size
=
output
[
0
]
batch_size
=
output_size
.
shape
[
0
]
output_elements_count
=
batch_size
for
val
in
output_size
.
shape
[
1
:]:
output_elements_count
*=
val
module
.
__flops__
+=
int
(
output_elements_count
)
def
relu_flops_counter_hook
(
module
,
input
,
output
):
active_elements_count
=
output
.
numel
()
module
.
__flops__
+=
int
(
active_elements_count
)
def
linear_flops_counter_hook
(
module
,
input
,
output
):
input
=
input
[
0
]
output_last_dim
=
output
.
shape
[
-
1
]
# pytorch checks dimensions, so here we don't care much
module
.
__flops__
+=
int
(
np
.
prod
(
input
.
shape
)
*
output_last_dim
)
def
pool_flops_counter_hook
(
module
,
input
,
output
):
input
=
input
[
0
]
module
.
__flops__
+=
int
(
np
.
prod
(
input
.
shape
))
def
norm_flops_counter_hook
(
module
,
input
,
output
):
input
=
input
[
0
]
batch_flops
=
np
.
prod
(
input
.
shape
)
if
(
getattr
(
module
,
'affine'
,
False
)
or
getattr
(
module
,
'elementwise_affine'
,
False
)):
batch_flops
*=
2
module
.
__flops__
+=
int
(
batch_flops
)
def
deconv_flops_counter_hook
(
conv_module
,
input
,
output
):
# Can have multiple inputs, getting the first one
input
=
input
[
0
]
batch_size
=
input
.
shape
[
0
]
input_height
,
input_width
=
input
.
shape
[
2
:]
kernel_height
,
kernel_width
=
conv_module
.
kernel_size
in_channels
=
conv_module
.
in_channels
out_channels
=
conv_module
.
out_channels
groups
=
conv_module
.
groups
filters_per_channel
=
out_channels
//
groups
conv_per_position_flops
=
(
kernel_height
*
kernel_width
*
in_channels
*
filters_per_channel
)
active_elements_count
=
batch_size
*
input_height
*
input_width
overall_conv_flops
=
conv_per_position_flops
*
active_elements_count
bias_flops
=
0
if
conv_module
.
bias
is
not
None
:
output_height
,
output_width
=
output
.
shape
[
2
:]
bias_flops
=
out_channels
*
batch_size
*
output_height
*
output_height
overall_flops
=
overall_conv_flops
+
bias_flops
conv_module
.
__flops__
+=
int
(
overall_flops
)
def
conv_flops_counter_hook
(
conv_module
,
input
,
output
):
# Can have multiple inputs, getting the first one
input
=
input
[
0
]
batch_size
=
input
.
shape
[
0
]
output_dims
=
list
(
output
.
shape
[
2
:])
kernel_dims
=
list
(
conv_module
.
kernel_size
)
in_channels
=
conv_module
.
in_channels
out_channels
=
conv_module
.
out_channels
groups
=
conv_module
.
groups
filters_per_channel
=
out_channels
//
groups
conv_per_position_flops
=
int
(
np
.
prod
(
kernel_dims
))
*
in_channels
*
filters_per_channel
active_elements_count
=
batch_size
*
int
(
np
.
prod
(
output_dims
))
overall_conv_flops
=
conv_per_position_flops
*
active_elements_count
bias_flops
=
0
if
conv_module
.
bias
is
not
None
:
bias_flops
=
out_channels
*
active_elements_count
overall_flops
=
overall_conv_flops
+
bias_flops
conv_module
.
__flops__
+=
int
(
overall_flops
)
def
batch_counter_hook
(
module
,
input
,
output
):
batch_size
=
1
if
len
(
input
)
>
0
:
# Can have multiple inputs, getting the first one
input
=
input
[
0
]
batch_size
=
len
(
input
)
else
:
pass
print
(
'Warning! No positional inputs found for a module, '
'assuming batch size is 1.'
)
module
.
__batch_counter__
+=
batch_size
def
add_batch_counter_variables_or_reset
(
module
):
module
.
__batch_counter__
=
0
def
add_batch_counter_hook_function
(
module
):
if
hasattr
(
module
,
'__batch_counter_handle__'
):
return
handle
=
module
.
register_forward_hook
(
batch_counter_hook
)
module
.
__batch_counter_handle__
=
handle
def
remove_batch_counter_hook_function
(
module
):
if
hasattr
(
module
,
'__batch_counter_handle__'
):
module
.
__batch_counter_handle__
.
remove
()
del
module
.
__batch_counter_handle__
def
add_flops_counter_variable_or_reset
(
module
):
if
is_supported_instance
(
module
):
if
hasattr
(
module
,
'__flops__'
)
or
hasattr
(
module
,
'__params__'
):
print
(
'Warning: variables __flops__ or __params__ are already '
'defined for the module'
+
type
(
module
).
__name__
+
' ptflops can affect your code!'
)
module
.
__flops__
=
0
module
.
__params__
=
get_model_parameters_number
(
module
)
def
is_supported_instance
(
module
):
if
type
(
module
)
in
get_modules_mapping
():
return
True
return
False
def
remove_flops_counter_hook_function
(
module
):
if
is_supported_instance
(
module
):
if
hasattr
(
module
,
'__flops_handle__'
):
module
.
__flops_handle__
.
remove
()
del
module
.
__flops_handle__
def
get_modules_mapping
():
return
{
# convolutions
nn
.
Conv1d
:
conv_flops_counter_hook
,
nn
.
Conv2d
:
conv_flops_counter_hook
,
mmcv
.
cnn
.
bricks
.
Conv2d
:
conv_flops_counter_hook
,
nn
.
Conv3d
:
conv_flops_counter_hook
,
mmcv
.
cnn
.
bricks
.
Conv3d
:
conv_flops_counter_hook
,
# activations
nn
.
ReLU
:
relu_flops_counter_hook
,
nn
.
PReLU
:
relu_flops_counter_hook
,
nn
.
ELU
:
relu_flops_counter_hook
,
nn
.
LeakyReLU
:
relu_flops_counter_hook
,
nn
.
ReLU6
:
relu_flops_counter_hook
,
# poolings
nn
.
MaxPool1d
:
pool_flops_counter_hook
,
nn
.
AvgPool1d
:
pool_flops_counter_hook
,
nn
.
AvgPool2d
:
pool_flops_counter_hook
,
nn
.
MaxPool2d
:
pool_flops_counter_hook
,
mmcv
.
cnn
.
bricks
.
MaxPool2d
:
pool_flops_counter_hook
,
nn
.
MaxPool3d
:
pool_flops_counter_hook
,
mmcv
.
cnn
.
bricks
.
MaxPool3d
:
pool_flops_counter_hook
,
nn
.
AvgPool3d
:
pool_flops_counter_hook
,
nn
.
AdaptiveMaxPool1d
:
pool_flops_counter_hook
,
nn
.
AdaptiveAvgPool1d
:
pool_flops_counter_hook
,
nn
.
AdaptiveMaxPool2d
:
pool_flops_counter_hook
,
nn
.
AdaptiveAvgPool2d
:
pool_flops_counter_hook
,
nn
.
AdaptiveMaxPool3d
:
pool_flops_counter_hook
,
nn
.
AdaptiveAvgPool3d
:
pool_flops_counter_hook
,
# normalizations
nn
.
BatchNorm1d
:
norm_flops_counter_hook
,
nn
.
BatchNorm2d
:
norm_flops_counter_hook
,
nn
.
BatchNorm3d
:
norm_flops_counter_hook
,
nn
.
GroupNorm
:
norm_flops_counter_hook
,
nn
.
InstanceNorm1d
:
norm_flops_counter_hook
,
nn
.
InstanceNorm2d
:
norm_flops_counter_hook
,
nn
.
InstanceNorm3d
:
norm_flops_counter_hook
,
nn
.
LayerNorm
:
norm_flops_counter_hook
,
# FC
nn
.
Linear
:
linear_flops_counter_hook
,
mmcv
.
cnn
.
bricks
.
Linear
:
linear_flops_counter_hook
,
# Upscale
nn
.
Upsample
:
upsample_flops_counter_hook
,
# Deconvolution
nn
.
ConvTranspose2d
:
deconv_flops_counter_hook
,
mmcv
.
cnn
.
bricks
.
ConvTranspose2d
:
deconv_flops_counter_hook
,
}
lavis/common/annotator/uniformer/mmcv/cnn/utils/fuse_conv_bn.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch.nn
as
nn
def
_fuse_conv_bn
(
conv
,
bn
):
"""Fuse conv and bn into one module.
Args:
conv (nn.Module): Conv to be fused.
bn (nn.Module): BN to be fused.
Returns:
nn.Module: Fused module.
"""
conv_w
=
conv
.
weight
conv_b
=
conv
.
bias
if
conv
.
bias
is
not
None
else
torch
.
zeros_like
(
bn
.
running_mean
)
factor
=
bn
.
weight
/
torch
.
sqrt
(
bn
.
running_var
+
bn
.
eps
)
conv
.
weight
=
nn
.
Parameter
(
conv_w
*
factor
.
reshape
([
conv
.
out_channels
,
1
,
1
,
1
]))
conv
.
bias
=
nn
.
Parameter
((
conv_b
-
bn
.
running_mean
)
*
factor
+
bn
.
bias
)
return
conv
def
fuse_conv_bn
(
module
):
"""Recursively fuse conv and bn in a module.
During inference, the functionary of batch norm layers is turned off
but only the mean and var alone channels are used, which exposes the
chance to fuse it with the preceding conv layers to save computations and
simplify network structures.
Args:
module (nn.Module): Module to be fused.
Returns:
nn.Module: Fused module.
"""
last_conv
=
None
last_conv_name
=
None
for
name
,
child
in
module
.
named_children
():
if
isinstance
(
child
,
(
nn
.
modules
.
batchnorm
.
_BatchNorm
,
nn
.
SyncBatchNorm
)):
if
last_conv
is
None
:
# only fuse BN that is after Conv
continue
fused_conv
=
_fuse_conv_bn
(
last_conv
,
child
)
module
.
_modules
[
last_conv_name
]
=
fused_conv
# To reduce changes, set BN as Identity instead of deleting it.
module
.
_modules
[
name
]
=
nn
.
Identity
()
last_conv
=
None
elif
isinstance
(
child
,
nn
.
Conv2d
):
last_conv
=
child
last_conv_name
=
name
else
:
fuse_conv_bn
(
child
)
return
module
lavis/common/annotator/uniformer/mmcv/cnn/utils/sync_bn.py
0 → 100644
View file @
c04f261a
import
torch
import
annotator.uniformer.mmcv
as
mmcv
class
_BatchNormXd
(
torch
.
nn
.
modules
.
batchnorm
.
_BatchNorm
):
"""A general BatchNorm layer without input dimension check.
Reproduced from @kapily's work:
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
is `_check_input_dim` that is designed for tensor sanity checks.
The check has been bypassed in this class for the convenience of converting
SyncBatchNorm.
"""
def
_check_input_dim
(
self
,
input
):
return
def
revert_sync_batchnorm
(
module
):
"""Helper function to convert all `SyncBatchNorm` (SyncBN) and
`mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to
`BatchNormXd` layers.
Adapted from @kapily's work:
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
Args:
module (nn.Module): The module containing `SyncBatchNorm` layers.
Returns:
module_output: The converted module with `BatchNormXd` layers.
"""
module_output
=
module
module_checklist
=
[
torch
.
nn
.
modules
.
batchnorm
.
SyncBatchNorm
]
if
hasattr
(
mmcv
,
'ops'
):
module_checklist
.
append
(
mmcv
.
ops
.
SyncBatchNorm
)
if
isinstance
(
module
,
tuple
(
module_checklist
)):
module_output
=
_BatchNormXd
(
module
.
num_features
,
module
.
eps
,
module
.
momentum
,
module
.
affine
,
module
.
track_running_stats
)
if
module
.
affine
:
# no_grad() may not be needed here but
# just to be consistent with `convert_sync_batchnorm()`
with
torch
.
no_grad
():
module_output
.
weight
=
module
.
weight
module_output
.
bias
=
module
.
bias
module_output
.
running_mean
=
module
.
running_mean
module_output
.
running_var
=
module
.
running_var
module_output
.
num_batches_tracked
=
module
.
num_batches_tracked
module_output
.
training
=
module
.
training
# qconfig exists in quantized models
if
hasattr
(
module
,
'qconfig'
):
module_output
.
qconfig
=
module
.
qconfig
for
name
,
child
in
module
.
named_children
():
module_output
.
add_module
(
name
,
revert_sync_batchnorm
(
child
))
del
module
return
module_output
lavis/common/annotator/uniformer/mmcv/cnn/utils/weight_init.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
import
copy
import
math
import
warnings
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
torch
import
Tensor
from
annotator.uniformer.mmcv.utils
import
Registry
,
build_from_cfg
,
get_logger
,
print_log
INITIALIZERS
=
Registry
(
'initializer'
)
def
update_init_info
(
module
,
init_info
):
"""Update the `_params_init_info` in the module if the value of parameters
are changed.
Args:
module (obj:`nn.Module`): The module of PyTorch with a user-defined
attribute `_params_init_info` which records the initialization
information.
init_info (str): The string that describes the initialization.
"""
assert
hasattr
(
module
,
'_params_init_info'
),
f
'Can not find `_params_init_info` in
{
module
}
'
for
name
,
param
in
module
.
named_parameters
():
assert
param
in
module
.
_params_init_info
,
(
f
'Find a new :obj:`Parameter` '
f
'named `
{
name
}
` during executing the '
f
'`init_weights` of '
f
'`
{
module
.
__class__
.
__name__
}
`. '
f
'Please do not add or '
f
'replace parameters during executing '
f
'the `init_weights`. '
)
# The parameter has been changed during executing the
# `init_weights` of module
mean_value
=
param
.
data
.
mean
()
if
module
.
_params_init_info
[
param
][
'tmp_mean_value'
]
!=
mean_value
:
module
.
_params_init_info
[
param
][
'init_info'
]
=
init_info
module
.
_params_init_info
[
param
][
'tmp_mean_value'
]
=
mean_value
def
constant_init
(
module
,
val
,
bias
=
0
):
if
hasattr
(
module
,
'weight'
)
and
module
.
weight
is
not
None
:
nn
.
init
.
constant_
(
module
.
weight
,
val
)
if
hasattr
(
module
,
'bias'
)
and
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
bias
)
def
xavier_init
(
module
,
gain
=
1
,
bias
=
0
,
distribution
=
'normal'
):
assert
distribution
in
[
'uniform'
,
'normal'
]
if
hasattr
(
module
,
'weight'
)
and
module
.
weight
is
not
None
:
if
distribution
==
'uniform'
:
nn
.
init
.
xavier_uniform_
(
module
.
weight
,
gain
=
gain
)
else
:
nn
.
init
.
xavier_normal_
(
module
.
weight
,
gain
=
gain
)
if
hasattr
(
module
,
'bias'
)
and
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
bias
)
def
normal_init
(
module
,
mean
=
0
,
std
=
1
,
bias
=
0
):
if
hasattr
(
module
,
'weight'
)
and
module
.
weight
is
not
None
:
nn
.
init
.
normal_
(
module
.
weight
,
mean
,
std
)
if
hasattr
(
module
,
'bias'
)
and
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
bias
)
def
trunc_normal_init
(
module
:
nn
.
Module
,
mean
:
float
=
0
,
std
:
float
=
1
,
a
:
float
=
-
2
,
b
:
float
=
2
,
bias
:
float
=
0
)
->
None
:
if
hasattr
(
module
,
'weight'
)
and
module
.
weight
is
not
None
:
trunc_normal_
(
module
.
weight
,
mean
,
std
,
a
,
b
)
# type: ignore
if
hasattr
(
module
,
'bias'
)
and
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
bias
)
# type: ignore
def
uniform_init
(
module
,
a
=
0
,
b
=
1
,
bias
=
0
):
if
hasattr
(
module
,
'weight'
)
and
module
.
weight
is
not
None
:
nn
.
init
.
uniform_
(
module
.
weight
,
a
,
b
)
if
hasattr
(
module
,
'bias'
)
and
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
bias
)
def
kaiming_init
(
module
,
a
=
0
,
mode
=
'fan_out'
,
nonlinearity
=
'relu'
,
bias
=
0
,
distribution
=
'normal'
):
assert
distribution
in
[
'uniform'
,
'normal'
]
if
hasattr
(
module
,
'weight'
)
and
module
.
weight
is
not
None
:
if
distribution
==
'uniform'
:
nn
.
init
.
kaiming_uniform_
(
module
.
weight
,
a
=
a
,
mode
=
mode
,
nonlinearity
=
nonlinearity
)
else
:
nn
.
init
.
kaiming_normal_
(
module
.
weight
,
a
=
a
,
mode
=
mode
,
nonlinearity
=
nonlinearity
)
if
hasattr
(
module
,
'bias'
)
and
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
bias
)
def
caffe2_xavier_init
(
module
,
bias
=
0
):
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
# Acknowledgment to FAIR's internal code
kaiming_init
(
module
,
a
=
1
,
mode
=
'fan_in'
,
nonlinearity
=
'leaky_relu'
,
bias
=
bias
,
distribution
=
'uniform'
)
def
bias_init_with_prob
(
prior_prob
):
"""initialize conv/fc bias value according to a given probability value."""
bias_init
=
float
(
-
np
.
log
((
1
-
prior_prob
)
/
prior_prob
))
return
bias_init
def
_get_bases_name
(
m
):
return
[
b
.
__name__
for
b
in
m
.
__class__
.
__bases__
]
class
BaseInit
(
object
):
def
__init__
(
self
,
*
,
bias
=
0
,
bias_prob
=
None
,
layer
=
None
):
self
.
wholemodule
=
False
if
not
isinstance
(
bias
,
(
int
,
float
)):
raise
TypeError
(
f
'bias must be a number, but got a
{
type
(
bias
)
}
'
)
if
bias_prob
is
not
None
:
if
not
isinstance
(
bias_prob
,
float
):
raise
TypeError
(
f
'bias_prob type must be float,
\
but got
{
type
(
bias_prob
)
}
'
)
if
layer
is
not
None
:
if
not
isinstance
(
layer
,
(
str
,
list
)):
raise
TypeError
(
f
'layer must be a str or a list of str,
\
but got a
{
type
(
layer
)
}
'
)
else
:
layer
=
[]
if
bias_prob
is
not
None
:
self
.
bias
=
bias_init_with_prob
(
bias_prob
)
else
:
self
.
bias
=
bias
self
.
layer
=
[
layer
]
if
isinstance
(
layer
,
str
)
else
layer
def
_get_init_info
(
self
):
info
=
f
'
{
self
.
__class__
.
__name__
}
, bias=
{
self
.
bias
}
'
return
info
@
INITIALIZERS
.
register_module
(
name
=
'Constant'
)
class
ConstantInit
(
BaseInit
):
"""Initialize module parameters with constant values.
Args:
val (int | float): the value to fill the weights in the module with
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def
__init__
(
self
,
val
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
val
=
val
def
__call__
(
self
,
module
):
def
init
(
m
):
if
self
.
wholemodule
:
constant_init
(
m
,
self
.
val
,
self
.
bias
)
else
:
layername
=
m
.
__class__
.
__name__
basesname
=
_get_bases_name
(
m
)
if
len
(
set
(
self
.
layer
)
&
set
([
layername
]
+
basesname
)):
constant_init
(
m
,
self
.
val
,
self
.
bias
)
module
.
apply
(
init
)
if
hasattr
(
module
,
'_params_init_info'
):
update_init_info
(
module
,
init_info
=
self
.
_get_init_info
())
def
_get_init_info
(
self
):
info
=
f
'
{
self
.
__class__
.
__name__
}
: val=
{
self
.
val
}
, bias=
{
self
.
bias
}
'
return
info
@
INITIALIZERS
.
register_module
(
name
=
'Xavier'
)
class
XavierInit
(
BaseInit
):
r
"""Initialize module parameters with values according to the method
described in `Understanding the difficulty of training deep feedforward
neural networks - Glorot, X. & Bengio, Y. (2010).
<http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_
Args:
gain (int | float): an optional scaling factor. Defaults to 1.
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
distribution (str): distribution either be ``'normal'``
or ``'uniform'``. Defaults to ``'normal'``.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def
__init__
(
self
,
gain
=
1
,
distribution
=
'normal'
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
gain
=
gain
self
.
distribution
=
distribution
def
__call__
(
self
,
module
):
def
init
(
m
):
if
self
.
wholemodule
:
xavier_init
(
m
,
self
.
gain
,
self
.
bias
,
self
.
distribution
)
else
:
layername
=
m
.
__class__
.
__name__
basesname
=
_get_bases_name
(
m
)
if
len
(
set
(
self
.
layer
)
&
set
([
layername
]
+
basesname
)):
xavier_init
(
m
,
self
.
gain
,
self
.
bias
,
self
.
distribution
)
module
.
apply
(
init
)
if
hasattr
(
module
,
'_params_init_info'
):
update_init_info
(
module
,
init_info
=
self
.
_get_init_info
())
def
_get_init_info
(
self
):
info
=
f
'
{
self
.
__class__
.
__name__
}
: gain=
{
self
.
gain
}
, '
\
f
'distribution=
{
self
.
distribution
}
, bias=
{
self
.
bias
}
'
return
info
@
INITIALIZERS
.
register_module
(
name
=
'Normal'
)
class
NormalInit
(
BaseInit
):
r
"""Initialize module parameters with the values drawn from the normal
distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
Args:
mean (int | float):the mean of the normal distribution. Defaults to 0.
std (int | float): the standard deviation of the normal distribution.
Defaults to 1.
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def
__init__
(
self
,
mean
=
0
,
std
=
1
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
mean
=
mean
self
.
std
=
std
def
__call__
(
self
,
module
):
def
init
(
m
):
if
self
.
wholemodule
:
normal_init
(
m
,
self
.
mean
,
self
.
std
,
self
.
bias
)
else
:
layername
=
m
.
__class__
.
__name__
basesname
=
_get_bases_name
(
m
)
if
len
(
set
(
self
.
layer
)
&
set
([
layername
]
+
basesname
)):
normal_init
(
m
,
self
.
mean
,
self
.
std
,
self
.
bias
)
module
.
apply
(
init
)
if
hasattr
(
module
,
'_params_init_info'
):
update_init_info
(
module
,
init_info
=
self
.
_get_init_info
())
def
_get_init_info
(
self
):
info
=
f
'
{
self
.
__class__
.
__name__
}
: mean=
{
self
.
mean
}
,'
\
f
' std=
{
self
.
std
}
, bias=
{
self
.
bias
}
'
return
info
@
INITIALIZERS
.
register_module
(
name
=
'TruncNormal'
)
class
TruncNormalInit
(
BaseInit
):
r
"""Initialize module parameters with the values drawn from the normal
distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values
outside :math:`[a, b]`.
Args:
mean (float): the mean of the normal distribution. Defaults to 0.
std (float): the standard deviation of the normal distribution.
Defaults to 1.
a (float): The minimum cutoff value.
b ( float): The maximum cutoff value.
bias (float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def
__init__
(
self
,
mean
:
float
=
0
,
std
:
float
=
1
,
a
:
float
=
-
2
,
b
:
float
=
2
,
**
kwargs
)
->
None
:
super
().
__init__
(
**
kwargs
)
self
.
mean
=
mean
self
.
std
=
std
self
.
a
=
a
self
.
b
=
b
def
__call__
(
self
,
module
:
nn
.
Module
)
->
None
:
def
init
(
m
):
if
self
.
wholemodule
:
trunc_normal_init
(
m
,
self
.
mean
,
self
.
std
,
self
.
a
,
self
.
b
,
self
.
bias
)
else
:
layername
=
m
.
__class__
.
__name__
basesname
=
_get_bases_name
(
m
)
if
len
(
set
(
self
.
layer
)
&
set
([
layername
]
+
basesname
)):
trunc_normal_init
(
m
,
self
.
mean
,
self
.
std
,
self
.
a
,
self
.
b
,
self
.
bias
)
module
.
apply
(
init
)
if
hasattr
(
module
,
'_params_init_info'
):
update_init_info
(
module
,
init_info
=
self
.
_get_init_info
())
def
_get_init_info
(
self
):
info
=
f
'
{
self
.
__class__
.
__name__
}
: a=
{
self
.
a
}
, b=
{
self
.
b
}
,'
\
f
' mean=
{
self
.
mean
}
, std=
{
self
.
std
}
, bias=
{
self
.
bias
}
'
return
info
@
INITIALIZERS
.
register_module
(
name
=
'Uniform'
)
class
UniformInit
(
BaseInit
):
r
"""Initialize module parameters with values drawn from the uniform
distribution :math:`\mathcal{U}(a, b)`.
Args:
a (int | float): the lower bound of the uniform distribution.
Defaults to 0.
b (int | float): the upper bound of the uniform distribution.
Defaults to 1.
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def
__init__
(
self
,
a
=
0
,
b
=
1
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
a
=
a
self
.
b
=
b
def
__call__
(
self
,
module
):
def
init
(
m
):
if
self
.
wholemodule
:
uniform_init
(
m
,
self
.
a
,
self
.
b
,
self
.
bias
)
else
:
layername
=
m
.
__class__
.
__name__
basesname
=
_get_bases_name
(
m
)
if
len
(
set
(
self
.
layer
)
&
set
([
layername
]
+
basesname
)):
uniform_init
(
m
,
self
.
a
,
self
.
b
,
self
.
bias
)
module
.
apply
(
init
)
if
hasattr
(
module
,
'_params_init_info'
):
update_init_info
(
module
,
init_info
=
self
.
_get_init_info
())
def
_get_init_info
(
self
):
info
=
f
'
{
self
.
__class__
.
__name__
}
: a=
{
self
.
a
}
,'
\
f
' b=
{
self
.
b
}
, bias=
{
self
.
bias
}
'
return
info
@
INITIALIZERS
.
register_module
(
name
=
'Kaiming'
)
class
KaimingInit
(
BaseInit
):
r
"""Initialize module parameters with the values according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification - He, K. et al. (2015).
<https://www.cv-foundation.org/openaccess/content_iccv_2015/
papers/He_Delving_Deep_into_ICCV_2015_paper.pdf>`_
Args:
a (int | float): the negative slope of the rectifier used after this
layer (only used with ``'leaky_relu'``). Defaults to 0.
mode (str): either ``'fan_in'`` or ``'fan_out'``. Choosing
``'fan_in'`` preserves the magnitude of the variance of the weights
in the forward pass. Choosing ``'fan_out'`` preserves the
magnitudes in the backwards pass. Defaults to ``'fan_out'``.
nonlinearity (str): the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` .
Defaults to 'relu'.
bias (int | float): the value to fill the bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
distribution (str): distribution either be ``'normal'`` or
``'uniform'``. Defaults to ``'normal'``.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def
__init__
(
self
,
a
=
0
,
mode
=
'fan_out'
,
nonlinearity
=
'relu'
,
distribution
=
'normal'
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
a
=
a
self
.
mode
=
mode
self
.
nonlinearity
=
nonlinearity
self
.
distribution
=
distribution
def
__call__
(
self
,
module
):
def
init
(
m
):
if
self
.
wholemodule
:
kaiming_init
(
m
,
self
.
a
,
self
.
mode
,
self
.
nonlinearity
,
self
.
bias
,
self
.
distribution
)
else
:
layername
=
m
.
__class__
.
__name__
basesname
=
_get_bases_name
(
m
)
if
len
(
set
(
self
.
layer
)
&
set
([
layername
]
+
basesname
)):
kaiming_init
(
m
,
self
.
a
,
self
.
mode
,
self
.
nonlinearity
,
self
.
bias
,
self
.
distribution
)
module
.
apply
(
init
)
if
hasattr
(
module
,
'_params_init_info'
):
update_init_info
(
module
,
init_info
=
self
.
_get_init_info
())
def
_get_init_info
(
self
):
info
=
f
'
{
self
.
__class__
.
__name__
}
: a=
{
self
.
a
}
, mode=
{
self
.
mode
}
, '
\
f
'nonlinearity=
{
self
.
nonlinearity
}
, '
\
f
'distribution =
{
self
.
distribution
}
, bias=
{
self
.
bias
}
'
return
info
@
INITIALIZERS
.
register_module
(
name
=
'Caffe2Xavier'
)
class
Caffe2XavierInit
(
KaimingInit
):
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
# Acknowledgment to FAIR's internal code
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
a
=
1
,
mode
=
'fan_in'
,
nonlinearity
=
'leaky_relu'
,
distribution
=
'uniform'
,
**
kwargs
)
def
__call__
(
self
,
module
):
super
().
__call__
(
module
)
@
INITIALIZERS
.
register_module
(
name
=
'Pretrained'
)
class
PretrainedInit
(
object
):
"""Initialize module by loading a pretrained model.
Args:
checkpoint (str): the checkpoint file of the pretrained model should
be load.
prefix (str, optional): the prefix of a sub-module in the pretrained
model. it is for loading a part of the pretrained model to
initialize. For example, if we would like to only load the
backbone of a detector model, we can set ``prefix='backbone.'``.
Defaults to None.
map_location (str): map tensors into proper locations.
"""
def
__init__
(
self
,
checkpoint
,
prefix
=
None
,
map_location
=
None
):
self
.
checkpoint
=
checkpoint
self
.
prefix
=
prefix
self
.
map_location
=
map_location
def
__call__
(
self
,
module
):
from
annotator.uniformer.mmcv.runner
import
(
_load_checkpoint_with_prefix
,
load_checkpoint
,
load_state_dict
)
logger
=
get_logger
(
'mmcv'
)
if
self
.
prefix
is
None
:
print_log
(
f
'load model from:
{
self
.
checkpoint
}
'
,
logger
=
logger
)
load_checkpoint
(
module
,
self
.
checkpoint
,
map_location
=
self
.
map_location
,
strict
=
False
,
logger
=
logger
)
else
:
print_log
(
f
'load
{
self
.
prefix
}
in model from:
{
self
.
checkpoint
}
'
,
logger
=
logger
)
state_dict
=
_load_checkpoint_with_prefix
(
self
.
prefix
,
self
.
checkpoint
,
map_location
=
self
.
map_location
)
load_state_dict
(
module
,
state_dict
,
strict
=
False
,
logger
=
logger
)
if
hasattr
(
module
,
'_params_init_info'
):
update_init_info
(
module
,
init_info
=
self
.
_get_init_info
())
def
_get_init_info
(
self
):
info
=
f
'
{
self
.
__class__
.
__name__
}
: load from
{
self
.
checkpoint
}
'
return
info
def
_initialize
(
module
,
cfg
,
wholemodule
=
False
):
func
=
build_from_cfg
(
cfg
,
INITIALIZERS
)
# wholemodule flag is for override mode, there is no layer key in override
# and initializer will give init values for the whole module with the name
# in override.
func
.
wholemodule
=
wholemodule
func
(
module
)
def
_initialize_override
(
module
,
override
,
cfg
):
if
not
isinstance
(
override
,
(
dict
,
list
)):
raise
TypeError
(
f
'override must be a dict or a list of dict,
\
but got
{
type
(
override
)
}
'
)
override
=
[
override
]
if
isinstance
(
override
,
dict
)
else
override
for
override_
in
override
:
cp_override
=
copy
.
deepcopy
(
override_
)
name
=
cp_override
.
pop
(
'name'
,
None
)
if
name
is
None
:
raise
ValueError
(
'`override` must contain the key "name",'
f
'but got
{
cp_override
}
'
)
# if override only has name key, it means use args in init_cfg
if
not
cp_override
:
cp_override
.
update
(
cfg
)
# if override has name key and other args except type key, it will
# raise error
elif
'type'
not
in
cp_override
.
keys
():
raise
ValueError
(
f
'`override` need "type" key, but got
{
cp_override
}
'
)
if
hasattr
(
module
,
name
):
_initialize
(
getattr
(
module
,
name
),
cp_override
,
wholemodule
=
True
)
else
:
raise
RuntimeError
(
f
'module did not have attribute
{
name
}
, '
f
'but init_cfg is
{
cp_override
}
.'
)
def
initialize
(
module
,
init_cfg
):
"""Initialize a module.
Args:
module (``torch.nn.Module``): the module will be initialized.
init_cfg (dict | list[dict]): initialization configuration dict to
define initializer. OpenMMLab has implemented 6 initializers
including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
``Kaiming``, and ``Pretrained``.
Example:
>>> module = nn.Linear(2, 3, bias=True)
>>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2)
>>> initialize(module, init_cfg)
>>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2))
>>> # define key ``'layer'`` for initializing layer with different
>>> # configuration
>>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1),
dict(type='Constant', layer='Linear', val=2)]
>>> initialize(module, init_cfg)
>>> # define key``'override'`` to initialize some specific part in
>>> # module
>>> class FooNet(nn.Module):
>>> def __init__(self):
>>> super().__init__()
>>> self.feat = nn.Conv2d(3, 16, 3)
>>> self.reg = nn.Conv2d(16, 10, 3)
>>> self.cls = nn.Conv2d(16, 5, 3)
>>> model = FooNet()
>>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d',
>>> override=dict(type='Constant', name='reg', val=3, bias=4))
>>> initialize(model, init_cfg)
>>> model = ResNet(depth=50)
>>> # Initialize weights with the pretrained model.
>>> init_cfg = dict(type='Pretrained',
checkpoint='torchvision://resnet50')
>>> initialize(model, init_cfg)
>>> # Initialize weights of a sub-module with the specific part of
>>> # a pretrained model by using "prefix".
>>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'
\
>>> 'retinanet_r50_fpn_1x_coco/'
\
>>> 'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth'
>>> init_cfg = dict(type='Pretrained',
checkpoint=url, prefix='backbone.')
"""
if
not
isinstance
(
init_cfg
,
(
dict
,
list
)):
raise
TypeError
(
f
'init_cfg must be a dict or a list of dict,
\
but got
{
type
(
init_cfg
)
}
'
)
if
isinstance
(
init_cfg
,
dict
):
init_cfg
=
[
init_cfg
]
for
cfg
in
init_cfg
:
# should deeply copy the original config because cfg may be used by
# other modules, e.g., one init_cfg shared by multiple bottleneck
# blocks, the expected cfg will be changed after pop and will change
# the initialization behavior of other modules
cp_cfg
=
copy
.
deepcopy
(
cfg
)
override
=
cp_cfg
.
pop
(
'override'
,
None
)
_initialize
(
module
,
cp_cfg
)
if
override
is
not
None
:
cp_cfg
.
pop
(
'layer'
,
None
)
_initialize_override
(
module
,
override
,
cp_cfg
)
else
:
# All attributes in module have same initialization.
pass
def
_no_grad_trunc_normal_
(
tensor
:
Tensor
,
mean
:
float
,
std
:
float
,
a
:
float
,
b
:
float
)
->
Tensor
:
# Method based on
# https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
# Modified from
# https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
def
norm_cdf
(
x
):
# Computes standard normal cumulative distribution function
return
(
1.
+
math
.
erf
(
x
/
math
.
sqrt
(
2.
)))
/
2.
if
(
mean
<
a
-
2
*
std
)
or
(
mean
>
b
+
2
*
std
):
warnings
.
warn
(
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
'The distribution of values may be incorrect.'
,
stacklevel
=
2
)
with
torch
.
no_grad
():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
lower
=
norm_cdf
((
a
-
mean
)
/
std
)
upper
=
norm_cdf
((
b
-
mean
)
/
std
)
# Uniformly fill tensor with values from [lower, upper], then translate
# to [2lower-1, 2upper-1].
tensor
.
uniform_
(
2
*
lower
-
1
,
2
*
upper
-
1
)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor
.
erfinv_
()
# Transform to proper mean, std
tensor
.
mul_
(
std
*
math
.
sqrt
(
2.
))
tensor
.
add_
(
mean
)
# Clamp to ensure it's in the proper range
tensor
.
clamp_
(
min
=
a
,
max
=
b
)
return
tensor
def
trunc_normal_
(
tensor
:
Tensor
,
mean
:
float
=
0.
,
std
:
float
=
1.
,
a
:
float
=
-
2.
,
b
:
float
=
2.
)
->
Tensor
:
r
"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Modified from
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
Args:
tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
mean (float): the mean of the normal distribution.
std (float): the standard deviation of the normal distribution.
a (float): the minimum cutoff value.
b (float): the maximum cutoff value.
"""
return
_no_grad_trunc_normal_
(
tensor
,
mean
,
std
,
a
,
b
)
lavis/common/annotator/uniformer/mmcv/cnn/vgg.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
import
logging
import
torch.nn
as
nn
from
.utils
import
constant_init
,
kaiming_init
,
normal_init
def
conv3x3
(
in_planes
,
out_planes
,
dilation
=
1
):
"""3x3 convolution with padding."""
return
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
padding
=
dilation
,
dilation
=
dilation
)
def
make_vgg_layer
(
inplanes
,
planes
,
num_blocks
,
dilation
=
1
,
with_bn
=
False
,
ceil_mode
=
False
):
layers
=
[]
for
_
in
range
(
num_blocks
):
layers
.
append
(
conv3x3
(
inplanes
,
planes
,
dilation
))
if
with_bn
:
layers
.
append
(
nn
.
BatchNorm2d
(
planes
))
layers
.
append
(
nn
.
ReLU
(
inplace
=
True
))
inplanes
=
planes
layers
.
append
(
nn
.
MaxPool2d
(
kernel_size
=
2
,
stride
=
2
,
ceil_mode
=
ceil_mode
))
return
layers
class
VGG
(
nn
.
Module
):
"""VGG backbone.
Args:
depth (int): Depth of vgg, from {11, 13, 16, 19}.
with_bn (bool): Use BatchNorm or not.
num_classes (int): number of classes for classification.
num_stages (int): VGG stages, normally 5.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
running stats (mean and var).
bn_frozen (bool): Whether to freeze weight and bias of BN layers.
"""
arch_settings
=
{
11
:
(
1
,
1
,
2
,
2
,
2
),
13
:
(
2
,
2
,
2
,
2
,
2
),
16
:
(
2
,
2
,
3
,
3
,
3
),
19
:
(
2
,
2
,
4
,
4
,
4
)
}
def
__init__
(
self
,
depth
,
with_bn
=
False
,
num_classes
=-
1
,
num_stages
=
5
,
dilations
=
(
1
,
1
,
1
,
1
,
1
),
out_indices
=
(
0
,
1
,
2
,
3
,
4
),
frozen_stages
=-
1
,
bn_eval
=
True
,
bn_frozen
=
False
,
ceil_mode
=
False
,
with_last_pool
=
True
):
super
(
VGG
,
self
).
__init__
()
if
depth
not
in
self
.
arch_settings
:
raise
KeyError
(
f
'invalid depth
{
depth
}
for vgg'
)
assert
num_stages
>=
1
and
num_stages
<=
5
stage_blocks
=
self
.
arch_settings
[
depth
]
self
.
stage_blocks
=
stage_blocks
[:
num_stages
]
assert
len
(
dilations
)
==
num_stages
assert
max
(
out_indices
)
<=
num_stages
self
.
num_classes
=
num_classes
self
.
out_indices
=
out_indices
self
.
frozen_stages
=
frozen_stages
self
.
bn_eval
=
bn_eval
self
.
bn_frozen
=
bn_frozen
self
.
inplanes
=
3
start_idx
=
0
vgg_layers
=
[]
self
.
range_sub_modules
=
[]
for
i
,
num_blocks
in
enumerate
(
self
.
stage_blocks
):
num_modules
=
num_blocks
*
(
2
+
with_bn
)
+
1
end_idx
=
start_idx
+
num_modules
dilation
=
dilations
[
i
]
planes
=
64
*
2
**
i
if
i
<
4
else
512
vgg_layer
=
make_vgg_layer
(
self
.
inplanes
,
planes
,
num_blocks
,
dilation
=
dilation
,
with_bn
=
with_bn
,
ceil_mode
=
ceil_mode
)
vgg_layers
.
extend
(
vgg_layer
)
self
.
inplanes
=
planes
self
.
range_sub_modules
.
append
([
start_idx
,
end_idx
])
start_idx
=
end_idx
if
not
with_last_pool
:
vgg_layers
.
pop
(
-
1
)
self
.
range_sub_modules
[
-
1
][
1
]
-=
1
self
.
module_name
=
'features'
self
.
add_module
(
self
.
module_name
,
nn
.
Sequential
(
*
vgg_layers
))
if
self
.
num_classes
>
0
:
self
.
classifier
=
nn
.
Sequential
(
nn
.
Linear
(
512
*
7
*
7
,
4096
),
nn
.
ReLU
(
True
),
nn
.
Dropout
(),
nn
.
Linear
(
4096
,
4096
),
nn
.
ReLU
(
True
),
nn
.
Dropout
(),
nn
.
Linear
(
4096
,
num_classes
),
)
def
init_weights
(
self
,
pretrained
=
None
):
if
isinstance
(
pretrained
,
str
):
logger
=
logging
.
getLogger
()
from
..runner
import
load_checkpoint
load_checkpoint
(
self
,
pretrained
,
strict
=
False
,
logger
=
logger
)
elif
pretrained
is
None
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
kaiming_init
(
m
)
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
constant_init
(
m
,
1
)
elif
isinstance
(
m
,
nn
.
Linear
):
normal_init
(
m
,
std
=
0.01
)
else
:
raise
TypeError
(
'pretrained must be a str or None'
)
def
forward
(
self
,
x
):
outs
=
[]
vgg_layers
=
getattr
(
self
,
self
.
module_name
)
for
i
in
range
(
len
(
self
.
stage_blocks
)):
for
j
in
range
(
*
self
.
range_sub_modules
[
i
]):
vgg_layer
=
vgg_layers
[
j
]
x
=
vgg_layer
(
x
)
if
i
in
self
.
out_indices
:
outs
.
append
(
x
)
if
self
.
num_classes
>
0
:
x
=
x
.
view
(
x
.
size
(
0
),
-
1
)
x
=
self
.
classifier
(
x
)
outs
.
append
(
x
)
if
len
(
outs
)
==
1
:
return
outs
[
0
]
else
:
return
tuple
(
outs
)
def
train
(
self
,
mode
=
True
):
super
(
VGG
,
self
).
train
(
mode
)
if
self
.
bn_eval
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
BatchNorm2d
):
m
.
eval
()
if
self
.
bn_frozen
:
for
params
in
m
.
parameters
():
params
.
requires_grad
=
False
vgg_layers
=
getattr
(
self
,
self
.
module_name
)
if
mode
and
self
.
frozen_stages
>=
0
:
for
i
in
range
(
self
.
frozen_stages
):
for
j
in
range
(
*
self
.
range_sub_modules
[
i
]):
mod
=
vgg_layers
[
j
]
mod
.
eval
()
for
param
in
mod
.
parameters
():
param
.
requires_grad
=
False
lavis/common/annotator/uniformer/mmcv/engine/__init__.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
from
.test
import
(
collect_results_cpu
,
collect_results_gpu
,
multi_gpu_test
,
single_gpu_test
)
__all__
=
[
'collect_results_cpu'
,
'collect_results_gpu'
,
'multi_gpu_test'
,
'single_gpu_test'
]
lavis/common/annotator/uniformer/mmcv/engine/test.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
import
pickle
import
shutil
import
tempfile
import
time
import
torch
import
torch.distributed
as
dist
import
annotator.uniformer.mmcv
as
mmcv
from
annotator.uniformer.mmcv.runner
import
get_dist_info
def
single_gpu_test
(
model
,
data_loader
):
"""Test model with a single gpu.
This method tests model with a single gpu and displays test progress bar.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): Pytorch data loader.
Returns:
list: The prediction results.
"""
model
.
eval
()
results
=
[]
dataset
=
data_loader
.
dataset
prog_bar
=
mmcv
.
ProgressBar
(
len
(
dataset
))
for
data
in
data_loader
:
with
torch
.
no_grad
():
result
=
model
(
return_loss
=
False
,
**
data
)
results
.
extend
(
result
)
# Assume result has the same length of batch_size
# refer to https://github.com/open-mmlab/mmcv/issues/985
batch_size
=
len
(
result
)
for
_
in
range
(
batch_size
):
prog_bar
.
update
()
return
results
def
multi_gpu_test
(
model
,
data_loader
,
tmpdir
=
None
,
gpu_collect
=
False
):
"""Test model with multiple gpus.
This method tests model with multiple gpus and collects the results
under two different modes: gpu and cpu modes. By setting
``gpu_collect=True``, it encodes results to gpu tensors and use gpu
communication for results collection. On cpu mode it saves the results on
different gpus to ``tmpdir`` and collects them by the rank 0 worker.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): Pytorch data loader.
tmpdir (str): Path of directory to save the temporary results from
different gpus under cpu mode.
gpu_collect (bool): Option to use either gpu or cpu to collect results.
Returns:
list: The prediction results.
"""
model
.
eval
()
results
=
[]
dataset
=
data_loader
.
dataset
rank
,
world_size
=
get_dist_info
()
if
rank
==
0
:
prog_bar
=
mmcv
.
ProgressBar
(
len
(
dataset
))
time
.
sleep
(
2
)
# This line can prevent deadlock problem in some cases.
for
i
,
data
in
enumerate
(
data_loader
):
with
torch
.
no_grad
():
result
=
model
(
return_loss
=
False
,
**
data
)
results
.
extend
(
result
)
if
rank
==
0
:
batch_size
=
len
(
result
)
batch_size_all
=
batch_size
*
world_size
if
batch_size_all
+
prog_bar
.
completed
>
len
(
dataset
):
batch_size_all
=
len
(
dataset
)
-
prog_bar
.
completed
for
_
in
range
(
batch_size_all
):
prog_bar
.
update
()
# collect results from all ranks
if
gpu_collect
:
results
=
collect_results_gpu
(
results
,
len
(
dataset
))
else
:
results
=
collect_results_cpu
(
results
,
len
(
dataset
),
tmpdir
)
return
results
def
collect_results_cpu
(
result_part
,
size
,
tmpdir
=
None
):
"""Collect results under cpu mode.
On cpu mode, this function will save the results on different gpus to
``tmpdir`` and collect them by the rank 0 worker.
Args:
result_part (list): Result list containing result parts
to be collected.
size (int): Size of the results, commonly equal to length of
the results.
tmpdir (str | None): temporal directory for collected results to
store. If set to None, it will create a random temporal directory
for it.
Returns:
list: The collected results.
"""
rank
,
world_size
=
get_dist_info
()
# create a tmp dir if it is not specified
if
tmpdir
is
None
:
MAX_LEN
=
512
# 32 is whitespace
dir_tensor
=
torch
.
full
((
MAX_LEN
,
),
32
,
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
if
rank
==
0
:
mmcv
.
mkdir_or_exist
(
'.dist_test'
)
tmpdir
=
tempfile
.
mkdtemp
(
dir
=
'.dist_test'
)
tmpdir
=
torch
.
tensor
(
bytearray
(
tmpdir
.
encode
()),
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
dir_tensor
[:
len
(
tmpdir
)]
=
tmpdir
dist
.
broadcast
(
dir_tensor
,
0
)
tmpdir
=
dir_tensor
.
cpu
().
numpy
().
tobytes
().
decode
().
rstrip
()
else
:
mmcv
.
mkdir_or_exist
(
tmpdir
)
# dump the part result to the dir
mmcv
.
dump
(
result_part
,
osp
.
join
(
tmpdir
,
f
'part_
{
rank
}
.pkl'
))
dist
.
barrier
()
# collect all parts
if
rank
!=
0
:
return
None
else
:
# load results of all parts from tmp dir
part_list
=
[]
for
i
in
range
(
world_size
):
part_file
=
osp
.
join
(
tmpdir
,
f
'part_
{
i
}
.pkl'
)
part_result
=
mmcv
.
load
(
part_file
)
# When data is severely insufficient, an empty part_result
# on a certain gpu could makes the overall outputs empty.
if
part_result
:
part_list
.
append
(
part_result
)
# sort the results
ordered_results
=
[]
for
res
in
zip
(
*
part_list
):
ordered_results
.
extend
(
list
(
res
))
# the dataloader may pad some samples
ordered_results
=
ordered_results
[:
size
]
# remove tmp dir
shutil
.
rmtree
(
tmpdir
)
return
ordered_results
def
collect_results_gpu
(
result_part
,
size
):
"""Collect results under gpu mode.
On gpu mode, this function will encode results to gpu tensors and use gpu
communication for results collection.
Args:
result_part (list): Result list containing result parts
to be collected.
size (int): Size of the results, commonly equal to length of
the results.
Returns:
list: The collected results.
"""
rank
,
world_size
=
get_dist_info
()
# dump result part to tensor with pickle
part_tensor
=
torch
.
tensor
(
bytearray
(
pickle
.
dumps
(
result_part
)),
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
# gather all result part tensor shape
shape_tensor
=
torch
.
tensor
(
part_tensor
.
shape
,
device
=
'cuda'
)
shape_list
=
[
shape_tensor
.
clone
()
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
shape_list
,
shape_tensor
)
# padding result part tensor to max length
shape_max
=
torch
.
tensor
(
shape_list
).
max
()
part_send
=
torch
.
zeros
(
shape_max
,
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
part_send
[:
shape_tensor
[
0
]]
=
part_tensor
part_recv_list
=
[
part_tensor
.
new_zeros
(
shape_max
)
for
_
in
range
(
world_size
)
]
# gather all result part
dist
.
all_gather
(
part_recv_list
,
part_send
)
if
rank
==
0
:
part_list
=
[]
for
recv
,
shape
in
zip
(
part_recv_list
,
shape_list
):
part_result
=
pickle
.
loads
(
recv
[:
shape
[
0
]].
cpu
().
numpy
().
tobytes
())
# When data is severely insufficient, an empty part_result
# on a certain gpu could makes the overall outputs empty.
if
part_result
:
part_list
.
append
(
part_result
)
# sort the results
ordered_results
=
[]
for
res
in
zip
(
*
part_list
):
ordered_results
.
extend
(
list
(
res
))
# the dataloader may pad some samples
ordered_results
=
ordered_results
[:
size
]
return
ordered_results
lavis/common/annotator/uniformer/mmcv/fileio/__init__.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
from
.file_client
import
BaseStorageBackend
,
FileClient
from
.handlers
import
BaseFileHandler
,
JsonHandler
,
PickleHandler
,
YamlHandler
from
.io
import
dump
,
load
,
register_handler
from
.parse
import
dict_from_file
,
list_from_file
__all__
=
[
'BaseStorageBackend'
,
'FileClient'
,
'load'
,
'dump'
,
'register_handler'
,
'BaseFileHandler'
,
'JsonHandler'
,
'PickleHandler'
,
'YamlHandler'
,
'list_from_file'
,
'dict_from_file'
]
lavis/common/annotator/uniformer/mmcv/fileio/file_client.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
import
inspect
import
os
import
os.path
as
osp
import
re
import
tempfile
import
warnings
from
abc
import
ABCMeta
,
abstractmethod
from
contextlib
import
contextmanager
from
pathlib
import
Path
from
typing
import
Iterable
,
Iterator
,
Optional
,
Tuple
,
Union
from
urllib.request
import
urlopen
import
annotator.uniformer.mmcv
as
mmcv
from
annotator.uniformer.mmcv.utils.misc
import
has_method
from
annotator.uniformer.mmcv.utils.path
import
is_filepath
class
BaseStorageBackend
(
metaclass
=
ABCMeta
):
"""Abstract class of storage backends.
All backends need to implement two apis: ``get()`` and ``get_text()``.
``get()`` reads the file as a byte stream and ``get_text()`` reads the file
as texts.
"""
# a flag to indicate whether the backend can create a symlink for a file
_allow_symlink
=
False
@
property
def
name
(
self
):
return
self
.
__class__
.
__name__
@
property
def
allow_symlink
(
self
):
return
self
.
_allow_symlink
@
abstractmethod
def
get
(
self
,
filepath
):
pass
@
abstractmethod
def
get_text
(
self
,
filepath
):
pass
class
CephBackend
(
BaseStorageBackend
):
"""Ceph storage backend (for internal use).
Args:
path_mapping (dict|None): path mapping dict from local path to Petrel
path. When ``path_mapping={'src': 'dst'}``, ``src`` in ``filepath``
will be replaced by ``dst``. Default: None.
.. warning::
:class:`mmcv.fileio.file_client.CephBackend` will be deprecated,
please use :class:`mmcv.fileio.file_client.PetrelBackend` instead.
"""
def
__init__
(
self
,
path_mapping
=
None
):
try
:
import
ceph
except
ImportError
:
raise
ImportError
(
'Please install ceph to enable CephBackend.'
)
warnings
.
warn
(
'CephBackend will be deprecated, please use PetrelBackend instead'
)
self
.
_client
=
ceph
.
S3Client
()
assert
isinstance
(
path_mapping
,
dict
)
or
path_mapping
is
None
self
.
path_mapping
=
path_mapping
def
get
(
self
,
filepath
):
filepath
=
str
(
filepath
)
if
self
.
path_mapping
is
not
None
:
for
k
,
v
in
self
.
path_mapping
.
items
():
filepath
=
filepath
.
replace
(
k
,
v
)
value
=
self
.
_client
.
Get
(
filepath
)
value_buf
=
memoryview
(
value
)
return
value_buf
def
get_text
(
self
,
filepath
,
encoding
=
None
):
raise
NotImplementedError
class
PetrelBackend
(
BaseStorageBackend
):
"""Petrel storage backend (for internal use).
PetrelBackend supports reading and writing data to multiple clusters.
If the file path contains the cluster name, PetrelBackend will read data
from specified cluster or write data to it. Otherwise, PetrelBackend will
access the default cluster.
Args:
path_mapping (dict, optional): Path mapping dict from local path to
Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in
``filepath`` will be replaced by ``dst``. Default: None.
enable_mc (bool, optional): Whether to enable memcached support.
Default: True.
Examples:
>>> filepath1 = 's3://path/of/file'
>>> filepath2 = 'cluster-name:s3://path/of/file'
>>> client = PetrelBackend()
>>> client.get(filepath1) # get data from default cluster
>>> client.get(filepath2) # get data from 'cluster-name' cluster
"""
def
__init__
(
self
,
path_mapping
:
Optional
[
dict
]
=
None
,
enable_mc
:
bool
=
True
):
try
:
from
petrel_client
import
client
except
ImportError
:
raise
ImportError
(
'Please install petrel_client to enable '
'PetrelBackend.'
)
self
.
_client
=
client
.
Client
(
enable_mc
=
enable_mc
)
assert
isinstance
(
path_mapping
,
dict
)
or
path_mapping
is
None
self
.
path_mapping
=
path_mapping
def
_map_path
(
self
,
filepath
:
Union
[
str
,
Path
])
->
str
:
"""Map ``filepath`` to a string path whose prefix will be replaced by
:attr:`self.path_mapping`.
Args:
filepath (str): Path to be mapped.
"""
filepath
=
str
(
filepath
)
if
self
.
path_mapping
is
not
None
:
for
k
,
v
in
self
.
path_mapping
.
items
():
filepath
=
filepath
.
replace
(
k
,
v
)
return
filepath
def
_format_path
(
self
,
filepath
:
str
)
->
str
:
"""Convert a ``filepath`` to standard format of petrel oss.
If the ``filepath`` is concatenated by ``os.path.join``, in a Windows
environment, the ``filepath`` will be the format of
's3://bucket_name
\\
image.jpg'. By invoking :meth:`_format_path`, the
above ``filepath`` will be converted to 's3://bucket_name/image.jpg'.
Args:
filepath (str): Path to be formatted.
"""
return
re
.
sub
(
r
'\\+'
,
'/'
,
filepath
)
def
get
(
self
,
filepath
:
Union
[
str
,
Path
])
->
memoryview
:
"""Read data from a given ``filepath`` with 'rb' mode.
Args:
filepath (str or Path): Path to read data.
Returns:
memoryview: A memory view of expected bytes object to avoid
copying. The memoryview object can be converted to bytes by
``value_buf.tobytes()``.
"""
filepath
=
self
.
_map_path
(
filepath
)
filepath
=
self
.
_format_path
(
filepath
)
value
=
self
.
_client
.
Get
(
filepath
)
value_buf
=
memoryview
(
value
)
return
value_buf
def
get_text
(
self
,
filepath
:
Union
[
str
,
Path
],
encoding
:
str
=
'utf-8'
)
->
str
:
"""Read data from a given ``filepath`` with 'r' mode.
Args:
filepath (str or Path): Path to read data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
Returns:
str: Expected text reading from ``filepath``.
"""
return
str
(
self
.
get
(
filepath
),
encoding
=
encoding
)
def
put
(
self
,
obj
:
bytes
,
filepath
:
Union
[
str
,
Path
])
->
None
:
"""Save data to a given ``filepath``.
Args:
obj (bytes): Data to be saved.
filepath (str or Path): Path to write data.
"""
filepath
=
self
.
_map_path
(
filepath
)
filepath
=
self
.
_format_path
(
filepath
)
self
.
_client
.
put
(
filepath
,
obj
)
def
put_text
(
self
,
obj
:
str
,
filepath
:
Union
[
str
,
Path
],
encoding
:
str
=
'utf-8'
)
->
None
:
"""Save data to a given ``filepath``.
Args:
obj (str): Data to be written.
filepath (str or Path): Path to write data.
encoding (str): The encoding format used to encode the ``obj``.
Default: 'utf-8'.
"""
self
.
put
(
bytes
(
obj
,
encoding
=
encoding
),
filepath
)
def
remove
(
self
,
filepath
:
Union
[
str
,
Path
])
->
None
:
"""Remove a file.
Args:
filepath (str or Path): Path to be removed.
"""
if
not
has_method
(
self
.
_client
,
'delete'
):
raise
NotImplementedError
(
(
'Current version of Petrel Python SDK has not supported '
'the `delete` method, please use a higher version or dev'
' branch instead.'
))
filepath
=
self
.
_map_path
(
filepath
)
filepath
=
self
.
_format_path
(
filepath
)
self
.
_client
.
delete
(
filepath
)
def
exists
(
self
,
filepath
:
Union
[
str
,
Path
])
->
bool
:
"""Check whether a file path exists.
Args:
filepath (str or Path): Path to be checked whether exists.
Returns:
bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
"""
if
not
(
has_method
(
self
.
_client
,
'contains'
)
and
has_method
(
self
.
_client
,
'isdir'
)):
raise
NotImplementedError
(
(
'Current version of Petrel Python SDK has not supported '
'the `contains` and `isdir` methods, please use a higher'
'version or dev branch instead.'
))
filepath
=
self
.
_map_path
(
filepath
)
filepath
=
self
.
_format_path
(
filepath
)
return
self
.
_client
.
contains
(
filepath
)
or
self
.
_client
.
isdir
(
filepath
)
def
isdir
(
self
,
filepath
:
Union
[
str
,
Path
])
->
bool
:
"""Check whether a file path is a directory.
Args:
filepath (str or Path): Path to be checked whether it is a
directory.
Returns:
bool: Return ``True`` if ``filepath`` points to a directory,
``False`` otherwise.
"""
if
not
has_method
(
self
.
_client
,
'isdir'
):
raise
NotImplementedError
(
(
'Current version of Petrel Python SDK has not supported '
'the `isdir` method, please use a higher version or dev'
' branch instead.'
))
filepath
=
self
.
_map_path
(
filepath
)
filepath
=
self
.
_format_path
(
filepath
)
return
self
.
_client
.
isdir
(
filepath
)
def
isfile
(
self
,
filepath
:
Union
[
str
,
Path
])
->
bool
:
"""Check whether a file path is a file.
Args:
filepath (str or Path): Path to be checked whether it is a file.
Returns:
bool: Return ``True`` if ``filepath`` points to a file, ``False``
otherwise.
"""
if
not
has_method
(
self
.
_client
,
'contains'
):
raise
NotImplementedError
(
(
'Current version of Petrel Python SDK has not supported '
'the `contains` method, please use a higher version or '
'dev branch instead.'
))
filepath
=
self
.
_map_path
(
filepath
)
filepath
=
self
.
_format_path
(
filepath
)
return
self
.
_client
.
contains
(
filepath
)
def
join_path
(
self
,
filepath
:
Union
[
str
,
Path
],
*
filepaths
:
Union
[
str
,
Path
])
->
str
:
"""Concatenate all file paths.
Args:
filepath (str or Path): Path to be concatenated.
Returns:
str: The result after concatenation.
"""
filepath
=
self
.
_format_path
(
self
.
_map_path
(
filepath
))
if
filepath
.
endswith
(
'/'
):
filepath
=
filepath
[:
-
1
]
formatted_paths
=
[
filepath
]
for
path
in
filepaths
:
formatted_paths
.
append
(
self
.
_format_path
(
self
.
_map_path
(
path
)))
return
'/'
.
join
(
formatted_paths
)
@
contextmanager
def
get_local_path
(
self
,
filepath
:
Union
[
str
,
Path
])
->
Iterable
[
str
]:
"""Download a file from ``filepath`` and return a temporary path.
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
can be called with ``with`` statement, and when exists from the
``with`` statement, the temporary path will be released.
Args:
filepath (str | Path): Download a file from ``filepath``.
Examples:
>>> client = PetrelBackend()
>>> # After existing from the ``with`` clause,
>>> # the path will be removed
>>> with client.get_local_path('s3://path/of/your/file') as path:
... # do something here
Yields:
Iterable[str]: Only yield one temporary path.
"""
filepath
=
self
.
_map_path
(
filepath
)
filepath
=
self
.
_format_path
(
filepath
)
assert
self
.
isfile
(
filepath
)
try
:
f
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
)
f
.
write
(
self
.
get
(
filepath
))
f
.
close
()
yield
f
.
name
finally
:
os
.
remove
(
f
.
name
)
def
list_dir_or_file
(
self
,
dir_path
:
Union
[
str
,
Path
],
list_dir
:
bool
=
True
,
list_file
:
bool
=
True
,
suffix
:
Optional
[
Union
[
str
,
Tuple
[
str
]]]
=
None
,
recursive
:
bool
=
False
)
->
Iterator
[
str
]:
"""Scan a directory to find the interested directories or files in
arbitrary order.
Note:
Petrel has no concept of directories but it simulates the directory
hierarchy in the filesystem through public prefixes. In addition,
if the returned path ends with '/', it means the path is a public
prefix which is a logical directory.
Note:
:meth:`list_dir_or_file` returns the path relative to ``dir_path``.
In addition, the returned path of directory will not contains the
suffix '/' which is consistent with other backends.
Args:
dir_path (str | Path): Path of the directory.
list_dir (bool): List the directories. Default: True.
list_file (bool): List the path of files. Default: True.
suffix (str or tuple[str], optional): File suffix
that we are interested in. Default: None.
recursive (bool): If set to True, recursively scan the
directory. Default: False.
Yields:
Iterable[str]: A relative path to ``dir_path``.
"""
if
not
has_method
(
self
.
_client
,
'list'
):
raise
NotImplementedError
(
(
'Current version of Petrel Python SDK has not supported '
'the `list` method, please use a higher version or dev'
' branch instead.'
))
dir_path
=
self
.
_map_path
(
dir_path
)
dir_path
=
self
.
_format_path
(
dir_path
)
if
list_dir
and
suffix
is
not
None
:
raise
TypeError
(
'`list_dir` should be False when `suffix` is not None'
)
if
(
suffix
is
not
None
)
and
not
isinstance
(
suffix
,
(
str
,
tuple
)):
raise
TypeError
(
'`suffix` must be a string or tuple of strings'
)
# Petrel's simulated directory hierarchy assumes that directory paths
# should end with `/`
if
not
dir_path
.
endswith
(
'/'
):
dir_path
+=
'/'
root
=
dir_path
def
_list_dir_or_file
(
dir_path
,
list_dir
,
list_file
,
suffix
,
recursive
):
for
path
in
self
.
_client
.
list
(
dir_path
):
# the `self.isdir` is not used here to determine whether path
# is a directory, because `self.isdir` relies on
# `self._client.list`
if
path
.
endswith
(
'/'
):
# a directory path
next_dir_path
=
self
.
join_path
(
dir_path
,
path
)
if
list_dir
:
# get the relative path and exclude the last
# character '/'
rel_dir
=
next_dir_path
[
len
(
root
):
-
1
]
yield
rel_dir
if
recursive
:
yield
from
_list_dir_or_file
(
next_dir_path
,
list_dir
,
list_file
,
suffix
,
recursive
)
else
:
# a file path
absolute_path
=
self
.
join_path
(
dir_path
,
path
)
rel_path
=
absolute_path
[
len
(
root
):]
if
(
suffix
is
None
or
rel_path
.
endswith
(
suffix
))
and
list_file
:
yield
rel_path
return
_list_dir_or_file
(
dir_path
,
list_dir
,
list_file
,
suffix
,
recursive
)
class
MemcachedBackend
(
BaseStorageBackend
):
"""Memcached storage backend.
Attributes:
server_list_cfg (str): Config file for memcached server list.
client_cfg (str): Config file for memcached client.
sys_path (str | None): Additional path to be appended to `sys.path`.
Default: None.
"""
def
__init__
(
self
,
server_list_cfg
,
client_cfg
,
sys_path
=
None
):
if
sys_path
is
not
None
:
import
sys
sys
.
path
.
append
(
sys_path
)
try
:
import
mc
except
ImportError
:
raise
ImportError
(
'Please install memcached to enable MemcachedBackend.'
)
self
.
server_list_cfg
=
server_list_cfg
self
.
client_cfg
=
client_cfg
self
.
_client
=
mc
.
MemcachedClient
.
GetInstance
(
self
.
server_list_cfg
,
self
.
client_cfg
)
# mc.pyvector servers as a point which points to a memory cache
self
.
_mc_buffer
=
mc
.
pyvector
()
def
get
(
self
,
filepath
):
filepath
=
str
(
filepath
)
import
mc
self
.
_client
.
Get
(
filepath
,
self
.
_mc_buffer
)
value_buf
=
mc
.
ConvertBuffer
(
self
.
_mc_buffer
)
return
value_buf
def
get_text
(
self
,
filepath
,
encoding
=
None
):
raise
NotImplementedError
class
LmdbBackend
(
BaseStorageBackend
):
"""Lmdb storage backend.
Args:
db_path (str): Lmdb database path.
readonly (bool, optional): Lmdb environment parameter. If True,
disallow any write operations. Default: True.
lock (bool, optional): Lmdb environment parameter. If False, when
concurrent access occurs, do not lock the database. Default: False.
readahead (bool, optional): Lmdb environment parameter. If False,
disable the OS filesystem readahead mechanism, which may improve
random read performance when a database is larger than RAM.
Default: False.
Attributes:
db_path (str): Lmdb database path.
"""
def
__init__
(
self
,
db_path
,
readonly
=
True
,
lock
=
False
,
readahead
=
False
,
**
kwargs
):
try
:
import
lmdb
except
ImportError
:
raise
ImportError
(
'Please install lmdb to enable LmdbBackend.'
)
self
.
db_path
=
str
(
db_path
)
self
.
_client
=
lmdb
.
open
(
self
.
db_path
,
readonly
=
readonly
,
lock
=
lock
,
readahead
=
readahead
,
**
kwargs
)
def
get
(
self
,
filepath
):
"""Get values according to the filepath.
Args:
filepath (str | obj:`Path`): Here, filepath is the lmdb key.
"""
filepath
=
str
(
filepath
)
with
self
.
_client
.
begin
(
write
=
False
)
as
txn
:
value_buf
=
txn
.
get
(
filepath
.
encode
(
'ascii'
))
return
value_buf
def
get_text
(
self
,
filepath
,
encoding
=
None
):
raise
NotImplementedError
class
HardDiskBackend
(
BaseStorageBackend
):
"""Raw hard disks storage backend."""
_allow_symlink
=
True
def
get
(
self
,
filepath
:
Union
[
str
,
Path
])
->
bytes
:
"""Read data from a given ``filepath`` with 'rb' mode.
Args:
filepath (str or Path): Path to read data.
Returns:
bytes: Expected bytes object.
"""
with
open
(
filepath
,
'rb'
)
as
f
:
value_buf
=
f
.
read
()
return
value_buf
def
get_text
(
self
,
filepath
:
Union
[
str
,
Path
],
encoding
:
str
=
'utf-8'
)
->
str
:
"""Read data from a given ``filepath`` with 'r' mode.
Args:
filepath (str or Path): Path to read data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
Returns:
str: Expected text reading from ``filepath``.
"""
with
open
(
filepath
,
'r'
,
encoding
=
encoding
)
as
f
:
value_buf
=
f
.
read
()
return
value_buf
def
put
(
self
,
obj
:
bytes
,
filepath
:
Union
[
str
,
Path
])
->
None
:
"""Write data to a given ``filepath`` with 'wb' mode.
Note:
``put`` will create a directory if the directory of ``filepath``
does not exist.
Args:
obj (bytes): Data to be written.
filepath (str or Path): Path to write data.
"""
mmcv
.
mkdir_or_exist
(
osp
.
dirname
(
filepath
))
with
open
(
filepath
,
'wb'
)
as
f
:
f
.
write
(
obj
)
def
put_text
(
self
,
obj
:
str
,
filepath
:
Union
[
str
,
Path
],
encoding
:
str
=
'utf-8'
)
->
None
:
"""Write data to a given ``filepath`` with 'w' mode.
Note:
``put_text`` will create a directory if the directory of
``filepath`` does not exist.
Args:
obj (str): Data to be written.
filepath (str or Path): Path to write data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
"""
mmcv
.
mkdir_or_exist
(
osp
.
dirname
(
filepath
))
with
open
(
filepath
,
'w'
,
encoding
=
encoding
)
as
f
:
f
.
write
(
obj
)
def
remove
(
self
,
filepath
:
Union
[
str
,
Path
])
->
None
:
"""Remove a file.
Args:
filepath (str or Path): Path to be removed.
"""
os
.
remove
(
filepath
)
def
exists
(
self
,
filepath
:
Union
[
str
,
Path
])
->
bool
:
"""Check whether a file path exists.
Args:
filepath (str or Path): Path to be checked whether exists.
Returns:
bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
"""
return
osp
.
exists
(
filepath
)
def
isdir
(
self
,
filepath
:
Union
[
str
,
Path
])
->
bool
:
"""Check whether a file path is a directory.
Args:
filepath (str or Path): Path to be checked whether it is a
directory.
Returns:
bool: Return ``True`` if ``filepath`` points to a directory,
``False`` otherwise.
"""
return
osp
.
isdir
(
filepath
)
def
isfile
(
self
,
filepath
:
Union
[
str
,
Path
])
->
bool
:
"""Check whether a file path is a file.
Args:
filepath (str or Path): Path to be checked whether it is a file.
Returns:
bool: Return ``True`` if ``filepath`` points to a file, ``False``
otherwise.
"""
return
osp
.
isfile
(
filepath
)
def
join_path
(
self
,
filepath
:
Union
[
str
,
Path
],
*
filepaths
:
Union
[
str
,
Path
])
->
str
:
"""Concatenate all file paths.
Join one or more filepath components intelligently. The return value
is the concatenation of filepath and any members of *filepaths.
Args:
filepath (str or Path): Path to be concatenated.
Returns:
str: The result of concatenation.
"""
return
osp
.
join
(
filepath
,
*
filepaths
)
@
contextmanager
def
get_local_path
(
self
,
filepath
:
Union
[
str
,
Path
])
->
Iterable
[
Union
[
str
,
Path
]]:
"""Only for unified API and do nothing."""
yield
filepath
def
list_dir_or_file
(
self
,
dir_path
:
Union
[
str
,
Path
],
list_dir
:
bool
=
True
,
list_file
:
bool
=
True
,
suffix
:
Optional
[
Union
[
str
,
Tuple
[
str
]]]
=
None
,
recursive
:
bool
=
False
)
->
Iterator
[
str
]:
"""Scan a directory to find the interested directories or files in
arbitrary order.
Note:
:meth:`list_dir_or_file` returns the path relative to ``dir_path``.
Args:
dir_path (str | Path): Path of the directory.
list_dir (bool): List the directories. Default: True.
list_file (bool): List the path of files. Default: True.
suffix (str or tuple[str], optional): File suffix
that we are interested in. Default: None.
recursive (bool): If set to True, recursively scan the
directory. Default: False.
Yields:
Iterable[str]: A relative path to ``dir_path``.
"""
if
list_dir
and
suffix
is
not
None
:
raise
TypeError
(
'`suffix` should be None when `list_dir` is True'
)
if
(
suffix
is
not
None
)
and
not
isinstance
(
suffix
,
(
str
,
tuple
)):
raise
TypeError
(
'`suffix` must be a string or tuple of strings'
)
root
=
dir_path
def
_list_dir_or_file
(
dir_path
,
list_dir
,
list_file
,
suffix
,
recursive
):
for
entry
in
os
.
scandir
(
dir_path
):
if
not
entry
.
name
.
startswith
(
'.'
)
and
entry
.
is_file
():
rel_path
=
osp
.
relpath
(
entry
.
path
,
root
)
if
(
suffix
is
None
or
rel_path
.
endswith
(
suffix
))
and
list_file
:
yield
rel_path
elif
osp
.
isdir
(
entry
.
path
):
if
list_dir
:
rel_dir
=
osp
.
relpath
(
entry
.
path
,
root
)
yield
rel_dir
if
recursive
:
yield
from
_list_dir_or_file
(
entry
.
path
,
list_dir
,
list_file
,
suffix
,
recursive
)
return
_list_dir_or_file
(
dir_path
,
list_dir
,
list_file
,
suffix
,
recursive
)
class
HTTPBackend
(
BaseStorageBackend
):
"""HTTP and HTTPS storage bachend."""
def
get
(
self
,
filepath
):
value_buf
=
urlopen
(
filepath
).
read
()
return
value_buf
def
get_text
(
self
,
filepath
,
encoding
=
'utf-8'
):
value_buf
=
urlopen
(
filepath
).
read
()
return
value_buf
.
decode
(
encoding
)
@
contextmanager
def
get_local_path
(
self
,
filepath
:
str
)
->
Iterable
[
str
]:
"""Download a file from ``filepath``.
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
can be called with ``with`` statement, and when exists from the
``with`` statement, the temporary path will be released.
Args:
filepath (str): Download a file from ``filepath``.
Examples:
>>> client = HTTPBackend()
>>> # After existing from the ``with`` clause,
>>> # the path will be removed
>>> with client.get_local_path('http://path/of/your/file') as path:
... # do something here
"""
try
:
f
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
)
f
.
write
(
self
.
get
(
filepath
))
f
.
close
()
yield
f
.
name
finally
:
os
.
remove
(
f
.
name
)
class
FileClient
:
"""A general file client to access files in different backends.
The client loads a file or text in a specified backend from its path
and returns it as a binary or text file. There are two ways to choose a
backend, the name of backend and the prefix of path. Although both of them
can be used to choose a storage backend, ``backend`` has a higher priority
that is if they are all set, the storage backend will be chosen by the
backend argument. If they are all `None`, the disk backend will be chosen.
Note that It can also register other backend accessor with a given name,
prefixes, and backend class. In addition, We use the singleton pattern to
avoid repeated object creation. If the arguments are the same, the same
object will be returned.
Args:
backend (str, optional): The storage backend type. Options are "disk",
"ceph", "memcached", "lmdb", "http" and "petrel". Default: None.
prefix (str, optional): The prefix of the registered storage backend.
Options are "s3", "http", "https". Default: None.
Examples:
>>> # only set backend
>>> file_client = FileClient(backend='petrel')
>>> # only set prefix
>>> file_client = FileClient(prefix='s3')
>>> # set both backend and prefix but use backend to choose client
>>> file_client = FileClient(backend='petrel', prefix='s3')
>>> # if the arguments are the same, the same object is returned
>>> file_client1 = FileClient(backend='petrel')
>>> file_client1 is file_client
True
Attributes:
client (:obj:`BaseStorageBackend`): The backend object.
"""
_backends
=
{
'disk'
:
HardDiskBackend
,
'ceph'
:
CephBackend
,
'memcached'
:
MemcachedBackend
,
'lmdb'
:
LmdbBackend
,
'petrel'
:
PetrelBackend
,
'http'
:
HTTPBackend
,
}
# This collection is used to record the overridden backends, and when a
# backend appears in the collection, the singleton pattern is disabled for
# that backend, because if the singleton pattern is used, then the object
# returned will be the backend before overwriting
_overridden_backends
=
set
()
_prefix_to_backends
=
{
's3'
:
PetrelBackend
,
'http'
:
HTTPBackend
,
'https'
:
HTTPBackend
,
}
_overridden_prefixes
=
set
()
_instances
=
{}
def
__new__
(
cls
,
backend
=
None
,
prefix
=
None
,
**
kwargs
):
if
backend
is
None
and
prefix
is
None
:
backend
=
'disk'
if
backend
is
not
None
and
backend
not
in
cls
.
_backends
:
raise
ValueError
(
f
'Backend
{
backend
}
is not supported. Currently supported ones'
f
' are
{
list
(
cls
.
_backends
.
keys
())
}
'
)
if
prefix
is
not
None
and
prefix
not
in
cls
.
_prefix_to_backends
:
raise
ValueError
(
f
'prefix
{
prefix
}
is not supported. Currently supported ones '
f
'are
{
list
(
cls
.
_prefix_to_backends
.
keys
())
}
'
)
# concatenate the arguments to a unique key for determining whether
# objects with the same arguments were created
arg_key
=
f
'
{
backend
}
:
{
prefix
}
'
for
key
,
value
in
kwargs
.
items
():
arg_key
+=
f
':
{
key
}
:
{
value
}
'
# if a backend was overridden, it will create a new object
if
(
arg_key
in
cls
.
_instances
and
backend
not
in
cls
.
_overridden_backends
and
prefix
not
in
cls
.
_overridden_prefixes
):
_instance
=
cls
.
_instances
[
arg_key
]
else
:
# create a new object and put it to _instance
_instance
=
super
().
__new__
(
cls
)
if
backend
is
not
None
:
_instance
.
client
=
cls
.
_backends
[
backend
](
**
kwargs
)
else
:
_instance
.
client
=
cls
.
_prefix_to_backends
[
prefix
](
**
kwargs
)
cls
.
_instances
[
arg_key
]
=
_instance
return
_instance
@
property
def
name
(
self
):
return
self
.
client
.
name
@
property
def
allow_symlink
(
self
):
return
self
.
client
.
allow_symlink
@
staticmethod
def
parse_uri_prefix
(
uri
:
Union
[
str
,
Path
])
->
Optional
[
str
]:
"""Parse the prefix of a uri.
Args:
uri (str | Path): Uri to be parsed that contains the file prefix.
Examples:
>>> FileClient.parse_uri_prefix('s3://path/of/your/file')
's3'
Returns:
str | None: Return the prefix of uri if the uri contains '://'
else ``None``.
"""
assert
is_filepath
(
uri
)
uri
=
str
(
uri
)
if
'://'
not
in
uri
:
return
None
else
:
prefix
,
_
=
uri
.
split
(
'://'
)
# In the case of PetrelBackend, the prefix may contains the cluster
# name like clusterName:s3
if
':'
in
prefix
:
_
,
prefix
=
prefix
.
split
(
':'
)
return
prefix
@
classmethod
def
infer_client
(
cls
,
file_client_args
:
Optional
[
dict
]
=
None
,
uri
:
Optional
[
Union
[
str
,
Path
]]
=
None
)
->
'FileClient'
:
"""Infer a suitable file client based on the URI and arguments.
Args:
file_client_args (dict, optional): Arguments to instantiate a
FileClient. Default: None.
uri (str | Path, optional): Uri to be parsed that contains the file
prefix. Default: None.
Examples:
>>> uri = 's3://path/of/your/file'
>>> file_client = FileClient.infer_client(uri=uri)
>>> file_client_args = {'backend': 'petrel'}
>>> file_client = FileClient.infer_client(file_client_args)
Returns:
FileClient: Instantiated FileClient object.
"""
assert
file_client_args
is
not
None
or
uri
is
not
None
if
file_client_args
is
None
:
file_prefix
=
cls
.
parse_uri_prefix
(
uri
)
# type: ignore
return
cls
(
prefix
=
file_prefix
)
else
:
return
cls
(
**
file_client_args
)
@
classmethod
def
_register_backend
(
cls
,
name
,
backend
,
force
=
False
,
prefixes
=
None
):
if
not
isinstance
(
name
,
str
):
raise
TypeError
(
'the backend name should be a string, '
f
'but got
{
type
(
name
)
}
'
)
if
not
inspect
.
isclass
(
backend
):
raise
TypeError
(
f
'backend should be a class but got
{
type
(
backend
)
}
'
)
if
not
issubclass
(
backend
,
BaseStorageBackend
):
raise
TypeError
(
f
'backend
{
backend
}
is not a subclass of BaseStorageBackend'
)
if
not
force
and
name
in
cls
.
_backends
:
raise
KeyError
(
f
'
{
name
}
is already registered as a storage backend, '
'add "force=True" if you want to override it'
)
if
name
in
cls
.
_backends
and
force
:
cls
.
_overridden_backends
.
add
(
name
)
cls
.
_backends
[
name
]
=
backend
if
prefixes
is
not
None
:
if
isinstance
(
prefixes
,
str
):
prefixes
=
[
prefixes
]
else
:
assert
isinstance
(
prefixes
,
(
list
,
tuple
))
for
prefix
in
prefixes
:
if
prefix
not
in
cls
.
_prefix_to_backends
:
cls
.
_prefix_to_backends
[
prefix
]
=
backend
elif
(
prefix
in
cls
.
_prefix_to_backends
)
and
force
:
cls
.
_overridden_prefixes
.
add
(
prefix
)
cls
.
_prefix_to_backends
[
prefix
]
=
backend
else
:
raise
KeyError
(
f
'
{
prefix
}
is already registered as a storage backend,'
' add "force=True" if you want to override it'
)
@
classmethod
def
register_backend
(
cls
,
name
,
backend
=
None
,
force
=
False
,
prefixes
=
None
):
"""Register a backend to FileClient.
This method can be used as a normal class method or a decorator.
.. code-block:: python
class NewBackend(BaseStorageBackend):
def get(self, filepath):
return filepath
def get_text(self, filepath):
return filepath
FileClient.register_backend('new', NewBackend)
or
.. code-block:: python
@FileClient.register_backend('new')
class NewBackend(BaseStorageBackend):
def get(self, filepath):
return filepath
def get_text(self, filepath):
return filepath
Args:
name (str): The name of the registered backend.
backend (class, optional): The backend class to be registered,
which must be a subclass of :class:`BaseStorageBackend`.
When this method is used as a decorator, backend is None.
Defaults to None.
force (bool, optional): Whether to override the backend if the name
has already been registered. Defaults to False.
prefixes (str or list[str] or tuple[str], optional): The prefixes
of the registered storage backend. Default: None.
`New in version 1.3.15.`
"""
if
backend
is
not
None
:
cls
.
_register_backend
(
name
,
backend
,
force
=
force
,
prefixes
=
prefixes
)
return
def
_register
(
backend_cls
):
cls
.
_register_backend
(
name
,
backend_cls
,
force
=
force
,
prefixes
=
prefixes
)
return
backend_cls
return
_register
def
get
(
self
,
filepath
:
Union
[
str
,
Path
])
->
Union
[
bytes
,
memoryview
]:
"""Read data from a given ``filepath`` with 'rb' mode.
Note:
There are two types of return values for ``get``, one is ``bytes``
and the other is ``memoryview``. The advantage of using memoryview
is that you can avoid copying, and if you want to convert it to
``bytes``, you can use ``.tobytes()``.
Args:
filepath (str or Path): Path to read data.
Returns:
bytes | memoryview: Expected bytes object or a memory view of the
bytes object.
"""
return
self
.
client
.
get
(
filepath
)
def
get_text
(
self
,
filepath
:
Union
[
str
,
Path
],
encoding
=
'utf-8'
)
->
str
:
"""Read data from a given ``filepath`` with 'r' mode.
Args:
filepath (str or Path): Path to read data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
Returns:
str: Expected text reading from ``filepath``.
"""
return
self
.
client
.
get_text
(
filepath
,
encoding
)
def
put
(
self
,
obj
:
bytes
,
filepath
:
Union
[
str
,
Path
])
->
None
:
"""Write data to a given ``filepath`` with 'wb' mode.
Note:
``put`` should create a directory if the directory of ``filepath``
does not exist.
Args:
obj (bytes): Data to be written.
filepath (str or Path): Path to write data.
"""
self
.
client
.
put
(
obj
,
filepath
)
def
put_text
(
self
,
obj
:
str
,
filepath
:
Union
[
str
,
Path
])
->
None
:
"""Write data to a given ``filepath`` with 'w' mode.
Note:
``put_text`` should create a directory if the directory of
``filepath`` does not exist.
Args:
obj (str): Data to be written.
filepath (str or Path): Path to write data.
encoding (str, optional): The encoding format used to open the
`filepath`. Default: 'utf-8'.
"""
self
.
client
.
put_text
(
obj
,
filepath
)
def
remove
(
self
,
filepath
:
Union
[
str
,
Path
])
->
None
:
"""Remove a file.
Args:
filepath (str, Path): Path to be removed.
"""
self
.
client
.
remove
(
filepath
)
def
exists
(
self
,
filepath
:
Union
[
str
,
Path
])
->
bool
:
"""Check whether a file path exists.
Args:
filepath (str or Path): Path to be checked whether exists.
Returns:
bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
"""
return
self
.
client
.
exists
(
filepath
)
def
isdir
(
self
,
filepath
:
Union
[
str
,
Path
])
->
bool
:
"""Check whether a file path is a directory.
Args:
filepath (str or Path): Path to be checked whether it is a
directory.
Returns:
bool: Return ``True`` if ``filepath`` points to a directory,
``False`` otherwise.
"""
return
self
.
client
.
isdir
(
filepath
)
def
isfile
(
self
,
filepath
:
Union
[
str
,
Path
])
->
bool
:
"""Check whether a file path is a file.
Args:
filepath (str or Path): Path to be checked whether it is a file.
Returns:
bool: Return ``True`` if ``filepath`` points to a file, ``False``
otherwise.
"""
return
self
.
client
.
isfile
(
filepath
)
def
join_path
(
self
,
filepath
:
Union
[
str
,
Path
],
*
filepaths
:
Union
[
str
,
Path
])
->
str
:
"""Concatenate all file paths.
Join one or more filepath components intelligently. The return value
is the concatenation of filepath and any members of *filepaths.
Args:
filepath (str or Path): Path to be concatenated.
Returns:
str: The result of concatenation.
"""
return
self
.
client
.
join_path
(
filepath
,
*
filepaths
)
@
contextmanager
def
get_local_path
(
self
,
filepath
:
Union
[
str
,
Path
])
->
Iterable
[
str
]:
"""Download data from ``filepath`` and write the data to local path.
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
can be called with ``with`` statement, and when exists from the
``with`` statement, the temporary path will be released.
Note:
If the ``filepath`` is a local path, just return itself.
.. warning::
``get_local_path`` is an experimental interface that may change in
the future.
Args:
filepath (str or Path): Path to be read data.
Examples:
>>> file_client = FileClient(prefix='s3')
>>> with file_client.get_local_path('s3://bucket/abc.jpg') as path:
... # do something here
Yields:
Iterable[str]: Only yield one path.
"""
with
self
.
client
.
get_local_path
(
str
(
filepath
))
as
local_path
:
yield
local_path
def
list_dir_or_file
(
self
,
dir_path
:
Union
[
str
,
Path
],
list_dir
:
bool
=
True
,
list_file
:
bool
=
True
,
suffix
:
Optional
[
Union
[
str
,
Tuple
[
str
]]]
=
None
,
recursive
:
bool
=
False
)
->
Iterator
[
str
]:
"""Scan a directory to find the interested directories or files in
arbitrary order.
Note:
:meth:`list_dir_or_file` returns the path relative to ``dir_path``.
Args:
dir_path (str | Path): Path of the directory.
list_dir (bool): List the directories. Default: True.
list_file (bool): List the path of files. Default: True.
suffix (str or tuple[str], optional): File suffix
that we are interested in. Default: None.
recursive (bool): If set to True, recursively scan the
directory. Default: False.
Yields:
Iterable[str]: A relative path to ``dir_path``.
"""
yield
from
self
.
client
.
list_dir_or_file
(
dir_path
,
list_dir
,
list_file
,
suffix
,
recursive
)
lavis/common/annotator/uniformer/mmcv/fileio/handlers/__init__.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
from
.base
import
BaseFileHandler
from
.json_handler
import
JsonHandler
from
.pickle_handler
import
PickleHandler
from
.yaml_handler
import
YamlHandler
__all__
=
[
'BaseFileHandler'
,
'JsonHandler'
,
'PickleHandler'
,
'YamlHandler'
]
lavis/common/annotator/uniformer/mmcv/fileio/handlers/base.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
from
abc
import
ABCMeta
,
abstractmethod
class
BaseFileHandler
(
metaclass
=
ABCMeta
):
# `str_like` is a flag to indicate whether the type of file object is
# str-like object or bytes-like object. Pickle only processes bytes-like
# objects but json only processes str-like object. If it is str-like
# object, `StringIO` will be used to process the buffer.
str_like
=
True
@
abstractmethod
def
load_from_fileobj
(
self
,
file
,
**
kwargs
):
pass
@
abstractmethod
def
dump_to_fileobj
(
self
,
obj
,
file
,
**
kwargs
):
pass
@
abstractmethod
def
dump_to_str
(
self
,
obj
,
**
kwargs
):
pass
def
load_from_path
(
self
,
filepath
,
mode
=
'r'
,
**
kwargs
):
with
open
(
filepath
,
mode
)
as
f
:
return
self
.
load_from_fileobj
(
f
,
**
kwargs
)
def
dump_to_path
(
self
,
obj
,
filepath
,
mode
=
'w'
,
**
kwargs
):
with
open
(
filepath
,
mode
)
as
f
:
self
.
dump_to_fileobj
(
obj
,
f
,
**
kwargs
)
lavis/common/annotator/uniformer/mmcv/fileio/handlers/json_handler.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
import
json
import
numpy
as
np
from
.base
import
BaseFileHandler
def
set_default
(
obj
):
"""Set default json values for non-serializable values.
It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
etc.) into plain numbers of plain python built-in types.
"""
if
isinstance
(
obj
,
(
set
,
range
)):
return
list
(
obj
)
elif
isinstance
(
obj
,
np
.
ndarray
):
return
obj
.
tolist
()
elif
isinstance
(
obj
,
np
.
generic
):
return
obj
.
item
()
raise
TypeError
(
f
'
{
type
(
obj
)
}
is unsupported for json dump'
)
class
JsonHandler
(
BaseFileHandler
):
def
load_from_fileobj
(
self
,
file
):
return
json
.
load
(
file
)
def
dump_to_fileobj
(
self
,
obj
,
file
,
**
kwargs
):
kwargs
.
setdefault
(
'default'
,
set_default
)
json
.
dump
(
obj
,
file
,
**
kwargs
)
def
dump_to_str
(
self
,
obj
,
**
kwargs
):
kwargs
.
setdefault
(
'default'
,
set_default
)
return
json
.
dumps
(
obj
,
**
kwargs
)
lavis/common/annotator/uniformer/mmcv/fileio/handlers/pickle_handler.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
import
pickle
from
.base
import
BaseFileHandler
class
PickleHandler
(
BaseFileHandler
):
str_like
=
False
def
load_from_fileobj
(
self
,
file
,
**
kwargs
):
return
pickle
.
load
(
file
,
**
kwargs
)
def
load_from_path
(
self
,
filepath
,
**
kwargs
):
return
super
(
PickleHandler
,
self
).
load_from_path
(
filepath
,
mode
=
'rb'
,
**
kwargs
)
def
dump_to_str
(
self
,
obj
,
**
kwargs
):
kwargs
.
setdefault
(
'protocol'
,
2
)
return
pickle
.
dumps
(
obj
,
**
kwargs
)
def
dump_to_fileobj
(
self
,
obj
,
file
,
**
kwargs
):
kwargs
.
setdefault
(
'protocol'
,
2
)
pickle
.
dump
(
obj
,
file
,
**
kwargs
)
def
dump_to_path
(
self
,
obj
,
filepath
,
**
kwargs
):
super
(
PickleHandler
,
self
).
dump_to_path
(
obj
,
filepath
,
mode
=
'wb'
,
**
kwargs
)
lavis/common/annotator/uniformer/mmcv/fileio/handlers/yaml_handler.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
import
yaml
try
:
from
yaml
import
CLoader
as
Loader
,
CDumper
as
Dumper
except
ImportError
:
from
yaml
import
Loader
,
Dumper
from
.base
import
BaseFileHandler
# isort:skip
class
YamlHandler
(
BaseFileHandler
):
def
load_from_fileobj
(
self
,
file
,
**
kwargs
):
kwargs
.
setdefault
(
'Loader'
,
Loader
)
return
yaml
.
load
(
file
,
**
kwargs
)
def
dump_to_fileobj
(
self
,
obj
,
file
,
**
kwargs
):
kwargs
.
setdefault
(
'Dumper'
,
Dumper
)
yaml
.
dump
(
obj
,
file
,
**
kwargs
)
def
dump_to_str
(
self
,
obj
,
**
kwargs
):
kwargs
.
setdefault
(
'Dumper'
,
Dumper
)
return
yaml
.
dump
(
obj
,
**
kwargs
)
lavis/common/annotator/uniformer/mmcv/fileio/io.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
from
io
import
BytesIO
,
StringIO
from
pathlib
import
Path
from
..utils
import
is_list_of
,
is_str
from
.file_client
import
FileClient
from
.handlers
import
BaseFileHandler
,
JsonHandler
,
PickleHandler
,
YamlHandler
file_handlers
=
{
'json'
:
JsonHandler
(),
'yaml'
:
YamlHandler
(),
'yml'
:
YamlHandler
(),
'pickle'
:
PickleHandler
(),
'pkl'
:
PickleHandler
()
}
def
load
(
file
,
file_format
=
None
,
file_client_args
=
None
,
**
kwargs
):
"""Load data from json/yaml/pickle files.
This method provides a unified api for loading data from serialized files.
Note:
In v1.3.16 and later, ``load`` supports loading data from serialized
files those can be storaged in different backends.
Args:
file (str or :obj:`Path` or file-like object): Filename or a file-like
object.
file_format (str, optional): If not specified, the file format will be
inferred from the file extension, otherwise use the specified one.
Currently supported formats include "json", "yaml/yml" and
"pickle/pkl".
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
Examples:
>>> load('/path/of/your/file') # file is storaged in disk
>>> load('https://path/of/your/file') # file is storaged in Internet
>>> load('s3://path/of/your/file') # file is storaged in petrel
Returns:
The content from the file.
"""
if
isinstance
(
file
,
Path
):
file
=
str
(
file
)
if
file_format
is
None
and
is_str
(
file
):
file_format
=
file
.
split
(
'.'
)[
-
1
]
if
file_format
not
in
file_handlers
:
raise
TypeError
(
f
'Unsupported format:
{
file_format
}
'
)
handler
=
file_handlers
[
file_format
]
if
is_str
(
file
):
file_client
=
FileClient
.
infer_client
(
file_client_args
,
file
)
if
handler
.
str_like
:
with
StringIO
(
file_client
.
get_text
(
file
))
as
f
:
obj
=
handler
.
load_from_fileobj
(
f
,
**
kwargs
)
else
:
with
BytesIO
(
file_client
.
get
(
file
))
as
f
:
obj
=
handler
.
load_from_fileobj
(
f
,
**
kwargs
)
elif
hasattr
(
file
,
'read'
):
obj
=
handler
.
load_from_fileobj
(
file
,
**
kwargs
)
else
:
raise
TypeError
(
'"file" must be a filepath str or a file-object'
)
return
obj
def
dump
(
obj
,
file
=
None
,
file_format
=
None
,
file_client_args
=
None
,
**
kwargs
):
"""Dump data to json/yaml/pickle strings or files.
This method provides a unified api for dumping data as strings or to files,
and also supports custom arguments for each file format.
Note:
In v1.3.16 and later, ``dump`` supports dumping data as strings or to
files which is saved to different backends.
Args:
obj (any): The python object to be dumped.
file (str or :obj:`Path` or file-like object, optional): If not
specified, then the object is dumped to a str, otherwise to a file
specified by the filename or file-like object.
file_format (str, optional): Same as :func:`load`.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
Examples:
>>> dump('hello world', '/path/of/your/file') # disk
>>> dump('hello world', 's3://path/of/your/file') # ceph or petrel
Returns:
bool: True for success, False otherwise.
"""
if
isinstance
(
file
,
Path
):
file
=
str
(
file
)
if
file_format
is
None
:
if
is_str
(
file
):
file_format
=
file
.
split
(
'.'
)[
-
1
]
elif
file
is
None
:
raise
ValueError
(
'file_format must be specified since file is None'
)
if
file_format
not
in
file_handlers
:
raise
TypeError
(
f
'Unsupported format:
{
file_format
}
'
)
handler
=
file_handlers
[
file_format
]
if
file
is
None
:
return
handler
.
dump_to_str
(
obj
,
**
kwargs
)
elif
is_str
(
file
):
file_client
=
FileClient
.
infer_client
(
file_client_args
,
file
)
if
handler
.
str_like
:
with
StringIO
()
as
f
:
handler
.
dump_to_fileobj
(
obj
,
f
,
**
kwargs
)
file_client
.
put_text
(
f
.
getvalue
(),
file
)
else
:
with
BytesIO
()
as
f
:
handler
.
dump_to_fileobj
(
obj
,
f
,
**
kwargs
)
file_client
.
put
(
f
.
getvalue
(),
file
)
elif
hasattr
(
file
,
'write'
):
handler
.
dump_to_fileobj
(
obj
,
file
,
**
kwargs
)
else
:
raise
TypeError
(
'"file" must be a filename str or a file-object'
)
def
_register_handler
(
handler
,
file_formats
):
"""Register a handler for some file extensions.
Args:
handler (:obj:`BaseFileHandler`): Handler to be registered.
file_formats (str or list[str]): File formats to be handled by this
handler.
"""
if
not
isinstance
(
handler
,
BaseFileHandler
):
raise
TypeError
(
f
'handler must be a child of BaseFileHandler, not
{
type
(
handler
)
}
'
)
if
isinstance
(
file_formats
,
str
):
file_formats
=
[
file_formats
]
if
not
is_list_of
(
file_formats
,
str
):
raise
TypeError
(
'file_formats must be a str or a list of str'
)
for
ext
in
file_formats
:
file_handlers
[
ext
]
=
handler
def
register_handler
(
file_formats
,
**
kwargs
):
def
wrap
(
cls
):
_register_handler
(
cls
(
**
kwargs
),
file_formats
)
return
cls
return
wrap
lavis/common/annotator/uniformer/mmcv/fileio/parse.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
from
io
import
StringIO
from
.file_client
import
FileClient
def
list_from_file
(
filename
,
prefix
=
''
,
offset
=
0
,
max_num
=
0
,
encoding
=
'utf-8'
,
file_client_args
=
None
):
"""Load a text file and parse the content as a list of strings.
Note:
In v1.3.16 and later, ``list_from_file`` supports loading a text file
which can be storaged in different backends and parsing the content as
a list for strings.
Args:
filename (str): Filename.
prefix (str): The prefix to be inserted to the beginning of each item.
offset (int): The offset of lines.
max_num (int): The maximum number of lines to be read,
zeros and negatives mean no limitation.
encoding (str): Encoding used to open the file. Default utf-8.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
Examples:
>>> list_from_file('/path/of/your/file') # disk
['hello', 'world']
>>> list_from_file('s3://path/of/your/file') # ceph or petrel
['hello', 'world']
Returns:
list[str]: A list of strings.
"""
cnt
=
0
item_list
=
[]
file_client
=
FileClient
.
infer_client
(
file_client_args
,
filename
)
with
StringIO
(
file_client
.
get_text
(
filename
,
encoding
))
as
f
:
for
_
in
range
(
offset
):
f
.
readline
()
for
line
in
f
:
if
0
<
max_num
<=
cnt
:
break
item_list
.
append
(
prefix
+
line
.
rstrip
(
'
\n\r
'
))
cnt
+=
1
return
item_list
def
dict_from_file
(
filename
,
key_type
=
str
,
encoding
=
'utf-8'
,
file_client_args
=
None
):
"""Load a text file and parse the content as a dict.
Each line of the text file will be two or more columns split by
whitespaces or tabs. The first column will be parsed as dict keys, and
the following columns will be parsed as dict values.
Note:
In v1.3.16 and later, ``dict_from_file`` supports loading a text file
which can be storaged in different backends and parsing the content as
a dict.
Args:
filename(str): Filename.
key_type(type): Type of the dict keys. str is user by default and
type conversion will be performed if specified.
encoding (str): Encoding used to open the file. Default utf-8.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
Examples:
>>> dict_from_file('/path/of/your/file') # disk
{'key1': 'value1', 'key2': 'value2'}
>>> dict_from_file('s3://path/of/your/file') # ceph or petrel
{'key1': 'value1', 'key2': 'value2'}
Returns:
dict: The parsed contents.
"""
mapping
=
{}
file_client
=
FileClient
.
infer_client
(
file_client_args
,
filename
)
with
StringIO
(
file_client
.
get_text
(
filename
,
encoding
))
as
f
:
for
line
in
f
:
items
=
line
.
rstrip
(
'
\n
'
).
split
()
assert
len
(
items
)
>=
2
key
=
key_type
(
items
[
0
])
val
=
items
[
1
:]
if
len
(
items
)
>
2
else
items
[
1
]
mapping
[
key
]
=
val
return
mapping
lavis/common/annotator/uniformer/mmcv/image/__init__.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
from
.colorspace
import
(
bgr2gray
,
bgr2hls
,
bgr2hsv
,
bgr2rgb
,
bgr2ycbcr
,
gray2bgr
,
gray2rgb
,
hls2bgr
,
hsv2bgr
,
imconvert
,
rgb2bgr
,
rgb2gray
,
rgb2ycbcr
,
ycbcr2bgr
,
ycbcr2rgb
)
from
.geometric
import
(
cutout
,
imcrop
,
imflip
,
imflip_
,
impad
,
impad_to_multiple
,
imrescale
,
imresize
,
imresize_like
,
imresize_to_multiple
,
imrotate
,
imshear
,
imtranslate
,
rescale_size
)
from
.io
import
imfrombytes
,
imread
,
imwrite
,
supported_backends
,
use_backend
from
.misc
import
tensor2imgs
from
.photometric
import
(
adjust_brightness
,
adjust_color
,
adjust_contrast
,
adjust_lighting
,
adjust_sharpness
,
auto_contrast
,
clahe
,
imdenormalize
,
imequalize
,
iminvert
,
imnormalize
,
imnormalize_
,
lut_transform
,
posterize
,
solarize
)
__all__
=
[
'bgr2gray'
,
'bgr2hls'
,
'bgr2hsv'
,
'bgr2rgb'
,
'gray2bgr'
,
'gray2rgb'
,
'hls2bgr'
,
'hsv2bgr'
,
'imconvert'
,
'rgb2bgr'
,
'rgb2gray'
,
'imrescale'
,
'imresize'
,
'imresize_like'
,
'imresize_to_multiple'
,
'rescale_size'
,
'imcrop'
,
'imflip'
,
'imflip_'
,
'impad'
,
'impad_to_multiple'
,
'imrotate'
,
'imfrombytes'
,
'imread'
,
'imwrite'
,
'supported_backends'
,
'use_backend'
,
'imdenormalize'
,
'imnormalize'
,
'imnormalize_'
,
'iminvert'
,
'posterize'
,
'solarize'
,
'rgb2ycbcr'
,
'bgr2ycbcr'
,
'ycbcr2rgb'
,
'ycbcr2bgr'
,
'tensor2imgs'
,
'imshear'
,
'imtranslate'
,
'adjust_color'
,
'imequalize'
,
'adjust_brightness'
,
'adjust_contrast'
,
'lut_transform'
,
'clahe'
,
'adjust_sharpness'
,
'auto_contrast'
,
'cutout'
,
'adjust_lighting'
]
lavis/common/annotator/uniformer/mmcv/image/colorspace.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
import
cv2
import
numpy
as
np
def
imconvert
(
img
,
src
,
dst
):
"""Convert an image from the src colorspace to dst colorspace.
Args:
img (ndarray): The input image.
src (str): The source colorspace, e.g., 'rgb', 'hsv'.
dst (str): The destination colorspace, e.g., 'rgb', 'hsv'.
Returns:
ndarray: The converted image.
"""
code
=
getattr
(
cv2
,
f
'COLOR_
{
src
.
upper
()
}
2
{
dst
.
upper
()
}
'
)
out_img
=
cv2
.
cvtColor
(
img
,
code
)
return
out_img
def
bgr2gray
(
img
,
keepdim
=
False
):
"""Convert a BGR image to grayscale image.
Args:
img (ndarray): The input image.
keepdim (bool): If False (by default), then return the grayscale image
with 2 dims, otherwise 3 dims.
Returns:
ndarray: The converted grayscale image.
"""
out_img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2GRAY
)
if
keepdim
:
out_img
=
out_img
[...,
None
]
return
out_img
def
rgb2gray
(
img
,
keepdim
=
False
):
"""Convert a RGB image to grayscale image.
Args:
img (ndarray): The input image.
keepdim (bool): If False (by default), then return the grayscale image
with 2 dims, otherwise 3 dims.
Returns:
ndarray: The converted grayscale image.
"""
out_img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_RGB2GRAY
)
if
keepdim
:
out_img
=
out_img
[...,
None
]
return
out_img
def
gray2bgr
(
img
):
"""Convert a grayscale image to BGR image.
Args:
img (ndarray): The input image.
Returns:
ndarray: The converted BGR image.
"""
img
=
img
[...,
None
]
if
img
.
ndim
==
2
else
img
out_img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
return
out_img
def
gray2rgb
(
img
):
"""Convert a grayscale image to RGB image.
Args:
img (ndarray): The input image.
Returns:
ndarray: The converted RGB image.
"""
img
=
img
[...,
None
]
if
img
.
ndim
==
2
else
img
out_img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2RGB
)
return
out_img
def
_convert_input_type_range
(
img
):
"""Convert the type and range of the input image.
It converts the input image to np.float32 type and range of [0, 1].
It is mainly used for pre-processing the input image in colorspace
conversion functions such as rgb2ycbcr and ycbcr2rgb.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
(ndarray): The converted image with type of np.float32 and range of
[0, 1].
"""
img_type
=
img
.
dtype
img
=
img
.
astype
(
np
.
float32
)
if
img_type
==
np
.
float32
:
pass
elif
img_type
==
np
.
uint8
:
img
/=
255.
else
:
raise
TypeError
(
'The img type should be np.float32 or np.uint8, '
f
'but got
{
img_type
}
'
)
return
img
def
_convert_output_type_range
(
img
,
dst_type
):
"""Convert the type and range of the image according to dst_type.
It converts the image to desired type and range. If `dst_type` is np.uint8,
images will be converted to np.uint8 type with range [0, 255]. If
`dst_type` is np.float32, it converts the image to np.float32 type with
range [0, 1].
It is mainly used for post-processing images in colorspace conversion
functions such as rgb2ycbcr and ycbcr2rgb.
Args:
img (ndarray): The image to be converted with np.float32 type and
range [0, 255].
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
converts the image to np.uint8 type with range [0, 255]. If
dst_type is np.float32, it converts the image to np.float32 type
with range [0, 1].
Returns:
(ndarray): The converted image with desired type and range.
"""
if
dst_type
not
in
(
np
.
uint8
,
np
.
float32
):
raise
TypeError
(
'The dst_type should be np.float32 or np.uint8, '
f
'but got
{
dst_type
}
'
)
if
dst_type
==
np
.
uint8
:
img
=
img
.
round
()
else
:
img
/=
255.
return
img
.
astype
(
dst_type
)
def
rgb2ycbcr
(
img
,
y_only
=
False
):
"""Convert a RGB image to YCbCr image.
This function produces the same results as Matlab's `rgb2ycbcr` function.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type
=
img
.
dtype
img
=
_convert_input_type_range
(
img
)
if
y_only
:
out_img
=
np
.
dot
(
img
,
[
65.481
,
128.553
,
24.966
])
+
16.0
else
:
out_img
=
np
.
matmul
(
img
,
[[
65.481
,
-
37.797
,
112.0
],
[
128.553
,
-
74.203
,
-
93.786
],
[
24.966
,
112.0
,
-
18.214
]])
+
[
16
,
128
,
128
]
out_img
=
_convert_output_type_range
(
out_img
,
img_type
)
return
out_img
def
bgr2ycbcr
(
img
,
y_only
=
False
):
"""Convert a BGR image to YCbCr image.
The bgr version of rgb2ycbcr.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type
=
img
.
dtype
img
=
_convert_input_type_range
(
img
)
if
y_only
:
out_img
=
np
.
dot
(
img
,
[
24.966
,
128.553
,
65.481
])
+
16.0
else
:
out_img
=
np
.
matmul
(
img
,
[[
24.966
,
112.0
,
-
18.214
],
[
128.553
,
-
74.203
,
-
93.786
],
[
65.481
,
-
37.797
,
112.0
]])
+
[
16
,
128
,
128
]
out_img
=
_convert_output_type_range
(
out_img
,
img_type
)
return
out_img
def
ycbcr2rgb
(
img
):
"""Convert a YCbCr image to RGB image.
This function produces the same results as Matlab's ycbcr2rgb function.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
ndarray: The converted RGB image. The output image has the same type
and range as input image.
"""
img_type
=
img
.
dtype
img
=
_convert_input_type_range
(
img
)
*
255
out_img
=
np
.
matmul
(
img
,
[[
0.00456621
,
0.00456621
,
0.00456621
],
[
0
,
-
0.00153632
,
0.00791071
],
[
0.00625893
,
-
0.00318811
,
0
]])
*
255.0
+
[
-
222.921
,
135.576
,
-
276.836
]
out_img
=
_convert_output_type_range
(
out_img
,
img_type
)
return
out_img
def
ycbcr2bgr
(
img
):
"""Convert a YCbCr image to BGR image.
The bgr version of ycbcr2rgb.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
ndarray: The converted BGR image. The output image has the same type
and range as input image.
"""
img_type
=
img
.
dtype
img
=
_convert_input_type_range
(
img
)
*
255
out_img
=
np
.
matmul
(
img
,
[[
0.00456621
,
0.00456621
,
0.00456621
],
[
0.00791071
,
-
0.00153632
,
0
],
[
0
,
-
0.00318811
,
0.00625893
]])
*
255.0
+
[
-
276.836
,
135.576
,
-
222.921
]
out_img
=
_convert_output_type_range
(
out_img
,
img_type
)
return
out_img
def
convert_color_factory
(
src
,
dst
):
code
=
getattr
(
cv2
,
f
'COLOR_
{
src
.
upper
()
}
2
{
dst
.
upper
()
}
'
)
def
convert_color
(
img
):
out_img
=
cv2
.
cvtColor
(
img
,
code
)
return
out_img
convert_color
.
__doc__
=
f
"""Convert a
{
src
.
upper
()
}
image to
{
dst
.
upper
()
}
image.
Args:
img (ndarray or str): The input image.
Returns:
ndarray: The converted
{
dst
.
upper
()
}
image.
"""
return
convert_color
bgr2rgb
=
convert_color_factory
(
'bgr'
,
'rgb'
)
rgb2bgr
=
convert_color_factory
(
'rgb'
,
'bgr'
)
bgr2hsv
=
convert_color_factory
(
'bgr'
,
'hsv'
)
hsv2bgr
=
convert_color_factory
(
'hsv'
,
'bgr'
)
bgr2hls
=
convert_color_factory
(
'bgr'
,
'hls'
)
hls2bgr
=
convert_color_factory
(
'hls'
,
'bgr'
)
Prev
1
…
7
8
9
10
11
12
13
14
15
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment