Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
fdeee889
Commit
fdeee889
authored
May 25, 2025
by
limm
Browse files
release v1.6.1 of mmcv
parent
df465820
Changes
457
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
566 additions
and
192 deletions
+566
-192
mmcv/device/mps/data_parallel.py
mmcv/device/mps/data_parallel.py
+34
-0
mmcv/device/scatter_gather.py
mmcv/device/scatter_gather.py
+64
-0
mmcv/device/utils.py
mmcv/device/utils.py
+18
-0
mmcv/engine/test.py
mmcv/engine/test.py
+21
-10
mmcv/fileio/file_client.py
mmcv/fileio/file_client.py
+78
-53
mmcv/fileio/handlers/base.py
mmcv/fileio/handlers/base.py
+2
-2
mmcv/fileio/handlers/pickle_handler.py
mmcv/fileio/handlers/pickle_handler.py
+2
-4
mmcv/fileio/handlers/yaml_handler.py
mmcv/fileio/handlers/yaml_handler.py
+3
-2
mmcv/fileio/io.py
mmcv/fileio/io.py
+22
-10
mmcv/fileio/parse.py
mmcv/fileio/parse.py
+12
-10
mmcv/image/__init__.py
mmcv/image/__init__.py
+6
-5
mmcv/image/colorspace.py
mmcv/image/colorspace.py
+20
-17
mmcv/image/geometric.py
mmcv/image/geometric.py
+36
-23
mmcv/image/io.py
mmcv/image/io.py
+99
-43
mmcv/image/misc.py
mmcv/image/misc.py
+18
-9
mmcv/image/photometric.py
mmcv/image/photometric.py
+43
-0
mmcv/model_zoo/torchvision_0.12.json
mmcv/model_zoo/torchvision_0.12.json
+57
-0
mmcv/onnx/info.py
mmcv/onnx/info.py
+15
-1
mmcv/onnx/onnx_utils/symbolic_helper.py
mmcv/onnx/onnx_utils/symbolic_helper.py
+2
-2
mmcv/onnx/symbolic.py
mmcv/onnx/symbolic.py
+14
-1
No files found.
Too many changes to show.
To preserve performance only
457 of 457+
files are displayed.
Plain diff
Email patch
mmcv/device/mps/data_parallel.py
0 → 100644
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
mmcv.parallel
import
MMDataParallel
from
..scatter_gather
import
scatter_kwargs
class
MPSDataParallel
(
MMDataParallel
):
"""The MPSDataParallel module that supports DataContainer.
MPSDataParallel is a class inherited from MMDataParall, which supports
MPS training and inference only.
The main differences with MMDataParallel:
- It only supports single-card of MPS, and only use first card to
run training and inference.
- It uses direct host-to-device copy instead of stream-background
scatter.
Args:
module (:class:`nn.Module`): Module to be encapsulated.
dim (int): Dimension used to scatter the data. Defaults to 0.
"""
def
__init__
(
self
,
*
args
,
dim
=
0
,
**
kwargs
):
super
().
__init__
(
*
args
,
dim
=
dim
,
**
kwargs
)
self
.
device_ids
=
[
0
]
self
.
src_device_obj
=
torch
.
device
(
'mps:0'
)
def
scatter
(
self
,
inputs
,
kwargs
,
device_ids
):
return
scatter_kwargs
(
inputs
,
kwargs
,
device_ids
,
dim
=
self
.
dim
)
mmcv/device/scatter_gather.py
0 → 100644
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
mmcv.parallel.data_container
import
DataContainer
from
mmcv.utils
import
deprecated_api_warning
from
._functions
import
Scatter
from
.utils
import
get_device
@
deprecated_api_warning
({
'target_mlus'
:
'target_devices'
})
def
scatter
(
inputs
,
target_devices
,
dim
=
0
):
"""Scatter inputs to target devices.
The only difference from original :func:`scatter` is to add support for
:type:`~mmcv.parallel.DataContainer`.
"""
current_device
=
get_device
()
def
scatter_map
(
obj
):
if
isinstance
(
obj
,
torch
.
Tensor
):
if
target_devices
!=
[
-
1
]:
obj
=
obj
.
to
(
current_device
)
return
[
obj
]
else
:
# for CPU inference we use self-implemented scatter
return
Scatter
.
forward
(
target_devices
,
obj
)
if
isinstance
(
obj
,
DataContainer
):
if
obj
.
cpu_only
:
return
obj
.
data
else
:
return
Scatter
.
forward
(
target_devices
,
obj
.
data
)
if
isinstance
(
obj
,
tuple
)
and
len
(
obj
)
>
0
:
return
list
(
zip
(
*
map
(
scatter_map
,
obj
)))
if
isinstance
(
obj
,
list
)
and
len
(
obj
)
>
0
:
out
=
list
(
map
(
list
,
zip
(
*
map
(
scatter_map
,
obj
))))
return
out
if
isinstance
(
obj
,
dict
)
and
len
(
obj
)
>
0
:
out
=
list
(
map
(
type
(
obj
),
zip
(
*
map
(
scatter_map
,
obj
.
items
()))))
return
out
return
[
obj
for
_
in
target_devices
]
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try
:
return
scatter_map
(
inputs
)
finally
:
scatter_map
=
None
@
deprecated_api_warning
({
'target_mlus'
:
'target_devices'
})
def
scatter_kwargs
(
inputs
,
kwargs
,
target_devices
,
dim
=
0
):
"""Scatter with support for kwargs dictionary."""
inputs
=
scatter
(
inputs
,
target_devices
,
dim
)
if
inputs
else
[]
kwargs
=
scatter
(
kwargs
,
target_devices
,
dim
)
if
kwargs
else
[]
if
len
(
inputs
)
<
len
(
kwargs
):
inputs
.
extend
([()
for
_
in
range
(
len
(
kwargs
)
-
len
(
inputs
))])
elif
len
(
kwargs
)
<
len
(
inputs
):
kwargs
.
extend
([{}
for
_
in
range
(
len
(
inputs
)
-
len
(
kwargs
))])
inputs
=
tuple
(
inputs
)
kwargs
=
tuple
(
kwargs
)
return
inputs
,
kwargs
mmcv/device/utils.py
0 → 100644
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
from
mmcv.utils
import
IS_CUDA_AVAILABLE
,
IS_MLU_AVAILABLE
,
IS_MPS_AVAILABLE
def
get_device
()
->
str
:
"""Returns the currently existing device type.
Returns:
str: cuda | mlu | mps | cpu.
"""
if
IS_CUDA_AVAILABLE
:
return
'cuda'
elif
IS_MLU_AVAILABLE
:
return
'mlu'
elif
IS_MPS_AVAILABLE
:
return
'mps'
else
:
return
'cpu'
mmcv/engine/test.py
View file @
fdeee889
...
@@ -4,15 +4,18 @@ import pickle
...
@@ -4,15 +4,18 @@ import pickle
import
shutil
import
shutil
import
tempfile
import
tempfile
import
time
import
time
from
typing
import
Optional
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
torch.utils.data
import
DataLoader
import
mmcv
import
mmcv
from
mmcv.runner
import
get_dist_info
from
mmcv.runner
import
get_dist_info
def
single_gpu_test
(
model
,
data_loader
)
:
def
single_gpu_test
(
model
:
nn
.
Module
,
data_loader
:
DataLoader
)
->
list
:
"""Test model with a single gpu.
"""Test model with a single gpu.
This method tests model with a single gpu and displays test progress bar.
This method tests model with a single gpu and displays test progress bar.
...
@@ -41,7 +44,10 @@ def single_gpu_test(model, data_loader):
...
@@ -41,7 +44,10 @@ def single_gpu_test(model, data_loader):
return
results
return
results
def
multi_gpu_test
(
model
,
data_loader
,
tmpdir
=
None
,
gpu_collect
=
False
):
def
multi_gpu_test
(
model
:
nn
.
Module
,
data_loader
:
DataLoader
,
tmpdir
:
Optional
[
str
]
=
None
,
gpu_collect
:
bool
=
False
)
->
Optional
[
list
]:
"""Test model with multiple gpus.
"""Test model with multiple gpus.
This method tests model with multiple gpus and collects the results
This method tests model with multiple gpus and collects the results
...
@@ -82,13 +88,15 @@ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
...
@@ -82,13 +88,15 @@ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
# collect results from all ranks
# collect results from all ranks
if
gpu_collect
:
if
gpu_collect
:
results
=
collect_results_gpu
(
results
,
len
(
dataset
))
result
_from_rank
s
=
collect_results_gpu
(
results
,
len
(
dataset
))
else
:
else
:
results
=
collect_results_cpu
(
results
,
len
(
dataset
),
tmpdir
)
result
_from_rank
s
=
collect_results_cpu
(
results
,
len
(
dataset
),
tmpdir
)
return
results
return
result
_from_rank
s
def
collect_results_cpu
(
result_part
,
size
,
tmpdir
=
None
):
def
collect_results_cpu
(
result_part
:
list
,
size
:
int
,
tmpdir
:
Optional
[
str
]
=
None
)
->
Optional
[
list
]:
"""Collect results under cpu mode.
"""Collect results under cpu mode.
On cpu mode, this function will save the results on different gpus to
On cpu mode, this function will save the results on different gpus to
...
@@ -126,7 +134,8 @@ def collect_results_cpu(result_part, size, tmpdir=None):
...
@@ -126,7 +134,8 @@ def collect_results_cpu(result_part, size, tmpdir=None):
else
:
else
:
mmcv
.
mkdir_or_exist
(
tmpdir
)
mmcv
.
mkdir_or_exist
(
tmpdir
)
# dump the part result to the dir
# dump the part result to the dir
mmcv
.
dump
(
result_part
,
osp
.
join
(
tmpdir
,
f
'part_
{
rank
}
.pkl'
))
part_file
=
osp
.
join
(
tmpdir
,
f
'part_
{
rank
}
.pkl'
)
# type: ignore
mmcv
.
dump
(
result_part
,
part_file
)
dist
.
barrier
()
dist
.
barrier
()
# collect all parts
# collect all parts
if
rank
!=
0
:
if
rank
!=
0
:
...
@@ -135,7 +144,7 @@ def collect_results_cpu(result_part, size, tmpdir=None):
...
@@ -135,7 +144,7 @@ def collect_results_cpu(result_part, size, tmpdir=None):
# load results of all parts from tmp dir
# load results of all parts from tmp dir
part_list
=
[]
part_list
=
[]
for
i
in
range
(
world_size
):
for
i
in
range
(
world_size
):
part_file
=
osp
.
join
(
tmpdir
,
f
'part_
{
i
}
.pkl'
)
part_file
=
osp
.
join
(
tmpdir
,
f
'part_
{
i
}
.pkl'
)
# type: ignore
part_result
=
mmcv
.
load
(
part_file
)
part_result
=
mmcv
.
load
(
part_file
)
# When data is severely insufficient, an empty part_result
# When data is severely insufficient, an empty part_result
# on a certain gpu could makes the overall outputs empty.
# on a certain gpu could makes the overall outputs empty.
...
@@ -148,11 +157,11 @@ def collect_results_cpu(result_part, size, tmpdir=None):
...
@@ -148,11 +157,11 @@ def collect_results_cpu(result_part, size, tmpdir=None):
# the dataloader may pad some samples
# the dataloader may pad some samples
ordered_results
=
ordered_results
[:
size
]
ordered_results
=
ordered_results
[:
size
]
# remove tmp dir
# remove tmp dir
shutil
.
rmtree
(
tmpdir
)
shutil
.
rmtree
(
tmpdir
)
# type: ignore
return
ordered_results
return
ordered_results
def
collect_results_gpu
(
result_part
,
size
)
:
def
collect_results_gpu
(
result_part
:
list
,
size
:
int
)
->
Optional
[
list
]
:
"""Collect results under gpu mode.
"""Collect results under gpu mode.
On gpu mode, this function will encode results to gpu tensors and use gpu
On gpu mode, this function will encode results to gpu tensors and use gpu
...
@@ -200,3 +209,5 @@ def collect_results_gpu(result_part, size):
...
@@ -200,3 +209,5 @@ def collect_results_gpu(result_part, size):
# the dataloader may pad some samples
# the dataloader may pad some samples
ordered_results
=
ordered_results
[:
size
]
ordered_results
=
ordered_results
[:
size
]
return
ordered_results
return
ordered_results
else
:
return
None
mmcv/fileio/file_client.py
View file @
fdeee889
...
@@ -8,7 +8,7 @@ import warnings
...
@@ -8,7 +8,7 @@ import warnings
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Iterable
,
Iterator
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Generator
,
Iterator
,
Optional
,
Tuple
,
Union
from
urllib.request
import
urlopen
from
urllib.request
import
urlopen
import
mmcv
import
mmcv
...
@@ -64,7 +64,8 @@ class CephBackend(BaseStorageBackend):
...
@@ -64,7 +64,8 @@ class CephBackend(BaseStorageBackend):
raise
ImportError
(
'Please install ceph to enable CephBackend.'
)
raise
ImportError
(
'Please install ceph to enable CephBackend.'
)
warnings
.
warn
(
warnings
.
warn
(
'CephBackend will be deprecated, please use PetrelBackend instead'
)
'CephBackend will be deprecated, please use PetrelBackend instead'
,
DeprecationWarning
)
self
.
_client
=
ceph
.
S3Client
()
self
.
_client
=
ceph
.
S3Client
()
assert
isinstance
(
path_mapping
,
dict
)
or
path_mapping
is
None
assert
isinstance
(
path_mapping
,
dict
)
or
path_mapping
is
None
self
.
path_mapping
=
path_mapping
self
.
path_mapping
=
path_mapping
...
@@ -209,9 +210,9 @@ class PetrelBackend(BaseStorageBackend):
...
@@ -209,9 +210,9 @@ class PetrelBackend(BaseStorageBackend):
"""
"""
if
not
has_method
(
self
.
_client
,
'delete'
):
if
not
has_method
(
self
.
_client
,
'delete'
):
raise
NotImplementedError
(
raise
NotImplementedError
(
(
'Current version of Petrel Python SDK has not supported '
'Current version of Petrel Python SDK has not supported '
'the `delete` method, please use a higher version or dev'
'the `delete` method, please use a higher version or dev'
' branch instead.'
)
)
' branch instead.'
)
filepath
=
self
.
_map_path
(
filepath
)
filepath
=
self
.
_map_path
(
filepath
)
filepath
=
self
.
_format_path
(
filepath
)
filepath
=
self
.
_format_path
(
filepath
)
...
@@ -229,9 +230,9 @@ class PetrelBackend(BaseStorageBackend):
...
@@ -229,9 +230,9 @@ class PetrelBackend(BaseStorageBackend):
if
not
(
has_method
(
self
.
_client
,
'contains'
)
if
not
(
has_method
(
self
.
_client
,
'contains'
)
and
has_method
(
self
.
_client
,
'isdir'
)):
and
has_method
(
self
.
_client
,
'isdir'
)):
raise
NotImplementedError
(
raise
NotImplementedError
(
(
'Current version of Petrel Python SDK has not supported '
'Current version of Petrel Python SDK has not supported '
'the `contains` and `isdir` methods, please use a higher'
'the `contains` and `isdir` methods, please use a higher'
'version or dev branch instead.'
)
)
'version or dev branch instead.'
)
filepath
=
self
.
_map_path
(
filepath
)
filepath
=
self
.
_map_path
(
filepath
)
filepath
=
self
.
_format_path
(
filepath
)
filepath
=
self
.
_format_path
(
filepath
)
...
@@ -246,13 +247,13 @@ class PetrelBackend(BaseStorageBackend):
...
@@ -246,13 +247,13 @@ class PetrelBackend(BaseStorageBackend):
Returns:
Returns:
bool: Return ``True`` if ``filepath`` points to a directory,
bool: Return ``True`` if ``filepath`` points to a directory,
``False`` otherwise.
``False`` otherwise.
"""
"""
if
not
has_method
(
self
.
_client
,
'isdir'
):
if
not
has_method
(
self
.
_client
,
'isdir'
):
raise
NotImplementedError
(
raise
NotImplementedError
(
(
'Current version of Petrel Python SDK has not supported '
'Current version of Petrel Python SDK has not supported '
'the `isdir` method, please use a higher version or dev'
'the `isdir` method, please use a higher version or dev'
' branch instead.'
)
)
' branch instead.'
)
filepath
=
self
.
_map_path
(
filepath
)
filepath
=
self
.
_map_path
(
filepath
)
filepath
=
self
.
_format_path
(
filepath
)
filepath
=
self
.
_format_path
(
filepath
)
...
@@ -266,13 +267,13 @@ class PetrelBackend(BaseStorageBackend):
...
@@ -266,13 +267,13 @@ class PetrelBackend(BaseStorageBackend):
Returns:
Returns:
bool: Return ``True`` if ``filepath`` points to a file, ``False``
bool: Return ``True`` if ``filepath`` points to a file, ``False``
otherwise.
otherwise.
"""
"""
if
not
has_method
(
self
.
_client
,
'contains'
):
if
not
has_method
(
self
.
_client
,
'contains'
):
raise
NotImplementedError
(
raise
NotImplementedError
(
(
'Current version of Petrel Python SDK has not supported '
'Current version of Petrel Python SDK has not supported '
'the `contains` method, please use a higher version or '
'the `contains` method, please use a higher version or '
'dev branch instead.'
)
)
'dev branch instead.'
)
filepath
=
self
.
_map_path
(
filepath
)
filepath
=
self
.
_map_path
(
filepath
)
filepath
=
self
.
_format_path
(
filepath
)
filepath
=
self
.
_format_path
(
filepath
)
...
@@ -297,7 +298,10 @@ class PetrelBackend(BaseStorageBackend):
...
@@ -297,7 +298,10 @@ class PetrelBackend(BaseStorageBackend):
return
'/'
.
join
(
formatted_paths
)
return
'/'
.
join
(
formatted_paths
)
@
contextmanager
@
contextmanager
def
get_local_path
(
self
,
filepath
:
Union
[
str
,
Path
])
->
Iterable
[
str
]:
def
get_local_path
(
self
,
filepath
:
Union
[
str
,
Path
])
->
Generator
[
Union
[
str
,
Path
],
None
,
None
]:
"""Download a file from ``filepath`` and return a temporary path.
"""Download a file from ``filepath`` and return a temporary path.
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
...
@@ -362,9 +366,9 @@ class PetrelBackend(BaseStorageBackend):
...
@@ -362,9 +366,9 @@ class PetrelBackend(BaseStorageBackend):
"""
"""
if
not
has_method
(
self
.
_client
,
'list'
):
if
not
has_method
(
self
.
_client
,
'list'
):
raise
NotImplementedError
(
raise
NotImplementedError
(
(
'Current version of Petrel Python SDK has not supported '
'Current version of Petrel Python SDK has not supported '
'the `list` method, please use a higher version or dev'
'the `list` method, please use a higher version or dev'
' branch instead.'
)
)
' branch instead.'
)
dir_path
=
self
.
_map_path
(
dir_path
)
dir_path
=
self
.
_map_path
(
dir_path
)
dir_path
=
self
.
_format_path
(
dir_path
)
dir_path
=
self
.
_format_path
(
dir_path
)
...
@@ -473,17 +477,16 @@ class LmdbBackend(BaseStorageBackend):
...
@@ -473,17 +477,16 @@ class LmdbBackend(BaseStorageBackend):
readahead
=
False
,
readahead
=
False
,
**
kwargs
):
**
kwargs
):
try
:
try
:
import
lmdb
import
lmdb
# NOQA
except
ImportError
:
except
ImportError
:
raise
ImportError
(
'Please install lmdb to enable LmdbBackend.'
)
raise
ImportError
(
'Please install lmdb to enable LmdbBackend.'
)
self
.
db_path
=
str
(
db_path
)
self
.
db_path
=
str
(
db_path
)
self
.
_client
=
lmdb
.
open
(
self
.
readonly
=
readonly
self
.
db_path
,
self
.
lock
=
lock
readonly
=
readonly
,
self
.
readahead
=
readahead
lock
=
lock
,
self
.
kwargs
=
kwargs
readahead
=
readahead
,
self
.
_client
=
None
**
kwargs
)
def
get
(
self
,
filepath
):
def
get
(
self
,
filepath
):
"""Get values according to the filepath.
"""Get values according to the filepath.
...
@@ -491,14 +494,29 @@ class LmdbBackend(BaseStorageBackend):
...
@@ -491,14 +494,29 @@ class LmdbBackend(BaseStorageBackend):
Args:
Args:
filepath (str | obj:`Path`): Here, filepath is the lmdb key.
filepath (str | obj:`Path`): Here, filepath is the lmdb key.
"""
"""
filepath
=
str
(
filepath
)
if
self
.
_client
is
None
:
self
.
_client
=
self
.
_get_client
()
with
self
.
_client
.
begin
(
write
=
False
)
as
txn
:
with
self
.
_client
.
begin
(
write
=
False
)
as
txn
:
value_buf
=
txn
.
get
(
filepath
.
encode
(
'
ascii
'
))
value_buf
=
txn
.
get
(
str
(
filepath
)
.
encode
(
'
utf-8
'
))
return
value_buf
return
value_buf
def
get_text
(
self
,
filepath
,
encoding
=
None
):
def
get_text
(
self
,
filepath
,
encoding
=
None
):
raise
NotImplementedError
raise
NotImplementedError
def
_get_client
(
self
):
import
lmdb
return
lmdb
.
open
(
self
.
db_path
,
readonly
=
self
.
readonly
,
lock
=
self
.
lock
,
readahead
=
self
.
readahead
,
**
self
.
kwargs
)
def
__del__
(
self
):
self
.
_client
.
close
()
class
HardDiskBackend
(
BaseStorageBackend
):
class
HardDiskBackend
(
BaseStorageBackend
):
"""Raw hard disks storage backend."""
"""Raw hard disks storage backend."""
...
@@ -531,7 +549,7 @@ class HardDiskBackend(BaseStorageBackend):
...
@@ -531,7 +549,7 @@ class HardDiskBackend(BaseStorageBackend):
Returns:
Returns:
str: Expected text reading from ``filepath``.
str: Expected text reading from ``filepath``.
"""
"""
with
open
(
filepath
,
'r'
,
encoding
=
encoding
)
as
f
:
with
open
(
filepath
,
encoding
=
encoding
)
as
f
:
value_buf
=
f
.
read
()
value_buf
=
f
.
read
()
return
value_buf
return
value_buf
...
@@ -598,7 +616,7 @@ class HardDiskBackend(BaseStorageBackend):
...
@@ -598,7 +616,7 @@ class HardDiskBackend(BaseStorageBackend):
Returns:
Returns:
bool: Return ``True`` if ``filepath`` points to a directory,
bool: Return ``True`` if ``filepath`` points to a directory,
``False`` otherwise.
``False`` otherwise.
"""
"""
return
osp
.
isdir
(
filepath
)
return
osp
.
isdir
(
filepath
)
...
@@ -610,7 +628,7 @@ class HardDiskBackend(BaseStorageBackend):
...
@@ -610,7 +628,7 @@ class HardDiskBackend(BaseStorageBackend):
Returns:
Returns:
bool: Return ``True`` if ``filepath`` points to a file, ``False``
bool: Return ``True`` if ``filepath`` points to a file, ``False``
otherwise.
otherwise.
"""
"""
return
osp
.
isfile
(
filepath
)
return
osp
.
isfile
(
filepath
)
...
@@ -631,7 +649,9 @@ class HardDiskBackend(BaseStorageBackend):
...
@@ -631,7 +649,9 @@ class HardDiskBackend(BaseStorageBackend):
@
contextmanager
@
contextmanager
def
get_local_path
(
def
get_local_path
(
self
,
filepath
:
Union
[
str
,
Path
])
->
Iterable
[
Union
[
str
,
Path
]]:
self
,
filepath
:
Union
[
str
,
Path
])
->
Generator
[
Union
[
str
,
Path
],
None
,
None
]:
"""Only for unified API and do nothing."""
"""Only for unified API and do nothing."""
yield
filepath
yield
filepath
...
@@ -700,7 +720,8 @@ class HTTPBackend(BaseStorageBackend):
...
@@ -700,7 +720,8 @@ class HTTPBackend(BaseStorageBackend):
return
value_buf
.
decode
(
encoding
)
return
value_buf
.
decode
(
encoding
)
@
contextmanager
@
contextmanager
def
get_local_path
(
self
,
filepath
:
str
)
->
Iterable
[
str
]:
def
get_local_path
(
self
,
filepath
:
str
)
->
Generator
[
Union
[
str
,
Path
],
None
,
None
]:
"""Download a file from ``filepath``.
"""Download a file from ``filepath``.
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
...
@@ -770,19 +791,16 @@ class FileClient:
...
@@ -770,19 +791,16 @@ class FileClient:
'petrel'
:
PetrelBackend
,
'petrel'
:
PetrelBackend
,
'http'
:
HTTPBackend
,
'http'
:
HTTPBackend
,
}
}
# This collection is used to record the overridden backends, and when a
# backend appears in the collection, the singleton pattern is disabled for
# that backend, because if the singleton pattern is used, then the object
# returned will be the backend before overwriting
_overridden_backends
=
set
()
_prefix_to_backends
=
{
_prefix_to_backends
=
{
's3'
:
PetrelBackend
,
's3'
:
PetrelBackend
,
'http'
:
HTTPBackend
,
'http'
:
HTTPBackend
,
'https'
:
HTTPBackend
,
'https'
:
HTTPBackend
,
}
}
_overridden_prefixes
=
set
()
_instances
=
{}
_instances
:
dict
=
{}
client
:
Any
def
__new__
(
cls
,
backend
=
None
,
prefix
=
None
,
**
kwargs
):
def
__new__
(
cls
,
backend
=
None
,
prefix
=
None
,
**
kwargs
):
if
backend
is
None
and
prefix
is
None
:
if
backend
is
None
and
prefix
is
None
:
...
@@ -802,10 +820,7 @@ class FileClient:
...
@@ -802,10 +820,7 @@ class FileClient:
for
key
,
value
in
kwargs
.
items
():
for
key
,
value
in
kwargs
.
items
():
arg_key
+=
f
':
{
key
}
:
{
value
}
'
arg_key
+=
f
':
{
key
}
:
{
value
}
'
# if a backend was overridden, it will create a new object
if
arg_key
in
cls
.
_instances
:
if
(
arg_key
in
cls
.
_instances
and
backend
not
in
cls
.
_overridden_backends
and
prefix
not
in
cls
.
_overridden_prefixes
):
_instance
=
cls
.
_instances
[
arg_key
]
_instance
=
cls
.
_instances
[
arg_key
]
else
:
else
:
# create a new object and put it to _instance
# create a new object and put it to _instance
...
@@ -839,8 +854,8 @@ class FileClient:
...
@@ -839,8 +854,8 @@ class FileClient:
's3'
's3'
Returns:
Returns:
str | None: Return the prefix of uri if the uri contains '://'
str | None: Return the prefix of uri if the uri contains '://'
else
else
``None``.
``None``.
"""
"""
assert
is_filepath
(
uri
)
assert
is_filepath
(
uri
)
uri
=
str
(
uri
)
uri
=
str
(
uri
)
...
@@ -899,7 +914,9 @@ class FileClient:
...
@@ -899,7 +914,9 @@ class FileClient:
'add "force=True" if you want to override it'
)
'add "force=True" if you want to override it'
)
if
name
in
cls
.
_backends
and
force
:
if
name
in
cls
.
_backends
and
force
:
cls
.
_overridden_backends
.
add
(
name
)
for
arg_key
,
instance
in
list
(
cls
.
_instances
.
items
()):
if
isinstance
(
instance
.
client
,
cls
.
_backends
[
name
]):
cls
.
_instances
.
pop
(
arg_key
)
cls
.
_backends
[
name
]
=
backend
cls
.
_backends
[
name
]
=
backend
if
prefixes
is
not
None
:
if
prefixes
is
not
None
:
...
@@ -911,7 +928,12 @@ class FileClient:
...
@@ -911,7 +928,12 @@ class FileClient:
if
prefix
not
in
cls
.
_prefix_to_backends
:
if
prefix
not
in
cls
.
_prefix_to_backends
:
cls
.
_prefix_to_backends
[
prefix
]
=
backend
cls
.
_prefix_to_backends
[
prefix
]
=
backend
elif
(
prefix
in
cls
.
_prefix_to_backends
)
and
force
:
elif
(
prefix
in
cls
.
_prefix_to_backends
)
and
force
:
cls
.
_overridden_prefixes
.
add
(
prefix
)
overridden_backend
=
cls
.
_prefix_to_backends
[
prefix
]
if
isinstance
(
overridden_backend
,
list
):
overridden_backend
=
tuple
(
overridden_backend
)
for
arg_key
,
instance
in
list
(
cls
.
_instances
.
items
()):
if
isinstance
(
instance
.
client
,
overridden_backend
):
cls
.
_instances
.
pop
(
arg_key
)
cls
.
_prefix_to_backends
[
prefix
]
=
backend
cls
.
_prefix_to_backends
[
prefix
]
=
backend
else
:
else
:
raise
KeyError
(
raise
KeyError
(
...
@@ -987,7 +1009,7 @@ class FileClient:
...
@@ -987,7 +1009,7 @@ class FileClient:
Returns:
Returns:
bytes | memoryview: Expected bytes object or a memory view of the
bytes | memoryview: Expected bytes object or a memory view of the
bytes object.
bytes object.
"""
"""
return
self
.
client
.
get
(
filepath
)
return
self
.
client
.
get
(
filepath
)
...
@@ -1060,7 +1082,7 @@ class FileClient:
...
@@ -1060,7 +1082,7 @@ class FileClient:
Returns:
Returns:
bool: Return ``True`` if ``filepath`` points to a directory,
bool: Return ``True`` if ``filepath`` points to a directory,
``False`` otherwise.
``False`` otherwise.
"""
"""
return
self
.
client
.
isdir
(
filepath
)
return
self
.
client
.
isdir
(
filepath
)
...
@@ -1072,7 +1094,7 @@ class FileClient:
...
@@ -1072,7 +1094,7 @@ class FileClient:
Returns:
Returns:
bool: Return ``True`` if ``filepath`` points to a file, ``False``
bool: Return ``True`` if ``filepath`` points to a file, ``False``
otherwise.
otherwise.
"""
"""
return
self
.
client
.
isfile
(
filepath
)
return
self
.
client
.
isfile
(
filepath
)
...
@@ -1092,7 +1114,10 @@ class FileClient:
...
@@ -1092,7 +1114,10 @@ class FileClient:
return
self
.
client
.
join_path
(
filepath
,
*
filepaths
)
return
self
.
client
.
join_path
(
filepath
,
*
filepaths
)
@
contextmanager
@
contextmanager
def
get_local_path
(
self
,
filepath
:
Union
[
str
,
Path
])
->
Iterable
[
str
]:
def
get_local_path
(
self
,
filepath
:
Union
[
str
,
Path
])
->
Generator
[
Union
[
str
,
Path
],
None
,
None
]:
"""Download data from ``filepath`` and write the data to local path.
"""Download data from ``filepath`` and write the data to local path.
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
...
...
mmcv/fileio/handlers/base.py
View file @
fdeee889
...
@@ -21,10 +21,10 @@ class BaseFileHandler(metaclass=ABCMeta):
...
@@ -21,10 +21,10 @@ class BaseFileHandler(metaclass=ABCMeta):
def
dump_to_str
(
self
,
obj
,
**
kwargs
):
def
dump_to_str
(
self
,
obj
,
**
kwargs
):
pass
pass
def
load_from_path
(
self
,
filepath
,
mode
=
'r'
,
**
kwargs
):
def
load_from_path
(
self
,
filepath
:
str
,
mode
:
str
=
'r'
,
**
kwargs
):
with
open
(
filepath
,
mode
)
as
f
:
with
open
(
filepath
,
mode
)
as
f
:
return
self
.
load_from_fileobj
(
f
,
**
kwargs
)
return
self
.
load_from_fileobj
(
f
,
**
kwargs
)
def
dump_to_path
(
self
,
obj
,
filepath
,
mode
=
'w'
,
**
kwargs
):
def
dump_to_path
(
self
,
obj
,
filepath
:
str
,
mode
:
str
=
'w'
,
**
kwargs
):
with
open
(
filepath
,
mode
)
as
f
:
with
open
(
filepath
,
mode
)
as
f
:
self
.
dump_to_fileobj
(
obj
,
f
,
**
kwargs
)
self
.
dump_to_fileobj
(
obj
,
f
,
**
kwargs
)
mmcv/fileio/handlers/pickle_handler.py
View file @
fdeee889
...
@@ -12,8 +12,7 @@ class PickleHandler(BaseFileHandler):
...
@@ -12,8 +12,7 @@ class PickleHandler(BaseFileHandler):
return
pickle
.
load
(
file
,
**
kwargs
)
return
pickle
.
load
(
file
,
**
kwargs
)
def
load_from_path
(
self
,
filepath
,
**
kwargs
):
def
load_from_path
(
self
,
filepath
,
**
kwargs
):
return
super
(
PickleHandler
,
self
).
load_from_path
(
return
super
().
load_from_path
(
filepath
,
mode
=
'rb'
,
**
kwargs
)
filepath
,
mode
=
'rb'
,
**
kwargs
)
def
dump_to_str
(
self
,
obj
,
**
kwargs
):
def
dump_to_str
(
self
,
obj
,
**
kwargs
):
kwargs
.
setdefault
(
'protocol'
,
2
)
kwargs
.
setdefault
(
'protocol'
,
2
)
...
@@ -24,5 +23,4 @@ class PickleHandler(BaseFileHandler):
...
@@ -24,5 +23,4 @@ class PickleHandler(BaseFileHandler):
pickle
.
dump
(
obj
,
file
,
**
kwargs
)
pickle
.
dump
(
obj
,
file
,
**
kwargs
)
def
dump_to_path
(
self
,
obj
,
filepath
,
**
kwargs
):
def
dump_to_path
(
self
,
obj
,
filepath
,
**
kwargs
):
super
(
PickleHandler
,
self
).
dump_to_path
(
super
().
dump_to_path
(
obj
,
filepath
,
mode
=
'wb'
,
**
kwargs
)
obj
,
filepath
,
mode
=
'wb'
,
**
kwargs
)
mmcv/fileio/handlers/yaml_handler.py
View file @
fdeee889
...
@@ -2,9 +2,10 @@
...
@@ -2,9 +2,10 @@
import
yaml
import
yaml
try
:
try
:
from
yaml
import
CLoader
as
Loader
,
CDumper
as
Dumper
from
yaml
import
CDumper
as
Dumper
from
yaml
import
CLoader
as
Loader
except
ImportError
:
except
ImportError
:
from
yaml
import
Loader
,
Dumper
from
yaml
import
Loader
,
Dumper
# type: ignore
from
.base
import
BaseFileHandler
# isort:skip
from
.base
import
BaseFileHandler
# isort:skip
...
...
mmcv/fileio/io.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
io
import
BytesIO
,
StringIO
from
io
import
BytesIO
,
StringIO
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
TextIO
,
Union
from
..utils
import
is_list_of
,
is_str
from
..utils
import
is_list_of
from
.file_client
import
FileClient
from
.file_client
import
FileClient
from
.handlers
import
BaseFileHandler
,
JsonHandler
,
PickleHandler
,
YamlHandler
from
.handlers
import
BaseFileHandler
,
JsonHandler
,
PickleHandler
,
YamlHandler
FileLikeObject
=
Union
[
TextIO
,
StringIO
,
BytesIO
]
file_handlers
=
{
file_handlers
=
{
'json'
:
JsonHandler
(),
'json'
:
JsonHandler
(),
'yaml'
:
YamlHandler
(),
'yaml'
:
YamlHandler
(),
...
@@ -15,7 +18,10 @@ file_handlers = {
...
@@ -15,7 +18,10 @@ file_handlers = {
}
}
def
load
(
file
,
file_format
=
None
,
file_client_args
=
None
,
**
kwargs
):
def
load
(
file
:
Union
[
str
,
Path
,
FileLikeObject
],
file_format
:
Optional
[
str
]
=
None
,
file_client_args
:
Optional
[
Dict
]
=
None
,
**
kwargs
):
"""Load data from json/yaml/pickle files.
"""Load data from json/yaml/pickle files.
This method provides a unified api for loading data from serialized files.
This method provides a unified api for loading data from serialized files.
...
@@ -45,13 +51,14 @@ def load(file, file_format=None, file_client_args=None, **kwargs):
...
@@ -45,13 +51,14 @@ def load(file, file_format=None, file_client_args=None, **kwargs):
"""
"""
if
isinstance
(
file
,
Path
):
if
isinstance
(
file
,
Path
):
file
=
str
(
file
)
file
=
str
(
file
)
if
file_format
is
None
and
is
_str
(
file
):
if
file_format
is
None
and
is
instance
(
file
,
str
):
file_format
=
file
.
split
(
'.'
)[
-
1
]
file_format
=
file
.
split
(
'.'
)[
-
1
]
if
file_format
not
in
file_handlers
:
if
file_format
not
in
file_handlers
:
raise
TypeError
(
f
'Unsupported format:
{
file_format
}
'
)
raise
TypeError
(
f
'Unsupported format:
{
file_format
}
'
)
handler
=
file_handlers
[
file_format
]
handler
=
file_handlers
[
file_format
]
if
is_str
(
file
):
f
:
FileLikeObject
if
isinstance
(
file
,
str
):
file_client
=
FileClient
.
infer_client
(
file_client_args
,
file
)
file_client
=
FileClient
.
infer_client
(
file_client_args
,
file
)
if
handler
.
str_like
:
if
handler
.
str_like
:
with
StringIO
(
file_client
.
get_text
(
file
))
as
f
:
with
StringIO
(
file_client
.
get_text
(
file
))
as
f
:
...
@@ -66,7 +73,11 @@ def load(file, file_format=None, file_client_args=None, **kwargs):
...
@@ -66,7 +73,11 @@ def load(file, file_format=None, file_client_args=None, **kwargs):
return
obj
return
obj
def
dump
(
obj
,
file
=
None
,
file_format
=
None
,
file_client_args
=
None
,
**
kwargs
):
def
dump
(
obj
:
Any
,
file
:
Optional
[
Union
[
str
,
Path
,
FileLikeObject
]]
=
None
,
file_format
:
Optional
[
str
]
=
None
,
file_client_args
:
Optional
[
Dict
]
=
None
,
**
kwargs
):
"""Dump data to json/yaml/pickle strings or files.
"""Dump data to json/yaml/pickle strings or files.
This method provides a unified api for dumping data as strings or to files,
This method provides a unified api for dumping data as strings or to files,
...
@@ -96,18 +107,18 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
...
@@ -96,18 +107,18 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
if
isinstance
(
file
,
Path
):
if
isinstance
(
file
,
Path
):
file
=
str
(
file
)
file
=
str
(
file
)
if
file_format
is
None
:
if
file_format
is
None
:
if
is
_str
(
file
):
if
is
instance
(
file
,
str
):
file_format
=
file
.
split
(
'.'
)[
-
1
]
file_format
=
file
.
split
(
'.'
)[
-
1
]
elif
file
is
None
:
elif
file
is
None
:
raise
ValueError
(
raise
ValueError
(
'file_format must be specified since file is None'
)
'file_format must be specified since file is None'
)
if
file_format
not
in
file_handlers
:
if
file_format
not
in
file_handlers
:
raise
TypeError
(
f
'Unsupported format:
{
file_format
}
'
)
raise
TypeError
(
f
'Unsupported format:
{
file_format
}
'
)
f
:
FileLikeObject
handler
=
file_handlers
[
file_format
]
handler
=
file_handlers
[
file_format
]
if
file
is
None
:
if
file
is
None
:
return
handler
.
dump_to_str
(
obj
,
**
kwargs
)
return
handler
.
dump_to_str
(
obj
,
**
kwargs
)
elif
is
_str
(
file
):
elif
is
instance
(
file
,
str
):
file_client
=
FileClient
.
infer_client
(
file_client_args
,
file
)
file_client
=
FileClient
.
infer_client
(
file_client_args
,
file
)
if
handler
.
str_like
:
if
handler
.
str_like
:
with
StringIO
()
as
f
:
with
StringIO
()
as
f
:
...
@@ -123,7 +134,8 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
...
@@ -123,7 +134,8 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
raise
TypeError
(
'"file" must be a filename str or a file-object'
)
raise
TypeError
(
'"file" must be a filename str or a file-object'
)
def
_register_handler
(
handler
,
file_formats
):
def
_register_handler
(
handler
:
BaseFileHandler
,
file_formats
:
Union
[
str
,
List
[
str
]])
->
None
:
"""Register a handler for some file extensions.
"""Register a handler for some file extensions.
Args:
Args:
...
@@ -142,7 +154,7 @@ def _register_handler(handler, file_formats):
...
@@ -142,7 +154,7 @@ def _register_handler(handler, file_formats):
file_handlers
[
ext
]
=
handler
file_handlers
[
ext
]
=
handler
def
register_handler
(
file_formats
,
**
kwargs
)
:
def
register_handler
(
file_formats
:
Union
[
str
,
list
],
**
kwargs
)
->
Callable
:
def
wrap
(
cls
):
def
wrap
(
cls
):
_register_handler
(
cls
(
**
kwargs
),
file_formats
)
_register_handler
(
cls
(
**
kwargs
),
file_formats
)
...
...
mmcv/fileio/parse.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
io
import
StringIO
from
io
import
StringIO
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Optional
,
Union
from
.file_client
import
FileClient
from
.file_client
import
FileClient
def
list_from_file
(
filename
,
def
list_from_file
(
filename
:
Union
[
str
,
Path
]
,
prefix
=
''
,
prefix
:
str
=
''
,
offset
=
0
,
offset
:
int
=
0
,
max_num
=
0
,
max_num
:
int
=
0
,
encoding
=
'utf-8'
,
encoding
:
str
=
'utf-8'
,
file_client_args
=
None
)
:
file_client_args
:
Optional
[
Dict
]
=
None
)
->
List
:
"""Load a text file and parse the content as a list of strings.
"""Load a text file and parse the content as a list of strings.
Note:
Note:
...
@@ -52,10 +54,10 @@ def list_from_file(filename,
...
@@ -52,10 +54,10 @@ def list_from_file(filename,
return
item_list
return
item_list
def
dict_from_file
(
filename
,
def
dict_from_file
(
filename
:
Union
[
str
,
Path
]
,
key_type
=
str
,
key_type
:
type
=
str
,
encoding
=
'utf-8'
,
encoding
:
str
=
'utf-8'
,
file_client_args
=
None
)
:
file_client_args
:
Optional
[
Dict
]
=
None
)
->
Dict
:
"""Load a text file and parse the content as a dict.
"""Load a text file and parse the content as a dict.
Each line of the text file will be two or more columns split by
Each line of the text file will be two or more columns split by
...
...
mmcv/image/__init__.py
View file @
fdeee889
...
@@ -9,10 +9,10 @@ from .geometric import (cutout, imcrop, imflip, imflip_, impad,
...
@@ -9,10 +9,10 @@ from .geometric import (cutout, imcrop, imflip, imflip_, impad,
from
.io
import
imfrombytes
,
imread
,
imwrite
,
supported_backends
,
use_backend
from
.io
import
imfrombytes
,
imread
,
imwrite
,
supported_backends
,
use_backend
from
.misc
import
tensor2imgs
from
.misc
import
tensor2imgs
from
.photometric
import
(
adjust_brightness
,
adjust_color
,
adjust_contrast
,
from
.photometric
import
(
adjust_brightness
,
adjust_color
,
adjust_contrast
,
adjust_lighting
,
adjust_sharpness
,
auto_contrast
,
adjust_hue
,
adjust_lighting
,
adjust_sharpness
,
clahe
,
imdenormalize
,
imequalize
,
iminvert
,
auto_contrast
,
clahe
,
imdenormalize
,
imequalize
,
imnormalize
,
imnormalize_
,
lut_transform
,
posterize
,
iminvert
,
imnormalize
,
imnormalize_
,
lut_transform
,
solarize
)
posterize
,
solarize
)
__all__
=
[
__all__
=
[
'bgr2gray'
,
'bgr2hls'
,
'bgr2hsv'
,
'bgr2rgb'
,
'gray2bgr'
,
'gray2rgb'
,
'bgr2gray'
,
'bgr2hls'
,
'bgr2hsv'
,
'bgr2rgb'
,
'gray2bgr'
,
'gray2rgb'
,
...
@@ -24,5 +24,6 @@ __all__ = [
...
@@ -24,5 +24,6 @@ __all__ = [
'solarize'
,
'rgb2ycbcr'
,
'bgr2ycbcr'
,
'ycbcr2rgb'
,
'ycbcr2bgr'
,
'solarize'
,
'rgb2ycbcr'
,
'bgr2ycbcr'
,
'ycbcr2rgb'
,
'ycbcr2bgr'
,
'tensor2imgs'
,
'imshear'
,
'imtranslate'
,
'adjust_color'
,
'imequalize'
,
'tensor2imgs'
,
'imshear'
,
'imtranslate'
,
'adjust_color'
,
'imequalize'
,
'adjust_brightness'
,
'adjust_contrast'
,
'lut_transform'
,
'clahe'
,
'adjust_brightness'
,
'adjust_contrast'
,
'lut_transform'
,
'clahe'
,
'adjust_sharpness'
,
'auto_contrast'
,
'cutout'
,
'adjust_lighting'
'adjust_sharpness'
,
'auto_contrast'
,
'cutout'
,
'adjust_lighting'
,
'adjust_hue'
]
]
mmcv/image/colorspace.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Callable
,
Union
import
cv2
import
cv2
import
numpy
as
np
import
numpy
as
np
def
imconvert
(
img
,
src
,
dst
)
:
def
imconvert
(
img
:
np
.
ndarray
,
src
:
str
,
dst
:
str
)
->
np
.
ndarray
:
"""Convert an image from the src colorspace to dst colorspace.
"""Convert an image from the src colorspace to dst colorspace.
Args:
Args:
...
@@ -19,7 +21,7 @@ def imconvert(img, src, dst):
...
@@ -19,7 +21,7 @@ def imconvert(img, src, dst):
return
out_img
return
out_img
def
bgr2gray
(
img
,
keepdim
=
False
)
:
def
bgr2gray
(
img
:
np
.
ndarray
,
keepdim
:
bool
=
False
)
->
np
.
ndarray
:
"""Convert a BGR image to grayscale image.
"""Convert a BGR image to grayscale image.
Args:
Args:
...
@@ -36,7 +38,7 @@ def bgr2gray(img, keepdim=False):
...
@@ -36,7 +38,7 @@ def bgr2gray(img, keepdim=False):
return
out_img
return
out_img
def
rgb2gray
(
img
,
keepdim
=
False
)
:
def
rgb2gray
(
img
:
np
.
ndarray
,
keepdim
:
bool
=
False
)
->
np
.
ndarray
:
"""Convert a RGB image to grayscale image.
"""Convert a RGB image to grayscale image.
Args:
Args:
...
@@ -53,7 +55,7 @@ def rgb2gray(img, keepdim=False):
...
@@ -53,7 +55,7 @@ def rgb2gray(img, keepdim=False):
return
out_img
return
out_img
def
gray2bgr
(
img
)
:
def
gray2bgr
(
img
:
np
.
ndarray
)
->
np
.
ndarray
:
"""Convert a grayscale image to BGR image.
"""Convert a grayscale image to BGR image.
Args:
Args:
...
@@ -67,7 +69,7 @@ def gray2bgr(img):
...
@@ -67,7 +69,7 @@ def gray2bgr(img):
return
out_img
return
out_img
def
gray2rgb
(
img
)
:
def
gray2rgb
(
img
:
np
.
ndarray
)
->
np
.
ndarray
:
"""Convert a grayscale image to RGB image.
"""Convert a grayscale image to RGB image.
Args:
Args:
...
@@ -81,7 +83,7 @@ def gray2rgb(img):
...
@@ -81,7 +83,7 @@ def gray2rgb(img):
return
out_img
return
out_img
def
_convert_input_type_range
(
img
)
:
def
_convert_input_type_range
(
img
:
np
.
ndarray
)
->
np
.
ndarray
:
"""Convert the type and range of the input image.
"""Convert the type and range of the input image.
It converts the input image to np.float32 type and range of [0, 1].
It converts the input image to np.float32 type and range of [0, 1].
...
@@ -109,7 +111,8 @@ def _convert_input_type_range(img):
...
@@ -109,7 +111,8 @@ def _convert_input_type_range(img):
return
img
return
img
def
_convert_output_type_range
(
img
,
dst_type
):
def
_convert_output_type_range
(
img
:
np
.
ndarray
,
dst_type
:
Union
[
np
.
uint8
,
np
.
float32
])
->
np
.
ndarray
:
"""Convert the type and range of the image according to dst_type.
"""Convert the type and range of the image according to dst_type.
It converts the image to desired type and range. If `dst_type` is np.uint8,
It converts the image to desired type and range. If `dst_type` is np.uint8,
...
@@ -140,7 +143,7 @@ def _convert_output_type_range(img, dst_type):
...
@@ -140,7 +143,7 @@ def _convert_output_type_range(img, dst_type):
return
img
.
astype
(
dst_type
)
return
img
.
astype
(
dst_type
)
def
rgb2ycbcr
(
img
,
y_only
=
False
)
:
def
rgb2ycbcr
(
img
:
np
.
ndarray
,
y_only
:
bool
=
False
)
->
np
.
ndarray
:
"""Convert a RGB image to YCbCr image.
"""Convert a RGB image to YCbCr image.
This function produces the same results as Matlab's `rgb2ycbcr` function.
This function produces the same results as Matlab's `rgb2ycbcr` function.
...
@@ -160,7 +163,7 @@ def rgb2ycbcr(img, y_only=False):
...
@@ -160,7 +163,7 @@ def rgb2ycbcr(img, y_only=False):
Returns:
Returns:
ndarray: The converted YCbCr image. The output image has the same type
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
and range as input image.
"""
"""
img_type
=
img
.
dtype
img_type
=
img
.
dtype
img
=
_convert_input_type_range
(
img
)
img
=
_convert_input_type_range
(
img
)
...
@@ -174,7 +177,7 @@ def rgb2ycbcr(img, y_only=False):
...
@@ -174,7 +177,7 @@ def rgb2ycbcr(img, y_only=False):
return
out_img
return
out_img
def
bgr2ycbcr
(
img
,
y_only
=
False
)
:
def
bgr2ycbcr
(
img
:
np
.
ndarray
,
y_only
:
bool
=
False
)
->
np
.
ndarray
:
"""Convert a BGR image to YCbCr image.
"""Convert a BGR image to YCbCr image.
The bgr version of rgb2ycbcr.
The bgr version of rgb2ycbcr.
...
@@ -194,7 +197,7 @@ def bgr2ycbcr(img, y_only=False):
...
@@ -194,7 +197,7 @@ def bgr2ycbcr(img, y_only=False):
Returns:
Returns:
ndarray: The converted YCbCr image. The output image has the same type
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
and range as input image.
"""
"""
img_type
=
img
.
dtype
img_type
=
img
.
dtype
img
=
_convert_input_type_range
(
img
)
img
=
_convert_input_type_range
(
img
)
...
@@ -208,7 +211,7 @@ def bgr2ycbcr(img, y_only=False):
...
@@ -208,7 +211,7 @@ def bgr2ycbcr(img, y_only=False):
return
out_img
return
out_img
def
ycbcr2rgb
(
img
)
:
def
ycbcr2rgb
(
img
:
np
.
ndarray
)
->
np
.
ndarray
:
"""Convert a YCbCr image to RGB image.
"""Convert a YCbCr image to RGB image.
This function produces the same results as Matlab's ycbcr2rgb function.
This function produces the same results as Matlab's ycbcr2rgb function.
...
@@ -227,7 +230,7 @@ def ycbcr2rgb(img):
...
@@ -227,7 +230,7 @@ def ycbcr2rgb(img):
Returns:
Returns:
ndarray: The converted RGB image. The output image has the same type
ndarray: The converted RGB image. The output image has the same type
and range as input image.
and range as input image.
"""
"""
img_type
=
img
.
dtype
img_type
=
img
.
dtype
img
=
_convert_input_type_range
(
img
)
*
255
img
=
_convert_input_type_range
(
img
)
*
255
...
@@ -240,7 +243,7 @@ def ycbcr2rgb(img):
...
@@ -240,7 +243,7 @@ def ycbcr2rgb(img):
return
out_img
return
out_img
def
ycbcr2bgr
(
img
)
:
def
ycbcr2bgr
(
img
:
np
.
ndarray
)
->
np
.
ndarray
:
"""Convert a YCbCr image to BGR image.
"""Convert a YCbCr image to BGR image.
The bgr version of ycbcr2rgb.
The bgr version of ycbcr2rgb.
...
@@ -259,7 +262,7 @@ def ycbcr2bgr(img):
...
@@ -259,7 +262,7 @@ def ycbcr2bgr(img):
Returns:
Returns:
ndarray: The converted BGR image. The output image has the same type
ndarray: The converted BGR image. The output image has the same type
and range as input image.
and range as input image.
"""
"""
img_type
=
img
.
dtype
img_type
=
img
.
dtype
img
=
_convert_input_type_range
(
img
)
*
255
img
=
_convert_input_type_range
(
img
)
*
255
...
@@ -272,11 +275,11 @@ def ycbcr2bgr(img):
...
@@ -272,11 +275,11 @@ def ycbcr2bgr(img):
return
out_img
return
out_img
def
convert_color_factory
(
src
,
dst
)
:
def
convert_color_factory
(
src
:
str
,
dst
:
str
)
->
Callable
:
code
=
getattr
(
cv2
,
f
'COLOR_
{
src
.
upper
()
}
2
{
dst
.
upper
()
}
'
)
code
=
getattr
(
cv2
,
f
'COLOR_
{
src
.
upper
()
}
2
{
dst
.
upper
()
}
'
)
def
convert_color
(
img
)
:
def
convert_color
(
img
:
np
.
ndarray
)
->
np
.
ndarray
:
out_img
=
cv2
.
cvtColor
(
img
,
code
)
out_img
=
cv2
.
cvtColor
(
img
,
code
)
return
out_img
return
out_img
...
...
mmcv/image/geometric.py
View file @
fdeee889
...
@@ -37,15 +37,27 @@ cv2_interp_codes = {
...
@@ -37,15 +37,27 @@ cv2_interp_codes = {
'lanczos'
:
cv2
.
INTER_LANCZOS4
'lanczos'
:
cv2
.
INTER_LANCZOS4
}
}
# Pillow >=v9.1.0 use a slightly different naming scheme for filters.
# Set pillow_interp_codes according to the naming scheme used.
if
Image
is
not
None
:
if
Image
is
not
None
:
pillow_interp_codes
=
{
if
hasattr
(
Image
,
'Resampling'
):
'nearest'
:
Image
.
NEAREST
,
pillow_interp_codes
=
{
'bilinear'
:
Image
.
BILINEAR
,
'nearest'
:
Image
.
Resampling
.
NEAREST
,
'bicubic'
:
Image
.
BICUBIC
,
'bilinear'
:
Image
.
Resampling
.
BILINEAR
,
'box'
:
Image
.
BOX
,
'bicubic'
:
Image
.
Resampling
.
BICUBIC
,
'lanczos'
:
Image
.
LANCZOS
,
'box'
:
Image
.
Resampling
.
BOX
,
'hamming'
:
Image
.
HAMMING
'lanczos'
:
Image
.
Resampling
.
LANCZOS
,
}
'hamming'
:
Image
.
Resampling
.
HAMMING
}
else
:
pillow_interp_codes
=
{
'nearest'
:
Image
.
NEAREST
,
'bilinear'
:
Image
.
BILINEAR
,
'bicubic'
:
Image
.
BICUBIC
,
'box'
:
Image
.
BOX
,
'lanczos'
:
Image
.
LANCZOS
,
'hamming'
:
Image
.
HAMMING
}
def
imresize
(
img
,
def
imresize
(
img
,
...
@@ -70,7 +82,7 @@ def imresize(img,
...
@@ -70,7 +82,7 @@ def imresize(img,
Returns:
Returns:
tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
`resized_img`.
`resized_img`.
"""
"""
h
,
w
=
img
.
shape
[:
2
]
h
,
w
=
img
.
shape
[:
2
]
if
backend
is
None
:
if
backend
is
None
:
...
@@ -130,7 +142,7 @@ def imresize_to_multiple(img,
...
@@ -130,7 +142,7 @@ def imresize_to_multiple(img,
Returns:
Returns:
tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
`resized_img`.
`resized_img`.
"""
"""
h
,
w
=
img
.
shape
[:
2
]
h
,
w
=
img
.
shape
[:
2
]
if
size
is
not
None
and
scale_factor
is
not
None
:
if
size
is
not
None
and
scale_factor
is
not
None
:
...
@@ -145,7 +157,7 @@ def imresize_to_multiple(img,
...
@@ -145,7 +157,7 @@ def imresize_to_multiple(img,
size
=
_scale_size
((
w
,
h
),
scale_factor
)
size
=
_scale_size
((
w
,
h
),
scale_factor
)
divisor
=
to_2tuple
(
divisor
)
divisor
=
to_2tuple
(
divisor
)
size
=
tuple
(
[
int
(
np
.
ceil
(
s
/
d
))
*
d
for
s
,
d
in
zip
(
size
,
divisor
)
]
)
size
=
tuple
(
int
(
np
.
ceil
(
s
/
d
))
*
d
for
s
,
d
in
zip
(
size
,
divisor
))
resized_img
,
w_scale
,
h_scale
=
imresize
(
resized_img
,
w_scale
,
h_scale
=
imresize
(
img
,
img
,
size
,
size
,
...
@@ -175,7 +187,7 @@ def imresize_like(img,
...
@@ -175,7 +187,7 @@ def imresize_like(img,
Returns:
Returns:
tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or
tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or
`resized_img`.
`resized_img`.
"""
"""
h
,
w
=
dst_img
.
shape
[:
2
]
h
,
w
=
dst_img
.
shape
[:
2
]
return
imresize
(
img
,
(
w
,
h
),
return_scale
,
interpolation
,
backend
=
backend
)
return
imresize
(
img
,
(
w
,
h
),
return_scale
,
interpolation
,
backend
=
backend
)
...
@@ -460,18 +472,17 @@ def impad(img,
...
@@ -460,18 +472,17 @@ def impad(img,
areas when padding_mode is 'constant'. Default: 0.
areas when padding_mode is 'constant'. Default: 0.
padding_mode (str): Type of padding. Should be: constant, edge,
padding_mode (str): Type of padding. Should be: constant, edge,
reflect or symmetric. Default: constant.
reflect or symmetric. Default: constant.
- constant: pads with a constant value, this value is specified
- constant: pads with a constant value, this value is specified
with pad_val.
with pad_val.
- edge: pads with the last value at the edge of the image.
- edge: pads with the last value at the edge of the image.
- reflect: pads with reflection of image without repeating the
- reflect: pads with reflection of image without repeating the
last
last
value on the edge. For example, padding [1, 2, 3, 4]
value on the edge. For example, padding [1, 2, 3, 4]
with 2
with 2
elements on both sides in reflect mode will result
elements on both sides in reflect mode will result
in
in
[3, 2, 1, 2, 3, 4, 3, 2].
[3, 2, 1, 2, 3, 4, 3, 2].
- symmetric: pads with reflection of image repeating the last
- symmetric: pads with reflection of image repeating the last
value
value
on the edge. For example, padding [1, 2, 3, 4] with
on the edge. For example, padding [1, 2, 3, 4] with
2 elements on
2 elements on
both sides in symmetric mode will result in
both sides in symmetric mode will result in
[2, 1, 1, 2, 3, 4, 4, 3]
[2, 1, 1, 2, 3, 4, 4, 3]
Returns:
Returns:
ndarray: The padded image.
ndarray: The padded image.
...
@@ -479,7 +490,9 @@ def impad(img,
...
@@ -479,7 +490,9 @@ def impad(img,
assert
(
shape
is
not
None
)
^
(
padding
is
not
None
)
assert
(
shape
is
not
None
)
^
(
padding
is
not
None
)
if
shape
is
not
None
:
if
shape
is
not
None
:
padding
=
(
0
,
0
,
shape
[
1
]
-
img
.
shape
[
1
],
shape
[
0
]
-
img
.
shape
[
0
])
width
=
max
(
shape
[
1
]
-
img
.
shape
[
1
],
0
)
height
=
max
(
shape
[
0
]
-
img
.
shape
[
0
],
0
)
padding
=
(
0
,
0
,
width
,
height
)
# check pad_val
# check pad_val
if
isinstance
(
pad_val
,
tuple
):
if
isinstance
(
pad_val
,
tuple
):
...
...
mmcv/image/io.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
io
import
io
import
os.path
as
osp
import
os.path
as
osp
import
warnings
from
pathlib
import
Path
from
pathlib
import
Path
import
cv2
import
cv2
...
@@ -8,7 +9,8 @@ import numpy as np
...
@@ -8,7 +9,8 @@ import numpy as np
from
cv2
import
(
IMREAD_COLOR
,
IMREAD_GRAYSCALE
,
IMREAD_IGNORE_ORIENTATION
,
from
cv2
import
(
IMREAD_COLOR
,
IMREAD_GRAYSCALE
,
IMREAD_IGNORE_ORIENTATION
,
IMREAD_UNCHANGED
)
IMREAD_UNCHANGED
)
from
mmcv.utils
import
check_file_exist
,
is_str
,
mkdir_or_exist
from
mmcv.fileio
import
FileClient
from
mmcv.utils
import
is_filepath
,
is_str
try
:
try
:
from
turbojpeg
import
TJCS_RGB
,
TJPF_BGR
,
TJPF_GRAY
,
TurboJPEG
from
turbojpeg
import
TJCS_RGB
,
TJPF_BGR
,
TJPF_GRAY
,
TurboJPEG
...
@@ -137,9 +139,16 @@ def _pillow2array(img, flag='color', channel_order='bgr'):
...
@@ -137,9 +139,16 @@ def _pillow2array(img, flag='color', channel_order='bgr'):
return
array
return
array
def
imread
(
img_or_path
,
flag
=
'color'
,
channel_order
=
'bgr'
,
backend
=
None
):
def
imread
(
img_or_path
,
flag
=
'color'
,
channel_order
=
'bgr'
,
backend
=
None
,
file_client_args
=
None
):
"""Read an image.
"""Read an image.
Note:
In v1.4.1 and later, add `file_client_args` parameters.
Args:
Args:
img_or_path (ndarray or str or Path): Either a numpy array or str or
img_or_path (ndarray or str or Path): Either a numpy array or str or
pathlib.Path. If it is a numpy array (loaded image), then
pathlib.Path. If it is a numpy array (loaded image), then
...
@@ -157,44 +166,42 @@ def imread(img_or_path, flag='color', channel_order='bgr', backend=None):
...
@@ -157,44 +166,42 @@ def imread(img_or_path, flag='color', channel_order='bgr', backend=None):
`cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`.
`cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`.
If backend is None, the global imread_backend specified by
If backend is None, the global imread_backend specified by
``mmcv.use_backend()`` will be used. Default: None.
``mmcv.use_backend()`` will be used. Default: None.
file_client_args (dict | None): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
Returns:
Returns:
ndarray: Loaded image array.
ndarray: Loaded image array.
Examples:
>>> import mmcv
>>> img_path = '/path/to/img.jpg'
>>> img = mmcv.imread(img_path)
>>> img = mmcv.imread(img_path, flag='color', channel_order='rgb',
... backend='cv2')
>>> img = mmcv.imread(img_path, flag='color', channel_order='bgr',
... backend='pillow')
>>> s3_img_path = 's3://bucket/img.jpg'
>>> # infer the file backend by the prefix s3
>>> img = mmcv.imread(s3_img_path)
>>> # manually set the file backend petrel
>>> img = mmcv.imread(s3_img_path, file_client_args={
... 'backend': 'petrel'})
>>> http_img_path = 'http://path/to/img.jpg'
>>> img = mmcv.imread(http_img_path)
>>> img = mmcv.imread(http_img_path, file_client_args={
... 'backend': 'http'})
"""
"""
if
backend
is
None
:
backend
=
imread_backend
if
backend
not
in
supported_backends
:
raise
ValueError
(
f
'backend:
{
backend
}
is not supported. Supported '
"backends are 'cv2', 'turbojpeg', 'pillow'"
)
if
isinstance
(
img_or_path
,
Path
):
if
isinstance
(
img_or_path
,
Path
):
img_or_path
=
str
(
img_or_path
)
img_or_path
=
str
(
img_or_path
)
if
isinstance
(
img_or_path
,
np
.
ndarray
):
if
isinstance
(
img_or_path
,
np
.
ndarray
):
return
img_or_path
return
img_or_path
elif
is_str
(
img_or_path
):
elif
is_str
(
img_or_path
):
check_file_exist
(
img_or_path
,
file_client
=
FileClient
.
infer_client
(
file_client_args
,
img_or_path
)
f
'img file does not exist:
{
img_or_path
}
'
)
img_bytes
=
file_client
.
get
(
img_or_path
)
if
backend
==
'turbojpeg'
:
return
imfrombytes
(
img_bytes
,
flag
,
channel_order
,
backend
)
with
open
(
img_or_path
,
'rb'
)
as
in_file
:
img
=
jpeg
.
decode
(
in_file
.
read
(),
_jpegflag
(
flag
,
channel_order
))
if
img
.
shape
[
-
1
]
==
1
:
img
=
img
[:,
:,
0
]
return
img
elif
backend
==
'pillow'
:
img
=
Image
.
open
(
img_or_path
)
img
=
_pillow2array
(
img
,
flag
,
channel_order
)
return
img
elif
backend
==
'tifffile'
:
img
=
tifffile
.
imread
(
img_or_path
)
return
img
else
:
flag
=
imread_flags
[
flag
]
if
is_str
(
flag
)
else
flag
img
=
cv2
.
imread
(
img_or_path
,
flag
)
if
flag
==
IMREAD_COLOR
and
channel_order
==
'rgb'
:
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
,
img
)
return
img
else
:
else
:
raise
TypeError
(
'"img" must be a numpy array or a str or '
raise
TypeError
(
'"img" must be a numpy array or a str or '
'a pathlib.Path object'
)
'a pathlib.Path object'
)
...
@@ -206,29 +213,45 @@ def imfrombytes(content, flag='color', channel_order='bgr', backend=None):
...
@@ -206,29 +213,45 @@ def imfrombytes(content, flag='color', channel_order='bgr', backend=None):
Args:
Args:
content (bytes): Image bytes got from files or other streams.
content (bytes): Image bytes got from files or other streams.
flag (str): Same as :func:`imread`.
flag (str): Same as :func:`imread`.
channel_order (str): The channel order of the output, candidates
are 'bgr' and 'rgb'. Default to 'bgr'.
backend (str | None): The image decoding backend type. Options are
backend (str | None): The image decoding backend type. Options are
`cv2`, `pillow`, `turbojpeg`, `None`. If backend is
None, the
`cv2`, `pillow`, `turbojpeg`,
`tifffile`,
`None`. If backend is
global imread_backend specified by ``mmcv.use_backend()``
will be
None, the
global imread_backend specified by ``mmcv.use_backend()``
used. Default: None.
will be
used. Default: None.
Returns:
Returns:
ndarray: Loaded image array.
ndarray: Loaded image array.
Examples:
>>> img_path = '/path/to/img.jpg'
>>> with open(img_path, 'rb') as f:
>>> img_buff = f.read()
>>> img = mmcv.imfrombytes(img_buff)
>>> img = mmcv.imfrombytes(img_buff, flag='color', channel_order='rgb')
>>> img = mmcv.imfrombytes(img_buff, backend='pillow')
>>> img = mmcv.imfrombytes(img_buff, backend='cv2')
"""
"""
if
backend
is
None
:
if
backend
is
None
:
backend
=
imread_backend
backend
=
imread_backend
if
backend
not
in
supported_backends
:
if
backend
not
in
supported_backends
:
raise
ValueError
(
f
'backend:
{
backend
}
is not supported. Supported '
raise
ValueError
(
"backends are 'cv2', 'turbojpeg', 'pillow'"
)
f
'backend:
{
backend
}
is not supported. Supported '
"backends are 'cv2', 'turbojpeg', 'pillow', 'tifffile'"
)
if
backend
==
'turbojpeg'
:
if
backend
==
'turbojpeg'
:
img
=
jpeg
.
decode
(
content
,
_jpegflag
(
flag
,
channel_order
))
img
=
jpeg
.
decode
(
content
,
_jpegflag
(
flag
,
channel_order
))
if
img
.
shape
[
-
1
]
==
1
:
if
img
.
shape
[
-
1
]
==
1
:
img
=
img
[:,
:,
0
]
img
=
img
[:,
:,
0
]
return
img
return
img
elif
backend
==
'pillow'
:
elif
backend
==
'pillow'
:
buff
=
io
.
BytesIO
(
content
)
with
io
.
BytesIO
(
content
)
as
buff
:
img
=
Image
.
open
(
buff
)
img
=
Image
.
open
(
buff
)
img
=
_pillow2array
(
img
,
flag
,
channel_order
)
img
=
_pillow2array
(
img
,
flag
,
channel_order
)
return
img
elif
backend
==
'tifffile'
:
with
io
.
BytesIO
(
content
)
as
buff
:
img
=
tifffile
.
imread
(
buff
)
return
img
return
img
else
:
else
:
img_np
=
np
.
frombuffer
(
content
,
np
.
uint8
)
img_np
=
np
.
frombuffer
(
content
,
np
.
uint8
)
...
@@ -239,20 +262,53 @@ def imfrombytes(content, flag='color', channel_order='bgr', backend=None):
...
@@ -239,20 +262,53 @@ def imfrombytes(content, flag='color', channel_order='bgr', backend=None):
return
img
return
img
def
imwrite
(
img
,
file_path
,
params
=
None
,
auto_mkdir
=
True
):
def
imwrite
(
img
,
file_path
,
params
=
None
,
auto_mkdir
=
None
,
file_client_args
=
None
):
"""Write image to file.
"""Write image to file.
Note:
In v1.4.1 and later, add `file_client_args` parameters.
Warning:
The parameter `auto_mkdir` will be deprecated in the future and every
file clients will make directory automatically.
Args:
Args:
img (ndarray): Image array to be written.
img (ndarray): Image array to be written.
file_path (str): Image file path.
file_path (str): Image file path.
params (None or list): Same as opencv :func:`imwrite` interface.
params (None or list): Same as opencv :func:`imwrite` interface.
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
whether to create it automatically.
whether to create it automatically. It will be deprecated.
file_client_args (dict | None): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
Returns:
Returns:
bool: Successful or not.
bool: Successful or not.
Examples:
>>> # write to hard disk client
>>> ret = mmcv.imwrite(img, '/path/to/img.jpg')
>>> # infer the file backend by the prefix s3
>>> ret = mmcv.imwrite(img, 's3://bucket/img.jpg')
>>> # manually set the file backend petrel
>>> ret = mmcv.imwrite(img, 's3://bucket/img.jpg', file_client_args={
... 'backend': 'petrel'})
"""
"""
if
auto_mkdir
:
assert
is_filepath
(
file_path
)
dir_name
=
osp
.
abspath
(
osp
.
dirname
(
file_path
))
file_path
=
str
(
file_path
)
mkdir_or_exist
(
dir_name
)
if
auto_mkdir
is
not
None
:
return
cv2
.
imwrite
(
file_path
,
img
,
params
)
warnings
.
warn
(
'The parameter `auto_mkdir` will be deprecated in the future and '
'every file clients will make directory automatically.'
)
file_client
=
FileClient
.
infer_client
(
file_client_args
,
file_path
)
img_ext
=
osp
.
splitext
(
file_path
)[
-
1
]
# Encode image according to image suffix.
# For example, if image path is '/path/your/img.jpg', the encode
# format is '.jpg'.
flag
,
img_buff
=
cv2
.
imencode
(
img_ext
,
img
,
params
)
file_client
.
put
(
img_buff
.
tobytes
(),
file_path
)
return
flag
mmcv/image/misc.py
View file @
fdeee889
...
@@ -9,18 +9,21 @@ except ImportError:
...
@@ -9,18 +9,21 @@ except ImportError:
torch
=
None
torch
=
None
def
tensor2imgs
(
tensor
,
mean
=
(
0
,
0
,
0
),
std
=
(
1
,
1
,
1
)
,
to_rgb
=
True
):
def
tensor2imgs
(
tensor
,
mean
=
None
,
std
=
None
,
to_rgb
=
True
):
"""Convert tensor to 3-channel images.
"""Convert tensor to 3-channel
images or 1-channel gray
images.
Args:
Args:
tensor (torch.Tensor): Tensor that contains multiple images, shape (
tensor (torch.Tensor): Tensor that contains multiple images, shape (
N, C, H, W).
N, C, H, W). :math:`C` can be either 3 or 1.
mean (tuple[float], optional): Mean of images. Defaults to (0, 0, 0).
mean (tuple[float], optional): Mean of images. If None,
std (tuple[float], optional): Standard deviation of images.
(0, 0, 0) will be used for tensor with 3-channel,
Defaults to (1, 1, 1).
while (0, ) for tensor with 1-channel. Defaults to None.
std (tuple[float], optional): Standard deviation of images. If None,
(1, 1, 1) will be used for tensor with 3-channel,
while (1, ) for tensor with 1-channel. Defaults to None.
to_rgb (bool, optional): Whether the tensor was converted to RGB
to_rgb (bool, optional): Whether the tensor was converted to RGB
format in the first place. If so, convert it back to BGR.
format in the first place. If so, convert it back to BGR.
Defaults to True.
For the tensor with 1 channel, it must be False.
Defaults to True.
Returns:
Returns:
list[np.ndarray]: A list that contains multiple images.
list[np.ndarray]: A list that contains multiple images.
...
@@ -29,8 +32,14 @@ def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True):
...
@@ -29,8 +32,14 @@ def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True):
if
torch
is
None
:
if
torch
is
None
:
raise
RuntimeError
(
'pytorch is not installed'
)
raise
RuntimeError
(
'pytorch is not installed'
)
assert
torch
.
is_tensor
(
tensor
)
and
tensor
.
ndim
==
4
assert
torch
.
is_tensor
(
tensor
)
and
tensor
.
ndim
==
4
assert
len
(
mean
)
==
3
channels
=
tensor
.
size
(
1
)
assert
len
(
std
)
==
3
assert
channels
in
[
1
,
3
]
if
mean
is
None
:
mean
=
(
0
,
)
*
channels
if
std
is
None
:
std
=
(
1
,
)
*
channels
assert
(
channels
==
len
(
mean
)
==
len
(
std
)
==
3
)
or
\
(
channels
==
len
(
mean
)
==
len
(
std
)
==
1
and
not
to_rgb
)
num_imgs
=
tensor
.
size
(
0
)
num_imgs
=
tensor
.
size
(
0
)
mean
=
np
.
array
(
mean
,
dtype
=
np
.
float32
)
mean
=
np
.
array
(
mean
,
dtype
=
np
.
float32
)
...
...
mmcv/image/photometric.py
View file @
fdeee889
...
@@ -426,3 +426,46 @@ def clahe(img, clip_limit=40.0, tile_grid_size=(8, 8)):
...
@@ -426,3 +426,46 @@ def clahe(img, clip_limit=40.0, tile_grid_size=(8, 8)):
clahe
=
cv2
.
createCLAHE
(
clip_limit
,
tile_grid_size
)
clahe
=
cv2
.
createCLAHE
(
clip_limit
,
tile_grid_size
)
return
clahe
.
apply
(
np
.
array
(
img
,
dtype
=
np
.
uint8
))
return
clahe
.
apply
(
np
.
array
(
img
,
dtype
=
np
.
uint8
))
def
adjust_hue
(
img
:
np
.
ndarray
,
hue_factor
:
float
)
->
np
.
ndarray
:
"""Adjust hue of an image.
The image hue is adjusted by converting the image to HSV and cyclically
shifting the intensities in the hue channel (H). The image is then
converted back to original image mode.
`hue_factor` is the amount of shift in H channel and must be in the
interval `[-0.5, 0.5]`.
Modified from
https://github.com/pytorch/vision/blob/main/torchvision/
transforms/functional.py
Args:
img (ndarray): Image to be adjusted.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
with complementary colors while 0 gives the original image.
Returns:
ndarray: Hue adjusted image.
"""
if
not
(
-
0.5
<=
hue_factor
<=
0.5
):
raise
ValueError
(
f
'hue_factor:
{
hue_factor
}
is not in [-0.5, 0.5].'
)
if
not
(
isinstance
(
img
,
np
.
ndarray
)
and
(
img
.
ndim
in
{
2
,
3
})):
raise
TypeError
(
'img should be ndarray with dim=[2 or 3].'
)
dtype
=
img
.
dtype
img
=
img
.
astype
(
np
.
uint8
)
hsv_img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_RGB2HSV_FULL
)
h
,
s
,
v
=
cv2
.
split
(
hsv_img
)
h
=
h
.
astype
(
np
.
uint8
)
# uint8 addition take cares of rotation across boundaries
with
np
.
errstate
(
over
=
'ignore'
):
h
+=
np
.
uint8
(
hue_factor
*
255
)
hsv_img
=
cv2
.
merge
([
h
,
s
,
v
])
return
cv2
.
cvtColor
(
hsv_img
,
cv2
.
COLOR_HSV2RGB_FULL
).
astype
(
dtype
)
mmcv/model_zoo/torchvision_0.12.json
0 → 100644
View file @
fdeee889
{
"alexnet"
:
"https://download.pytorch.org/models/alexnet-owt-7be5be79.pth"
,
"densenet121"
:
"https://download.pytorch.org/models/densenet121-a639ec97.pth"
,
"densenet169"
:
"https://download.pytorch.org/models/densenet169-b2777c0a.pth"
,
"densenet201"
:
"https://download.pytorch.org/models/densenet201-c1103571.pth"
,
"densenet161"
:
"https://download.pytorch.org/models/densenet161-8d451a50.pth"
,
"efficientnet_b0"
:
"https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth"
,
"efficientnet_b1"
:
"https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth"
,
"efficientnet_b2"
:
"https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth"
,
"efficientnet_b3"
:
"https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth"
,
"efficientnet_b4"
:
"https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth"
,
"efficientnet_b5"
:
"https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth"
,
"efficientnet_b6"
:
"https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth"
,
"efficientnet_b7"
:
"https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth"
,
"googlenet"
:
"https://download.pytorch.org/models/googlenet-1378be20.pth"
,
"inception_v3_google"
:
"https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth"
,
"mobilenet_v2"
:
"https://download.pytorch.org/models/mobilenet_v2-b0353104.pth"
,
"mobilenet_v3_large"
:
"https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth"
,
"mobilenet_v3_small"
:
"https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth"
,
"regnet_y_400mf"
:
"https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth"
,
"regnet_y_800mf"
:
"https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth"
,
"regnet_y_1_6gf"
:
"https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth"
,
"regnet_y_3_2gf"
:
"https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth"
,
"regnet_y_8gf"
:
"https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth"
,
"regnet_y_16gf"
:
"https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth"
,
"regnet_y_32gf"
:
"https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth"
,
"regnet_x_400mf"
:
"https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth"
,
"regnet_x_800mf"
:
"https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth"
,
"regnet_x_1_6gf"
:
"https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth"
,
"regnet_x_3_2gf"
:
"https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth"
,
"regnet_x_8gf"
:
"https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth"
,
"regnet_x_16gf"
:
"https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth"
,
"regnet_x_32gf"
:
"https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth"
,
"resnet18"
:
"https://download.pytorch.org/models/resnet18-f37072fd.pth"
,
"resnet34"
:
"https://download.pytorch.org/models/resnet34-b627a593.pth"
,
"resnet50"
:
"https://download.pytorch.org/models/resnet50-0676ba61.pth"
,
"resnet101"
:
"https://download.pytorch.org/models/resnet101-63fe2227.pth"
,
"resnet152"
:
"https://download.pytorch.org/models/resnet152-394f9c45.pth"
,
"resnext50_32x4d"
:
"https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth"
,
"resnext101_32x8d"
:
"https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth"
,
"wide_resnet50_2"
:
"https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth"
,
"wide_resnet101_2"
:
"https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth"
,
"shufflenetv2_x0.5"
:
"https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth"
,
"shufflenetv2_x1.0"
:
"https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth"
,
"shufflenetv2_x1.5"
:
null
,
"shufflenetv2_x2.0"
:
null
,
"squeezenet1_0"
:
"https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth"
,
"squeezenet1_1"
:
"https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth"
,
"vgg11"
:
"https://download.pytorch.org/models/vgg11-8a719046.pth"
,
"vgg13"
:
"https://download.pytorch.org/models/vgg13-19584684.pth"
,
"vgg16"
:
"https://download.pytorch.org/models/vgg16-397923af.pth"
,
"vgg19"
:
"https://download.pytorch.org/models/vgg19-dcbb9e9d.pth"
,
"vgg11_bn"
:
"https://download.pytorch.org/models/vgg11_bn-6002323d.pth"
,
"vgg13_bn"
:
"https://download.pytorch.org/models/vgg13_bn-abd245e5.pth"
,
"vgg16_bn"
:
"https://download.pytorch.org/models/vgg16_bn-6c64b313.pth"
,
"vgg19_bn"
:
"https://download.pytorch.org/models/vgg19_bn-c79401a0.pth"
}
mmcv/onnx/info.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
os
import
os
import
warnings
import
torch
import
torch
def
is_custom_op_loaded
():
def
is_custom_op_loaded
()
->
bool
:
# Following strings of text style are from colorama package
bright_style
,
reset_style
=
'
\x1b
[1m'
,
'
\x1b
[0m'
red_text
,
blue_text
=
'
\x1b
[31m'
,
'
\x1b
[34m'
white_background
=
'
\x1b
[107m'
msg
=
white_background
+
bright_style
+
red_text
msg
+=
'DeprecationWarning: This function will be deprecated in future. '
msg
+=
blue_text
+
'Welcome to use the unified model deployment toolbox '
msg
+=
'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg
+=
reset_style
warnings
.
warn
(
msg
)
flag
=
False
flag
=
False
try
:
try
:
from
..tensorrt
import
is_tensorrt_plugin_loaded
from
..tensorrt
import
is_tensorrt_plugin_loaded
...
...
mmcv/onnx/onnx_utils/symbolic_helper.py
View file @
fdeee889
...
@@ -59,7 +59,7 @@ def _parse_arg(value, desc):
...
@@ -59,7 +59,7 @@ def _parse_arg(value, desc):
raise
RuntimeError
(
raise
RuntimeError
(
"ONNX symbolic doesn't know to interpret ListConstruct node"
)
"ONNX symbolic doesn't know to interpret ListConstruct node"
)
raise
RuntimeError
(
'Unexpected node type: {
}'
.
format
(
value
.
node
().
kind
()
)
)
raise
RuntimeError
(
f
'Unexpected node type:
{
value
.
node
().
kind
()
}
'
)
def
_maybe_get_const
(
value
,
desc
):
def
_maybe_get_const
(
value
,
desc
):
...
@@ -328,4 +328,4 @@ cast_pytorch_to_onnx = {
...
@@ -328,4 +328,4 @@ cast_pytorch_to_onnx = {
# Global set to store the list of quantized operators in the network.
# Global set to store the list of quantized operators in the network.
# This is currently only used in the conversion of quantized ops from PT
# This is currently only used in the conversion of quantized ops from PT
# -> C2 via ONNX.
# -> C2 via ONNX.
_quantized_ops
=
set
()
_quantized_ops
:
set
=
set
()
mmcv/onnx/symbolic.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
"""Modified from https://github.com/pytorch/pytorch."""
"""Modified from https://github.com/pytorch/pytorch."""
import
os
import
os
import
warnings
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -409,8 +410,8 @@ def cummin(g, input, dim):
...
@@ -409,8 +410,8 @@ def cummin(g, input, dim):
@
parse_args
(
'v'
,
'v'
,
'is'
)
@
parse_args
(
'v'
,
'v'
,
'is'
)
def
roll
(
g
,
input
,
shifts
,
dims
):
def
roll
(
g
,
input
,
shifts
,
dims
):
from
torch.onnx.symbolic_opset9
import
squeeze
from
packaging
import
version
from
packaging
import
version
from
torch.onnx.symbolic_opset9
import
squeeze
input_shape
=
g
.
op
(
'Shape'
,
input
)
input_shape
=
g
.
op
(
'Shape'
,
input
)
need_flatten
=
len
(
dims
)
==
0
need_flatten
=
len
(
dims
)
==
0
...
@@ -467,6 +468,18 @@ def roll(g, input, shifts, dims):
...
@@ -467,6 +468,18 @@ def roll(g, input, shifts, dims):
def
register_extra_symbolics
(
opset
=
11
):
def
register_extra_symbolics
(
opset
=
11
):
# Following strings of text style are from colorama package
bright_style
,
reset_style
=
'
\x1b
[1m'
,
'
\x1b
[0m'
red_text
,
blue_text
=
'
\x1b
[31m'
,
'
\x1b
[34m'
white_background
=
'
\x1b
[107m'
msg
=
white_background
+
bright_style
+
red_text
msg
+=
'DeprecationWarning: This function will be deprecated in future. '
msg
+=
blue_text
+
'Welcome to use the unified model deployment toolbox '
msg
+=
'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg
+=
reset_style
warnings
.
warn
(
msg
)
register_op
(
'one_hot'
,
one_hot
,
''
,
opset
)
register_op
(
'one_hot'
,
one_hot
,
''
,
opset
)
register_op
(
'im2col'
,
im2col
,
''
,
opset
)
register_op
(
'im2col'
,
im2col
,
''
,
opset
)
register_op
(
'topk'
,
topk
,
''
,
opset
)
register_op
(
'topk'
,
topk
,
''
,
opset
)
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
…
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment