Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
0fe1c647
Unverified
Commit
0fe1c647
authored
Aug 16, 2022
by
Zaida Zhou
Committed by
GitHub
Aug 16, 2022
Browse files
Remove fileio from mmcv and use mmengine.fileio instead (#2179)
parent
0b4285d9
Changes
36
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
37 additions
and
1107 deletions
+37
-1107
mmcv/runner/hooks/evaluation.py
mmcv/runner/hooks/evaluation.py
+3
-3
mmcv/runner/hooks/logger/pavi.py
mmcv/runner/hooks/logger/pavi.py
+3
-2
mmcv/runner/hooks/logger/text.py
mmcv/runner/hooks/logger/text.py
+4
-4
mmcv/transforms/loading.py
mmcv/transforms/loading.py
+5
-4
mmcv/utils/config.py
mmcv/utils/config.py
+4
-5
tests/test_fileclient.py
tests/test_fileclient.py
+0
-862
tests/test_fileio.py
tests/test_fileio.py
+0
-211
tests/test_image/test_io.py
tests/test_image/test_io.py
+3
-2
tests/test_load_model_zoo.py
tests/test_load_model_zoo.py
+2
-1
tests/test_ops/test_nms.py
tests/test_ops/test_nms.py
+2
-2
tests/test_ops/test_tensorrt.py
tests/test_ops/test_tensorrt.py
+3
-4
tests/test_runner/test_basemodule.py
tests/test_runner/test_basemodule.py
+3
-3
tests/test_runner/test_checkpoint.py
tests/test_runner/test_checkpoint.py
+1
-1
tests/test_runner/test_eval_hook.py
tests/test_runner/test_eval_hook.py
+1
-1
tests/test_runner/test_hooks.py
tests/test_runner/test_hooks.py
+1
-1
tests/test_utils/test_config.py
tests/test_utils/test_config.py
+2
-1
No files found.
mmcv/runner/hooks/evaluation.py
View file @
0fe1c647
...
@@ -5,10 +5,10 @@ from math import inf
...
@@ -5,10 +5,10 @@ from math import inf
from
typing
import
Callable
,
List
,
Optional
from
typing
import
Callable
,
List
,
Optional
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
mmengine.fileio
import
FileClient
from
torch.nn.modules.batchnorm
import
_BatchNorm
from
torch.nn.modules.batchnorm
import
_BatchNorm
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
mmcv.fileio
import
FileClient
from
mmcv.utils
import
is_seq_of
from
mmcv.utils
import
is_seq_of
from
.hook
import
Hook
from
.hook
import
Hook
from
.logger
import
LoggerHook
from
.logger
import
LoggerHook
...
@@ -61,7 +61,7 @@ class EvalHook(Hook):
...
@@ -61,7 +61,7 @@ class EvalHook(Hook):
level directory of `runner.work_dir`.
level directory of `runner.work_dir`.
`New in version 1.3.16.`
`New in version 1.3.16.`
file_client_args (dict): Arguments to instantiate a FileClient.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mm
cv
.fileio.FileClient` for details. Default: None.
See :class:`mm
engine
.fileio.FileClient` for details. Default: None.
`New in version 1.3.16.`
`New in version 1.3.16.`
**eval_kwargs: Evaluation arguments fed into the evaluate function of
**eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset.
the dataset.
...
@@ -437,7 +437,7 @@ class DistEvalHook(EvalHook):
...
@@ -437,7 +437,7 @@ class DistEvalHook(EvalHook):
the `out_dir` will be the concatenation of `out_dir` and the last
the `out_dir` will be the concatenation of `out_dir` and the last
level directory of `runner.work_dir`.
level directory of `runner.work_dir`.
file_client_args (dict): Arguments to instantiate a FileClient.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mm
cv
.fileio.FileClient` for details. Default: None.
See :class:`mm
engine
.fileio.FileClient` for details. Default: None.
**eval_kwargs: Evaluation arguments fed into the evaluate function of
**eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset.
the dataset.
"""
"""
...
...
mmcv/runner/hooks/logger/pavi.py
View file @
0fe1c647
...
@@ -4,6 +4,7 @@ import os
...
@@ -4,6 +4,7 @@ import os
import
os.path
as
osp
import
os.path
as
osp
from
typing
import
Dict
,
Optional
from
typing
import
Dict
,
Optional
import
mmengine
import
torch
import
torch
import
yaml
import
yaml
...
@@ -96,9 +97,9 @@ class PaviLoggerHook(LoggerHook):
...
@@ -96,9 +97,9 @@ class PaviLoggerHook(LoggerHook):
config_dict
=
config_dict
.
copy
()
config_dict
=
config_dict
.
copy
()
config_dict
.
setdefault
(
'max_iter'
,
runner
.
max_iters
)
config_dict
.
setdefault
(
'max_iter'
,
runner
.
max_iters
)
# non-serializable values are first converted in
# non-serializable values are first converted in
# mm
cv
.dump to json
# mm
engine
.dump to json
config_dict
=
json
.
loads
(
config_dict
=
json
.
loads
(
mm
cv
.
dump
(
config_dict
,
file_format
=
'json'
))
mm
engine
.
dump
(
config_dict
,
file_format
=
'json'
))
session_text
=
yaml
.
dump
(
config_dict
)
session_text
=
yaml
.
dump
(
config_dict
)
self
.
init_kwargs
.
setdefault
(
'session_text'
,
session_text
)
self
.
init_kwargs
.
setdefault
(
'session_text'
,
session_text
)
self
.
writer
=
SummaryWriter
(
**
self
.
init_kwargs
)
self
.
writer
=
SummaryWriter
(
**
self
.
init_kwargs
)
...
...
mmcv/runner/hooks/logger/text.py
View file @
0fe1c647
...
@@ -5,11 +5,11 @@ import os.path as osp
...
@@ -5,11 +5,11 @@ import os.path as osp
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Dict
,
Optional
,
Union
from
typing
import
Dict
,
Optional
,
Union
import
mmengine
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
mmengine.fileio.file_client
import
FileClient
import
mmcv
from
mmcv.fileio.file_client
import
FileClient
from
mmcv.utils
import
is_tuple_of
,
scandir
from
mmcv.utils
import
is_tuple_of
,
scandir
from
..hook
import
HOOKS
from
..hook
import
HOOKS
from
.base
import
LoggerHook
from
.base
import
LoggerHook
...
@@ -48,7 +48,7 @@ class TextLoggerHook(LoggerHook):
...
@@ -48,7 +48,7 @@ class TextLoggerHook(LoggerHook):
removed. Default: True.
removed. Default: True.
`New in version 1.3.16.`
`New in version 1.3.16.`
file_client_args (dict, optional): Arguments to instantiate a
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mm
cv
.fileio.FileClient` for details.
FileClient. See :class:`mm
engine
.fileio.FileClient` for details.
Default: None.
Default: None.
`New in version 1.3.16.`
`New in version 1.3.16.`
"""
"""
...
@@ -190,7 +190,7 @@ class TextLoggerHook(LoggerHook):
...
@@ -190,7 +190,7 @@ class TextLoggerHook(LoggerHook):
# only append log at last line
# only append log at last line
if
runner
.
rank
==
0
:
if
runner
.
rank
==
0
:
with
open
(
self
.
json_log_path
,
'a+'
)
as
f
:
with
open
(
self
.
json_log_path
,
'a+'
)
as
f
:
mm
cv
.
dump
(
json_log
,
f
,
file_format
=
'json'
)
mm
engine
.
dump
(
json_log
,
f
,
file_format
=
'json'
)
f
.
write
(
'
\n
'
)
f
.
write
(
'
\n
'
)
def
_round_float
(
self
,
items
):
def
_round_float
(
self
,
items
):
...
...
mmcv/transforms/loading.py
View file @
0fe1c647
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Optional
from
typing
import
Optional
import
mmengine
import
numpy
as
np
import
numpy
as
np
import
mmcv
import
mmcv
...
@@ -33,7 +34,7 @@ class LoadImageFromFile(BaseTransform):
...
@@ -33,7 +34,7 @@ class LoadImageFromFile(BaseTransform):
See :func:``mmcv.imfrombytes`` for details.
See :func:``mmcv.imfrombytes`` for details.
Defaults to 'cv2'.
Defaults to 'cv2'.
file_client_args (dict): Arguments to instantiate a FileClient.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mm
cv
.fileio.FileClient` for details.
See :class:`mm
engine
.fileio.FileClient` for details.
Defaults to ``dict(backend='disk')``.
Defaults to ``dict(backend='disk')``.
ignore_empty (bool): Whether to allow loading empty image or file path
ignore_empty (bool): Whether to allow loading empty image or file path
not existent. Defaults to False.
not existent. Defaults to False.
...
@@ -50,7 +51,7 @@ class LoadImageFromFile(BaseTransform):
...
@@ -50,7 +51,7 @@ class LoadImageFromFile(BaseTransform):
self
.
color_type
=
color_type
self
.
color_type
=
color_type
self
.
imdecode_backend
=
imdecode_backend
self
.
imdecode_backend
=
imdecode_backend
self
.
file_client_args
=
file_client_args
.
copy
()
self
.
file_client_args
=
file_client_args
.
copy
()
self
.
file_client
=
mm
cv
.
FileClient
(
**
self
.
file_client_args
)
self
.
file_client
=
mm
engine
.
FileClient
(
**
self
.
file_client_args
)
def
transform
(
self
,
results
:
dict
)
->
Optional
[
dict
]:
def
transform
(
self
,
results
:
dict
)
->
Optional
[
dict
]:
"""Functions to load image.
"""Functions to load image.
...
@@ -168,7 +169,7 @@ class LoadAnnotations(BaseTransform):
...
@@ -168,7 +169,7 @@ class LoadAnnotations(BaseTransform):
See :fun:``mmcv.imfrombytes`` for details.
See :fun:``mmcv.imfrombytes`` for details.
Defaults to 'cv2'.
Defaults to 'cv2'.
file_client_args (dict): Arguments to instantiate a FileClient.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:``mm
cv
.fileio.FileClient`` for details.
See :class:``mm
engine
.fileio.FileClient`` for details.
Defaults to ``dict(backend='disk')``.
Defaults to ``dict(backend='disk')``.
"""
"""
...
@@ -188,7 +189,7 @@ class LoadAnnotations(BaseTransform):
...
@@ -188,7 +189,7 @@ class LoadAnnotations(BaseTransform):
self
.
with_keypoints
=
with_keypoints
self
.
with_keypoints
=
with_keypoints
self
.
imdecode_backend
=
imdecode_backend
self
.
imdecode_backend
=
imdecode_backend
self
.
file_client_args
=
file_client_args
.
copy
()
self
.
file_client_args
=
file_client_args
.
copy
()
self
.
file_client
=
mm
cv
.
FileClient
(
**
self
.
file_client_args
)
self
.
file_client
=
mm
engine
.
FileClient
(
**
self
.
file_client_args
)
def
_load_bboxes
(
self
,
results
:
dict
)
->
None
:
def
_load_bboxes
(
self
,
results
:
dict
)
->
None
:
"""Private function to load bounding box annotations.
"""Private function to load bounding box annotations.
...
...
mmcv/utils/config.py
View file @
0fe1c647
...
@@ -15,6 +15,7 @@ from collections import abc
...
@@ -15,6 +15,7 @@ from collections import abc
from
importlib
import
import_module
from
importlib
import
import_module
from
pathlib
import
Path
from
pathlib
import
Path
import
mmengine
from
addict
import
Dict
from
addict
import
Dict
from
yapf.yapflib.yapf_api
import
FormatCode
from
yapf.yapflib.yapf_api
import
FormatCode
...
@@ -217,8 +218,7 @@ class Config:
...
@@ -217,8 +218,7 @@ class Config:
# delete imported module
# delete imported module
del
sys
.
modules
[
temp_module_name
]
del
sys
.
modules
[
temp_module_name
]
elif
filename
.
endswith
((
'.yml'
,
'.yaml'
,
'.json'
)):
elif
filename
.
endswith
((
'.yml'
,
'.yaml'
,
'.json'
)):
import
mmcv
cfg_dict
=
mmengine
.
load
(
temp_config_file
.
name
)
cfg_dict
=
mmcv
.
load
(
temp_config_file
.
name
)
# close temp file
# close temp file
temp_config_file
.
close
()
temp_config_file
.
close
()
...
@@ -583,20 +583,19 @@ class Config:
...
@@ -583,20 +583,19 @@ class Config:
file (str, optional): Path of the output file where the config
file (str, optional): Path of the output file where the config
will be dumped. Defaults to None.
will be dumped. Defaults to None.
"""
"""
import
mmcv
cfg_dict
=
super
().
__getattribute__
(
'_cfg_dict'
).
to_dict
()
cfg_dict
=
super
().
__getattribute__
(
'_cfg_dict'
).
to_dict
()
if
file
is
None
:
if
file
is
None
:
if
self
.
filename
is
None
or
self
.
filename
.
endswith
(
'.py'
):
if
self
.
filename
is
None
or
self
.
filename
.
endswith
(
'.py'
):
return
self
.
pretty_text
return
self
.
pretty_text
else
:
else
:
file_format
=
self
.
filename
.
split
(
'.'
)[
-
1
]
file_format
=
self
.
filename
.
split
(
'.'
)[
-
1
]
return
mm
cv
.
dump
(
cfg_dict
,
file_format
=
file_format
)
return
mm
engine
.
dump
(
cfg_dict
,
file_format
=
file_format
)
elif
file
.
endswith
(
'.py'
):
elif
file
.
endswith
(
'.py'
):
with
open
(
file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
self
.
pretty_text
)
f
.
write
(
self
.
pretty_text
)
else
:
else
:
file_format
=
file
.
split
(
'.'
)[
-
1
]
file_format
=
file
.
split
(
'.'
)[
-
1
]
return
mm
cv
.
dump
(
cfg_dict
,
file
=
file
,
file_format
=
file_format
)
return
mm
engine
.
dump
(
cfg_dict
,
file
=
file
,
file_format
=
file_format
)
def
merge_from_dict
(
self
,
options
,
allow_list_keys
=
True
):
def
merge_from_dict
(
self
,
options
,
allow_list_keys
=
True
):
"""Merge list into cfg_dict.
"""Merge list into cfg_dict.
...
...
tests/test_fileclient.py
deleted
100644 → 0
View file @
0b4285d9
# Copyright (c) OpenMMLab. All rights reserved.
import
os
import
os.path
as
osp
import
sys
import
tempfile
from
contextlib
import
contextmanager
from
copy
import
deepcopy
from
pathlib
import
Path
from
unittest.mock
import
MagicMock
,
patch
import
pytest
import
mmcv
from
mmcv
import
BaseStorageBackend
,
FileClient
from
mmcv.utils
import
has_method
sys
.
modules
[
'ceph'
]
=
MagicMock
()
sys
.
modules
[
'petrel_client'
]
=
MagicMock
()
sys
.
modules
[
'petrel_client.client'
]
=
MagicMock
()
sys
.
modules
[
'mc'
]
=
MagicMock
()
@
contextmanager
def
build_temporary_directory
():
"""Build a temporary directory containing many files to test
``FileClient.list_dir_or_file``.
.
\n
| -- dir1
\n
| -- | -- text3.txt
\n
| -- dir2
\n
| -- | -- dir3
\n
| -- | -- | -- text4.txt
\n
| -- | -- img.jpg
\n
| -- text1.txt
\n
| -- text2.txt
\n
"""
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
text1
=
Path
(
tmp_dir
)
/
'text1.txt'
text1
.
open
(
'w'
).
write
(
'text1'
)
text2
=
Path
(
tmp_dir
)
/
'text2.txt'
text2
.
open
(
'w'
).
write
(
'text2'
)
dir1
=
Path
(
tmp_dir
)
/
'dir1'
dir1
.
mkdir
()
text3
=
dir1
/
'text3.txt'
text3
.
open
(
'w'
).
write
(
'text3'
)
dir2
=
Path
(
tmp_dir
)
/
'dir2'
dir2
.
mkdir
()
jpg1
=
dir2
/
'img.jpg'
jpg1
.
open
(
'wb'
).
write
(
b
'img'
)
dir3
=
dir2
/
'dir3'
dir3
.
mkdir
()
text4
=
dir3
/
'text4.txt'
text4
.
open
(
'w'
).
write
(
'text4'
)
yield
tmp_dir
@
contextmanager
def
delete_and_reset_method
(
obj
,
method
):
method_obj
=
deepcopy
(
getattr
(
type
(
obj
),
method
))
try
:
delattr
(
type
(
obj
),
method
)
yield
finally
:
setattr
(
type
(
obj
),
method
,
method_obj
)
class
MockS3Client
:
def
__init__
(
self
,
enable_mc
=
True
):
self
.
enable_mc
=
enable_mc
def
Get
(
self
,
filepath
):
with
open
(
filepath
,
'rb'
)
as
f
:
content
=
f
.
read
()
return
content
class
MockPetrelClient
:
def
__init__
(
self
,
enable_mc
=
True
,
enable_multi_cluster
=
False
):
self
.
enable_mc
=
enable_mc
self
.
enable_multi_cluster
=
enable_multi_cluster
def
Get
(
self
,
filepath
):
with
open
(
filepath
,
'rb'
)
as
f
:
content
=
f
.
read
()
return
content
def
put
(
self
):
pass
def
delete
(
self
):
pass
def
contains
(
self
):
pass
def
isdir
(
self
):
pass
def
list
(
self
,
dir_path
):
for
entry
in
os
.
scandir
(
dir_path
):
if
not
entry
.
name
.
startswith
(
'.'
)
and
entry
.
is_file
():
yield
entry
.
name
elif
osp
.
isdir
(
entry
.
path
):
yield
entry
.
name
+
'/'
class
MockMemcachedClient
:
def
__init__
(
self
,
server_list_cfg
,
client_cfg
):
pass
def
Get
(
self
,
filepath
,
buffer
):
with
open
(
filepath
,
'rb'
)
as
f
:
buffer
.
content
=
f
.
read
()
class
TestFileClient
:
@
classmethod
def
setup_class
(
cls
):
cls
.
test_data_dir
=
Path
(
__file__
).
parent
/
'data'
cls
.
img_path
=
cls
.
test_data_dir
/
'color.jpg'
cls
.
img_shape
=
(
300
,
400
,
3
)
cls
.
text_path
=
cls
.
test_data_dir
/
'filelist.txt'
def
test_error
(
self
):
with
pytest
.
raises
(
ValueError
):
FileClient
(
'hadoop'
)
def
test_disk_backend
(
self
):
disk_backend
=
FileClient
(
'disk'
)
# test `name` attribute
assert
disk_backend
.
name
==
'HardDiskBackend'
# test `allow_symlink` attribute
assert
disk_backend
.
allow_symlink
# test `get`
# input path is Path object
img_bytes
=
disk_backend
.
get
(
self
.
img_path
)
img
=
mmcv
.
imfrombytes
(
img_bytes
)
assert
self
.
img_path
.
open
(
'rb'
).
read
()
==
img_bytes
assert
img
.
shape
==
self
.
img_shape
# input path is str
img_bytes
=
disk_backend
.
get
(
str
(
self
.
img_path
))
img
=
mmcv
.
imfrombytes
(
img_bytes
)
assert
self
.
img_path
.
open
(
'rb'
).
read
()
==
img_bytes
assert
img
.
shape
==
self
.
img_shape
# test `get_text`
# input path is Path object
value_buf
=
disk_backend
.
get_text
(
self
.
text_path
)
assert
self
.
text_path
.
open
(
'r'
).
read
()
==
value_buf
# input path is str
value_buf
=
disk_backend
.
get_text
(
str
(
self
.
text_path
))
assert
self
.
text_path
.
open
(
'r'
).
read
()
==
value_buf
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
# test `put`
filepath1
=
Path
(
tmp_dir
)
/
'test.jpg'
disk_backend
.
put
(
b
'disk'
,
filepath1
)
assert
filepath1
.
open
(
'rb'
).
read
()
==
b
'disk'
# test the `mkdir_or_exist` behavior in `put`
_filepath1
=
Path
(
tmp_dir
)
/
'not_existed_dir1'
/
'test.jpg'
disk_backend
.
put
(
b
'disk'
,
_filepath1
)
assert
_filepath1
.
open
(
'rb'
).
read
()
==
b
'disk'
# test `put_text`
filepath2
=
Path
(
tmp_dir
)
/
'test.txt'
disk_backend
.
put_text
(
'disk'
,
filepath2
)
assert
filepath2
.
open
(
'r'
).
read
()
==
'disk'
# test the `mkdir_or_exist` behavior in `put_text`
_filepath2
=
Path
(
tmp_dir
)
/
'not_existed_dir2'
/
'test.txt'
disk_backend
.
put_text
(
'disk'
,
_filepath2
)
assert
_filepath2
.
open
(
'r'
).
read
()
==
'disk'
# test `isfile`
assert
disk_backend
.
isfile
(
filepath2
)
assert
not
disk_backend
.
isfile
(
Path
(
tmp_dir
)
/
'not/existed/path'
)
# test `remove`
disk_backend
.
remove
(
filepath2
)
# test `exists`
assert
not
disk_backend
.
exists
(
filepath2
)
# test `get_local_path`
# if the backend is disk, `get_local_path` just return the input
with
disk_backend
.
get_local_path
(
filepath1
)
as
path
:
assert
str
(
filepath1
)
==
path
assert
osp
.
isfile
(
filepath1
)
# test `join_path`
disk_dir
=
'/path/of/your/directory'
assert
disk_backend
.
join_path
(
disk_dir
,
'file'
)
==
\
osp
.
join
(
disk_dir
,
'file'
)
assert
disk_backend
.
join_path
(
disk_dir
,
'dir'
,
'file'
)
==
\
osp
.
join
(
disk_dir
,
'dir'
,
'file'
)
# test `list_dir_or_file`
with
build_temporary_directory
()
as
tmp_dir
:
# 1. list directories and files
assert
set
(
disk_backend
.
list_dir_or_file
(
tmp_dir
))
==
{
'dir1'
,
'dir2'
,
'text1.txt'
,
'text2.txt'
}
# 2. list directories and files recursively
assert
set
(
disk_backend
.
list_dir_or_file
(
tmp_dir
,
recursive
=
True
))
==
{
'dir1'
,
osp
.
join
(
'dir1'
,
'text3.txt'
),
'dir2'
,
osp
.
join
(
'dir2'
,
'dir3'
),
osp
.
join
(
'dir2'
,
'dir3'
,
'text4.txt'
),
osp
.
join
(
'dir2'
,
'img.jpg'
),
'text1.txt'
,
'text2.txt'
}
# 3. only list directories
assert
set
(
disk_backend
.
list_dir_or_file
(
tmp_dir
,
list_file
=
False
))
==
{
'dir1'
,
'dir2'
}
with
pytest
.
raises
(
TypeError
,
match
=
'`suffix` should be None when `list_dir` is True'
):
# Exception is raised among the `list_dir_or_file` of client,
# so we need to invode the client to trigger the exception
disk_backend
.
client
.
list_dir_or_file
(
tmp_dir
,
list_file
=
False
,
suffix
=
'.txt'
)
# 4. only list directories recursively
assert
set
(
disk_backend
.
list_dir_or_file
(
tmp_dir
,
list_file
=
False
,
recursive
=
True
))
==
{
'dir1'
,
'dir2'
,
osp
.
join
(
'dir2'
,
'dir3'
)
}
# 5. only list files
assert
set
(
disk_backend
.
list_dir_or_file
(
tmp_dir
,
list_dir
=
False
))
==
{
'text1.txt'
,
'text2.txt'
}
# 6. only list files recursively
assert
set
(
disk_backend
.
list_dir_or_file
(
tmp_dir
,
list_dir
=
False
,
recursive
=
True
))
==
{
osp
.
join
(
'dir1'
,
'text3.txt'
),
osp
.
join
(
'dir2'
,
'dir3'
,
'text4.txt'
),
osp
.
join
(
'dir2'
,
'img.jpg'
),
'text1.txt'
,
'text2.txt'
}
# 7. only list files ending with suffix
assert
set
(
disk_backend
.
list_dir_or_file
(
tmp_dir
,
list_dir
=
False
,
suffix
=
'.txt'
))
==
{
'text1.txt'
,
'text2.txt'
}
assert
set
(
disk_backend
.
list_dir_or_file
(
tmp_dir
,
list_dir
=
False
,
suffix
=
(
'.txt'
,
'.jpg'
)))
==
{
'text1.txt'
,
'text2.txt'
}
with
pytest
.
raises
(
TypeError
,
match
=
'`suffix` must be a string or tuple of strings'
):
disk_backend
.
client
.
list_dir_or_file
(
tmp_dir
,
list_dir
=
False
,
suffix
=
[
'.txt'
,
'.jpg'
])
# 8. only list files ending with suffix recursively
assert
set
(
disk_backend
.
list_dir_or_file
(
tmp_dir
,
list_dir
=
False
,
suffix
=
'.txt'
,
recursive
=
True
))
==
{
osp
.
join
(
'dir1'
,
'text3.txt'
),
osp
.
join
(
'dir2'
,
'dir3'
,
'text4.txt'
),
'text1.txt'
,
'text2.txt'
}
# 7. only list files ending with suffix
assert
set
(
disk_backend
.
list_dir_or_file
(
tmp_dir
,
list_dir
=
False
,
suffix
=
(
'.txt'
,
'.jpg'
),
recursive
=
True
))
==
{
osp
.
join
(
'dir1'
,
'text3.txt'
),
osp
.
join
(
'dir2'
,
'dir3'
,
'text4.txt'
),
osp
.
join
(
'dir2'
,
'img.jpg'
),
'text1.txt'
,
'text2.txt'
}
@
patch
(
'ceph.S3Client'
,
MockS3Client
)
def
test_ceph_backend
(
self
):
ceph_backend
=
FileClient
(
'ceph'
)
# test `allow_symlink` attribute
assert
not
ceph_backend
.
allow_symlink
# input path is Path object
with
pytest
.
raises
(
NotImplementedError
):
ceph_backend
.
get_text
(
self
.
text_path
)
# input path is str
with
pytest
.
raises
(
NotImplementedError
):
ceph_backend
.
get_text
(
str
(
self
.
text_path
))
# input path is Path object
img_bytes
=
ceph_backend
.
get
(
self
.
img_path
)
img
=
mmcv
.
imfrombytes
(
img_bytes
)
assert
img
.
shape
==
self
.
img_shape
# input path is str
img_bytes
=
ceph_backend
.
get
(
str
(
self
.
img_path
))
img
=
mmcv
.
imfrombytes
(
img_bytes
)
assert
img
.
shape
==
self
.
img_shape
# `path_mapping` is either None or dict
with
pytest
.
raises
(
AssertionError
):
FileClient
(
'ceph'
,
path_mapping
=
1
)
# test `path_mapping`
ceph_path
=
's3://user/data'
ceph_backend
=
FileClient
(
'ceph'
,
path_mapping
=
{
str
(
self
.
test_data_dir
):
ceph_path
})
ceph_backend
.
client
.
_client
.
Get
=
MagicMock
(
return_value
=
ceph_backend
.
client
.
_client
.
Get
(
self
.
img_path
))
img_bytes
=
ceph_backend
.
get
(
self
.
img_path
)
img
=
mmcv
.
imfrombytes
(
img_bytes
)
assert
img
.
shape
==
self
.
img_shape
ceph_backend
.
client
.
_client
.
Get
.
assert_called_with
(
str
(
self
.
img_path
).
replace
(
str
(
self
.
test_data_dir
),
ceph_path
))
@
patch
(
'petrel_client.client.Client'
,
MockPetrelClient
)
@
pytest
.
mark
.
parametrize
(
'backend,prefix'
,
[(
'petrel'
,
None
),
(
None
,
's3'
)])
def
test_petrel_backend
(
self
,
backend
,
prefix
):
petrel_backend
=
FileClient
(
backend
=
backend
,
prefix
=
prefix
)
# test `allow_symlink` attribute
assert
not
petrel_backend
.
allow_symlink
# input path is Path object
img_bytes
=
petrel_backend
.
get
(
self
.
img_path
)
img
=
mmcv
.
imfrombytes
(
img_bytes
)
assert
img
.
shape
==
self
.
img_shape
# input path is str
img_bytes
=
petrel_backend
.
get
(
str
(
self
.
img_path
))
img
=
mmcv
.
imfrombytes
(
img_bytes
)
assert
img
.
shape
==
self
.
img_shape
# `path_mapping` is either None or dict
with
pytest
.
raises
(
AssertionError
):
FileClient
(
'petrel'
,
path_mapping
=
1
)
# test `_map_path`
petrel_dir
=
's3://user/data'
petrel_backend
=
FileClient
(
'petrel'
,
path_mapping
=
{
str
(
self
.
test_data_dir
):
petrel_dir
})
assert
petrel_backend
.
client
.
_map_path
(
str
(
self
.
img_path
))
==
\
str
(
self
.
img_path
).
replace
(
str
(
self
.
test_data_dir
),
petrel_dir
)
petrel_path
=
f
'
{
petrel_dir
}
/test.jpg'
petrel_backend
=
FileClient
(
'petrel'
)
# test `_format_path`
assert
petrel_backend
.
client
.
_format_path
(
's3://user
\\
data
\\
test.jpg'
)
\
==
petrel_path
# test `get`
with
patch
.
object
(
petrel_backend
.
client
.
_client
,
'Get'
,
return_value
=
b
'petrel'
)
as
mock_get
:
assert
petrel_backend
.
get
(
petrel_path
)
==
b
'petrel'
mock_get
.
assert_called_once_with
(
petrel_path
)
# test `get_text`
with
patch
.
object
(
petrel_backend
.
client
.
_client
,
'Get'
,
return_value
=
b
'petrel'
)
as
mock_get
:
assert
petrel_backend
.
get_text
(
petrel_path
)
==
'petrel'
mock_get
.
assert_called_once_with
(
petrel_path
)
# test `put`
with
patch
.
object
(
petrel_backend
.
client
.
_client
,
'put'
)
as
mock_put
:
petrel_backend
.
put
(
b
'petrel'
,
petrel_path
)
mock_put
.
assert_called_once_with
(
petrel_path
,
b
'petrel'
)
# test `put_text`
with
patch
.
object
(
petrel_backend
.
client
.
_client
,
'put'
)
as
mock_put
:
petrel_backend
.
put_text
(
'petrel'
,
petrel_path
)
mock_put
.
assert_called_once_with
(
petrel_path
,
b
'petrel'
)
# test `remove`
assert
has_method
(
petrel_backend
.
client
.
_client
,
'delete'
)
# raise Exception if `delete` is not implemented
with
delete_and_reset_method
(
petrel_backend
.
client
.
_client
,
'delete'
):
assert
not
has_method
(
petrel_backend
.
client
.
_client
,
'delete'
)
with
pytest
.
raises
(
NotImplementedError
):
petrel_backend
.
remove
(
petrel_path
)
with
patch
.
object
(
petrel_backend
.
client
.
_client
,
'delete'
)
as
mock_delete
:
petrel_backend
.
remove
(
petrel_path
)
mock_delete
.
assert_called_once_with
(
petrel_path
)
# test `exists`
assert
has_method
(
petrel_backend
.
client
.
_client
,
'contains'
)
assert
has_method
(
petrel_backend
.
client
.
_client
,
'isdir'
)
# raise Exception if `delete` is not implemented
with
delete_and_reset_method
(
petrel_backend
.
client
.
_client
,
'contains'
),
delete_and_reset_method
(
petrel_backend
.
client
.
_client
,
'isdir'
):
assert
not
has_method
(
petrel_backend
.
client
.
_client
,
'contains'
)
assert
not
has_method
(
petrel_backend
.
client
.
_client
,
'isdir'
)
with
pytest
.
raises
(
NotImplementedError
):
petrel_backend
.
exists
(
petrel_path
)
with
patch
.
object
(
petrel_backend
.
client
.
_client
,
'contains'
,
return_value
=
True
)
as
mock_contains
:
assert
petrel_backend
.
exists
(
petrel_path
)
mock_contains
.
assert_called_once_with
(
petrel_path
)
# test `isdir`
assert
has_method
(
petrel_backend
.
client
.
_client
,
'isdir'
)
with
delete_and_reset_method
(
petrel_backend
.
client
.
_client
,
'isdir'
):
assert
not
has_method
(
petrel_backend
.
client
.
_client
,
'isdir'
)
with
pytest
.
raises
(
NotImplementedError
):
petrel_backend
.
isdir
(
petrel_path
)
with
patch
.
object
(
petrel_backend
.
client
.
_client
,
'isdir'
,
return_value
=
True
)
as
mock_isdir
:
assert
petrel_backend
.
isdir
(
petrel_dir
)
mock_isdir
.
assert_called_once_with
(
petrel_dir
)
# test `isfile`
assert
has_method
(
petrel_backend
.
client
.
_client
,
'contains'
)
with
delete_and_reset_method
(
petrel_backend
.
client
.
_client
,
'contains'
):
assert
not
has_method
(
petrel_backend
.
client
.
_client
,
'contains'
)
with
pytest
.
raises
(
NotImplementedError
):
petrel_backend
.
isfile
(
petrel_path
)
with
patch
.
object
(
petrel_backend
.
client
.
_client
,
'contains'
,
return_value
=
True
)
as
mock_contains
:
assert
petrel_backend
.
isfile
(
petrel_path
)
mock_contains
.
assert_called_once_with
(
petrel_path
)
# test `join_path`
assert
petrel_backend
.
join_path
(
petrel_dir
,
'file'
)
==
\
f
'
{
petrel_dir
}
/file'
assert
petrel_backend
.
join_path
(
f
'
{
petrel_dir
}
/'
,
'file'
)
==
\
f
'
{
petrel_dir
}
/file'
assert
petrel_backend
.
join_path
(
petrel_dir
,
'dir'
,
'file'
)
==
\
f
'
{
petrel_dir
}
/dir/file'
# test `get_local_path`
with
patch
.
object
(
petrel_backend
.
client
.
_client
,
'Get'
,
return_value
=
b
'petrel'
)
as
mock_get
,
\
patch
.
object
(
petrel_backend
.
client
.
_client
,
'contains'
,
return_value
=
True
)
as
mock_contains
:
with
petrel_backend
.
get_local_path
(
petrel_path
)
as
path
:
assert
Path
(
path
).
open
(
'rb'
).
read
()
==
b
'petrel'
# exist the with block and path will be released
assert
not
osp
.
isfile
(
path
)
mock_get
.
assert_called_once_with
(
petrel_path
)
mock_contains
.
assert_called_once_with
(
petrel_path
)
# test `list_dir_or_file`
assert
has_method
(
petrel_backend
.
client
.
_client
,
'list'
)
with
delete_and_reset_method
(
petrel_backend
.
client
.
_client
,
'list'
):
assert
not
has_method
(
petrel_backend
.
client
.
_client
,
'list'
)
with
pytest
.
raises
(
NotImplementedError
):
list
(
petrel_backend
.
list_dir_or_file
(
petrel_dir
))
with
build_temporary_directory
()
as
tmp_dir
:
# 1. list directories and files
assert
set
(
petrel_backend
.
list_dir_or_file
(
tmp_dir
))
==
{
'dir1'
,
'dir2'
,
'text1.txt'
,
'text2.txt'
}
# 2. list directories and files recursively
assert
set
(
petrel_backend
.
list_dir_or_file
(
tmp_dir
,
recursive
=
True
))
==
{
'dir1'
,
'/'
.
join
((
'dir1'
,
'text3.txt'
)),
'dir2'
,
'/'
.
join
(
(
'dir2'
,
'dir3'
)),
'/'
.
join
(
(
'dir2'
,
'dir3'
,
'text4.txt'
)),
'/'
.
join
(
(
'dir2'
,
'img.jpg'
)),
'text1.txt'
,
'text2.txt'
}
# 3. only list directories
assert
set
(
petrel_backend
.
list_dir_or_file
(
tmp_dir
,
list_file
=
False
))
==
{
'dir1'
,
'dir2'
}
with
pytest
.
raises
(
TypeError
,
match
=
(
'`list_dir` should be False when `suffix` is not '
'None'
)):
# Exception is raised among the `list_dir_or_file` of client,
# so we need to invode the client to trigger the exception
petrel_backend
.
client
.
list_dir_or_file
(
tmp_dir
,
list_file
=
False
,
suffix
=
'.txt'
)
# 4. only list directories recursively
assert
set
(
petrel_backend
.
list_dir_or_file
(
tmp_dir
,
list_file
=
False
,
recursive
=
True
))
==
{
'dir1'
,
'dir2'
,
'/'
.
join
((
'dir2'
,
'dir3'
))
}
# 5. only list files
assert
set
(
petrel_backend
.
list_dir_or_file
(
tmp_dir
,
list_dir
=
False
))
==
{
'text1.txt'
,
'text2.txt'
}
# 6. only list files recursively
assert
set
(
petrel_backend
.
list_dir_or_file
(
tmp_dir
,
list_dir
=
False
,
recursive
=
True
))
==
{
'/'
.
join
((
'dir1'
,
'text3.txt'
)),
'/'
.
join
(
(
'dir2'
,
'dir3'
,
'text4.txt'
)),
'/'
.
join
(
(
'dir2'
,
'img.jpg'
)),
'text1.txt'
,
'text2.txt'
}
# 7. only list files ending with suffix
assert
set
(
petrel_backend
.
list_dir_or_file
(
tmp_dir
,
list_dir
=
False
,
suffix
=
'.txt'
))
==
{
'text1.txt'
,
'text2.txt'
}
assert
set
(
petrel_backend
.
list_dir_or_file
(
tmp_dir
,
list_dir
=
False
,
suffix
=
(
'.txt'
,
'.jpg'
)))
==
{
'text1.txt'
,
'text2.txt'
}
with
pytest
.
raises
(
TypeError
,
match
=
'`suffix` must be a string or tuple of strings'
):
petrel_backend
.
client
.
list_dir_or_file
(
tmp_dir
,
list_dir
=
False
,
suffix
=
[
'.txt'
,
'.jpg'
])
# 8. only list files ending with suffix recursively
assert
set
(
petrel_backend
.
list_dir_or_file
(
tmp_dir
,
list_dir
=
False
,
suffix
=
'.txt'
,
recursive
=
True
))
==
{
'/'
.
join
((
'dir1'
,
'text3.txt'
)),
'/'
.
join
(
(
'dir2'
,
'dir3'
,
'text4.txt'
)),
'text1.txt'
,
'text2.txt'
}
# 7. only list files ending with suffix
assert
set
(
petrel_backend
.
list_dir_or_file
(
tmp_dir
,
list_dir
=
False
,
suffix
=
(
'.txt'
,
'.jpg'
),
recursive
=
True
))
==
{
'/'
.
join
((
'dir1'
,
'text3.txt'
)),
'/'
.
join
(
(
'dir2'
,
'dir3'
,
'text4.txt'
)),
'/'
.
join
(
(
'dir2'
,
'img.jpg'
)),
'text1.txt'
,
'text2.txt'
}
@
patch
(
'mc.MemcachedClient.GetInstance'
,
MockMemcachedClient
)
@
patch
(
'mc.pyvector'
,
MagicMock
)
@
patch
(
'mc.ConvertBuffer'
,
lambda
x
:
x
.
content
)
def
test_memcached_backend
(
self
):
mc_cfg
=
dict
(
server_list_cfg
=
''
,
client_cfg
=
''
,
sys_path
=
None
)
mc_backend
=
FileClient
(
'memcached'
,
**
mc_cfg
)
# test `allow_symlink` attribute
assert
not
mc_backend
.
allow_symlink
# input path is Path object
with
pytest
.
raises
(
NotImplementedError
):
mc_backend
.
get_text
(
self
.
text_path
)
# input path is str
with
pytest
.
raises
(
NotImplementedError
):
mc_backend
.
get_text
(
str
(
self
.
text_path
))
# input path is Path object
img_bytes
=
mc_backend
.
get
(
self
.
img_path
)
img
=
mmcv
.
imfrombytes
(
img_bytes
)
assert
img
.
shape
==
self
.
img_shape
# input path is str
img_bytes
=
mc_backend
.
get
(
str
(
self
.
img_path
))
img
=
mmcv
.
imfrombytes
(
img_bytes
)
assert
img
.
shape
==
self
.
img_shape
def
test_lmdb_backend
(
self
):
lmdb_path
=
self
.
test_data_dir
/
'demo.lmdb'
# db_path is Path object
lmdb_backend
=
FileClient
(
'lmdb'
,
db_path
=
lmdb_path
)
# test `allow_symlink` attribute
assert
not
lmdb_backend
.
allow_symlink
with
pytest
.
raises
(
NotImplementedError
):
lmdb_backend
.
get_text
(
self
.
text_path
)
img_bytes
=
lmdb_backend
.
get
(
'baboon'
)
img
=
mmcv
.
imfrombytes
(
img_bytes
)
assert
img
.
shape
==
(
120
,
125
,
3
)
# db_path is str
lmdb_backend
=
FileClient
(
'lmdb'
,
db_path
=
str
(
lmdb_path
))
with
pytest
.
raises
(
NotImplementedError
):
lmdb_backend
.
get_text
(
str
(
self
.
text_path
))
img_bytes
=
lmdb_backend
.
get
(
'baboon'
)
img
=
mmcv
.
imfrombytes
(
img_bytes
)
assert
img
.
shape
==
(
120
,
125
,
3
)
@
pytest
.
mark
.
parametrize
(
'backend,prefix'
,
[(
'http'
,
None
),
(
None
,
'http'
)])
def
test_http_backend
(
self
,
backend
,
prefix
):
http_backend
=
FileClient
(
backend
=
backend
,
prefix
=
prefix
)
img_url
=
'https://raw.githubusercontent.com/open-mmlab/mmcv/'
\
'master/tests/data/color.jpg'
text_url
=
'https://raw.githubusercontent.com/open-mmlab/mmcv/'
\
'master/tests/data/filelist.txt'
# test `allow_symlink` attribute
assert
not
http_backend
.
allow_symlink
# input is path or Path object
with
pytest
.
raises
(
Exception
):
http_backend
.
get
(
self
.
img_path
)
with
pytest
.
raises
(
Exception
):
http_backend
.
get
(
str
(
self
.
img_path
))
with
pytest
.
raises
(
Exception
):
http_backend
.
get_text
(
self
.
text_path
)
with
pytest
.
raises
(
Exception
):
http_backend
.
get_text
(
str
(
self
.
text_path
))
# input url is http image
img_bytes
=
http_backend
.
get
(
img_url
)
img
=
mmcv
.
imfrombytes
(
img_bytes
)
assert
img
.
shape
==
self
.
img_shape
# input url is http text
value_buf
=
http_backend
.
get_text
(
text_url
)
assert
self
.
text_path
.
open
(
'r'
).
read
()
==
value_buf
# test `_get_local_path`
# exist the with block and path will be released
with
http_backend
.
get_local_path
(
img_url
)
as
path
:
assert
mmcv
.
imread
(
path
).
shape
==
self
.
img_shape
assert
not
osp
.
isfile
(
path
)
def
test_new_magic_method
(
self
):
class
DummyBackend1
(
BaseStorageBackend
):
def
get
(
self
,
filepath
):
return
filepath
def
get_text
(
self
,
filepath
,
encoding
=
'utf-8'
):
return
filepath
FileClient
.
register_backend
(
'dummy_backend'
,
DummyBackend1
)
client1
=
FileClient
(
backend
=
'dummy_backend'
)
client2
=
FileClient
(
backend
=
'dummy_backend'
)
assert
client1
is
client2
# if a backend is overwrote, it will disable the singleton pattern for
# the backend
class
DummyBackend2
(
BaseStorageBackend
):
def
get
(
self
,
filepath
):
pass
def
get_text
(
self
,
filepath
):
pass
FileClient
.
register_backend
(
'dummy_backend'
,
DummyBackend2
,
force
=
True
)
client3
=
FileClient
(
backend
=
'dummy_backend'
)
client4
=
FileClient
(
backend
=
'dummy_backend'
)
assert
client2
is
not
client3
assert
client3
is
client4
def
test_parse_uri_prefix
(
self
):
# input path is None
with
pytest
.
raises
(
AssertionError
):
FileClient
.
parse_uri_prefix
(
None
)
# input path is list
with
pytest
.
raises
(
AssertionError
):
FileClient
.
parse_uri_prefix
([])
# input path is Path object
assert
FileClient
.
parse_uri_prefix
(
self
.
img_path
)
is
None
# input path is str
assert
FileClient
.
parse_uri_prefix
(
str
(
self
.
img_path
))
is
None
# input path starts with https
img_url
=
'https://raw.githubusercontent.com/open-mmlab/mmcv/'
\
'master/tests/data/color.jpg'
assert
FileClient
.
parse_uri_prefix
(
img_url
)
==
'https'
# input path starts with s3
img_url
=
's3://your_bucket/img.png'
assert
FileClient
.
parse_uri_prefix
(
img_url
)
==
's3'
# input path starts with clusterName:s3
img_url
=
'clusterName:s3://your_bucket/img.png'
assert
FileClient
.
parse_uri_prefix
(
img_url
)
==
's3'
def
test_infer_client
(
self
):
# HardDiskBackend
file_client_args
=
{
'backend'
:
'disk'
}
client
=
FileClient
.
infer_client
(
file_client_args
)
assert
client
.
name
==
'HardDiskBackend'
client
=
FileClient
.
infer_client
(
uri
=
self
.
img_path
)
assert
client
.
name
==
'HardDiskBackend'
# PetrelBackend
file_client_args
=
{
'backend'
:
'petrel'
}
client
=
FileClient
.
infer_client
(
file_client_args
)
assert
client
.
name
==
'PetrelBackend'
uri
=
's3://user_data'
client
=
FileClient
.
infer_client
(
uri
=
uri
)
assert
client
.
name
==
'PetrelBackend'
def
test_register_backend
(
self
):
# name must be a string
with
pytest
.
raises
(
TypeError
):
class
TestClass1
:
pass
FileClient
.
register_backend
(
1
,
TestClass1
)
# module must be a class
with
pytest
.
raises
(
TypeError
):
FileClient
.
register_backend
(
'int'
,
0
)
# module must be a subclass of BaseStorageBackend
with
pytest
.
raises
(
TypeError
):
class
TestClass1
:
pass
FileClient
.
register_backend
(
'TestClass1'
,
TestClass1
)
class
ExampleBackend
(
BaseStorageBackend
):
def
get
(
self
,
filepath
):
return
filepath
def
get_text
(
self
,
filepath
,
encoding
=
'utf-8'
):
return
filepath
FileClient
.
register_backend
(
'example'
,
ExampleBackend
)
example_backend
=
FileClient
(
'example'
)
assert
example_backend
.
get
(
self
.
img_path
)
==
self
.
img_path
assert
example_backend
.
get_text
(
self
.
text_path
)
==
self
.
text_path
assert
'example'
in
FileClient
.
_backends
class
Example2Backend
(
BaseStorageBackend
):
def
get
(
self
,
filepath
):
return
b
'bytes2'
def
get_text
(
self
,
filepath
,
encoding
=
'utf-8'
):
return
'text2'
# force=False
with
pytest
.
raises
(
KeyError
):
FileClient
.
register_backend
(
'example'
,
Example2Backend
)
FileClient
.
register_backend
(
'example'
,
Example2Backend
,
force
=
True
)
example_backend
=
FileClient
(
'example'
)
assert
example_backend
.
get
(
self
.
img_path
)
==
b
'bytes2'
assert
example_backend
.
get_text
(
self
.
text_path
)
==
'text2'
@
FileClient
.
register_backend
(
name
=
'example3'
)
class
Example3Backend
(
BaseStorageBackend
):
def
get
(
self
,
filepath
):
return
b
'bytes3'
def
get_text
(
self
,
filepath
,
encoding
=
'utf-8'
):
return
'text3'
example_backend
=
FileClient
(
'example3'
)
assert
example_backend
.
get
(
self
.
img_path
)
==
b
'bytes3'
assert
example_backend
.
get_text
(
self
.
text_path
)
==
'text3'
assert
'example3'
in
FileClient
.
_backends
# force=False
with
pytest
.
raises
(
KeyError
):
@
FileClient
.
register_backend
(
name
=
'example3'
)
class
Example4Backend
(
BaseStorageBackend
):
def
get
(
self
,
filepath
):
return
b
'bytes4'
def
get_text
(
self
,
filepath
,
encoding
=
'utf-8'
):
return
'text4'
@
FileClient
.
register_backend
(
name
=
'example3'
,
force
=
True
)
class
Example5Backend
(
BaseStorageBackend
):
def
get
(
self
,
filepath
):
return
b
'bytes5'
def
get_text
(
self
,
filepath
,
encoding
=
'utf-8'
):
return
'text5'
example_backend
=
FileClient
(
'example3'
)
assert
example_backend
.
get
(
self
.
img_path
)
==
b
'bytes5'
assert
example_backend
.
get_text
(
self
.
text_path
)
==
'text5'
# prefixes is a str
class
Example6Backend
(
BaseStorageBackend
):
def
get
(
self
,
filepath
):
return
b
'bytes6'
def
get_text
(
self
,
filepath
,
encoding
=
'utf-8'
):
return
'text6'
FileClient
.
register_backend
(
'example4'
,
Example6Backend
,
force
=
True
,
prefixes
=
'example4_prefix'
)
example_backend
=
FileClient
(
'example4'
)
assert
example_backend
.
get
(
self
.
img_path
)
==
b
'bytes6'
assert
example_backend
.
get_text
(
self
.
text_path
)
==
'text6'
example_backend
=
FileClient
(
prefix
=
'example4_prefix'
)
assert
example_backend
.
get
(
self
.
img_path
)
==
b
'bytes6'
assert
example_backend
.
get_text
(
self
.
text_path
)
==
'text6'
example_backend
=
FileClient
(
'example4'
,
prefix
=
'example4_prefix'
)
assert
example_backend
.
get
(
self
.
img_path
)
==
b
'bytes6'
assert
example_backend
.
get_text
(
self
.
text_path
)
==
'text6'
# prefixes is a list of str
class
Example7Backend
(
BaseStorageBackend
):
def
get
(
self
,
filepath
):
return
b
'bytes7'
def
get_text
(
self
,
filepath
,
encoding
=
'utf-8'
):
return
'text7'
FileClient
.
register_backend
(
'example5'
,
Example7Backend
,
force
=
True
,
prefixes
=
[
'example5_prefix1'
,
'example5_prefix2'
])
example_backend
=
FileClient
(
'example5'
)
assert
example_backend
.
get
(
self
.
img_path
)
==
b
'bytes7'
assert
example_backend
.
get_text
(
self
.
text_path
)
==
'text7'
example_backend
=
FileClient
(
prefix
=
'example5_prefix1'
)
assert
example_backend
.
get
(
self
.
img_path
)
==
b
'bytes7'
assert
example_backend
.
get_text
(
self
.
text_path
)
==
'text7'
example_backend
=
FileClient
(
prefix
=
'example5_prefix2'
)
assert
example_backend
.
get
(
self
.
img_path
)
==
b
'bytes7'
assert
example_backend
.
get_text
(
self
.
text_path
)
==
'text7'
# backend has a higher priority than prefixes
class
Example8Backend
(
BaseStorageBackend
):
def
get
(
self
,
filepath
):
return
b
'bytes8'
def
get_text
(
self
,
filepath
,
encoding
=
'utf-8'
):
return
'text8'
FileClient
.
register_backend
(
'example6'
,
Example8Backend
,
force
=
True
,
prefixes
=
'example6_prefix'
)
example_backend
=
FileClient
(
'example6'
)
assert
example_backend
.
get
(
self
.
img_path
)
==
b
'bytes8'
assert
example_backend
.
get_text
(
self
.
text_path
)
==
'text8'
example_backend
=
FileClient
(
'example6'
,
prefix
=
'example4_prefix'
)
assert
example_backend
.
get
(
self
.
img_path
)
==
b
'bytes8'
assert
example_backend
.
get_text
(
self
.
text_path
)
==
'text8'
tests/test_fileio.py
deleted
100644 → 0
View file @
0b4285d9
# Copyright (c) OpenMMLab. All rights reserved.
import
os
import
os.path
as
osp
import
sys
import
tempfile
from
unittest.mock
import
MagicMock
,
patch
import
pytest
import
mmcv
from
mmcv.fileio.file_client
import
HTTPBackend
,
PetrelBackend
sys
.
modules
[
'petrel_client'
]
=
MagicMock
()
sys
.
modules
[
'petrel_client.client'
]
=
MagicMock
()
def
_test_handler
(
file_format
,
test_obj
,
str_checker
,
mode
=
'r+'
):
# dump to a string
dump_str
=
mmcv
.
dump
(
test_obj
,
file_format
=
file_format
)
str_checker
(
dump_str
)
# load/dump with filenames from disk
tmp_filename
=
osp
.
join
(
tempfile
.
gettempdir
(),
'mmcv_test_dump'
)
mmcv
.
dump
(
test_obj
,
tmp_filename
,
file_format
=
file_format
)
assert
osp
.
isfile
(
tmp_filename
)
load_obj
=
mmcv
.
load
(
tmp_filename
,
file_format
=
file_format
)
assert
load_obj
==
test_obj
os
.
remove
(
tmp_filename
)
# load/dump with filename from petrel
method
=
'put'
if
'b'
in
mode
else
'put_text'
with
patch
.
object
(
PetrelBackend
,
method
,
return_value
=
None
)
as
mock_method
:
filename
=
's3://path/of/your/file'
mmcv
.
dump
(
test_obj
,
filename
,
file_format
=
file_format
)
mock_method
.
assert_called
()
# json load/dump with a file-like object
with
tempfile
.
NamedTemporaryFile
(
mode
,
delete
=
False
)
as
f
:
tmp_filename
=
f
.
name
mmcv
.
dump
(
test_obj
,
f
,
file_format
=
file_format
)
assert
osp
.
isfile
(
tmp_filename
)
with
open
(
tmp_filename
,
mode
)
as
f
:
load_obj
=
mmcv
.
load
(
f
,
file_format
=
file_format
)
assert
load_obj
==
test_obj
os
.
remove
(
tmp_filename
)
# automatically inference the file format from the given filename
tmp_filename
=
osp
.
join
(
tempfile
.
gettempdir
(),
'mmcv_test_dump.'
+
file_format
)
mmcv
.
dump
(
test_obj
,
tmp_filename
)
assert
osp
.
isfile
(
tmp_filename
)
load_obj
=
mmcv
.
load
(
tmp_filename
)
assert
load_obj
==
test_obj
os
.
remove
(
tmp_filename
)
obj_for_test
=
[{
'a'
:
'abc'
,
'b'
:
1
},
2
,
'c'
]
def
test_json
():
def
json_checker
(
dump_str
):
assert
dump_str
in
[
'[{"a": "abc", "b": 1}, 2, "c"]'
,
'[{"b": 1, "a": "abc"}, 2, "c"]'
]
_test_handler
(
'json'
,
obj_for_test
,
json_checker
)
def
test_yaml
():
def
yaml_checker
(
dump_str
):
assert
dump_str
in
[
'- {a: abc, b: 1}
\n
- 2
\n
- c
\n
'
,
'- {b: 1, a: abc}
\n
- 2
\n
- c
\n
'
,
'- a: abc
\n
b: 1
\n
- 2
\n
- c
\n
'
,
'- b: 1
\n
a: abc
\n
- 2
\n
- c
\n
'
]
_test_handler
(
'yaml'
,
obj_for_test
,
yaml_checker
)
def
test_pickle
():
def
pickle_checker
(
dump_str
):
import
pickle
assert
pickle
.
loads
(
dump_str
)
==
obj_for_test
_test_handler
(
'pickle'
,
obj_for_test
,
pickle_checker
,
mode
=
'rb+'
)
def
test_exception
():
test_obj
=
[{
'a'
:
'abc'
,
'b'
:
1
},
2
,
'c'
]
with
pytest
.
raises
(
ValueError
):
mmcv
.
dump
(
test_obj
)
with
pytest
.
raises
(
TypeError
):
mmcv
.
dump
(
test_obj
,
'tmp.txt'
)
def
test_register_handler
():
@
mmcv
.
register_handler
(
'txt'
)
class
TxtHandler1
(
mmcv
.
BaseFileHandler
):
def
load_from_fileobj
(
self
,
file
):
return
file
.
read
()
def
dump_to_fileobj
(
self
,
obj
,
file
):
file
.
write
(
str
(
obj
))
def
dump_to_str
(
self
,
obj
,
**
kwargs
):
return
str
(
obj
)
@
mmcv
.
register_handler
([
'txt1'
,
'txt2'
])
class
TxtHandler2
(
mmcv
.
BaseFileHandler
):
def
load_from_fileobj
(
self
,
file
):
return
file
.
read
()
def
dump_to_fileobj
(
self
,
obj
,
file
):
file
.
write
(
'
\n
'
)
file
.
write
(
str
(
obj
))
def
dump_to_str
(
self
,
obj
,
**
kwargs
):
return
str
(
obj
)
content
=
mmcv
.
load
(
osp
.
join
(
osp
.
dirname
(
__file__
),
'data/filelist.txt'
))
assert
content
==
'1.jpg
\n
2.jpg
\n
3.jpg
\n
4.jpg
\n
5.jpg'
tmp_filename
=
osp
.
join
(
tempfile
.
gettempdir
(),
'mmcv_test.txt2'
)
mmcv
.
dump
(
content
,
tmp_filename
)
with
open
(
tmp_filename
)
as
f
:
written
=
f
.
read
()
os
.
remove
(
tmp_filename
)
assert
written
==
'
\n
'
+
content
def
test_list_from_file
():
# get list from disk
filename
=
osp
.
join
(
osp
.
dirname
(
__file__
),
'data/filelist.txt'
)
filelist
=
mmcv
.
list_from_file
(
filename
)
assert
filelist
==
[
'1.jpg'
,
'2.jpg'
,
'3.jpg'
,
'4.jpg'
,
'5.jpg'
]
filelist
=
mmcv
.
list_from_file
(
filename
,
prefix
=
'a/'
)
assert
filelist
==
[
'a/1.jpg'
,
'a/2.jpg'
,
'a/3.jpg'
,
'a/4.jpg'
,
'a/5.jpg'
]
filelist
=
mmcv
.
list_from_file
(
filename
,
offset
=
2
)
assert
filelist
==
[
'3.jpg'
,
'4.jpg'
,
'5.jpg'
]
filelist
=
mmcv
.
list_from_file
(
filename
,
max_num
=
2
)
assert
filelist
==
[
'1.jpg'
,
'2.jpg'
]
filelist
=
mmcv
.
list_from_file
(
filename
,
offset
=
3
,
max_num
=
3
)
assert
filelist
==
[
'4.jpg'
,
'5.jpg'
]
# get list from http
with
patch
.
object
(
HTTPBackend
,
'get_text'
,
return_value
=
'1.jpg
\n
2.jpg
\n
3.jpg'
):
filename
=
'http://path/of/your/file'
filelist
=
mmcv
.
list_from_file
(
filename
,
file_client_args
=
{
'backend'
:
'http'
})
assert
filelist
==
[
'1.jpg'
,
'2.jpg'
,
'3.jpg'
]
filelist
=
mmcv
.
list_from_file
(
filename
,
file_client_args
=
{
'prefix'
:
'http'
})
assert
filelist
==
[
'1.jpg'
,
'2.jpg'
,
'3.jpg'
]
filelist
=
mmcv
.
list_from_file
(
filename
)
assert
filelist
==
[
'1.jpg'
,
'2.jpg'
,
'3.jpg'
]
# get list from petrel
with
patch
.
object
(
PetrelBackend
,
'get_text'
,
return_value
=
'1.jpg
\n
2.jpg
\n
3.jpg'
):
filename
=
's3://path/of/your/file'
filelist
=
mmcv
.
list_from_file
(
filename
,
file_client_args
=
{
'backend'
:
'petrel'
})
assert
filelist
==
[
'1.jpg'
,
'2.jpg'
,
'3.jpg'
]
filelist
=
mmcv
.
list_from_file
(
filename
,
file_client_args
=
{
'prefix'
:
's3'
})
assert
filelist
==
[
'1.jpg'
,
'2.jpg'
,
'3.jpg'
]
filelist
=
mmcv
.
list_from_file
(
filename
)
assert
filelist
==
[
'1.jpg'
,
'2.jpg'
,
'3.jpg'
]
def
test_dict_from_file
():
# get dict from disk
filename
=
osp
.
join
(
osp
.
dirname
(
__file__
),
'data/mapping.txt'
)
mapping
=
mmcv
.
dict_from_file
(
filename
)
assert
mapping
==
{
'1'
:
'cat'
,
'2'
:
[
'dog'
,
'cow'
],
'3'
:
'panda'
}
mapping
=
mmcv
.
dict_from_file
(
filename
,
key_type
=
int
)
assert
mapping
==
{
1
:
'cat'
,
2
:
[
'dog'
,
'cow'
],
3
:
'panda'
}
# get dict from http
with
patch
.
object
(
HTTPBackend
,
'get_text'
,
return_value
=
'1 cat
\n
2 dog cow
\n
3 panda'
):
filename
=
'http://path/of/your/file'
mapping
=
mmcv
.
dict_from_file
(
filename
,
file_client_args
=
{
'backend'
:
'http'
})
assert
mapping
==
{
'1'
:
'cat'
,
'2'
:
[
'dog'
,
'cow'
],
'3'
:
'panda'
}
mapping
=
mmcv
.
dict_from_file
(
filename
,
file_client_args
=
{
'prefix'
:
'http'
})
assert
mapping
==
{
'1'
:
'cat'
,
'2'
:
[
'dog'
,
'cow'
],
'3'
:
'panda'
}
mapping
=
mmcv
.
dict_from_file
(
filename
)
assert
mapping
==
{
'1'
:
'cat'
,
'2'
:
[
'dog'
,
'cow'
],
'3'
:
'panda'
}
# get dict from petrel
with
patch
.
object
(
PetrelBackend
,
'get_text'
,
return_value
=
'1 cat
\n
2 dog cow
\n
3 panda'
):
filename
=
's3://path/of/your/file'
mapping
=
mmcv
.
dict_from_file
(
filename
,
file_client_args
=
{
'backend'
:
'petrel'
})
assert
mapping
==
{
'1'
:
'cat'
,
'2'
:
[
'dog'
,
'cow'
],
'3'
:
'panda'
}
mapping
=
mmcv
.
dict_from_file
(
filename
,
file_client_args
=
{
'prefix'
:
's3'
})
assert
mapping
==
{
'1'
:
'cat'
,
'2'
:
[
'dog'
,
'cow'
],
'3'
:
'panda'
}
mapping
=
mmcv
.
dict_from_file
(
filename
)
assert
mapping
==
{
'1'
:
'cat'
,
'2'
:
[
'dog'
,
'cow'
],
'3'
:
'panda'
}
tests/test_image/test_io.py
View file @
0fe1c647
...
@@ -7,13 +7,14 @@ from pathlib import Path
...
@@ -7,13 +7,14 @@ from pathlib import Path
from
unittest.mock
import
MagicMock
,
patch
from
unittest.mock
import
MagicMock
,
patch
import
cv2
import
cv2
import
mmengine
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
torch
import
torch
from
mmengine.fileio.file_client
import
HTTPBackend
,
PetrelBackend
from
numpy.testing
import
assert_allclose
,
assert_array_equal
from
numpy.testing
import
assert_allclose
,
assert_array_equal
import
mmcv
import
mmcv
from
mmcv.fileio.file_client
import
HTTPBackend
,
PetrelBackend
if
torch
.
__version__
==
'parrots'
:
if
torch
.
__version__
==
'parrots'
:
pytest
.
skip
(
'not necessary in parrots test'
,
allow_module_level
=
True
)
pytest
.
skip
(
'not necessary in parrots test'
,
allow_module_level
=
True
)
...
@@ -46,7 +47,7 @@ class TestIO:
...
@@ -46,7 +47,7 @@ class TestIO:
@
classmethod
@
classmethod
def
teardown_class
(
cls
):
def
teardown_class
(
cls
):
# clean instances avoid to influence other unittest
# clean instances avoid to influence other unittest
mm
cv
.
FileClient
.
_instances
=
{}
mm
engine
.
FileClient
.
_instances
=
{}
def
assert_img_equal
(
self
,
img
,
ref_img
,
ratio_thr
=
0.999
):
def
assert_img_equal
(
self
,
img
,
ref_img
,
ratio_thr
=
0.999
):
assert
img
.
shape
==
ref_img
.
shape
assert
img
.
shape
==
ref_img
.
shape
...
...
tests/test_load_model_zoo.py
View file @
0fe1c647
...
@@ -3,6 +3,7 @@ import os
...
@@ -3,6 +3,7 @@ import os
import
os.path
as
osp
import
os.path
as
osp
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
mmengine
import
pytest
import
pytest
import
torchvision
import
torchvision
...
@@ -30,7 +31,7 @@ def test_default_mmcv_home():
...
@@ -30,7 +31,7 @@ def test_default_mmcv_home():
assert
_get_mmcv_home
()
==
os
.
path
.
expanduser
(
assert
_get_mmcv_home
()
==
os
.
path
.
expanduser
(
os
.
path
.
join
(
DEFAULT_CACHE_DIR
,
'mmcv'
))
os
.
path
.
join
(
DEFAULT_CACHE_DIR
,
'mmcv'
))
model_urls
=
get_external_models
()
model_urls
=
get_external_models
()
assert
model_urls
==
mm
cv
.
load
(
assert
model_urls
==
mm
engine
.
load
(
osp
.
join
(
mmcv
.
__path__
[
0
],
'model_zoo/open_mmlab.json'
))
osp
.
join
(
mmcv
.
__path__
[
0
],
'model_zoo/open_mmlab.json'
))
...
...
tests/test_ops/test_nms.py
View file @
0fe1c647
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
mmengine
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
torch
import
torch
...
@@ -144,9 +145,8 @@ class Testnms:
...
@@ -144,9 +145,8 @@ class Testnms:
nms_match
(
wrong_dets
,
iou_thr
)
nms_match
(
wrong_dets
,
iou_thr
)
def
test_batched_nms
(
self
):
def
test_batched_nms
(
self
):
import
mmcv
from
mmcv.ops
import
batched_nms
from
mmcv.ops
import
batched_nms
results
=
mm
cv
.
load
(
'./tests/data/batched_nms_data.pkl'
)
results
=
mm
engine
.
load
(
'./tests/data/batched_nms_data.pkl'
)
nms_max_num
=
100
nms_max_num
=
100
nms_cfg
=
dict
(
nms_cfg
=
dict
(
...
...
tests/test_ops/test_tensorrt.py
View file @
0fe1c647
...
@@ -3,6 +3,7 @@ import os
...
@@ -3,6 +3,7 @@ import os
from
functools
import
partial
from
functools
import
partial
from
typing
import
Callable
from
typing
import
Callable
import
mmengine
import
numpy
as
np
import
numpy
as
np
import
onnx
import
onnx
import
pytest
import
pytest
...
@@ -117,7 +118,6 @@ def test_roialign():
...
@@ -117,7 +118,6 @@ def test_roialign():
def
test_nms
():
def
test_nms
():
try
:
try
:
import
mmcv
from
mmcv.ops
import
nms
from
mmcv.ops
import
nms
except
(
ImportError
,
ModuleNotFoundError
):
except
(
ImportError
,
ModuleNotFoundError
):
pytest
.
skip
(
'test requires compilation'
)
pytest
.
skip
(
'test requires compilation'
)
...
@@ -125,7 +125,7 @@ def test_nms():
...
@@ -125,7 +125,7 @@ def test_nms():
# trt config
# trt config
fp16_mode
=
False
fp16_mode
=
False
max_workspace_size
=
1
<<
30
max_workspace_size
=
1
<<
30
data
=
mm
cv
.
load
(
'./tests/data/batched_nms_data.pkl'
)
data
=
mm
engine
.
load
(
'./tests/data/batched_nms_data.pkl'
)
boxes
=
torch
.
from_numpy
(
data
[
'boxes'
]).
cuda
()
boxes
=
torch
.
from_numpy
(
data
[
'boxes'
]).
cuda
()
scores
=
torch
.
from_numpy
(
data
[
'scores'
]).
cuda
()
scores
=
torch
.
from_numpy
(
data
[
'scores'
]).
cuda
()
nms
=
partial
(
nms
=
partial
(
...
@@ -188,7 +188,6 @@ def test_nms():
...
@@ -188,7 +188,6 @@ def test_nms():
def
test_batched_nms
():
def
test_batched_nms
():
try
:
try
:
import
mmcv
from
mmcv.ops
import
batched_nms
from
mmcv.ops
import
batched_nms
except
(
ImportError
,
ModuleNotFoundError
):
except
(
ImportError
,
ModuleNotFoundError
):
pytest
.
skip
(
'test requires compilation'
)
pytest
.
skip
(
'test requires compilation'
)
...
@@ -197,7 +196,7 @@ def test_batched_nms():
...
@@ -197,7 +196,7 @@ def test_batched_nms():
os
.
environ
[
'ONNX_BACKEND'
]
=
'MMCVTensorRT'
os
.
environ
[
'ONNX_BACKEND'
]
=
'MMCVTensorRT'
fp16_mode
=
False
fp16_mode
=
False
max_workspace_size
=
1
<<
30
max_workspace_size
=
1
<<
30
data
=
mm
cv
.
load
(
'./tests/data/batched_nms_data.pkl'
)
data
=
mm
engine
.
load
(
'./tests/data/batched_nms_data.pkl'
)
nms_cfg
=
dict
(
type
=
'nms'
,
iou_threshold
=
0.7
,
score_threshold
=
0.1
)
nms_cfg
=
dict
(
type
=
'nms'
,
iou_threshold
=
0.7
,
score_threshold
=
0.1
)
boxes
=
torch
.
from_numpy
(
data
[
'boxes'
]).
cuda
()
boxes
=
torch
.
from_numpy
(
data
[
'boxes'
]).
cuda
()
scores
=
torch
.
from_numpy
(
data
[
'scores'
]).
cuda
()
scores
=
torch
.
from_numpy
(
data
[
'scores'
]).
cuda
()
...
...
tests/test_runner/test_basemodule.py
View file @
0fe1c647
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
tempfile
import
tempfile
import
mmengine
import
pytest
import
pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
import
mmcv
from
mmcv.cnn.utils.weight_init
import
update_init_info
from
mmcv.cnn.utils.weight_init
import
update_init_info
from
mmcv.runner
import
BaseModule
,
ModuleDict
,
ModuleList
,
Sequential
from
mmcv.runner
import
BaseModule
,
ModuleDict
,
ModuleList
,
Sequential
from
mmcv.utils
import
Registry
,
build_from_cfg
from
mmcv.utils
import
Registry
,
build_from_cfg
...
@@ -135,7 +135,7 @@ def test_initilization_info_logger():
...
@@ -135,7 +135,7 @@ def test_initilization_info_logger():
# assert initialization information has been dumped
# assert initialization information has been dumped
assert
os
.
path
.
exists
(
log_file
)
assert
os
.
path
.
exists
(
log_file
)
lines
=
mm
cv
.
list_from_file
(
log_file
)
lines
=
mm
engine
.
list_from_file
(
log_file
)
# check initialization information is right
# check initialization information is right
for
i
,
line
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
...
@@ -210,7 +210,7 @@ def test_initilization_info_logger():
...
@@ -210,7 +210,7 @@ def test_initilization_info_logger():
# assert initialization information has been dumped
# assert initialization information has been dumped
assert
os
.
path
.
exists
(
log_file
)
assert
os
.
path
.
exists
(
log_file
)
lines
=
mm
cv
.
list_from_file
(
log_file
)
lines
=
mm
engine
.
list_from_file
(
log_file
)
# check initialization information is right
# check initialization information is right
for
i
,
line
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
'TopLevelModule'
in
line
and
'init_cfg'
not
in
line
:
if
'TopLevelModule'
in
line
and
'init_cfg'
not
in
line
:
...
...
tests/test_runner/test_checkpoint.py
View file @
0fe1c647
...
@@ -8,9 +8,9 @@ import pytest
...
@@ -8,9 +8,9 @@ import pytest
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.optim
as
optim
import
torch.optim
as
optim
from
mmengine.fileio.file_client
import
PetrelBackend
from
torch.nn.parallel
import
DataParallel
from
torch.nn.parallel
import
DataParallel
from
mmcv.fileio.file_client
import
PetrelBackend
from
mmcv.parallel.registry
import
MODULE_WRAPPERS
from
mmcv.parallel.registry
import
MODULE_WRAPPERS
from
mmcv.runner.checkpoint
import
(
_load_checkpoint_with_prefix
,
from
mmcv.runner.checkpoint
import
(
_load_checkpoint_with_prefix
,
get_state_dict
,
load_checkpoint
,
get_state_dict
,
load_checkpoint
,
...
...
tests/test_runner/test_eval_hook.py
View file @
0fe1c647
...
@@ -11,9 +11,9 @@ import pytest
...
@@ -11,9 +11,9 @@ import pytest
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.optim
as
optim
import
torch.optim
as
optim
from
mmengine.fileio.file_client
import
PetrelBackend
from
torch.utils.data
import
DataLoader
,
Dataset
from
torch.utils.data
import
DataLoader
,
Dataset
from
mmcv.fileio.file_client
import
PetrelBackend
from
mmcv.runner
import
DistEvalHook
as
BaseDistEvalHook
from
mmcv.runner
import
DistEvalHook
as
BaseDistEvalHook
from
mmcv.runner
import
EpochBasedRunner
from
mmcv.runner
import
EpochBasedRunner
from
mmcv.runner
import
EvalHook
as
BaseEvalHook
from
mmcv.runner
import
EvalHook
as
BaseEvalHook
...
...
tests/test_runner/test_hooks.py
View file @
0fe1c647
...
@@ -18,10 +18,10 @@ from unittest.mock import MagicMock, Mock, call, patch
...
@@ -18,10 +18,10 @@ from unittest.mock import MagicMock, Mock, call, patch
import
pytest
import
pytest
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
mmengine.fileio.file_client
import
PetrelBackend
from
torch.nn.init
import
constant_
from
torch.nn.init
import
constant_
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
mmcv.fileio.file_client
import
PetrelBackend
# yapf: disable
# yapf: disable
from
mmcv.runner
import
(
CheckpointHook
,
ClearMLLoggerHook
,
DvcliveLoggerHook
,
from
mmcv.runner
import
(
CheckpointHook
,
ClearMLLoggerHook
,
DvcliveLoggerHook
,
EMAHook
,
Fp16OptimizerHook
,
EMAHook
,
Fp16OptimizerHook
,
...
...
tests/test_utils/test_config.py
View file @
0fe1c647
...
@@ -10,8 +10,9 @@ from pathlib import Path
...
@@ -10,8 +10,9 @@ from pathlib import Path
import
pytest
import
pytest
import
yaml
import
yaml
from
mmengine
import
dump
,
load
from
mmcv
import
Config
,
ConfigDict
,
DictAction
,
dump
,
load
from
mmcv
import
Config
,
ConfigDict
,
DictAction
data_path
=
osp
.
join
(
osp
.
dirname
(
osp
.
dirname
(
__file__
)),
'data'
)
data_path
=
osp
.
join
(
osp
.
dirname
(
osp
.
dirname
(
__file__
)),
'data'
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment