Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
InstructBLIP_pytorch
Commits
c04f261a
Commit
c04f261a
authored
Aug 22, 2024
by
dongchy920
Browse files
InstruceBLIP
parents
Pipeline
#1594
canceled with stages
Changes
421
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3264 additions
and
0 deletions
+3264
-0
lavis/common/annotator/uniformer/mmcv/runner/utils.py
lavis/common/annotator/uniformer/mmcv/runner/utils.py
+93
-0
lavis/common/annotator/uniformer/mmcv/utils/__init__.py
lavis/common/annotator/uniformer/mmcv/utils/__init__.py
+69
-0
lavis/common/annotator/uniformer/mmcv/utils/config.py
lavis/common/annotator/uniformer/mmcv/utils/config.py
+688
-0
lavis/common/annotator/uniformer/mmcv/utils/env.py
lavis/common/annotator/uniformer/mmcv/utils/env.py
+95
-0
lavis/common/annotator/uniformer/mmcv/utils/ext_loader.py
lavis/common/annotator/uniformer/mmcv/utils/ext_loader.py
+71
-0
lavis/common/annotator/uniformer/mmcv/utils/logging.py
lavis/common/annotator/uniformer/mmcv/utils/logging.py
+110
-0
lavis/common/annotator/uniformer/mmcv/utils/misc.py
lavis/common/annotator/uniformer/mmcv/utils/misc.py
+377
-0
lavis/common/annotator/uniformer/mmcv/utils/parrots_jit.py
lavis/common/annotator/uniformer/mmcv/utils/parrots_jit.py
+41
-0
lavis/common/annotator/uniformer/mmcv/utils/parrots_wrapper.py
.../common/annotator/uniformer/mmcv/utils/parrots_wrapper.py
+107
-0
lavis/common/annotator/uniformer/mmcv/utils/path.py
lavis/common/annotator/uniformer/mmcv/utils/path.py
+101
-0
lavis/common/annotator/uniformer/mmcv/utils/progressbar.py
lavis/common/annotator/uniformer/mmcv/utils/progressbar.py
+208
-0
lavis/common/annotator/uniformer/mmcv/utils/registry.py
lavis/common/annotator/uniformer/mmcv/utils/registry.py
+315
-0
lavis/common/annotator/uniformer/mmcv/utils/testing.py
lavis/common/annotator/uniformer/mmcv/utils/testing.py
+140
-0
lavis/common/annotator/uniformer/mmcv/utils/timer.py
lavis/common/annotator/uniformer/mmcv/utils/timer.py
+118
-0
lavis/common/annotator/uniformer/mmcv/utils/trace.py
lavis/common/annotator/uniformer/mmcv/utils/trace.py
+23
-0
lavis/common/annotator/uniformer/mmcv/utils/version_utils.py
lavis/common/annotator/uniformer/mmcv/utils/version_utils.py
+90
-0
lavis/common/annotator/uniformer/mmcv/version.py
lavis/common/annotator/uniformer/mmcv/version.py
+35
-0
lavis/common/annotator/uniformer/mmcv/video/__init__.py
lavis/common/annotator/uniformer/mmcv/video/__init__.py
+11
-0
lavis/common/annotator/uniformer/mmcv/video/io.py
lavis/common/annotator/uniformer/mmcv/video/io.py
+318
-0
lavis/common/annotator/uniformer/mmcv/video/optflow.py
lavis/common/annotator/uniformer/mmcv/video/optflow.py
+254
-0
No files found.
Too many changes to show.
To preserve performance only
421 of 421+
files are displayed.
Plain diff
Email patch
lavis/common/annotator/uniformer/mmcv/runner/utils.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
import
os
import
random
import
sys
import
time
import
warnings
from
getpass
import
getuser
from
socket
import
gethostname
import
numpy
as
np
import
torch
import
annotator.uniformer.mmcv
as
mmcv
def
get_host_info
():
"""Get hostname and username.
Return empty string if exception raised, e.g. ``getpass.getuser()`` will
lead to error in docker container
"""
host
=
''
try
:
host
=
f
'
{
getuser
()
}
@
{
gethostname
()
}
'
except
Exception
as
e
:
warnings
.
warn
(
f
'Host or user not found:
{
str
(
e
)
}
'
)
finally
:
return
host
def
get_time_str
():
return
time
.
strftime
(
'%Y%m%d_%H%M%S'
,
time
.
localtime
())
def
obj_from_dict
(
info
,
parent
=
None
,
default_args
=
None
):
"""Initialize an object from dict.
The dict must contain the key "type", which indicates the object type, it
can be either a string or type, such as "list" or ``list``. Remaining
fields are treated as the arguments for constructing the object.
Args:
info (dict): Object types and arguments.
parent (:class:`module`): Module which may containing expected object
classes.
default_args (dict, optional): Default arguments for initializing the
object.
Returns:
any type: Object built from the dict.
"""
assert
isinstance
(
info
,
dict
)
and
'type'
in
info
assert
isinstance
(
default_args
,
dict
)
or
default_args
is
None
args
=
info
.
copy
()
obj_type
=
args
.
pop
(
'type'
)
if
mmcv
.
is_str
(
obj_type
):
if
parent
is
not
None
:
obj_type
=
getattr
(
parent
,
obj_type
)
else
:
obj_type
=
sys
.
modules
[
obj_type
]
elif
not
isinstance
(
obj_type
,
type
):
raise
TypeError
(
'type must be a str or valid type, but '
f
'got
{
type
(
obj_type
)
}
'
)
if
default_args
is
not
None
:
for
name
,
value
in
default_args
.
items
():
args
.
setdefault
(
name
,
value
)
return
obj_type
(
**
args
)
def
set_random_seed
(
seed
,
deterministic
=
False
,
use_rank_shift
=
False
):
"""Set random seed.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
rank_shift (bool): Whether to add rank number to the random seed to
have different random seed in different threads. Default: False.
"""
if
use_rank_shift
:
rank
,
_
=
mmcv
.
runner
.
get_dist_info
()
seed
+=
rank
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
os
.
environ
[
'PYTHONHASHSEED'
]
=
str
(
seed
)
if
deterministic
:
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
lavis/common/annotator/uniformer/mmcv/utils/__init__.py
0 → 100644
View file @
c04f261a
# flake8: noqa
# Copyright (c) OpenMMLab. All rights reserved.
from
.config
import
Config
,
ConfigDict
,
DictAction
from
.misc
import
(
check_prerequisites
,
concat_list
,
deprecated_api_warning
,
has_method
,
import_modules_from_strings
,
is_list_of
,
is_method_overridden
,
is_seq_of
,
is_str
,
is_tuple_of
,
iter_cast
,
list_cast
,
requires_executable
,
requires_package
,
slice_list
,
to_1tuple
,
to_2tuple
,
to_3tuple
,
to_4tuple
,
to_ntuple
,
tuple_cast
)
from
.path
import
(
check_file_exist
,
fopen
,
is_filepath
,
mkdir_or_exist
,
scandir
,
symlink
)
from
.progressbar
import
(
ProgressBar
,
track_iter_progress
,
track_parallel_progress
,
track_progress
)
from
.testing
import
(
assert_attrs_equal
,
assert_dict_contains_subset
,
assert_dict_has_keys
,
assert_is_norm_layer
,
assert_keys_equal
,
assert_params_all_zeros
,
check_python_script
)
from
.timer
import
Timer
,
TimerError
,
check_time
from
.version_utils
import
digit_version
,
get_git_hash
try
:
import
torch
except
ImportError
:
__all__
=
[
'Config'
,
'ConfigDict'
,
'DictAction'
,
'is_str'
,
'iter_cast'
,
'list_cast'
,
'tuple_cast'
,
'is_seq_of'
,
'is_list_of'
,
'is_tuple_of'
,
'slice_list'
,
'concat_list'
,
'check_prerequisites'
,
'requires_package'
,
'requires_executable'
,
'is_filepath'
,
'fopen'
,
'check_file_exist'
,
'mkdir_or_exist'
,
'symlink'
,
'scandir'
,
'ProgressBar'
,
'track_progress'
,
'track_iter_progress'
,
'track_parallel_progress'
,
'Timer'
,
'TimerError'
,
'check_time'
,
'deprecated_api_warning'
,
'digit_version'
,
'get_git_hash'
,
'import_modules_from_strings'
,
'assert_dict_contains_subset'
,
'assert_attrs_equal'
,
'assert_dict_has_keys'
,
'assert_keys_equal'
,
'check_python_script'
,
'to_1tuple'
,
'to_2tuple'
,
'to_3tuple'
,
'to_4tuple'
,
'to_ntuple'
,
'is_method_overridden'
,
'has_method'
]
else
:
from
.env
import
collect_env
from
.logging
import
get_logger
,
print_log
from
.parrots_jit
import
jit
,
skip_no_elena
from
.parrots_wrapper
import
(
TORCH_VERSION
,
BuildExtension
,
CppExtension
,
CUDAExtension
,
DataLoader
,
PoolDataLoader
,
SyncBatchNorm
,
_AdaptiveAvgPoolNd
,
_AdaptiveMaxPoolNd
,
_AvgPoolNd
,
_BatchNorm
,
_ConvNd
,
_ConvTransposeMixin
,
_InstanceNorm
,
_MaxPoolNd
,
get_build_config
,
is_rocm_pytorch
,
_get_cuda_home
)
from
.registry
import
Registry
,
build_from_cfg
from
.trace
import
is_jit_tracing
__all__
=
[
'Config'
,
'ConfigDict'
,
'DictAction'
,
'collect_env'
,
'get_logger'
,
'print_log'
,
'is_str'
,
'iter_cast'
,
'list_cast'
,
'tuple_cast'
,
'is_seq_of'
,
'is_list_of'
,
'is_tuple_of'
,
'slice_list'
,
'concat_list'
,
'check_prerequisites'
,
'requires_package'
,
'requires_executable'
,
'is_filepath'
,
'fopen'
,
'check_file_exist'
,
'mkdir_or_exist'
,
'symlink'
,
'scandir'
,
'ProgressBar'
,
'track_progress'
,
'track_iter_progress'
,
'track_parallel_progress'
,
'Registry'
,
'build_from_cfg'
,
'Timer'
,
'TimerError'
,
'check_time'
,
'SyncBatchNorm'
,
'_AdaptiveAvgPoolNd'
,
'_AdaptiveMaxPoolNd'
,
'_AvgPoolNd'
,
'_BatchNorm'
,
'_ConvNd'
,
'_ConvTransposeMixin'
,
'_InstanceNorm'
,
'_MaxPoolNd'
,
'get_build_config'
,
'BuildExtension'
,
'CppExtension'
,
'CUDAExtension'
,
'DataLoader'
,
'PoolDataLoader'
,
'TORCH_VERSION'
,
'deprecated_api_warning'
,
'digit_version'
,
'get_git_hash'
,
'import_modules_from_strings'
,
'jit'
,
'skip_no_elena'
,
'assert_dict_contains_subset'
,
'assert_attrs_equal'
,
'assert_dict_has_keys'
,
'assert_keys_equal'
,
'assert_is_norm_layer'
,
'assert_params_all_zeros'
,
'check_python_script'
,
'is_method_overridden'
,
'is_jit_tracing'
,
'is_rocm_pytorch'
,
'_get_cuda_home'
,
'has_method'
]
lavis/common/annotator/uniformer/mmcv/utils/config.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
import
ast
import
copy
import
os
import
os.path
as
osp
import
platform
import
shutil
import
sys
import
tempfile
import
uuid
import
warnings
from
argparse
import
Action
,
ArgumentParser
from
collections
import
abc
from
importlib
import
import_module
from
addict
import
Dict
from
yapf.yapflib.yapf_api
import
FormatCode
from
.misc
import
import_modules_from_strings
from
.path
import
check_file_exist
if
platform
.
system
()
==
'Windows'
:
import
regex
as
re
else
:
import
re
BASE_KEY
=
'_base_'
DELETE_KEY
=
'_delete_'
DEPRECATION_KEY
=
'_deprecation_'
RESERVED_KEYS
=
[
'filename'
,
'text'
,
'pretty_text'
]
class
ConfigDict
(
Dict
):
def
__missing__
(
self
,
name
):
raise
KeyError
(
name
)
def
__getattr__
(
self
,
name
):
try
:
value
=
super
(
ConfigDict
,
self
).
__getattr__
(
name
)
except
KeyError
:
ex
=
AttributeError
(
f
"'
{
self
.
__class__
.
__name__
}
' object has no "
f
"attribute '
{
name
}
'"
)
except
Exception
as
e
:
ex
=
e
else
:
return
value
raise
ex
def
add_args
(
parser
,
cfg
,
prefix
=
''
):
for
k
,
v
in
cfg
.
items
():
if
isinstance
(
v
,
str
):
parser
.
add_argument
(
'--'
+
prefix
+
k
)
elif
isinstance
(
v
,
int
):
parser
.
add_argument
(
'--'
+
prefix
+
k
,
type
=
int
)
elif
isinstance
(
v
,
float
):
parser
.
add_argument
(
'--'
+
prefix
+
k
,
type
=
float
)
elif
isinstance
(
v
,
bool
):
parser
.
add_argument
(
'--'
+
prefix
+
k
,
action
=
'store_true'
)
elif
isinstance
(
v
,
dict
):
add_args
(
parser
,
v
,
prefix
+
k
+
'.'
)
elif
isinstance
(
v
,
abc
.
Iterable
):
parser
.
add_argument
(
'--'
+
prefix
+
k
,
type
=
type
(
v
[
0
]),
nargs
=
'+'
)
else
:
print
(
f
'cannot parse key
{
prefix
+
k
}
of type
{
type
(
v
)
}
'
)
return
parser
class
Config
:
"""A facility for config and config files.
It supports common file formats as configs: python/json/yaml. The interface
is the same as a dict object and also allows access config values as
attributes.
Example:
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
>>> cfg.a
1
>>> cfg.b
{'b1': [0, 1]}
>>> cfg.b.b1
[0, 1]
>>> cfg = Config.fromfile('tests/data/config/a.py')
>>> cfg.filename
"/home/kchen/projects/mmcv/tests/data/config/a.py"
>>> cfg.item4
'test'
>>> cfg
"Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
"{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
"""
@
staticmethod
def
_validate_py_syntax
(
filename
):
with
open
(
filename
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
# Setting encoding explicitly to resolve coding issue on windows
content
=
f
.
read
()
try
:
ast
.
parse
(
content
)
except
SyntaxError
as
e
:
raise
SyntaxError
(
'There are syntax errors in config '
f
'file
{
filename
}
:
{
e
}
'
)
@
staticmethod
def
_substitute_predefined_vars
(
filename
,
temp_config_name
):
file_dirname
=
osp
.
dirname
(
filename
)
file_basename
=
osp
.
basename
(
filename
)
file_basename_no_extension
=
osp
.
splitext
(
file_basename
)[
0
]
file_extname
=
osp
.
splitext
(
filename
)[
1
]
support_templates
=
dict
(
fileDirname
=
file_dirname
,
fileBasename
=
file_basename
,
fileBasenameNoExtension
=
file_basename_no_extension
,
fileExtname
=
file_extname
)
with
open
(
filename
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
# Setting encoding explicitly to resolve coding issue on windows
config_file
=
f
.
read
()
for
key
,
value
in
support_templates
.
items
():
regexp
=
r
'\{\{\s*'
+
str
(
key
)
+
r
'\s*\}\}'
value
=
value
.
replace
(
'
\\
'
,
'/'
)
config_file
=
re
.
sub
(
regexp
,
value
,
config_file
)
with
open
(
temp_config_name
,
'w'
,
encoding
=
'utf-8'
)
as
tmp_config_file
:
tmp_config_file
.
write
(
config_file
)
@
staticmethod
def
_pre_substitute_base_vars
(
filename
,
temp_config_name
):
"""Substitute base variable placehoders to string, so that parsing
would work."""
with
open
(
filename
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
# Setting encoding explicitly to resolve coding issue on windows
config_file
=
f
.
read
()
base_var_dict
=
{}
regexp
=
r
'\{\{\s*'
+
BASE_KEY
+
r
'\.([\w\.]+)\s*\}\}'
base_vars
=
set
(
re
.
findall
(
regexp
,
config_file
))
for
base_var
in
base_vars
:
randstr
=
f
'_
{
base_var
}
_
{
uuid
.
uuid4
().
hex
.
lower
()[:
6
]
}
'
base_var_dict
[
randstr
]
=
base_var
regexp
=
r
'\{\{\s*'
+
BASE_KEY
+
r
'\.'
+
base_var
+
r
'\s*\}\}'
config_file
=
re
.
sub
(
regexp
,
f
'"
{
randstr
}
"'
,
config_file
)
with
open
(
temp_config_name
,
'w'
,
encoding
=
'utf-8'
)
as
tmp_config_file
:
tmp_config_file
.
write
(
config_file
)
return
base_var_dict
@
staticmethod
def
_substitute_base_vars
(
cfg
,
base_var_dict
,
base_cfg
):
"""Substitute variable strings to their actual values."""
cfg
=
copy
.
deepcopy
(
cfg
)
if
isinstance
(
cfg
,
dict
):
for
k
,
v
in
cfg
.
items
():
if
isinstance
(
v
,
str
)
and
v
in
base_var_dict
:
new_v
=
base_cfg
for
new_k
in
base_var_dict
[
v
].
split
(
'.'
):
new_v
=
new_v
[
new_k
]
cfg
[
k
]
=
new_v
elif
isinstance
(
v
,
(
list
,
tuple
,
dict
)):
cfg
[
k
]
=
Config
.
_substitute_base_vars
(
v
,
base_var_dict
,
base_cfg
)
elif
isinstance
(
cfg
,
tuple
):
cfg
=
tuple
(
Config
.
_substitute_base_vars
(
c
,
base_var_dict
,
base_cfg
)
for
c
in
cfg
)
elif
isinstance
(
cfg
,
list
):
cfg
=
[
Config
.
_substitute_base_vars
(
c
,
base_var_dict
,
base_cfg
)
for
c
in
cfg
]
elif
isinstance
(
cfg
,
str
)
and
cfg
in
base_var_dict
:
new_v
=
base_cfg
for
new_k
in
base_var_dict
[
cfg
].
split
(
'.'
):
new_v
=
new_v
[
new_k
]
cfg
=
new_v
return
cfg
@
staticmethod
def
_file2dict
(
filename
,
use_predefined_variables
=
True
):
filename
=
osp
.
abspath
(
osp
.
expanduser
(
filename
))
check_file_exist
(
filename
)
fileExtname
=
osp
.
splitext
(
filename
)[
1
]
if
fileExtname
not
in
[
'.py'
,
'.json'
,
'.yaml'
,
'.yml'
]:
raise
IOError
(
'Only py/yml/yaml/json type are supported now!'
)
with
tempfile
.
TemporaryDirectory
()
as
temp_config_dir
:
temp_config_file
=
tempfile
.
NamedTemporaryFile
(
dir
=
temp_config_dir
,
suffix
=
fileExtname
)
if
platform
.
system
()
==
'Windows'
:
temp_config_file
.
close
()
temp_config_name
=
osp
.
basename
(
temp_config_file
.
name
)
# Substitute predefined variables
if
use_predefined_variables
:
Config
.
_substitute_predefined_vars
(
filename
,
temp_config_file
.
name
)
else
:
shutil
.
copyfile
(
filename
,
temp_config_file
.
name
)
# Substitute base variables from placeholders to strings
base_var_dict
=
Config
.
_pre_substitute_base_vars
(
temp_config_file
.
name
,
temp_config_file
.
name
)
if
filename
.
endswith
(
'.py'
):
temp_module_name
=
osp
.
splitext
(
temp_config_name
)[
0
]
sys
.
path
.
insert
(
0
,
temp_config_dir
)
Config
.
_validate_py_syntax
(
filename
)
mod
=
import_module
(
temp_module_name
)
sys
.
path
.
pop
(
0
)
cfg_dict
=
{
name
:
value
for
name
,
value
in
mod
.
__dict__
.
items
()
if
not
name
.
startswith
(
'__'
)
}
# delete imported module
del
sys
.
modules
[
temp_module_name
]
elif
filename
.
endswith
((
'.yml'
,
'.yaml'
,
'.json'
)):
import
annotator.uniformer.mmcv
as
mmcv
cfg_dict
=
mmcv
.
load
(
temp_config_file
.
name
)
# close temp file
temp_config_file
.
close
()
# check deprecation information
if
DEPRECATION_KEY
in
cfg_dict
:
deprecation_info
=
cfg_dict
.
pop
(
DEPRECATION_KEY
)
warning_msg
=
f
'The config file
{
filename
}
will be deprecated '
\
'in the future.'
if
'expected'
in
deprecation_info
:
warning_msg
+=
f
' Please use
{
deprecation_info
[
"expected"
]
}
'
\
'instead.'
if
'reference'
in
deprecation_info
:
warning_msg
+=
' More information can be found at '
\
f
'
{
deprecation_info
[
"reference"
]
}
'
warnings
.
warn
(
warning_msg
)
cfg_text
=
filename
+
'
\n
'
with
open
(
filename
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
# Setting encoding explicitly to resolve coding issue on windows
cfg_text
+=
f
.
read
()
if
BASE_KEY
in
cfg_dict
:
cfg_dir
=
osp
.
dirname
(
filename
)
base_filename
=
cfg_dict
.
pop
(
BASE_KEY
)
base_filename
=
base_filename
if
isinstance
(
base_filename
,
list
)
else
[
base_filename
]
cfg_dict_list
=
list
()
cfg_text_list
=
list
()
for
f
in
base_filename
:
_cfg_dict
,
_cfg_text
=
Config
.
_file2dict
(
osp
.
join
(
cfg_dir
,
f
))
cfg_dict_list
.
append
(
_cfg_dict
)
cfg_text_list
.
append
(
_cfg_text
)
base_cfg_dict
=
dict
()
for
c
in
cfg_dict_list
:
duplicate_keys
=
base_cfg_dict
.
keys
()
&
c
.
keys
()
if
len
(
duplicate_keys
)
>
0
:
raise
KeyError
(
'Duplicate key is not allowed among bases. '
f
'Duplicate keys:
{
duplicate_keys
}
'
)
base_cfg_dict
.
update
(
c
)
# Substitute base variables from strings to their actual values
cfg_dict
=
Config
.
_substitute_base_vars
(
cfg_dict
,
base_var_dict
,
base_cfg_dict
)
base_cfg_dict
=
Config
.
_merge_a_into_b
(
cfg_dict
,
base_cfg_dict
)
cfg_dict
=
base_cfg_dict
# merge cfg_text
cfg_text_list
.
append
(
cfg_text
)
cfg_text
=
'
\n
'
.
join
(
cfg_text_list
)
return
cfg_dict
,
cfg_text
@
staticmethod
def
_merge_a_into_b
(
a
,
b
,
allow_list_keys
=
False
):
"""merge dict ``a`` into dict ``b`` (non-inplace).
Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
in-place modifications.
Args:
a (dict): The source dict to be merged into ``b``.
b (dict): The origin dict to be fetch keys from ``a``.
allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
are allowed in source ``a`` and will replace the element of the
corresponding index in b if b is a list. Default: False.
Returns:
dict: The modified dict of ``b`` using ``a``.
Examples:
# Normally merge a into b.
>>> Config._merge_a_into_b(
... dict(obj=dict(a=2)), dict(obj=dict(a=1)))
{'obj': {'a': 2}}
# Delete b first and merge a into b.
>>> Config._merge_a_into_b(
... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1)))
{'obj': {'a': 2}}
# b is a list
>>> Config._merge_a_into_b(
... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True)
[{'a': 2}, {'b': 2}]
"""
b
=
b
.
copy
()
for
k
,
v
in
a
.
items
():
if
allow_list_keys
and
k
.
isdigit
()
and
isinstance
(
b
,
list
):
k
=
int
(
k
)
if
len
(
b
)
<=
k
:
raise
KeyError
(
f
'Index
{
k
}
exceeds the length of list
{
b
}
'
)
b
[
k
]
=
Config
.
_merge_a_into_b
(
v
,
b
[
k
],
allow_list_keys
)
elif
isinstance
(
v
,
dict
)
and
k
in
b
and
not
v
.
pop
(
DELETE_KEY
,
False
):
allowed_types
=
(
dict
,
list
)
if
allow_list_keys
else
dict
if
not
isinstance
(
b
[
k
],
allowed_types
):
raise
TypeError
(
f
'
{
k
}
=
{
v
}
in child config cannot inherit from base '
f
'because
{
k
}
is a dict in the child config but is of '
f
'type
{
type
(
b
[
k
])
}
in base config. You may set '
f
'`
{
DELETE_KEY
}
=True` to ignore the base config'
)
b
[
k
]
=
Config
.
_merge_a_into_b
(
v
,
b
[
k
],
allow_list_keys
)
else
:
b
[
k
]
=
v
return
b
@
staticmethod
def
fromfile
(
filename
,
use_predefined_variables
=
True
,
import_custom_modules
=
True
):
cfg_dict
,
cfg_text
=
Config
.
_file2dict
(
filename
,
use_predefined_variables
)
if
import_custom_modules
and
cfg_dict
.
get
(
'custom_imports'
,
None
):
import_modules_from_strings
(
**
cfg_dict
[
'custom_imports'
])
return
Config
(
cfg_dict
,
cfg_text
=
cfg_text
,
filename
=
filename
)
@
staticmethod
def
fromstring
(
cfg_str
,
file_format
):
"""Generate config from config str.
Args:
cfg_str (str): Config str.
file_format (str): Config file format corresponding to the
config str. Only py/yml/yaml/json type are supported now!
Returns:
obj:`Config`: Config obj.
"""
if
file_format
not
in
[
'.py'
,
'.json'
,
'.yaml'
,
'.yml'
]:
raise
IOError
(
'Only py/yml/yaml/json type are supported now!'
)
if
file_format
!=
'.py'
and
'dict('
in
cfg_str
:
# check if users specify a wrong suffix for python
warnings
.
warn
(
'Please check "file_format", the file format may be .py'
)
with
tempfile
.
NamedTemporaryFile
(
'w'
,
encoding
=
'utf-8'
,
suffix
=
file_format
,
delete
=
False
)
as
temp_file
:
temp_file
.
write
(
cfg_str
)
# on windows, previous implementation cause error
# see PR 1077 for details
cfg
=
Config
.
fromfile
(
temp_file
.
name
)
os
.
remove
(
temp_file
.
name
)
return
cfg
@
staticmethod
def
auto_argparser
(
description
=
None
):
"""Generate argparser from config file automatically (experimental)"""
partial_parser
=
ArgumentParser
(
description
=
description
)
partial_parser
.
add_argument
(
'config'
,
help
=
'config file path'
)
cfg_file
=
partial_parser
.
parse_known_args
()[
0
].
config
cfg
=
Config
.
fromfile
(
cfg_file
)
parser
=
ArgumentParser
(
description
=
description
)
parser
.
add_argument
(
'config'
,
help
=
'config file path'
)
add_args
(
parser
,
cfg
)
return
parser
,
cfg
def
__init__
(
self
,
cfg_dict
=
None
,
cfg_text
=
None
,
filename
=
None
):
if
cfg_dict
is
None
:
cfg_dict
=
dict
()
elif
not
isinstance
(
cfg_dict
,
dict
):
raise
TypeError
(
'cfg_dict must be a dict, but '
f
'got
{
type
(
cfg_dict
)
}
'
)
for
key
in
cfg_dict
:
if
key
in
RESERVED_KEYS
:
raise
KeyError
(
f
'
{
key
}
is reserved for config file'
)
super
(
Config
,
self
).
__setattr__
(
'_cfg_dict'
,
ConfigDict
(
cfg_dict
))
super
(
Config
,
self
).
__setattr__
(
'_filename'
,
filename
)
if
cfg_text
:
text
=
cfg_text
elif
filename
:
with
open
(
filename
,
'r'
)
as
f
:
text
=
f
.
read
()
else
:
text
=
''
super
(
Config
,
self
).
__setattr__
(
'_text'
,
text
)
@
property
def
filename
(
self
):
return
self
.
_filename
@
property
def
text
(
self
):
return
self
.
_text
@
property
def
pretty_text
(
self
):
indent
=
4
def
_indent
(
s_
,
num_spaces
):
s
=
s_
.
split
(
'
\n
'
)
if
len
(
s
)
==
1
:
return
s_
first
=
s
.
pop
(
0
)
s
=
[(
num_spaces
*
' '
)
+
line
for
line
in
s
]
s
=
'
\n
'
.
join
(
s
)
s
=
first
+
'
\n
'
+
s
return
s
def
_format_basic_types
(
k
,
v
,
use_mapping
=
False
):
if
isinstance
(
v
,
str
):
v_str
=
f
"'
{
v
}
'"
else
:
v_str
=
str
(
v
)
if
use_mapping
:
k_str
=
f
"'
{
k
}
'"
if
isinstance
(
k
,
str
)
else
str
(
k
)
attr_str
=
f
'
{
k_str
}
:
{
v_str
}
'
else
:
attr_str
=
f
'
{
str
(
k
)
}
=
{
v_str
}
'
attr_str
=
_indent
(
attr_str
,
indent
)
return
attr_str
def
_format_list
(
k
,
v
,
use_mapping
=
False
):
# check if all items in the list are dict
if
all
(
isinstance
(
_
,
dict
)
for
_
in
v
):
v_str
=
'[
\n
'
v_str
+=
'
\n
'
.
join
(
f
'dict(
{
_indent
(
_format_dict
(
v_
),
indent
)
}
),'
for
v_
in
v
).
rstrip
(
','
)
if
use_mapping
:
k_str
=
f
"'
{
k
}
'"
if
isinstance
(
k
,
str
)
else
str
(
k
)
attr_str
=
f
'
{
k_str
}
:
{
v_str
}
'
else
:
attr_str
=
f
'
{
str
(
k
)
}
=
{
v_str
}
'
attr_str
=
_indent
(
attr_str
,
indent
)
+
']'
else
:
attr_str
=
_format_basic_types
(
k
,
v
,
use_mapping
)
return
attr_str
def
_contain_invalid_identifier
(
dict_str
):
contain_invalid_identifier
=
False
for
key_name
in
dict_str
:
contain_invalid_identifier
|=
\
(
not
str
(
key_name
).
isidentifier
())
return
contain_invalid_identifier
def
_format_dict
(
input_dict
,
outest_level
=
False
):
r
=
''
s
=
[]
use_mapping
=
_contain_invalid_identifier
(
input_dict
)
if
use_mapping
:
r
+=
'{'
for
idx
,
(
k
,
v
)
in
enumerate
(
input_dict
.
items
()):
is_last
=
idx
>=
len
(
input_dict
)
-
1
end
=
''
if
outest_level
or
is_last
else
','
if
isinstance
(
v
,
dict
):
v_str
=
'
\n
'
+
_format_dict
(
v
)
if
use_mapping
:
k_str
=
f
"'
{
k
}
'"
if
isinstance
(
k
,
str
)
else
str
(
k
)
attr_str
=
f
'
{
k_str
}
: dict(
{
v_str
}
'
else
:
attr_str
=
f
'
{
str
(
k
)
}
=dict(
{
v_str
}
'
attr_str
=
_indent
(
attr_str
,
indent
)
+
')'
+
end
elif
isinstance
(
v
,
list
):
attr_str
=
_format_list
(
k
,
v
,
use_mapping
)
+
end
else
:
attr_str
=
_format_basic_types
(
k
,
v
,
use_mapping
)
+
end
s
.
append
(
attr_str
)
r
+=
'
\n
'
.
join
(
s
)
if
use_mapping
:
r
+=
'}'
return
r
cfg_dict
=
self
.
_cfg_dict
.
to_dict
()
text
=
_format_dict
(
cfg_dict
,
outest_level
=
True
)
# copied from setup.cfg
yapf_style
=
dict
(
based_on_style
=
'pep8'
,
blank_line_before_nested_class_or_def
=
True
,
split_before_expression_after_opening_paren
=
True
)
text
,
_
=
FormatCode
(
text
,
style_config
=
yapf_style
,
verify
=
True
)
return
text
def
__repr__
(
self
):
return
f
'Config (path:
{
self
.
filename
}
):
{
self
.
_cfg_dict
.
__repr__
()
}
'
def
__len__
(
self
):
return
len
(
self
.
_cfg_dict
)
def
__getattr__
(
self
,
name
):
return
getattr
(
self
.
_cfg_dict
,
name
)
def
__getitem__
(
self
,
name
):
return
self
.
_cfg_dict
.
__getitem__
(
name
)
def
__setattr__
(
self
,
name
,
value
):
if
isinstance
(
value
,
dict
):
value
=
ConfigDict
(
value
)
self
.
_cfg_dict
.
__setattr__
(
name
,
value
)
def
__setitem__
(
self
,
name
,
value
):
if
isinstance
(
value
,
dict
):
value
=
ConfigDict
(
value
)
self
.
_cfg_dict
.
__setitem__
(
name
,
value
)
def
__iter__
(
self
):
return
iter
(
self
.
_cfg_dict
)
def
__getstate__
(
self
):
return
(
self
.
_cfg_dict
,
self
.
_filename
,
self
.
_text
)
def
__setstate__
(
self
,
state
):
_cfg_dict
,
_filename
,
_text
=
state
super
(
Config
,
self
).
__setattr__
(
'_cfg_dict'
,
_cfg_dict
)
super
(
Config
,
self
).
__setattr__
(
'_filename'
,
_filename
)
super
(
Config
,
self
).
__setattr__
(
'_text'
,
_text
)
def
dump
(
self
,
file
=
None
):
cfg_dict
=
super
(
Config
,
self
).
__getattribute__
(
'_cfg_dict'
).
to_dict
()
if
self
.
filename
.
endswith
(
'.py'
):
if
file
is
None
:
return
self
.
pretty_text
else
:
with
open
(
file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
self
.
pretty_text
)
else
:
import
annotator.uniformer.mmcv
as
mmcv
if
file
is
None
:
file_format
=
self
.
filename
.
split
(
'.'
)[
-
1
]
return
mmcv
.
dump
(
cfg_dict
,
file_format
=
file_format
)
else
:
mmcv
.
dump
(
cfg_dict
,
file
)
def
merge_from_dict
(
self
,
options
,
allow_list_keys
=
True
):
"""Merge list into cfg_dict.
Merge the dict parsed by MultipleKVAction into this cfg.
Examples:
>>> options = {'model.backbone.depth': 50,
... 'model.backbone.with_cp':True}
>>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
>>> cfg.merge_from_dict(options)
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
>>> assert cfg_dict == dict(
... model=dict(backbone=dict(depth=50, with_cp=True)))
# Merge list element
>>> cfg = Config(dict(pipeline=[
... dict(type='LoadImage'), dict(type='LoadAnnotations')]))
>>> options = dict(pipeline={'0': dict(type='SelfLoadImage')})
>>> cfg.merge_from_dict(options, allow_list_keys=True)
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
>>> assert cfg_dict == dict(pipeline=[
... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')])
Args:
options (dict): dict of configs to merge from.
allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
are allowed in ``options`` and will replace the element of the
corresponding index in the config if the config is a list.
Default: True.
"""
option_cfg_dict
=
{}
for
full_key
,
v
in
options
.
items
():
d
=
option_cfg_dict
key_list
=
full_key
.
split
(
'.'
)
for
subkey
in
key_list
[:
-
1
]:
d
.
setdefault
(
subkey
,
ConfigDict
())
d
=
d
[
subkey
]
subkey
=
key_list
[
-
1
]
d
[
subkey
]
=
v
cfg_dict
=
super
(
Config
,
self
).
__getattribute__
(
'_cfg_dict'
)
super
(
Config
,
self
).
__setattr__
(
'_cfg_dict'
,
Config
.
_merge_a_into_b
(
option_cfg_dict
,
cfg_dict
,
allow_list_keys
=
allow_list_keys
))
class
DictAction
(
Action
):
"""
argparse action to split an argument into KEY=VALUE form
on the first = and append to a dictionary. List options can
be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
"""
@
staticmethod
def
_parse_int_float_bool
(
val
):
try
:
return
int
(
val
)
except
ValueError
:
pass
try
:
return
float
(
val
)
except
ValueError
:
pass
if
val
.
lower
()
in
[
'true'
,
'false'
]:
return
True
if
val
.
lower
()
==
'true'
else
False
return
val
@
staticmethod
def
_parse_iterable
(
val
):
"""Parse iterable values in the string.
All elements inside '()' or '[]' are treated as iterable values.
Args:
val (str): Value string.
Returns:
list | tuple: The expanded list or tuple from the string.
Examples:
>>> DictAction._parse_iterable('1,2,3')
[1, 2, 3]
>>> DictAction._parse_iterable('[a, b, c]')
['a', 'b', 'c']
>>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
[(1, 2, 3), ['a', 'b'], 'c']
"""
def
find_next_comma
(
string
):
"""Find the position of next comma in the string.
If no ',' is found in the string, return the string length. All
chars inside '()' and '[]' are treated as one element and thus ','
inside these brackets are ignored.
"""
assert
(
string
.
count
(
'('
)
==
string
.
count
(
')'
))
and
(
string
.
count
(
'['
)
==
string
.
count
(
']'
)),
\
f
'Imbalanced brackets exist in
{
string
}
'
end
=
len
(
string
)
for
idx
,
char
in
enumerate
(
string
):
pre
=
string
[:
idx
]
# The string before this ',' is balanced
if
((
char
==
','
)
and
(
pre
.
count
(
'('
)
==
pre
.
count
(
')'
))
and
(
pre
.
count
(
'['
)
==
pre
.
count
(
']'
))):
end
=
idx
break
return
end
# Strip ' and " characters and replace whitespace.
val
=
val
.
strip
(
'
\'\"
'
).
replace
(
' '
,
''
)
is_tuple
=
False
if
val
.
startswith
(
'('
)
and
val
.
endswith
(
')'
):
is_tuple
=
True
val
=
val
[
1
:
-
1
]
elif
val
.
startswith
(
'['
)
and
val
.
endswith
(
']'
):
val
=
val
[
1
:
-
1
]
elif
','
not
in
val
:
# val is a single value
return
DictAction
.
_parse_int_float_bool
(
val
)
values
=
[]
while
len
(
val
)
>
0
:
comma_idx
=
find_next_comma
(
val
)
element
=
DictAction
.
_parse_iterable
(
val
[:
comma_idx
])
values
.
append
(
element
)
val
=
val
[
comma_idx
+
1
:]
if
is_tuple
:
values
=
tuple
(
values
)
return
values
def
__call__
(
self
,
parser
,
namespace
,
values
,
option_string
=
None
):
options
=
{}
for
kv
in
values
:
key
,
val
=
kv
.
split
(
'='
,
maxsplit
=
1
)
options
[
key
]
=
self
.
_parse_iterable
(
val
)
setattr
(
namespace
,
self
.
dest
,
options
)
lavis/common/annotator/uniformer/mmcv/utils/env.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
"""This file holding some environment constant for sharing by other files."""
import
os.path
as
osp
import
subprocess
import
sys
from
collections
import
defaultdict
import
cv2
import
torch
import
annotator.uniformer.mmcv
as
mmcv
from
.parrots_wrapper
import
get_build_config
def
collect_env
():
"""Collect the information of the running environments.
Returns:
dict: The environment information. The following fields are contained.
- sys.platform: The variable of ``sys.platform``.
- Python: Python version.
- CUDA available: Bool, indicating if CUDA is available.
- GPU devices: Device type of each GPU.
- CUDA_HOME (optional): The env var ``CUDA_HOME``.
- NVCC (optional): NVCC version.
- GCC: GCC version, "n/a" if GCC is not installed.
- PyTorch: PyTorch version.
- PyTorch compiling details: The output of
\
``torch.__config__.show()``.
- TorchVision (optional): TorchVision version.
- OpenCV: OpenCV version.
- MMCV: MMCV version.
- MMCV Compiler: The GCC version for compiling MMCV ops.
- MMCV CUDA Compiler: The CUDA version for compiling MMCV ops.
"""
env_info
=
{}
env_info
[
'sys.platform'
]
=
sys
.
platform
env_info
[
'Python'
]
=
sys
.
version
.
replace
(
'
\n
'
,
''
)
cuda_available
=
torch
.
cuda
.
is_available
()
env_info
[
'CUDA available'
]
=
cuda_available
if
cuda_available
:
devices
=
defaultdict
(
list
)
for
k
in
range
(
torch
.
cuda
.
device_count
()):
devices
[
torch
.
cuda
.
get_device_name
(
k
)].
append
(
str
(
k
))
for
name
,
device_ids
in
devices
.
items
():
env_info
[
'GPU '
+
','
.
join
(
device_ids
)]
=
name
from
annotator.uniformer.mmcv.utils.parrots_wrapper
import
_get_cuda_home
CUDA_HOME
=
_get_cuda_home
()
env_info
[
'CUDA_HOME'
]
=
CUDA_HOME
if
CUDA_HOME
is
not
None
and
osp
.
isdir
(
CUDA_HOME
):
try
:
nvcc
=
osp
.
join
(
CUDA_HOME
,
'bin/nvcc'
)
nvcc
=
subprocess
.
check_output
(
f
'"
{
nvcc
}
" -V | tail -n1'
,
shell
=
True
)
nvcc
=
nvcc
.
decode
(
'utf-8'
).
strip
()
except
subprocess
.
SubprocessError
:
nvcc
=
'Not Available'
env_info
[
'NVCC'
]
=
nvcc
try
:
gcc
=
subprocess
.
check_output
(
'gcc --version | head -n1'
,
shell
=
True
)
gcc
=
gcc
.
decode
(
'utf-8'
).
strip
()
env_info
[
'GCC'
]
=
gcc
except
subprocess
.
CalledProcessError
:
# gcc is unavailable
env_info
[
'GCC'
]
=
'n/a'
env_info
[
'PyTorch'
]
=
torch
.
__version__
env_info
[
'PyTorch compiling details'
]
=
get_build_config
()
try
:
import
torchvision
env_info
[
'TorchVision'
]
=
torchvision
.
__version__
except
ModuleNotFoundError
:
pass
env_info
[
'OpenCV'
]
=
cv2
.
__version__
env_info
[
'MMCV'
]
=
mmcv
.
__version__
try
:
from
annotator.uniformer.mmcv.ops
import
get_compiler_version
,
get_compiling_cuda_version
except
ModuleNotFoundError
:
env_info
[
'MMCV Compiler'
]
=
'n/a'
env_info
[
'MMCV CUDA Compiler'
]
=
'n/a'
else
:
env_info
[
'MMCV Compiler'
]
=
get_compiler_version
()
env_info
[
'MMCV CUDA Compiler'
]
=
get_compiling_cuda_version
()
return
env_info
lavis/common/annotator/uniformer/mmcv/utils/ext_loader.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
import
importlib
import
os
import
pkgutil
import
warnings
from
collections
import
namedtuple
import
torch
if
torch
.
__version__
!=
'parrots'
:
def
load_ext
(
name
,
funcs
):
ext
=
importlib
.
import_module
(
'mmcv.'
+
name
)
for
fun
in
funcs
:
assert
hasattr
(
ext
,
fun
),
f
'
{
fun
}
miss in module
{
name
}
'
return
ext
else
:
from
parrots
import
extension
from
parrots.base
import
ParrotsException
has_return_value_ops
=
[
'nms'
,
'softnms'
,
'nms_match'
,
'nms_rotated'
,
'top_pool_forward'
,
'top_pool_backward'
,
'bottom_pool_forward'
,
'bottom_pool_backward'
,
'left_pool_forward'
,
'left_pool_backward'
,
'right_pool_forward'
,
'right_pool_backward'
,
'fused_bias_leakyrelu'
,
'upfirdn2d'
,
'ms_deform_attn_forward'
,
'pixel_group'
,
'contour_expand'
,
]
def
get_fake_func
(
name
,
e
):
def
fake_func
(
*
args
,
**
kwargs
):
warnings
.
warn
(
f
'
{
name
}
is not supported in parrots now'
)
raise
e
return
fake_func
def
load_ext
(
name
,
funcs
):
ExtModule
=
namedtuple
(
'ExtModule'
,
funcs
)
ext_list
=
[]
lib_root
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)))
for
fun
in
funcs
:
try
:
ext_fun
=
extension
.
load
(
fun
,
name
,
lib_dir
=
lib_root
)
except
ParrotsException
as
e
:
if
'No element registered'
not
in
e
.
message
:
warnings
.
warn
(
e
.
message
)
ext_fun
=
get_fake_func
(
fun
,
e
)
ext_list
.
append
(
ext_fun
)
else
:
if
fun
in
has_return_value_ops
:
ext_list
.
append
(
ext_fun
.
op
)
else
:
ext_list
.
append
(
ext_fun
.
op_
)
return
ExtModule
(
*
ext_list
)
def
check_ops_exist
():
ext_loader
=
pkgutil
.
find_loader
(
'mmcv._ext'
)
return
ext_loader
is
not
None
lavis/common/annotator/uniformer/mmcv/utils/logging.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
import
logging
import
torch.distributed
as
dist
logger_initialized
=
{}
def
get_logger
(
name
,
log_file
=
None
,
log_level
=
logging
.
INFO
,
file_mode
=
'w'
):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified and the process rank is 0, a FileHandler
will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
file_mode (str): The file mode used in opening log file.
Defaults to 'w'.
Returns:
logging.Logger: The expected logger.
"""
logger
=
logging
.
getLogger
(
name
)
if
name
in
logger_initialized
:
return
logger
# handle hierarchical names
# e.g., logger "a" is initialized, then logger "a.b" will skip the
# initialization since it is a child of "a".
for
logger_name
in
logger_initialized
:
if
name
.
startswith
(
logger_name
):
return
logger
# handle duplicate logs to the console
# Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET)
# to the root logger. As logger.propagate is True by default, this root
# level handler causes logging messages from rank>0 processes to
# unexpectedly show up on the console, creating much unwanted clutter.
# To fix this issue, we set the root logger's StreamHandler, if any, to log
# at the ERROR level.
for
handler
in
logger
.
root
.
handlers
:
if
type
(
handler
)
is
logging
.
StreamHandler
:
handler
.
setLevel
(
logging
.
ERROR
)
stream_handler
=
logging
.
StreamHandler
()
handlers
=
[
stream_handler
]
if
dist
.
is_available
()
and
dist
.
is_initialized
():
rank
=
dist
.
get_rank
()
else
:
rank
=
0
# only rank 0 will add a FileHandler
if
rank
==
0
and
log_file
is
not
None
:
# Here, the default behaviour of the official logger is 'a'. Thus, we
# provide an interface to change the file mode to the default
# behaviour.
file_handler
=
logging
.
FileHandler
(
log_file
,
file_mode
)
handlers
.
append
(
file_handler
)
formatter
=
logging
.
Formatter
(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
for
handler
in
handlers
:
handler
.
setFormatter
(
formatter
)
handler
.
setLevel
(
log_level
)
logger
.
addHandler
(
handler
)
if
rank
==
0
:
logger
.
setLevel
(
log_level
)
else
:
logger
.
setLevel
(
logging
.
ERROR
)
logger_initialized
[
name
]
=
True
return
logger
def
print_log
(
msg
,
logger
=
None
,
level
=
logging
.
INFO
):
"""Print a log message.
Args:
msg (str): The message to be logged.
logger (logging.Logger | str | None): The logger to be used.
Some special loggers are:
- "silent": no message will be printed.
- other str: the logger obtained with `get_root_logger(logger)`.
- None: The `print()` method will be used to print log messages.
level (int): Logging level. Only available when `logger` is a Logger
object or "root".
"""
if
logger
is
None
:
print
(
msg
)
elif
isinstance
(
logger
,
logging
.
Logger
):
logger
.
log
(
level
,
msg
)
elif
logger
==
'silent'
:
pass
elif
isinstance
(
logger
,
str
):
_logger
=
get_logger
(
logger
)
_logger
.
log
(
level
,
msg
)
else
:
raise
TypeError
(
'logger should be either a logging.Logger object, str, '
f
'"silent" or None, but got
{
type
(
logger
)
}
'
)
lavis/common/annotator/uniformer/mmcv/utils/misc.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
import
collections.abc
import
functools
import
itertools
import
subprocess
import
warnings
from
collections
import
abc
from
importlib
import
import_module
from
inspect
import
getfullargspec
from
itertools
import
repeat
# From PyTorch internals
def
_ntuple
(
n
):
def
parse
(
x
):
if
isinstance
(
x
,
collections
.
abc
.
Iterable
):
return
x
return
tuple
(
repeat
(
x
,
n
))
return
parse
to_1tuple
=
_ntuple
(
1
)
to_2tuple
=
_ntuple
(
2
)
to_3tuple
=
_ntuple
(
3
)
to_4tuple
=
_ntuple
(
4
)
to_ntuple
=
_ntuple
def
is_str
(
x
):
"""Whether the input is an string instance.
Note: This method is deprecated since python 2 is no longer supported.
"""
return
isinstance
(
x
,
str
)
def
import_modules_from_strings
(
imports
,
allow_failed_imports
=
False
):
"""Import modules from the given list of strings.
Args:
imports (list | str | None): The given module names to be imported.
allow_failed_imports (bool): If True, the failed imports will return
None. Otherwise, an ImportError is raise. Default: False.
Returns:
list[module] | module | None: The imported modules.
Examples:
>>> osp, sys = import_modules_from_strings(
... ['os.path', 'sys'])
>>> import os.path as osp_
>>> import sys as sys_
>>> assert osp == osp_
>>> assert sys == sys_
"""
if
not
imports
:
return
single_import
=
False
if
isinstance
(
imports
,
str
):
single_import
=
True
imports
=
[
imports
]
if
not
isinstance
(
imports
,
list
):
raise
TypeError
(
f
'custom_imports must be a list but got type
{
type
(
imports
)
}
'
)
imported
=
[]
for
imp
in
imports
:
if
not
isinstance
(
imp
,
str
):
raise
TypeError
(
f
'
{
imp
}
is of type
{
type
(
imp
)
}
and cannot be imported.'
)
try
:
imported_tmp
=
import_module
(
imp
)
except
ImportError
:
if
allow_failed_imports
:
warnings
.
warn
(
f
'
{
imp
}
failed to import and is ignored.'
,
UserWarning
)
imported_tmp
=
None
else
:
raise
ImportError
imported
.
append
(
imported_tmp
)
if
single_import
:
imported
=
imported
[
0
]
return
imported
def
iter_cast
(
inputs
,
dst_type
,
return_type
=
None
):
"""Cast elements of an iterable object into some type.
Args:
inputs (Iterable): The input object.
dst_type (type): Destination type.
return_type (type, optional): If specified, the output object will be
converted to this type, otherwise an iterator.
Returns:
iterator or specified type: The converted object.
"""
if
not
isinstance
(
inputs
,
abc
.
Iterable
):
raise
TypeError
(
'inputs must be an iterable object'
)
if
not
isinstance
(
dst_type
,
type
):
raise
TypeError
(
'"dst_type" must be a valid type'
)
out_iterable
=
map
(
dst_type
,
inputs
)
if
return_type
is
None
:
return
out_iterable
else
:
return
return_type
(
out_iterable
)
def
list_cast
(
inputs
,
dst_type
):
"""Cast elements of an iterable object into a list of some type.
A partial method of :func:`iter_cast`.
"""
return
iter_cast
(
inputs
,
dst_type
,
return_type
=
list
)
def
tuple_cast
(
inputs
,
dst_type
):
"""Cast elements of an iterable object into a tuple of some type.
A partial method of :func:`iter_cast`.
"""
return
iter_cast
(
inputs
,
dst_type
,
return_type
=
tuple
)
def
is_seq_of
(
seq
,
expected_type
,
seq_type
=
None
):
"""Check whether it is a sequence of some type.
Args:
seq (Sequence): The sequence to be checked.
expected_type (type): Expected type of sequence items.
seq_type (type, optional): Expected sequence type.
Returns:
bool: Whether the sequence is valid.
"""
if
seq_type
is
None
:
exp_seq_type
=
abc
.
Sequence
else
:
assert
isinstance
(
seq_type
,
type
)
exp_seq_type
=
seq_type
if
not
isinstance
(
seq
,
exp_seq_type
):
return
False
for
item
in
seq
:
if
not
isinstance
(
item
,
expected_type
):
return
False
return
True
def
is_list_of
(
seq
,
expected_type
):
"""Check whether it is a list of some type.
A partial method of :func:`is_seq_of`.
"""
return
is_seq_of
(
seq
,
expected_type
,
seq_type
=
list
)
def
is_tuple_of
(
seq
,
expected_type
):
"""Check whether it is a tuple of some type.
A partial method of :func:`is_seq_of`.
"""
return
is_seq_of
(
seq
,
expected_type
,
seq_type
=
tuple
)
def
slice_list
(
in_list
,
lens
):
"""Slice a list into several sub lists by a list of given length.
Args:
in_list (list): The list to be sliced.
lens(int or list): The expected length of each out list.
Returns:
list: A list of sliced list.
"""
if
isinstance
(
lens
,
int
):
assert
len
(
in_list
)
%
lens
==
0
lens
=
[
lens
]
*
int
(
len
(
in_list
)
/
lens
)
if
not
isinstance
(
lens
,
list
):
raise
TypeError
(
'"indices" must be an integer or a list of integers'
)
elif
sum
(
lens
)
!=
len
(
in_list
):
raise
ValueError
(
'sum of lens and list length does not '
f
'match:
{
sum
(
lens
)
}
!=
{
len
(
in_list
)
}
'
)
out_list
=
[]
idx
=
0
for
i
in
range
(
len
(
lens
)):
out_list
.
append
(
in_list
[
idx
:
idx
+
lens
[
i
]])
idx
+=
lens
[
i
]
return
out_list
def
concat_list
(
in_list
):
"""Concatenate a list of list into a single list.
Args:
in_list (list): The list of list to be merged.
Returns:
list: The concatenated flat list.
"""
return
list
(
itertools
.
chain
(
*
in_list
))
def
check_prerequisites
(
prerequisites
,
checker
,
msg_tmpl
=
'Prerequisites "{}" are required in method "{}" but not '
'found, please install them first.'
):
# yapf: disable
"""A decorator factory to check if prerequisites are satisfied.
Args:
prerequisites (str of list[str]): Prerequisites to be checked.
checker (callable): The checker method that returns True if a
prerequisite is meet, False otherwise.
msg_tmpl (str): The message template with two variables.
Returns:
decorator: A specific decorator.
"""
def
wrap
(
func
):
@
functools
.
wraps
(
func
)
def
wrapped_func
(
*
args
,
**
kwargs
):
requirements
=
[
prerequisites
]
if
isinstance
(
prerequisites
,
str
)
else
prerequisites
missing
=
[]
for
item
in
requirements
:
if
not
checker
(
item
):
missing
.
append
(
item
)
if
missing
:
print
(
msg_tmpl
.
format
(
', '
.
join
(
missing
),
func
.
__name__
))
raise
RuntimeError
(
'Prerequisites not meet.'
)
else
:
return
func
(
*
args
,
**
kwargs
)
return
wrapped_func
return
wrap
def
_check_py_package
(
package
):
try
:
import_module
(
package
)
except
ImportError
:
return
False
else
:
return
True
def
_check_executable
(
cmd
):
if
subprocess
.
call
(
f
'which
{
cmd
}
'
,
shell
=
True
)
!=
0
:
return
False
else
:
return
True
def
requires_package
(
prerequisites
):
"""A decorator to check if some python packages are installed.
Example:
>>> @requires_package('numpy')
>>> func(arg1, args):
>>> return numpy.zeros(1)
array([0.])
>>> @requires_package(['numpy', 'non_package'])
>>> func(arg1, args):
>>> return numpy.zeros(1)
ImportError
"""
return
check_prerequisites
(
prerequisites
,
checker
=
_check_py_package
)
def
requires_executable
(
prerequisites
):
"""A decorator to check if some executable files are installed.
Example:
>>> @requires_executable('ffmpeg')
>>> func(arg1, args):
>>> print(1)
1
"""
return
check_prerequisites
(
prerequisites
,
checker
=
_check_executable
)
def
deprecated_api_warning
(
name_dict
,
cls_name
=
None
):
"""A decorator to check if some arguments are deprecate and try to replace
deprecate src_arg_name to dst_arg_name.
Args:
name_dict(dict):
key (str): Deprecate argument names.
val (str): Expected argument names.
Returns:
func: New function.
"""
def
api_warning_wrapper
(
old_func
):
@
functools
.
wraps
(
old_func
)
def
new_func
(
*
args
,
**
kwargs
):
# get the arg spec of the decorated method
args_info
=
getfullargspec
(
old_func
)
# get name of the function
func_name
=
old_func
.
__name__
if
cls_name
is
not
None
:
func_name
=
f
'
{
cls_name
}
.
{
func_name
}
'
if
args
:
arg_names
=
args_info
.
args
[:
len
(
args
)]
for
src_arg_name
,
dst_arg_name
in
name_dict
.
items
():
if
src_arg_name
in
arg_names
:
warnings
.
warn
(
f
'"
{
src_arg_name
}
" is deprecated in '
f
'`
{
func_name
}
`, please use "
{
dst_arg_name
}
" '
'instead'
)
arg_names
[
arg_names
.
index
(
src_arg_name
)]
=
dst_arg_name
if
kwargs
:
for
src_arg_name
,
dst_arg_name
in
name_dict
.
items
():
if
src_arg_name
in
kwargs
:
assert
dst_arg_name
not
in
kwargs
,
(
f
'The expected behavior is to replace '
f
'the deprecated key `
{
src_arg_name
}
` to '
f
'new key `
{
dst_arg_name
}
`, but got them '
f
'in the arguments at the same time, which '
f
'is confusing. `
{
src_arg_name
}
will be '
f
'deprecated in the future, please '
f
'use `
{
dst_arg_name
}
` instead.'
)
warnings
.
warn
(
f
'"
{
src_arg_name
}
" is deprecated in '
f
'`
{
func_name
}
`, please use "
{
dst_arg_name
}
" '
'instead'
)
kwargs
[
dst_arg_name
]
=
kwargs
.
pop
(
src_arg_name
)
# apply converted arguments to the decorated method
output
=
old_func
(
*
args
,
**
kwargs
)
return
output
return
new_func
return
api_warning_wrapper
def
is_method_overridden
(
method
,
base_class
,
derived_class
):
"""Check if a method of base class is overridden in derived class.
Args:
method (str): the method name to check.
base_class (type): the class of the base class.
derived_class (type | Any): the class or instance of the derived class.
"""
assert
isinstance
(
base_class
,
type
),
\
"base_class doesn't accept instance, Please pass class instead."
if
not
isinstance
(
derived_class
,
type
):
derived_class
=
derived_class
.
__class__
base_method
=
getattr
(
base_class
,
method
)
derived_method
=
getattr
(
derived_class
,
method
)
return
derived_method
!=
base_method
def
has_method
(
obj
:
object
,
method
:
str
)
->
bool
:
"""Check whether the object has a method.
Args:
method (str): The method name to check.
obj (object): The object to check.
Returns:
bool: True if the object has the method else False.
"""
return
hasattr
(
obj
,
method
)
and
callable
(
getattr
(
obj
,
method
))
lavis/common/annotator/uniformer/mmcv/utils/parrots_jit.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
import
os
from
.parrots_wrapper
import
TORCH_VERSION
parrots_jit_option
=
os
.
getenv
(
'PARROTS_JIT_OPTION'
)
if
TORCH_VERSION
==
'parrots'
and
parrots_jit_option
==
'ON'
:
from
parrots.jit
import
pat
as
jit
else
:
def
jit
(
func
=
None
,
check_input
=
None
,
full_shape
=
True
,
derivate
=
False
,
coderize
=
False
,
optimize
=
False
):
def
wrapper
(
func
):
def
wrapper_inner
(
*
args
,
**
kargs
):
return
func
(
*
args
,
**
kargs
)
return
wrapper_inner
if
func
is
None
:
return
wrapper
else
:
return
func
if
TORCH_VERSION
==
'parrots'
:
from
parrots.utils.tester
import
skip_no_elena
else
:
def
skip_no_elena
(
func
):
def
wrapper
(
*
args
,
**
kargs
):
return
func
(
*
args
,
**
kargs
)
return
wrapper
lavis/common/annotator/uniformer/mmcv/utils/parrots_wrapper.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
from
functools
import
partial
import
torch
TORCH_VERSION
=
torch
.
__version__
def
is_rocm_pytorch
()
->
bool
:
is_rocm
=
False
if
TORCH_VERSION
!=
'parrots'
:
try
:
from
torch.utils.cpp_extension
import
ROCM_HOME
is_rocm
=
True
if
((
torch
.
version
.
hip
is
not
None
)
and
(
ROCM_HOME
is
not
None
))
else
False
except
ImportError
:
pass
return
is_rocm
def
_get_cuda_home
():
if
TORCH_VERSION
==
'parrots'
:
from
parrots.utils.build_extension
import
CUDA_HOME
else
:
if
is_rocm_pytorch
():
from
torch.utils.cpp_extension
import
ROCM_HOME
CUDA_HOME
=
ROCM_HOME
else
:
from
torch.utils.cpp_extension
import
CUDA_HOME
return
CUDA_HOME
def
get_build_config
():
if
TORCH_VERSION
==
'parrots'
:
from
parrots.config
import
get_build_info
return
get_build_info
()
else
:
return
torch
.
__config__
.
show
()
def
_get_conv
():
if
TORCH_VERSION
==
'parrots'
:
from
parrots.nn.modules.conv
import
_ConvNd
,
_ConvTransposeMixin
else
:
from
torch.nn.modules.conv
import
_ConvNd
,
_ConvTransposeMixin
return
_ConvNd
,
_ConvTransposeMixin
def
_get_dataloader
():
if
TORCH_VERSION
==
'parrots'
:
from
torch.utils.data
import
DataLoader
,
PoolDataLoader
else
:
from
torch.utils.data
import
DataLoader
PoolDataLoader
=
DataLoader
return
DataLoader
,
PoolDataLoader
def
_get_extension
():
if
TORCH_VERSION
==
'parrots'
:
from
parrots.utils.build_extension
import
BuildExtension
,
Extension
CppExtension
=
partial
(
Extension
,
cuda
=
False
)
CUDAExtension
=
partial
(
Extension
,
cuda
=
True
)
else
:
from
torch.utils.cpp_extension
import
(
BuildExtension
,
CppExtension
,
CUDAExtension
)
return
BuildExtension
,
CppExtension
,
CUDAExtension
def
_get_pool
():
if
TORCH_VERSION
==
'parrots'
:
from
parrots.nn.modules.pool
import
(
_AdaptiveAvgPoolNd
,
_AdaptiveMaxPoolNd
,
_AvgPoolNd
,
_MaxPoolNd
)
else
:
from
torch.nn.modules.pooling
import
(
_AdaptiveAvgPoolNd
,
_AdaptiveMaxPoolNd
,
_AvgPoolNd
,
_MaxPoolNd
)
return
_AdaptiveAvgPoolNd
,
_AdaptiveMaxPoolNd
,
_AvgPoolNd
,
_MaxPoolNd
def
_get_norm
():
if
TORCH_VERSION
==
'parrots'
:
from
parrots.nn.modules.batchnorm
import
_BatchNorm
,
_InstanceNorm
SyncBatchNorm_
=
torch
.
nn
.
SyncBatchNorm2d
else
:
from
torch.nn.modules.instancenorm
import
_InstanceNorm
from
torch.nn.modules.batchnorm
import
_BatchNorm
SyncBatchNorm_
=
torch
.
nn
.
SyncBatchNorm
return
_BatchNorm
,
_InstanceNorm
,
SyncBatchNorm_
_ConvNd
,
_ConvTransposeMixin
=
_get_conv
()
DataLoader
,
PoolDataLoader
=
_get_dataloader
()
BuildExtension
,
CppExtension
,
CUDAExtension
=
_get_extension
()
_BatchNorm
,
_InstanceNorm
,
SyncBatchNorm_
=
_get_norm
()
_AdaptiveAvgPoolNd
,
_AdaptiveMaxPoolNd
,
_AvgPoolNd
,
_MaxPoolNd
=
_get_pool
()
class
SyncBatchNorm
(
SyncBatchNorm_
):
def
_check_input_dim
(
self
,
input
):
if
TORCH_VERSION
==
'parrots'
:
if
input
.
dim
()
<
2
:
raise
ValueError
(
f
'expected at least 2D input (got
{
input
.
dim
()
}
D input)'
)
else
:
super
().
_check_input_dim
(
input
)
lavis/common/annotator/uniformer/mmcv/utils/path.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
import
os
import
os.path
as
osp
from
pathlib
import
Path
from
.misc
import
is_str
def
is_filepath
(
x
):
return
is_str
(
x
)
or
isinstance
(
x
,
Path
)
def
fopen
(
filepath
,
*
args
,
**
kwargs
):
if
is_str
(
filepath
):
return
open
(
filepath
,
*
args
,
**
kwargs
)
elif
isinstance
(
filepath
,
Path
):
return
filepath
.
open
(
*
args
,
**
kwargs
)
raise
ValueError
(
'`filepath` should be a string or a Path'
)
def
check_file_exist
(
filename
,
msg_tmpl
=
'file "{}" does not exist'
):
if
not
osp
.
isfile
(
filename
):
raise
FileNotFoundError
(
msg_tmpl
.
format
(
filename
))
def
mkdir_or_exist
(
dir_name
,
mode
=
0o777
):
if
dir_name
==
''
:
return
dir_name
=
osp
.
expanduser
(
dir_name
)
os
.
makedirs
(
dir_name
,
mode
=
mode
,
exist_ok
=
True
)
def
symlink
(
src
,
dst
,
overwrite
=
True
,
**
kwargs
):
if
os
.
path
.
lexists
(
dst
)
and
overwrite
:
os
.
remove
(
dst
)
os
.
symlink
(
src
,
dst
,
**
kwargs
)
def
scandir
(
dir_path
,
suffix
=
None
,
recursive
=
False
,
case_sensitive
=
True
):
"""Scan a directory to find the interested files.
Args:
dir_path (str | obj:`Path`): Path of the directory.
suffix (str | tuple(str), optional): File suffix that we are
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the
directory. Default: False.
case_sensitive (bool, optional) : If set to False, ignore the case of
suffix. Default: True.
Returns:
A generator for all the interested files with relative paths.
"""
if
isinstance
(
dir_path
,
(
str
,
Path
)):
dir_path
=
str
(
dir_path
)
else
:
raise
TypeError
(
'"dir_path" must be a string or Path object'
)
if
(
suffix
is
not
None
)
and
not
isinstance
(
suffix
,
(
str
,
tuple
)):
raise
TypeError
(
'"suffix" must be a string or tuple of strings'
)
if
suffix
is
not
None
and
not
case_sensitive
:
suffix
=
suffix
.
lower
()
if
isinstance
(
suffix
,
str
)
else
tuple
(
item
.
lower
()
for
item
in
suffix
)
root
=
dir_path
def
_scandir
(
dir_path
,
suffix
,
recursive
,
case_sensitive
):
for
entry
in
os
.
scandir
(
dir_path
):
if
not
entry
.
name
.
startswith
(
'.'
)
and
entry
.
is_file
():
rel_path
=
osp
.
relpath
(
entry
.
path
,
root
)
_rel_path
=
rel_path
if
case_sensitive
else
rel_path
.
lower
()
if
suffix
is
None
or
_rel_path
.
endswith
(
suffix
):
yield
rel_path
elif
recursive
and
os
.
path
.
isdir
(
entry
.
path
):
# scan recursively if entry.path is a directory
yield
from
_scandir
(
entry
.
path
,
suffix
,
recursive
,
case_sensitive
)
return
_scandir
(
dir_path
,
suffix
,
recursive
,
case_sensitive
)
def
find_vcs_root
(
path
,
markers
=
(
'.git'
,
)):
"""Finds the root directory (including itself) of specified markers.
Args:
path (str): Path of directory or file.
markers (list[str], optional): List of file or directory names.
Returns:
The directory contained one of the markers or None if not found.
"""
if
osp
.
isfile
(
path
):
path
=
osp
.
dirname
(
path
)
prev
,
cur
=
None
,
osp
.
abspath
(
osp
.
expanduser
(
path
))
while
cur
!=
prev
:
if
any
(
osp
.
exists
(
osp
.
join
(
cur
,
marker
))
for
marker
in
markers
):
return
cur
prev
,
cur
=
cur
,
osp
.
split
(
cur
)[
0
]
return
None
lavis/common/annotator/uniformer/mmcv/utils/progressbar.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
import
sys
from
collections.abc
import
Iterable
from
multiprocessing
import
Pool
from
shutil
import
get_terminal_size
from
.timer
import
Timer
class
ProgressBar
:
"""A progress bar which can print the progress."""
def
__init__
(
self
,
task_num
=
0
,
bar_width
=
50
,
start
=
True
,
file
=
sys
.
stdout
):
self
.
task_num
=
task_num
self
.
bar_width
=
bar_width
self
.
completed
=
0
self
.
file
=
file
if
start
:
self
.
start
()
@
property
def
terminal_width
(
self
):
width
,
_
=
get_terminal_size
()
return
width
def
start
(
self
):
if
self
.
task_num
>
0
:
self
.
file
.
write
(
f
'[
{
" "
*
self
.
bar_width
}
] 0/
{
self
.
task_num
}
, '
'elapsed: 0s, ETA:'
)
else
:
self
.
file
.
write
(
'completed: 0, elapsed: 0s'
)
self
.
file
.
flush
()
self
.
timer
=
Timer
()
def
update
(
self
,
num_tasks
=
1
):
assert
num_tasks
>
0
self
.
completed
+=
num_tasks
elapsed
=
self
.
timer
.
since_start
()
if
elapsed
>
0
:
fps
=
self
.
completed
/
elapsed
else
:
fps
=
float
(
'inf'
)
if
self
.
task_num
>
0
:
percentage
=
self
.
completed
/
float
(
self
.
task_num
)
eta
=
int
(
elapsed
*
(
1
-
percentage
)
/
percentage
+
0.5
)
msg
=
f
'
\r
[{{}}]
{
self
.
completed
}
/
{
self
.
task_num
}
, '
\
f
'
{
fps
:.
1
f
}
task/s, elapsed:
{
int
(
elapsed
+
0.5
)
}
s, '
\
f
'ETA:
{
eta
:
5
}
s'
bar_width
=
min
(
self
.
bar_width
,
int
(
self
.
terminal_width
-
len
(
msg
))
+
2
,
int
(
self
.
terminal_width
*
0.6
))
bar_width
=
max
(
2
,
bar_width
)
mark_width
=
int
(
bar_width
*
percentage
)
bar_chars
=
'>'
*
mark_width
+
' '
*
(
bar_width
-
mark_width
)
self
.
file
.
write
(
msg
.
format
(
bar_chars
))
else
:
self
.
file
.
write
(
f
'completed:
{
self
.
completed
}
, elapsed:
{
int
(
elapsed
+
0.5
)
}
s,'
f
'
{
fps
:.
1
f
}
tasks/s'
)
self
.
file
.
flush
()
def
track_progress
(
func
,
tasks
,
bar_width
=
50
,
file
=
sys
.
stdout
,
**
kwargs
):
"""Track the progress of tasks execution with a progress bar.
Tasks are done with a simple for-loop.
Args:
func (callable): The function to be applied to each task.
tasks (list or tuple[Iterable, int]): A list of tasks or
(tasks, total num).
bar_width (int): Width of progress bar.
Returns:
list: The task results.
"""
if
isinstance
(
tasks
,
tuple
):
assert
len
(
tasks
)
==
2
assert
isinstance
(
tasks
[
0
],
Iterable
)
assert
isinstance
(
tasks
[
1
],
int
)
task_num
=
tasks
[
1
]
tasks
=
tasks
[
0
]
elif
isinstance
(
tasks
,
Iterable
):
task_num
=
len
(
tasks
)
else
:
raise
TypeError
(
'"tasks" must be an iterable object or a (iterator, int) tuple'
)
prog_bar
=
ProgressBar
(
task_num
,
bar_width
,
file
=
file
)
results
=
[]
for
task
in
tasks
:
results
.
append
(
func
(
task
,
**
kwargs
))
prog_bar
.
update
()
prog_bar
.
file
.
write
(
'
\n
'
)
return
results
def
init_pool
(
process_num
,
initializer
=
None
,
initargs
=
None
):
if
initializer
is
None
:
return
Pool
(
process_num
)
elif
initargs
is
None
:
return
Pool
(
process_num
,
initializer
)
else
:
if
not
isinstance
(
initargs
,
tuple
):
raise
TypeError
(
'"initargs" must be a tuple'
)
return
Pool
(
process_num
,
initializer
,
initargs
)
def
track_parallel_progress
(
func
,
tasks
,
nproc
,
initializer
=
None
,
initargs
=
None
,
bar_width
=
50
,
chunksize
=
1
,
skip_first
=
False
,
keep_order
=
True
,
file
=
sys
.
stdout
):
"""Track the progress of parallel task execution with a progress bar.
The built-in :mod:`multiprocessing` module is used for process pools and
tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
Args:
func (callable): The function to be applied to each task.
tasks (list or tuple[Iterable, int]): A list of tasks or
(tasks, total num).
nproc (int): Process (worker) number.
initializer (None or callable): Refer to :class:`multiprocessing.Pool`
for details.
initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for
details.
chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
bar_width (int): Width of progress bar.
skip_first (bool): Whether to skip the first sample for each worker
when estimating fps, since the initialization step may takes
longer.
keep_order (bool): If True, :func:`Pool.imap` is used, otherwise
:func:`Pool.imap_unordered` is used.
Returns:
list: The task results.
"""
if
isinstance
(
tasks
,
tuple
):
assert
len
(
tasks
)
==
2
assert
isinstance
(
tasks
[
0
],
Iterable
)
assert
isinstance
(
tasks
[
1
],
int
)
task_num
=
tasks
[
1
]
tasks
=
tasks
[
0
]
elif
isinstance
(
tasks
,
Iterable
):
task_num
=
len
(
tasks
)
else
:
raise
TypeError
(
'"tasks" must be an iterable object or a (iterator, int) tuple'
)
pool
=
init_pool
(
nproc
,
initializer
,
initargs
)
start
=
not
skip_first
task_num
-=
nproc
*
chunksize
*
int
(
skip_first
)
prog_bar
=
ProgressBar
(
task_num
,
bar_width
,
start
,
file
=
file
)
results
=
[]
if
keep_order
:
gen
=
pool
.
imap
(
func
,
tasks
,
chunksize
)
else
:
gen
=
pool
.
imap_unordered
(
func
,
tasks
,
chunksize
)
for
result
in
gen
:
results
.
append
(
result
)
if
skip_first
:
if
len
(
results
)
<
nproc
*
chunksize
:
continue
elif
len
(
results
)
==
nproc
*
chunksize
:
prog_bar
.
start
()
continue
prog_bar
.
update
()
prog_bar
.
file
.
write
(
'
\n
'
)
pool
.
close
()
pool
.
join
()
return
results
def
track_iter_progress
(
tasks
,
bar_width
=
50
,
file
=
sys
.
stdout
):
"""Track the progress of tasks iteration or enumeration with a progress
bar.
Tasks are yielded with a simple for-loop.
Args:
tasks (list or tuple[Iterable, int]): A list of tasks or
(tasks, total num).
bar_width (int): Width of progress bar.
Yields:
list: The task results.
"""
if
isinstance
(
tasks
,
tuple
):
assert
len
(
tasks
)
==
2
assert
isinstance
(
tasks
[
0
],
Iterable
)
assert
isinstance
(
tasks
[
1
],
int
)
task_num
=
tasks
[
1
]
tasks
=
tasks
[
0
]
elif
isinstance
(
tasks
,
Iterable
):
task_num
=
len
(
tasks
)
else
:
raise
TypeError
(
'"tasks" must be an iterable object or a (iterator, int) tuple'
)
prog_bar
=
ProgressBar
(
task_num
,
bar_width
,
file
=
file
)
for
task
in
tasks
:
yield
task
prog_bar
.
update
()
prog_bar
.
file
.
write
(
'
\n
'
)
lavis/common/annotator/uniformer/mmcv/utils/registry.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
import
inspect
import
warnings
from
functools
import
partial
from
.misc
import
is_seq_of
def
build_from_cfg
(
cfg
,
registry
,
default_args
=
None
):
"""Build a module from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
object: The constructed object.
"""
if
not
isinstance
(
cfg
,
dict
):
raise
TypeError
(
f
'cfg must be a dict, but got
{
type
(
cfg
)
}
'
)
if
'type'
not
in
cfg
:
if
default_args
is
None
or
'type'
not
in
default_args
:
raise
KeyError
(
'`cfg` or `default_args` must contain the key "type", '
f
'but got
{
cfg
}
\n
{
default_args
}
'
)
if
not
isinstance
(
registry
,
Registry
):
raise
TypeError
(
'registry must be an mmcv.Registry object, '
f
'but got
{
type
(
registry
)
}
'
)
if
not
(
isinstance
(
default_args
,
dict
)
or
default_args
is
None
):
raise
TypeError
(
'default_args must be a dict or None, '
f
'but got
{
type
(
default_args
)
}
'
)
args
=
cfg
.
copy
()
if
default_args
is
not
None
:
for
name
,
value
in
default_args
.
items
():
args
.
setdefault
(
name
,
value
)
obj_type
=
args
.
pop
(
'type'
)
if
isinstance
(
obj_type
,
str
):
obj_cls
=
registry
.
get
(
obj_type
)
if
obj_cls
is
None
:
raise
KeyError
(
f
'
{
obj_type
}
is not in the
{
registry
.
name
}
registry'
)
elif
inspect
.
isclass
(
obj_type
):
obj_cls
=
obj_type
else
:
raise
TypeError
(
f
'type must be a str or valid type, but got
{
type
(
obj_type
)
}
'
)
try
:
return
obj_cls
(
**
args
)
except
Exception
as
e
:
# Normal TypeError does not print class name.
raise
type
(
e
)(
f
'
{
obj_cls
.
__name__
}
:
{
e
}
'
)
class
Registry
:
"""A registry to map strings to classes.
Registered object could be built from registry.
Example:
>>> MODELS = Registry('models')
>>> @MODELS.register_module()
>>> class ResNet:
>>> pass
>>> resnet = MODELS.build(dict(type='ResNet'))
Please refer to
https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
advanced usage.
Args:
name (str): Registry name.
build_func(func, optional): Build function to construct instance from
Registry, func:`build_from_cfg` is used if neither ``parent`` or
``build_func`` is specified. If ``parent`` is specified and
``build_func`` is not given, ``build_func`` will be inherited
from ``parent``. Default: None.
parent (Registry, optional): Parent registry. The class registered in
children registry could be built from parent. Default: None.
scope (str, optional): The scope of registry. It is the key to search
for children registry. If not specified, scope will be the name of
the package where class is defined, e.g. mmdet, mmcls, mmseg.
Default: None.
"""
def
__init__
(
self
,
name
,
build_func
=
None
,
parent
=
None
,
scope
=
None
):
self
.
_name
=
name
self
.
_module_dict
=
dict
()
self
.
_children
=
dict
()
self
.
_scope
=
self
.
infer_scope
()
if
scope
is
None
else
scope
# self.build_func will be set with the following priority:
# 1. build_func
# 2. parent.build_func
# 3. build_from_cfg
if
build_func
is
None
:
if
parent
is
not
None
:
self
.
build_func
=
parent
.
build_func
else
:
self
.
build_func
=
build_from_cfg
else
:
self
.
build_func
=
build_func
if
parent
is
not
None
:
assert
isinstance
(
parent
,
Registry
)
parent
.
_add_children
(
self
)
self
.
parent
=
parent
else
:
self
.
parent
=
None
def
__len__
(
self
):
return
len
(
self
.
_module_dict
)
def
__contains__
(
self
,
key
):
return
self
.
get
(
key
)
is
not
None
def
__repr__
(
self
):
format_str
=
self
.
__class__
.
__name__
+
\
f
'(name=
{
self
.
_name
}
, '
\
f
'items=
{
self
.
_module_dict
}
)'
return
format_str
@
staticmethod
def
infer_scope
():
"""Infer the scope of registry.
The name of the package where registry is defined will be returned.
Example:
# in mmdet/models/backbone/resnet.py
>>> MODELS = Registry('models')
>>> @MODELS.register_module()
>>> class ResNet:
>>> pass
The scope of ``ResNet`` will be ``mmdet``.
Returns:
scope (str): The inferred scope name.
"""
# inspect.stack() trace where this function is called, the index-2
# indicates the frame where `infer_scope()` is called
filename
=
inspect
.
getmodule
(
inspect
.
stack
()[
2
][
0
]).
__name__
split_filename
=
filename
.
split
(
'.'
)
return
split_filename
[
0
]
@
staticmethod
def
split_scope_key
(
key
):
"""Split scope and key.
The first scope will be split from key.
Examples:
>>> Registry.split_scope_key('mmdet.ResNet')
'mmdet', 'ResNet'
>>> Registry.split_scope_key('ResNet')
None, 'ResNet'
Return:
scope (str, None): The first scope.
key (str): The remaining key.
"""
split_index
=
key
.
find
(
'.'
)
if
split_index
!=
-
1
:
return
key
[:
split_index
],
key
[
split_index
+
1
:]
else
:
return
None
,
key
@
property
def
name
(
self
):
return
self
.
_name
@
property
def
scope
(
self
):
return
self
.
_scope
@
property
def
module_dict
(
self
):
return
self
.
_module_dict
@
property
def
children
(
self
):
return
self
.
_children
def
get
(
self
,
key
):
"""Get the registry record.
Args:
key (str): The class name in string format.
Returns:
class: The corresponding class.
"""
scope
,
real_key
=
self
.
split_scope_key
(
key
)
if
scope
is
None
or
scope
==
self
.
_scope
:
# get from self
if
real_key
in
self
.
_module_dict
:
return
self
.
_module_dict
[
real_key
]
else
:
# get from self._children
if
scope
in
self
.
_children
:
return
self
.
_children
[
scope
].
get
(
real_key
)
else
:
# goto root
parent
=
self
.
parent
while
parent
.
parent
is
not
None
:
parent
=
parent
.
parent
return
parent
.
get
(
key
)
def
build
(
self
,
*
args
,
**
kwargs
):
return
self
.
build_func
(
*
args
,
**
kwargs
,
registry
=
self
)
def
_add_children
(
self
,
registry
):
"""Add children for a registry.
The ``registry`` will be added as children based on its scope.
The parent registry could build objects from children registry.
Example:
>>> models = Registry('models')
>>> mmdet_models = Registry('models', parent=models)
>>> @mmdet_models.register_module()
>>> class ResNet:
>>> pass
>>> resnet = models.build(dict(type='mmdet.ResNet'))
"""
assert
isinstance
(
registry
,
Registry
)
assert
registry
.
scope
is
not
None
assert
registry
.
scope
not
in
self
.
children
,
\
f
'scope
{
registry
.
scope
}
exists in
{
self
.
name
}
registry'
self
.
children
[
registry
.
scope
]
=
registry
def
_register_module
(
self
,
module_class
,
module_name
=
None
,
force
=
False
):
if
not
inspect
.
isclass
(
module_class
):
raise
TypeError
(
'module must be a class, '
f
'but got
{
type
(
module_class
)
}
'
)
if
module_name
is
None
:
module_name
=
module_class
.
__name__
if
isinstance
(
module_name
,
str
):
module_name
=
[
module_name
]
for
name
in
module_name
:
if
not
force
and
name
in
self
.
_module_dict
:
raise
KeyError
(
f
'
{
name
}
is already registered '
f
'in
{
self
.
name
}
'
)
self
.
_module_dict
[
name
]
=
module_class
def
deprecated_register_module
(
self
,
cls
=
None
,
force
=
False
):
warnings
.
warn
(
'The old API of register_module(module, force=False) '
'is deprecated and will be removed, please use the new API '
'register_module(name=None, force=False, module=None) instead.'
)
if
cls
is
None
:
return
partial
(
self
.
deprecated_register_module
,
force
=
force
)
self
.
_register_module
(
cls
,
force
=
force
)
return
cls
def
register_module
(
self
,
name
=
None
,
force
=
False
,
module
=
None
):
"""Register a module.
A record will be added to `self._module_dict`, whose key is the class
name or the specified name, and value is the class itself.
It can be used as a decorator or a normal function.
Example:
>>> backbones = Registry('backbone')
>>> @backbones.register_module()
>>> class ResNet:
>>> pass
>>> backbones = Registry('backbone')
>>> @backbones.register_module(name='mnet')
>>> class MobileNet:
>>> pass
>>> backbones = Registry('backbone')
>>> class ResNet:
>>> pass
>>> backbones.register_module(ResNet)
Args:
name (str | None): The module name to be registered. If not
specified, the class name will be used.
force (bool, optional): Whether to override an existing class with
the same name. Default: False.
module (type): Module class to be registered.
"""
if
not
isinstance
(
force
,
bool
):
raise
TypeError
(
f
'force must be a boolean, but got
{
type
(
force
)
}
'
)
# NOTE: This is a walkaround to be compatible with the old api,
# while it may introduce unexpected bugs.
if
isinstance
(
name
,
type
):
return
self
.
deprecated_register_module
(
name
,
force
=
force
)
# raise the error ahead of time
if
not
(
name
is
None
or
isinstance
(
name
,
str
)
or
is_seq_of
(
name
,
str
)):
raise
TypeError
(
'name must be either of None, an instance of str or a sequence'
f
' of str, but got
{
type
(
name
)
}
'
)
# use it as a normal method: x.register_module(module=SomeClass)
if
module
is
not
None
:
self
.
_register_module
(
module_class
=
module
,
module_name
=
name
,
force
=
force
)
return
module
# use it as a decorator: @x.register_module()
def
_register
(
cls
):
self
.
_register_module
(
module_class
=
cls
,
module_name
=
name
,
force
=
force
)
return
cls
return
_register
lavis/common/annotator/uniformer/mmcv/utils/testing.py
0 → 100644
View file @
c04f261a
# Copyright (c) Open-MMLab.
import
sys
from
collections.abc
import
Iterable
from
runpy
import
run_path
from
shlex
import
split
from
typing
import
Any
,
Dict
,
List
from
unittest.mock
import
patch
def
check_python_script
(
cmd
):
"""Run the python cmd script with `__main__`. The difference between
`os.system` is that, this function exectues code in the current process, so
that it can be tracked by coverage tools. Currently it supports two forms:
- ./tests/data/scripts/hello.py zz
- python tests/data/scripts/hello.py zz
"""
args
=
split
(
cmd
)
if
args
[
0
]
==
'python'
:
args
=
args
[
1
:]
with
patch
.
object
(
sys
,
'argv'
,
args
):
run_path
(
args
[
0
],
run_name
=
'__main__'
)
def
_any
(
judge_result
):
"""Since built-in ``any`` works only when the element of iterable is not
iterable, implement the function."""
if
not
isinstance
(
judge_result
,
Iterable
):
return
judge_result
try
:
for
element
in
judge_result
:
if
_any
(
element
):
return
True
except
TypeError
:
# Maybe encounter the case: torch.tensor(True) | torch.tensor(False)
if
judge_result
:
return
True
return
False
def
assert_dict_contains_subset
(
dict_obj
:
Dict
[
Any
,
Any
],
expected_subset
:
Dict
[
Any
,
Any
])
->
bool
:
"""Check if the dict_obj contains the expected_subset.
Args:
dict_obj (Dict[Any, Any]): Dict object to be checked.
expected_subset (Dict[Any, Any]): Subset expected to be contained in
dict_obj.
Returns:
bool: Whether the dict_obj contains the expected_subset.
"""
for
key
,
value
in
expected_subset
.
items
():
if
key
not
in
dict_obj
.
keys
()
or
_any
(
dict_obj
[
key
]
!=
value
):
return
False
return
True
def
assert_attrs_equal
(
obj
:
Any
,
expected_attrs
:
Dict
[
str
,
Any
])
->
bool
:
"""Check if attribute of class object is correct.
Args:
obj (object): Class object to be checked.
expected_attrs (Dict[str, Any]): Dict of the expected attrs.
Returns:
bool: Whether the attribute of class object is correct.
"""
for
attr
,
value
in
expected_attrs
.
items
():
if
not
hasattr
(
obj
,
attr
)
or
_any
(
getattr
(
obj
,
attr
)
!=
value
):
return
False
return
True
def
assert_dict_has_keys
(
obj
:
Dict
[
str
,
Any
],
expected_keys
:
List
[
str
])
->
bool
:
"""Check if the obj has all the expected_keys.
Args:
obj (Dict[str, Any]): Object to be checked.
expected_keys (List[str]): Keys expected to contained in the keys of
the obj.
Returns:
bool: Whether the obj has the expected keys.
"""
return
set
(
expected_keys
).
issubset
(
set
(
obj
.
keys
()))
def
assert_keys_equal
(
result_keys
:
List
[
str
],
target_keys
:
List
[
str
])
->
bool
:
"""Check if target_keys is equal to result_keys.
Args:
result_keys (List[str]): Result keys to be checked.
target_keys (List[str]): Target keys to be checked.
Returns:
bool: Whether target_keys is equal to result_keys.
"""
return
set
(
result_keys
)
==
set
(
target_keys
)
def
assert_is_norm_layer
(
module
)
->
bool
:
"""Check if the module is a norm layer.
Args:
module (nn.Module): The module to be checked.
Returns:
bool: Whether the module is a norm layer.
"""
from
.parrots_wrapper
import
_BatchNorm
,
_InstanceNorm
from
torch.nn
import
GroupNorm
,
LayerNorm
norm_layer_candidates
=
(
_BatchNorm
,
_InstanceNorm
,
GroupNorm
,
LayerNorm
)
return
isinstance
(
module
,
norm_layer_candidates
)
def
assert_params_all_zeros
(
module
)
->
bool
:
"""Check if the parameters of the module is all zeros.
Args:
module (nn.Module): The module to be checked.
Returns:
bool: Whether the parameters of the module is all zeros.
"""
weight_data
=
module
.
weight
.
data
is_weight_zero
=
weight_data
.
allclose
(
weight_data
.
new_zeros
(
weight_data
.
size
()))
if
hasattr
(
module
,
'bias'
)
and
module
.
bias
is
not
None
:
bias_data
=
module
.
bias
.
data
is_bias_zero
=
bias_data
.
allclose
(
bias_data
.
new_zeros
(
bias_data
.
size
()))
else
:
is_bias_zero
=
True
return
is_weight_zero
and
is_bias_zero
lavis/common/annotator/uniformer/mmcv/utils/timer.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
from
time
import
time
class
TimerError
(
Exception
):
def
__init__
(
self
,
message
):
self
.
message
=
message
super
(
TimerError
,
self
).
__init__
(
message
)
class
Timer
:
"""A flexible Timer class.
:Example:
>>> import time
>>> import annotator.uniformer.mmcv as mmcv
>>> with mmcv.Timer():
>>> # simulate a code block that will run for 1s
>>> time.sleep(1)
1.000
>>> with mmcv.Timer(print_tmpl='it takes {:.1f} seconds'):
>>> # simulate a code block that will run for 1s
>>> time.sleep(1)
it takes 1.0 seconds
>>> timer = mmcv.Timer()
>>> time.sleep(0.5)
>>> print(timer.since_start())
0.500
>>> time.sleep(0.5)
>>> print(timer.since_last_check())
0.500
>>> print(timer.since_start())
1.000
"""
def
__init__
(
self
,
start
=
True
,
print_tmpl
=
None
):
self
.
_is_running
=
False
self
.
print_tmpl
=
print_tmpl
if
print_tmpl
else
'{:.3f}'
if
start
:
self
.
start
()
@
property
def
is_running
(
self
):
"""bool: indicate whether the timer is running"""
return
self
.
_is_running
def
__enter__
(
self
):
self
.
start
()
return
self
def
__exit__
(
self
,
type
,
value
,
traceback
):
print
(
self
.
print_tmpl
.
format
(
self
.
since_last_check
()))
self
.
_is_running
=
False
def
start
(
self
):
"""Start the timer."""
if
not
self
.
_is_running
:
self
.
_t_start
=
time
()
self
.
_is_running
=
True
self
.
_t_last
=
time
()
def
since_start
(
self
):
"""Total time since the timer is started.
Returns (float): Time in seconds.
"""
if
not
self
.
_is_running
:
raise
TimerError
(
'timer is not running'
)
self
.
_t_last
=
time
()
return
self
.
_t_last
-
self
.
_t_start
def
since_last_check
(
self
):
"""Time since the last checking.
Either :func:`since_start` or :func:`since_last_check` is a checking
operation.
Returns (float): Time in seconds.
"""
if
not
self
.
_is_running
:
raise
TimerError
(
'timer is not running'
)
dur
=
time
()
-
self
.
_t_last
self
.
_t_last
=
time
()
return
dur
_g_timers
=
{}
# global timers
def
check_time
(
timer_id
):
"""Add check points in a single line.
This method is suitable for running a task on a list of items. A timer will
be registered when the method is called for the first time.
:Example:
>>> import time
>>> import annotator.uniformer.mmcv as mmcv
>>> for i in range(1, 6):
>>> # simulate a code block
>>> time.sleep(i)
>>> mmcv.check_time('task1')
2.000
3.000
4.000
5.000
Args:
timer_id (str): Timer identifier.
"""
if
timer_id
not
in
_g_timers
:
_g_timers
[
timer_id
]
=
Timer
()
return
0
else
:
return
_g_timers
[
timer_id
].
since_last_check
()
lavis/common/annotator/uniformer/mmcv/utils/trace.py
0 → 100644
View file @
c04f261a
import
warnings
import
torch
from
annotator.uniformer.mmcv.utils
import
digit_version
def
is_jit_tracing
()
->
bool
:
if
(
torch
.
__version__
!=
'parrots'
and
digit_version
(
torch
.
__version__
)
>=
digit_version
(
'1.6.0'
)):
on_trace
=
torch
.
jit
.
is_tracing
()
# In PyTorch 1.6, torch.jit.is_tracing has a bug.
# Refers to https://github.com/pytorch/pytorch/issues/42448
if
isinstance
(
on_trace
,
bool
):
return
on_trace
else
:
return
torch
.
_C
.
_is_tracing
()
else
:
warnings
.
warn
(
'torch.jit.is_tracing is only supported after v1.6.0. '
'Therefore is_tracing returns False automatically. Please '
'set on_trace manually if you are using trace.'
,
UserWarning
)
return
False
lavis/common/annotator/uniformer/mmcv/utils/version_utils.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
import
os
import
subprocess
import
warnings
from
packaging.version
import
parse
def
digit_version
(
version_str
:
str
,
length
:
int
=
4
):
"""Convert a version string into a tuple of integers.
This method is usually used for comparing two versions. For pre-release
versions: alpha < beta < rc.
Args:
version_str (str): The version string.
length (int): The maximum number of version levels. Default: 4.
Returns:
tuple[int]: The version info in digits (integers).
"""
assert
'parrots'
not
in
version_str
version
=
parse
(
version_str
)
assert
version
.
release
,
f
'failed to parse version
{
version_str
}
'
release
=
list
(
version
.
release
)
release
=
release
[:
length
]
if
len
(
release
)
<
length
:
release
=
release
+
[
0
]
*
(
length
-
len
(
release
))
if
version
.
is_prerelease
:
mapping
=
{
'a'
:
-
3
,
'b'
:
-
2
,
'rc'
:
-
1
}
val
=
-
4
# version.pre can be None
if
version
.
pre
:
if
version
.
pre
[
0
]
not
in
mapping
:
warnings
.
warn
(
f
'unknown prerelease version
{
version
.
pre
[
0
]
}
, '
'version checking may go wrong'
)
else
:
val
=
mapping
[
version
.
pre
[
0
]]
release
.
extend
([
val
,
version
.
pre
[
-
1
]])
else
:
release
.
extend
([
val
,
0
])
elif
version
.
is_postrelease
:
release
.
extend
([
1
,
version
.
post
])
else
:
release
.
extend
([
0
,
0
])
return
tuple
(
release
)
def
_minimal_ext_cmd
(
cmd
):
# construct minimal environment
env
=
{}
for
k
in
[
'SYSTEMROOT'
,
'PATH'
,
'HOME'
]:
v
=
os
.
environ
.
get
(
k
)
if
v
is
not
None
:
env
[
k
]
=
v
# LANGUAGE is used on win32
env
[
'LANGUAGE'
]
=
'C'
env
[
'LANG'
]
=
'C'
env
[
'LC_ALL'
]
=
'C'
out
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
env
=
env
).
communicate
()[
0
]
return
out
def
get_git_hash
(
fallback
=
'unknown'
,
digits
=
None
):
"""Get the git hash of the current repo.
Args:
fallback (str, optional): The fallback string when git hash is
unavailable. Defaults to 'unknown'.
digits (int, optional): kept digits of the hash. Defaults to None,
meaning all digits are kept.
Returns:
str: Git commit hash.
"""
if
digits
is
not
None
and
not
isinstance
(
digits
,
int
):
raise
TypeError
(
'digits must be None or an integer'
)
try
:
out
=
_minimal_ext_cmd
([
'git'
,
'rev-parse'
,
'HEAD'
])
sha
=
out
.
strip
().
decode
(
'ascii'
)
if
digits
is
not
None
:
sha
=
sha
[:
digits
]
except
OSError
:
sha
=
fallback
return
sha
lavis/common/annotator/uniformer/mmcv/version.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
__version__
=
'1.3.17'
def
parse_version_info
(
version_str
:
str
,
length
:
int
=
4
)
->
tuple
:
"""Parse a version string into a tuple.
Args:
version_str (str): The version string.
length (int): The maximum number of version levels. Default: 4.
Returns:
tuple[int | str]: The version info, e.g., "1.3.0" is parsed into
(1, 3, 0, 0, 0, 0), and "2.0.0rc1" is parsed into
(2, 0, 0, 0, 'rc', 1) (when length is set to 4).
"""
from
packaging.version
import
parse
version
=
parse
(
version_str
)
assert
version
.
release
,
f
'failed to parse version
{
version_str
}
'
release
=
list
(
version
.
release
)
release
=
release
[:
length
]
if
len
(
release
)
<
length
:
release
=
release
+
[
0
]
*
(
length
-
len
(
release
))
if
version
.
is_prerelease
:
release
.
extend
(
list
(
version
.
pre
))
elif
version
.
is_postrelease
:
release
.
extend
(
list
(
version
.
post
))
else
:
release
.
extend
([
0
,
0
])
return
tuple
(
release
)
version_info
=
tuple
(
int
(
x
)
for
x
in
__version__
.
split
(
'.'
)[:
3
])
__all__
=
[
'__version__'
,
'version_info'
,
'parse_version_info'
]
lavis/common/annotator/uniformer/mmcv/video/__init__.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
from
.io
import
Cache
,
VideoReader
,
frames2video
from
.optflow
import
(
dequantize_flow
,
flow_from_bytes
,
flow_warp
,
flowread
,
flowwrite
,
quantize_flow
,
sparse_flow_from_bytes
)
from
.processing
import
concat_video
,
convert_video
,
cut_video
,
resize_video
__all__
=
[
'Cache'
,
'VideoReader'
,
'frames2video'
,
'convert_video'
,
'resize_video'
,
'cut_video'
,
'concat_video'
,
'flowread'
,
'flowwrite'
,
'quantize_flow'
,
'dequantize_flow'
,
'flow_warp'
,
'flow_from_bytes'
,
'sparse_flow_from_bytes'
]
lavis/common/annotator/uniformer/mmcv/video/io.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
from
collections
import
OrderedDict
import
cv2
from
cv2
import
(
CAP_PROP_FOURCC
,
CAP_PROP_FPS
,
CAP_PROP_FRAME_COUNT
,
CAP_PROP_FRAME_HEIGHT
,
CAP_PROP_FRAME_WIDTH
,
CAP_PROP_POS_FRAMES
,
VideoWriter_fourcc
)
from
annotator.uniformer.mmcv.utils
import
(
check_file_exist
,
mkdir_or_exist
,
scandir
,
track_progress
)
class
Cache
:
def
__init__
(
self
,
capacity
):
self
.
_cache
=
OrderedDict
()
self
.
_capacity
=
int
(
capacity
)
if
capacity
<=
0
:
raise
ValueError
(
'capacity must be a positive integer'
)
@
property
def
capacity
(
self
):
return
self
.
_capacity
@
property
def
size
(
self
):
return
len
(
self
.
_cache
)
def
put
(
self
,
key
,
val
):
if
key
in
self
.
_cache
:
return
if
len
(
self
.
_cache
)
>=
self
.
capacity
:
self
.
_cache
.
popitem
(
last
=
False
)
self
.
_cache
[
key
]
=
val
def
get
(
self
,
key
,
default
=
None
):
val
=
self
.
_cache
[
key
]
if
key
in
self
.
_cache
else
default
return
val
class
VideoReader
:
"""Video class with similar usage to a list object.
This video warpper class provides convenient apis to access frames.
There exists an issue of OpenCV's VideoCapture class that jumping to a
certain frame may be inaccurate. It is fixed in this class by checking
the position after jumping each time.
Cache is used when decoding videos. So if the same frame is visited for
the second time, there is no need to decode again if it is stored in the
cache.
:Example:
>>> import annotator.uniformer.mmcv as mmcv
>>> v = mmcv.VideoReader('sample.mp4')
>>> len(v) # get the total frame number with `len()`
120
>>> for img in v: # v is iterable
>>> mmcv.imshow(img)
>>> v[5] # get the 6th frame
"""
def
__init__
(
self
,
filename
,
cache_capacity
=
10
):
# Check whether the video path is a url
if
not
filename
.
startswith
((
'https://'
,
'http://'
)):
check_file_exist
(
filename
,
'Video file not found: '
+
filename
)
self
.
_vcap
=
cv2
.
VideoCapture
(
filename
)
assert
cache_capacity
>
0
self
.
_cache
=
Cache
(
cache_capacity
)
self
.
_position
=
0
# get basic info
self
.
_width
=
int
(
self
.
_vcap
.
get
(
CAP_PROP_FRAME_WIDTH
))
self
.
_height
=
int
(
self
.
_vcap
.
get
(
CAP_PROP_FRAME_HEIGHT
))
self
.
_fps
=
self
.
_vcap
.
get
(
CAP_PROP_FPS
)
self
.
_frame_cnt
=
int
(
self
.
_vcap
.
get
(
CAP_PROP_FRAME_COUNT
))
self
.
_fourcc
=
self
.
_vcap
.
get
(
CAP_PROP_FOURCC
)
@
property
def
vcap
(
self
):
""":obj:`cv2.VideoCapture`: The raw VideoCapture object."""
return
self
.
_vcap
@
property
def
opened
(
self
):
"""bool: Indicate whether the video is opened."""
return
self
.
_vcap
.
isOpened
()
@
property
def
width
(
self
):
"""int: Width of video frames."""
return
self
.
_width
@
property
def
height
(
self
):
"""int: Height of video frames."""
return
self
.
_height
@
property
def
resolution
(
self
):
"""tuple: Video resolution (width, height)."""
return
(
self
.
_width
,
self
.
_height
)
@
property
def
fps
(
self
):
"""float: FPS of the video."""
return
self
.
_fps
@
property
def
frame_cnt
(
self
):
"""int: Total frames of the video."""
return
self
.
_frame_cnt
@
property
def
fourcc
(
self
):
"""str: "Four character code" of the video."""
return
self
.
_fourcc
@
property
def
position
(
self
):
"""int: Current cursor position, indicating frame decoded."""
return
self
.
_position
def
_get_real_position
(
self
):
return
int
(
round
(
self
.
_vcap
.
get
(
CAP_PROP_POS_FRAMES
)))
def
_set_real_position
(
self
,
frame_id
):
self
.
_vcap
.
set
(
CAP_PROP_POS_FRAMES
,
frame_id
)
pos
=
self
.
_get_real_position
()
for
_
in
range
(
frame_id
-
pos
):
self
.
_vcap
.
read
()
self
.
_position
=
frame_id
def
read
(
self
):
"""Read the next frame.
If the next frame have been decoded before and in the cache, then
return it directly, otherwise decode, cache and return it.
Returns:
ndarray or None: Return the frame if successful, otherwise None.
"""
# pos = self._position
if
self
.
_cache
:
img
=
self
.
_cache
.
get
(
self
.
_position
)
if
img
is
not
None
:
ret
=
True
else
:
if
self
.
_position
!=
self
.
_get_real_position
():
self
.
_set_real_position
(
self
.
_position
)
ret
,
img
=
self
.
_vcap
.
read
()
if
ret
:
self
.
_cache
.
put
(
self
.
_position
,
img
)
else
:
ret
,
img
=
self
.
_vcap
.
read
()
if
ret
:
self
.
_position
+=
1
return
img
def
get_frame
(
self
,
frame_id
):
"""Get frame by index.
Args:
frame_id (int): Index of the expected frame, 0-based.
Returns:
ndarray or None: Return the frame if successful, otherwise None.
"""
if
frame_id
<
0
or
frame_id
>=
self
.
_frame_cnt
:
raise
IndexError
(
f
'"frame_id" must be between 0 and
{
self
.
_frame_cnt
-
1
}
'
)
if
frame_id
==
self
.
_position
:
return
self
.
read
()
if
self
.
_cache
:
img
=
self
.
_cache
.
get
(
frame_id
)
if
img
is
not
None
:
self
.
_position
=
frame_id
+
1
return
img
self
.
_set_real_position
(
frame_id
)
ret
,
img
=
self
.
_vcap
.
read
()
if
ret
:
if
self
.
_cache
:
self
.
_cache
.
put
(
self
.
_position
,
img
)
self
.
_position
+=
1
return
img
def
current_frame
(
self
):
"""Get the current frame (frame that is just visited).
Returns:
ndarray or None: If the video is fresh, return None, otherwise
return the frame.
"""
if
self
.
_position
==
0
:
return
None
return
self
.
_cache
.
get
(
self
.
_position
-
1
)
def
cvt2frames
(
self
,
frame_dir
,
file_start
=
0
,
filename_tmpl
=
'{:06d}.jpg'
,
start
=
0
,
max_num
=
0
,
show_progress
=
True
):
"""Convert a video to frame images.
Args:
frame_dir (str): Output directory to store all the frame images.
file_start (int): Filenames will start from the specified number.
filename_tmpl (str): Filename template with the index as the
placeholder.
start (int): The starting frame index.
max_num (int): Maximum number of frames to be written.
show_progress (bool): Whether to show a progress bar.
"""
mkdir_or_exist
(
frame_dir
)
if
max_num
==
0
:
task_num
=
self
.
frame_cnt
-
start
else
:
task_num
=
min
(
self
.
frame_cnt
-
start
,
max_num
)
if
task_num
<=
0
:
raise
ValueError
(
'start must be less than total frame number'
)
if
start
>
0
:
self
.
_set_real_position
(
start
)
def
write_frame
(
file_idx
):
img
=
self
.
read
()
if
img
is
None
:
return
filename
=
osp
.
join
(
frame_dir
,
filename_tmpl
.
format
(
file_idx
))
cv2
.
imwrite
(
filename
,
img
)
if
show_progress
:
track_progress
(
write_frame
,
range
(
file_start
,
file_start
+
task_num
))
else
:
for
i
in
range
(
task_num
):
write_frame
(
file_start
+
i
)
def
__len__
(
self
):
return
self
.
frame_cnt
def
__getitem__
(
self
,
index
):
if
isinstance
(
index
,
slice
):
return
[
self
.
get_frame
(
i
)
for
i
in
range
(
*
index
.
indices
(
self
.
frame_cnt
))
]
# support negative indexing
if
index
<
0
:
index
+=
self
.
frame_cnt
if
index
<
0
:
raise
IndexError
(
'index out of range'
)
return
self
.
get_frame
(
index
)
def
__iter__
(
self
):
self
.
_set_real_position
(
0
)
return
self
def
__next__
(
self
):
img
=
self
.
read
()
if
img
is
not
None
:
return
img
else
:
raise
StopIteration
next
=
__next__
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
self
.
_vcap
.
release
()
def
frames2video
(
frame_dir
,
video_file
,
fps
=
30
,
fourcc
=
'XVID'
,
filename_tmpl
=
'{:06d}.jpg'
,
start
=
0
,
end
=
0
,
show_progress
=
True
):
"""Read the frame images from a directory and join them as a video.
Args:
frame_dir (str): The directory containing video frames.
video_file (str): Output filename.
fps (float): FPS of the output video.
fourcc (str): Fourcc of the output video, this should be compatible
with the output file type.
filename_tmpl (str): Filename template with the index as the variable.
start (int): Starting frame index.
end (int): Ending frame index.
show_progress (bool): Whether to show a progress bar.
"""
if
end
==
0
:
ext
=
filename_tmpl
.
split
(
'.'
)[
-
1
]
end
=
len
([
name
for
name
in
scandir
(
frame_dir
,
ext
)])
first_file
=
osp
.
join
(
frame_dir
,
filename_tmpl
.
format
(
start
))
check_file_exist
(
first_file
,
'The start frame not found: '
+
first_file
)
img
=
cv2
.
imread
(
first_file
)
height
,
width
=
img
.
shape
[:
2
]
resolution
=
(
width
,
height
)
vwriter
=
cv2
.
VideoWriter
(
video_file
,
VideoWriter_fourcc
(
*
fourcc
),
fps
,
resolution
)
def
write_frame
(
file_idx
):
filename
=
osp
.
join
(
frame_dir
,
filename_tmpl
.
format
(
file_idx
))
img
=
cv2
.
imread
(
filename
)
vwriter
.
write
(
img
)
if
show_progress
:
track_progress
(
write_frame
,
range
(
start
,
end
))
else
:
for
i
in
range
(
start
,
end
):
write_frame
(
i
)
vwriter
.
release
()
lavis/common/annotator/uniformer/mmcv/video/optflow.py
0 → 100644
View file @
c04f261a
# Copyright (c) OpenMMLab. All rights reserved.
import
warnings
import
cv2
import
numpy
as
np
from
annotator.uniformer.mmcv.arraymisc
import
dequantize
,
quantize
from
annotator.uniformer.mmcv.image
import
imread
,
imwrite
from
annotator.uniformer.mmcv.utils
import
is_str
def
flowread
(
flow_or_path
,
quantize
=
False
,
concat_axis
=
0
,
*
args
,
**
kwargs
):
"""Read an optical flow map.
Args:
flow_or_path (ndarray or str): A flow map or filepath.
quantize (bool): whether to read quantized pair, if set to True,
remaining args will be passed to :func:`dequantize_flow`.
concat_axis (int): The axis that dx and dy are concatenated,
can be either 0 or 1. Ignored if quantize is False.
Returns:
ndarray: Optical flow represented as a (h, w, 2) numpy array
"""
if
isinstance
(
flow_or_path
,
np
.
ndarray
):
if
(
flow_or_path
.
ndim
!=
3
)
or
(
flow_or_path
.
shape
[
-
1
]
!=
2
):
raise
ValueError
(
f
'Invalid flow with shape
{
flow_or_path
.
shape
}
'
)
return
flow_or_path
elif
not
is_str
(
flow_or_path
):
raise
TypeError
(
f
'"flow_or_path" must be a filename or numpy array, '
f
'not
{
type
(
flow_or_path
)
}
'
)
if
not
quantize
:
with
open
(
flow_or_path
,
'rb'
)
as
f
:
try
:
header
=
f
.
read
(
4
).
decode
(
'utf-8'
)
except
Exception
:
raise
IOError
(
f
'Invalid flow file:
{
flow_or_path
}
'
)
else
:
if
header
!=
'PIEH'
:
raise
IOError
(
f
'Invalid flow file:
{
flow_or_path
}
, '
'header does not contain PIEH'
)
w
=
np
.
fromfile
(
f
,
np
.
int32
,
1
).
squeeze
()
h
=
np
.
fromfile
(
f
,
np
.
int32
,
1
).
squeeze
()
flow
=
np
.
fromfile
(
f
,
np
.
float32
,
w
*
h
*
2
).
reshape
((
h
,
w
,
2
))
else
:
assert
concat_axis
in
[
0
,
1
]
cat_flow
=
imread
(
flow_or_path
,
flag
=
'unchanged'
)
if
cat_flow
.
ndim
!=
2
:
raise
IOError
(
f
'
{
flow_or_path
}
is not a valid quantized flow file, '
f
'its dimension is
{
cat_flow
.
ndim
}
.'
)
assert
cat_flow
.
shape
[
concat_axis
]
%
2
==
0
dx
,
dy
=
np
.
split
(
cat_flow
,
2
,
axis
=
concat_axis
)
flow
=
dequantize_flow
(
dx
,
dy
,
*
args
,
**
kwargs
)
return
flow
.
astype
(
np
.
float32
)
def
flowwrite
(
flow
,
filename
,
quantize
=
False
,
concat_axis
=
0
,
*
args
,
**
kwargs
):
"""Write optical flow to file.
If the flow is not quantized, it will be saved as a .flo file losslessly,
otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
will be concatenated horizontally into a single image if quantize is True.)
Args:
flow (ndarray): (h, w, 2) array of optical flow.
filename (str): Output filepath.
quantize (bool): Whether to quantize the flow and save it to 2 jpeg
images. If set to True, remaining args will be passed to
:func:`quantize_flow`.
concat_axis (int): The axis that dx and dy are concatenated,
can be either 0 or 1. Ignored if quantize is False.
"""
if
not
quantize
:
with
open
(
filename
,
'wb'
)
as
f
:
f
.
write
(
'PIEH'
.
encode
(
'utf-8'
))
np
.
array
([
flow
.
shape
[
1
],
flow
.
shape
[
0
]],
dtype
=
np
.
int32
).
tofile
(
f
)
flow
=
flow
.
astype
(
np
.
float32
)
flow
.
tofile
(
f
)
f
.
flush
()
else
:
assert
concat_axis
in
[
0
,
1
]
dx
,
dy
=
quantize_flow
(
flow
,
*
args
,
**
kwargs
)
dxdy
=
np
.
concatenate
((
dx
,
dy
),
axis
=
concat_axis
)
imwrite
(
dxdy
,
filename
)
def
quantize_flow
(
flow
,
max_val
=
0.02
,
norm
=
True
):
"""Quantize flow to [0, 255].
After this step, the size of flow will be much smaller, and can be
dumped as jpeg images.
Args:
flow (ndarray): (h, w, 2) array of optical flow.
max_val (float): Maximum value of flow, values beyond
[-max_val, max_val] will be truncated.
norm (bool): Whether to divide flow values by image width/height.
Returns:
tuple[ndarray]: Quantized dx and dy.
"""
h
,
w
,
_
=
flow
.
shape
dx
=
flow
[...,
0
]
dy
=
flow
[...,
1
]
if
norm
:
dx
=
dx
/
w
# avoid inplace operations
dy
=
dy
/
h
# use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
flow_comps
=
[
quantize
(
d
,
-
max_val
,
max_val
,
255
,
np
.
uint8
)
for
d
in
[
dx
,
dy
]
]
return
tuple
(
flow_comps
)
def
dequantize_flow
(
dx
,
dy
,
max_val
=
0.02
,
denorm
=
True
):
"""Recover from quantized flow.
Args:
dx (ndarray): Quantized dx.
dy (ndarray): Quantized dy.
max_val (float): Maximum value used when quantizing.
denorm (bool): Whether to multiply flow values with width/height.
Returns:
ndarray: Dequantized flow.
"""
assert
dx
.
shape
==
dy
.
shape
assert
dx
.
ndim
==
2
or
(
dx
.
ndim
==
3
and
dx
.
shape
[
-
1
]
==
1
)
dx
,
dy
=
[
dequantize
(
d
,
-
max_val
,
max_val
,
255
)
for
d
in
[
dx
,
dy
]]
if
denorm
:
dx
*=
dx
.
shape
[
1
]
dy
*=
dx
.
shape
[
0
]
flow
=
np
.
dstack
((
dx
,
dy
))
return
flow
def
flow_warp
(
img
,
flow
,
filling_value
=
0
,
interpolate_mode
=
'nearest'
):
"""Use flow to warp img.
Args:
img (ndarray, float or uint8): Image to be warped.
flow (ndarray, float): Optical Flow.
filling_value (int): The missing pixels will be set with filling_value.
interpolate_mode (str): bilinear -> Bilinear Interpolation;
nearest -> Nearest Neighbor.
Returns:
ndarray: Warped image with the same shape of img
"""
warnings
.
warn
(
'This function is just for prototyping and cannot '
'guarantee the computational efficiency.'
)
assert
flow
.
ndim
==
3
,
'Flow must be in 3D arrays.'
height
=
flow
.
shape
[
0
]
width
=
flow
.
shape
[
1
]
channels
=
img
.
shape
[
2
]
output
=
np
.
ones
(
(
height
,
width
,
channels
),
dtype
=
img
.
dtype
)
*
filling_value
grid
=
np
.
indices
((
height
,
width
)).
swapaxes
(
0
,
1
).
swapaxes
(
1
,
2
)
dx
=
grid
[:,
:,
0
]
+
flow
[:,
:,
1
]
dy
=
grid
[:,
:,
1
]
+
flow
[:,
:,
0
]
sx
=
np
.
floor
(
dx
).
astype
(
int
)
sy
=
np
.
floor
(
dy
).
astype
(
int
)
valid
=
(
sx
>=
0
)
&
(
sx
<
height
-
1
)
&
(
sy
>=
0
)
&
(
sy
<
width
-
1
)
if
interpolate_mode
==
'nearest'
:
output
[
valid
,
:]
=
img
[
dx
[
valid
].
round
().
astype
(
int
),
dy
[
valid
].
round
().
astype
(
int
),
:]
elif
interpolate_mode
==
'bilinear'
:
# dirty walkround for integer positions
eps_
=
1e-6
dx
,
dy
=
dx
+
eps_
,
dy
+
eps_
left_top_
=
img
[
np
.
floor
(
dx
[
valid
]).
astype
(
int
),
np
.
floor
(
dy
[
valid
]).
astype
(
int
),
:]
*
(
np
.
ceil
(
dx
[
valid
])
-
dx
[
valid
])[:,
None
]
*
(
np
.
ceil
(
dy
[
valid
])
-
dy
[
valid
])[:,
None
]
left_down_
=
img
[
np
.
ceil
(
dx
[
valid
]).
astype
(
int
),
np
.
floor
(
dy
[
valid
]).
astype
(
int
),
:]
*
(
dx
[
valid
]
-
np
.
floor
(
dx
[
valid
]))[:,
None
]
*
(
np
.
ceil
(
dy
[
valid
])
-
dy
[
valid
])[:,
None
]
right_top_
=
img
[
np
.
floor
(
dx
[
valid
]).
astype
(
int
),
np
.
ceil
(
dy
[
valid
]).
astype
(
int
),
:]
*
(
np
.
ceil
(
dx
[
valid
])
-
dx
[
valid
])[:,
None
]
*
(
dy
[
valid
]
-
np
.
floor
(
dy
[
valid
]))[:,
None
]
right_down_
=
img
[
np
.
ceil
(
dx
[
valid
]).
astype
(
int
),
np
.
ceil
(
dy
[
valid
]).
astype
(
int
),
:]
*
(
dx
[
valid
]
-
np
.
floor
(
dx
[
valid
]))[:,
None
]
*
(
dy
[
valid
]
-
np
.
floor
(
dy
[
valid
]))[:,
None
]
output
[
valid
,
:]
=
left_top_
+
left_down_
+
right_top_
+
right_down_
else
:
raise
NotImplementedError
(
'We only support interpolation modes of nearest and bilinear, '
f
'but got
{
interpolate_mode
}
.'
)
return
output
.
astype
(
img
.
dtype
)
def
flow_from_bytes
(
content
):
"""Read dense optical flow from bytes.
.. note::
This load optical flow function works for FlyingChairs, FlyingThings3D,
Sintel, FlyingChairsOcc datasets, but cannot load the data from
ChairsSDHom.
Args:
content (bytes): Optical flow bytes got from files or other streams.
Returns:
ndarray: Loaded optical flow with the shape (H, W, 2).
"""
# header in first 4 bytes
header
=
content
[:
4
]
if
header
.
decode
(
'utf-8'
)
!=
'PIEH'
:
raise
Exception
(
'Flow file header does not contain PIEH'
)
# width in second 4 bytes
width
=
np
.
frombuffer
(
content
[
4
:],
np
.
int32
,
1
).
squeeze
()
# height in third 4 bytes
height
=
np
.
frombuffer
(
content
[
8
:],
np
.
int32
,
1
).
squeeze
()
# after first 12 bytes, all bytes are flow
flow
=
np
.
frombuffer
(
content
[
12
:],
np
.
float32
,
width
*
height
*
2
).
reshape
(
(
height
,
width
,
2
))
return
flow
def
sparse_flow_from_bytes
(
content
):
"""Read the optical flow in KITTI datasets from bytes.
This function is modified from RAFT load the `KITTI datasets
<https://github.com/princeton-vl/RAFT/blob/224320502d66c356d88e6c712f38129e60661e80/core/utils/frame_utils.py#L102>`_.
Args:
content (bytes): Optical flow bytes got from files or other streams.
Returns:
Tuple(ndarray, ndarray): Loaded optical flow with the shape (H, W, 2)
and flow valid mask with the shape (H, W).
"""
# nopa
content
=
np
.
frombuffer
(
content
,
np
.
uint8
)
flow
=
cv2
.
imdecode
(
content
,
cv2
.
IMREAD_ANYDEPTH
|
cv2
.
IMREAD_COLOR
)
flow
=
flow
[:,
:,
::
-
1
].
astype
(
np
.
float32
)
# flow shape (H, W, 2) valid shape (H, W)
flow
,
valid
=
flow
[:,
:,
:
2
],
flow
[:,
:,
2
]
flow
=
(
flow
-
2
**
15
)
/
64.0
return
flow
,
valid
Prev
1
…
13
14
15
16
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment