Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
fdeee889
Commit
fdeee889
authored
May 25, 2025
by
limm
Browse files
release v1.6.1 of mmcv
parent
df465820
Changes
490
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
682 additions
and
317 deletions
+682
-317
mmcv/parallel/_functions.py
mmcv/parallel/_functions.py
+15
-12
mmcv/parallel/collate.py
mmcv/parallel/collate.py
+1
-1
mmcv/parallel/data_container.py
mmcv/parallel/data_container.py
+18
-16
mmcv/parallel/data_parallel.py
mmcv/parallel/data_parallel.py
+7
-5
mmcv/parallel/distributed.py
mmcv/parallel/distributed.py
+66
-11
mmcv/parallel/distributed_deprecated.py
mmcv/parallel/distributed_deprecated.py
+13
-9
mmcv/parallel/scatter_gather.py
mmcv/parallel/scatter_gather.py
+19
-8
mmcv/parallel/utils.py
mmcv/parallel/utils.py
+16
-4
mmcv/runner/__init__.py
mmcv/runner/__init__.py
+34
-8
mmcv/runner/base_module.py
mmcv/runner/base_module.py
+33
-15
mmcv/runner/base_runner.py
mmcv/runner/base_runner.py
+91
-67
mmcv/runner/builder.py
mmcv/runner/builder.py
+3
-2
mmcv/runner/checkpoint.py
mmcv/runner/checkpoint.py
+174
-74
mmcv/runner/default_constructor.py
mmcv/runner/default_constructor.py
+4
-1
mmcv/runner/dist_utils.py
mmcv/runner/dist_utils.py
+68
-21
mmcv/runner/epoch_based_runner.py
mmcv/runner/epoch_based_runner.py
+19
-9
mmcv/runner/fp16_utils.py
mmcv/runner/fp16_utils.py
+53
-28
mmcv/runner/hooks/__init__.py
mmcv/runner/hooks/__init__.py
+31
-12
mmcv/runner/hooks/checkpoint.py
mmcv/runner/hooks/checkpoint.py
+14
-13
mmcv/runner/hooks/closure.py
mmcv/runner/hooks/closure.py
+3
-1
No files found.
Too many changes to show.
To preserve performance only
490 of 490+
files are displayed.
Plain diff
Email patch
mmcv/parallel/_functions.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
List
,
Optional
,
Union
import
torch
import
torch
from
torch
import
Tensor
from
torch.nn.parallel._functions
import
_get_stream
from
torch.nn.parallel._functions
import
_get_stream
def
scatter
(
input
,
devices
,
streams
=
None
):
def
scatter
(
input
:
Union
[
List
,
Tensor
],
devices
:
List
,
streams
:
Optional
[
List
]
=
None
)
->
Union
[
List
,
Tensor
]:
"""Scatters tensor across multiple GPUs."""
"""Scatters tensor across multiple GPUs."""
if
streams
is
None
:
if
streams
is
None
:
streams
=
[
None
]
*
len
(
devices
)
streams
=
[
None
]
*
len
(
devices
)
...
@@ -15,30 +20,28 @@ def scatter(input, devices, streams=None):
...
@@ -15,30 +20,28 @@ def scatter(input, devices, streams=None):
[
streams
[
i
//
chunk_size
]])
for
i
in
range
(
len
(
input
))
[
streams
[
i
//
chunk_size
]])
for
i
in
range
(
len
(
input
))
]
]
return
outputs
return
outputs
elif
isinstance
(
input
,
torch
.
Tensor
):
elif
isinstance
(
input
,
Tensor
):
output
=
input
.
contiguous
()
output
=
input
.
contiguous
()
# TODO: copy to a pinned buffer first (if copying from CPU)
# TODO: copy to a pinned buffer first (if copying from CPU)
stream
=
streams
[
0
]
if
output
.
numel
()
>
0
else
None
stream
=
streams
[
0
]
if
output
.
numel
()
>
0
else
None
if
devices
!=
[
-
1
]:
if
devices
!=
[
-
1
]:
with
torch
.
cuda
.
device
(
devices
[
0
]),
torch
.
cuda
.
stream
(
stream
):
with
torch
.
cuda
.
device
(
devices
[
0
]),
torch
.
cuda
.
stream
(
stream
):
output
=
output
.
cuda
(
devices
[
0
],
non_blocking
=
True
)
output
=
output
.
cuda
(
devices
[
0
],
non_blocking
=
True
)
else
:
# unsqueeze the first dimension thus the tensor's shape is the
# same as those scattered with GPU.
output
=
output
.
unsqueeze
(
0
)
return
output
return
output
else
:
else
:
raise
Exception
(
f
'Unknown type
{
type
(
input
)
}
.'
)
raise
Exception
(
f
'Unknown type
{
type
(
input
)
}
.'
)
def
synchronize_stream
(
output
,
devices
,
streams
):
def
synchronize_stream
(
output
:
Union
[
List
,
Tensor
],
devices
:
List
,
streams
:
List
)
->
None
:
if
isinstance
(
output
,
list
):
if
isinstance
(
output
,
list
):
chunk_size
=
len
(
output
)
//
len
(
devices
)
chunk_size
=
len
(
output
)
//
len
(
devices
)
for
i
in
range
(
len
(
devices
)):
for
i
in
range
(
len
(
devices
)):
for
j
in
range
(
chunk_size
):
for
j
in
range
(
chunk_size
):
synchronize_stream
(
output
[
i
*
chunk_size
+
j
],
[
devices
[
i
]],
synchronize_stream
(
output
[
i
*
chunk_size
+
j
],
[
devices
[
i
]],
[
streams
[
i
]])
[
streams
[
i
]])
elif
isinstance
(
output
,
torch
.
Tensor
):
elif
isinstance
(
output
,
Tensor
):
if
output
.
numel
()
!=
0
:
if
output
.
numel
()
!=
0
:
with
torch
.
cuda
.
device
(
devices
[
0
]):
with
torch
.
cuda
.
device
(
devices
[
0
]):
main_stream
=
torch
.
cuda
.
current_stream
()
main_stream
=
torch
.
cuda
.
current_stream
()
...
@@ -48,14 +51,14 @@ def synchronize_stream(output, devices, streams):
...
@@ -48,14 +51,14 @@ def synchronize_stream(output, devices, streams):
raise
Exception
(
f
'Unknown type
{
type
(
output
)
}
.'
)
raise
Exception
(
f
'Unknown type
{
type
(
output
)
}
.'
)
def
get_input_device
(
input
)
:
def
get_input_device
(
input
:
Union
[
List
,
Tensor
])
->
int
:
if
isinstance
(
input
,
list
):
if
isinstance
(
input
,
list
):
for
item
in
input
:
for
item
in
input
:
input_device
=
get_input_device
(
item
)
input_device
=
get_input_device
(
item
)
if
input_device
!=
-
1
:
if
input_device
!=
-
1
:
return
input_device
return
input_device
return
-
1
return
-
1
elif
isinstance
(
input
,
torch
.
Tensor
):
elif
isinstance
(
input
,
Tensor
):
return
input
.
get_device
()
if
input
.
is_cuda
else
-
1
return
input
.
get_device
()
if
input
.
is_cuda
else
-
1
else
:
else
:
raise
Exception
(
f
'Unknown type
{
type
(
input
)
}
.'
)
raise
Exception
(
f
'Unknown type
{
type
(
input
)
}
.'
)
...
@@ -64,7 +67,7 @@ def get_input_device(input):
...
@@ -64,7 +67,7 @@ def get_input_device(input):
class
Scatter
:
class
Scatter
:
@
staticmethod
@
staticmethod
def
forward
(
target_gpus
,
input
)
:
def
forward
(
target_gpus
:
List
[
int
],
input
:
Union
[
List
,
Tensor
])
->
tuple
:
input_device
=
get_input_device
(
input
)
input_device
=
get_input_device
(
input
)
streams
=
None
streams
=
None
if
input_device
==
-
1
and
target_gpus
!=
[
-
1
]:
if
input_device
==
-
1
and
target_gpus
!=
[
-
1
]:
...
@@ -76,4 +79,4 @@ class Scatter:
...
@@ -76,4 +79,4 @@ class Scatter:
if
streams
is
not
None
:
if
streams
is
not
None
:
synchronize_stream
(
outputs
,
target_gpus
,
streams
)
synchronize_stream
(
outputs
,
target_gpus
,
streams
)
return
tuple
(
outputs
)
return
tuple
(
outputs
)
if
isinstance
(
outputs
,
list
)
else
(
outputs
,
)
mmcv/parallel/collate.py
View file @
fdeee889
...
@@ -8,7 +8,7 @@ from torch.utils.data.dataloader import default_collate
...
@@ -8,7 +8,7 @@ from torch.utils.data.dataloader import default_collate
from
.data_container
import
DataContainer
from
.data_container
import
DataContainer
def
collate
(
batch
,
samples_per_gpu
=
1
):
def
collate
(
batch
:
Sequence
,
samples_per_gpu
:
int
=
1
):
"""Puts each data field into a tensor/DataContainer with outer dimension
"""Puts each data field into a tensor/DataContainer with outer dimension
batch size.
batch size.
...
...
mmcv/parallel/data_container.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
functools
import
functools
from
typing
import
Callable
,
Type
,
Union
import
numpy
as
np
import
torch
import
torch
def
assert_tensor_type
(
func
)
:
def
assert_tensor_type
(
func
:
Callable
)
->
Callable
:
@
functools
.
wraps
(
func
)
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
def
wrapper
(
*
args
,
**
kwargs
):
...
@@ -35,11 +37,11 @@ class DataContainer:
...
@@ -35,11 +37,11 @@ class DataContainer:
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
data
,
data
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
,
stack
=
False
,
stack
:
bool
=
False
,
padding_value
=
0
,
padding_value
:
int
=
0
,
cpu_only
=
False
,
cpu_only
:
bool
=
False
,
pad_dims
=
2
):
pad_dims
:
int
=
2
):
self
.
_data
=
data
self
.
_data
=
data
self
.
_cpu_only
=
cpu_only
self
.
_cpu_only
=
cpu_only
self
.
_stack
=
stack
self
.
_stack
=
stack
...
@@ -47,43 +49,43 @@ class DataContainer:
...
@@ -47,43 +49,43 @@ class DataContainer:
assert
pad_dims
in
[
None
,
1
,
2
,
3
]
assert
pad_dims
in
[
None
,
1
,
2
,
3
]
self
.
_pad_dims
=
pad_dims
self
.
_pad_dims
=
pad_dims
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
return
f
'
{
self
.
__class__
.
__name__
}
(
{
repr
(
self
.
data
)
}
)'
return
f
'
{
self
.
__class__
.
__name__
}
(
{
repr
(
self
.
data
)
}
)'
def
__len__
(
self
):
def
__len__
(
self
)
->
int
:
return
len
(
self
.
_data
)
return
len
(
self
.
_data
)
@
property
@
property
def
data
(
self
):
def
data
(
self
)
->
Union
[
torch
.
Tensor
,
np
.
ndarray
]
:
return
self
.
_data
return
self
.
_data
@
property
@
property
def
datatype
(
self
):
def
datatype
(
self
)
->
Union
[
Type
,
str
]
:
if
isinstance
(
self
.
data
,
torch
.
Tensor
):
if
isinstance
(
self
.
data
,
torch
.
Tensor
):
return
self
.
data
.
type
()
return
self
.
data
.
type
()
else
:
else
:
return
type
(
self
.
data
)
return
type
(
self
.
data
)
@
property
@
property
def
cpu_only
(
self
):
def
cpu_only
(
self
)
->
bool
:
return
self
.
_cpu_only
return
self
.
_cpu_only
@
property
@
property
def
stack
(
self
):
def
stack
(
self
)
->
bool
:
return
self
.
_stack
return
self
.
_stack
@
property
@
property
def
padding_value
(
self
):
def
padding_value
(
self
)
->
int
:
return
self
.
_padding_value
return
self
.
_padding_value
@
property
@
property
def
pad_dims
(
self
):
def
pad_dims
(
self
)
->
int
:
return
self
.
_pad_dims
return
self
.
_pad_dims
@
assert_tensor_type
@
assert_tensor_type
def
size
(
self
,
*
args
,
**
kwargs
):
def
size
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Size
:
return
self
.
data
.
size
(
*
args
,
**
kwargs
)
return
self
.
data
.
size
(
*
args
,
**
kwargs
)
@
assert_tensor_type
@
assert_tensor_type
def
dim
(
self
):
def
dim
(
self
)
->
int
:
return
self
.
data
.
dim
()
return
self
.
data
.
dim
()
mmcv/parallel/data_parallel.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
itertools
import
chain
from
itertools
import
chain
from
typing
import
List
,
Tuple
from
torch.nn.parallel
import
DataParallel
from
torch.nn.parallel
import
DataParallel
from
.scatter_gather
import
scatter_kwargs
from
.scatter_gather
import
ScatterInputs
,
scatter_kwargs
class
MMDataParallel
(
DataParallel
):
class
MMDataParallel
(
DataParallel
):
...
@@ -13,7 +14,7 @@ class MMDataParallel(DataParallel):
...
@@ -13,7 +14,7 @@ class MMDataParallel(DataParallel):
- It supports a custom type :class:`DataContainer` which allows more
- It supports a custom type :class:`DataContainer` which allows more
flexible control of input data during both GPU and CPU inference.
flexible control of input data during both GPU and CPU inference.
- It implement two more APIs ``train_step()`` and ``val_step()``.
- It implement
s
two more APIs ``train_step()`` and ``val_step()``.
.. warning::
.. warning::
MMDataParallel only supports single GPU training, if you need to
MMDataParallel only supports single GPU training, if you need to
...
@@ -31,8 +32,8 @@ class MMDataParallel(DataParallel):
...
@@ -31,8 +32,8 @@ class MMDataParallel(DataParallel):
dim (int): Dimension used to scatter the data. Defaults to 0.
dim (int): Dimension used to scatter the data. Defaults to 0.
"""
"""
def
__init__
(
self
,
*
args
,
dim
=
0
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
dim
:
int
=
0
,
**
kwargs
):
super
(
MMDataParallel
,
self
).
__init__
(
*
args
,
dim
=
dim
,
**
kwargs
)
super
().
__init__
(
*
args
,
dim
=
dim
,
**
kwargs
)
self
.
dim
=
dim
self
.
dim
=
dim
def
forward
(
self
,
*
inputs
,
**
kwargs
):
def
forward
(
self
,
*
inputs
,
**
kwargs
):
...
@@ -49,7 +50,8 @@ class MMDataParallel(DataParallel):
...
@@ -49,7 +50,8 @@ class MMDataParallel(DataParallel):
else
:
else
:
return
super
().
forward
(
*
inputs
,
**
kwargs
)
return
super
().
forward
(
*
inputs
,
**
kwargs
)
def
scatter
(
self
,
inputs
,
kwargs
,
device_ids
):
def
scatter
(
self
,
inputs
:
ScatterInputs
,
kwargs
:
ScatterInputs
,
device_ids
:
List
[
int
])
->
Tuple
[
tuple
,
tuple
]:
return
scatter_kwargs
(
inputs
,
kwargs
,
device_ids
,
dim
=
self
.
dim
)
return
scatter_kwargs
(
inputs
,
kwargs
,
device_ids
,
dim
=
self
.
dim
)
def
train_step
(
self
,
*
inputs
,
**
kwargs
):
def
train_step
(
self
,
*
inputs
,
**
kwargs
):
...
...
mmcv/parallel/distributed.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Any
,
List
,
Tuple
import
torch
import
torch
from
torch.nn.parallel.distributed
import
(
DistributedDataParallel
,
from
torch.nn.parallel.distributed
import
(
DistributedDataParallel
,
_find_tensors
)
_find_tensors
)
from
mmcv
import
print_log
from
mmcv
import
print_log
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
from
.scatter_gather
import
scatter_kwargs
from
.scatter_gather
import
ScatterInputs
,
scatter_kwargs
class
MMDistributedDataParallel
(
DistributedDataParallel
):
class
MMDistributedDataParallel
(
DistributedDataParallel
):
...
@@ -18,12 +20,14 @@ class MMDistributedDataParallel(DistributedDataParallel):
...
@@ -18,12 +20,14 @@ class MMDistributedDataParallel(DistributedDataParallel):
- It implement two APIs ``train_step()`` and ``val_step()``.
- It implement two APIs ``train_step()`` and ``val_step()``.
"""
"""
def
to_kwargs
(
self
,
inputs
,
kwargs
,
device_id
):
def
to_kwargs
(
self
,
inputs
:
ScatterInputs
,
kwargs
:
ScatterInputs
,
device_id
:
int
)
->
Tuple
[
tuple
,
tuple
]:
# Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
# Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
# to move all tensors to device_id
# to move all tensors to device_id
return
scatter_kwargs
(
inputs
,
kwargs
,
[
device_id
],
dim
=
self
.
dim
)
return
scatter_kwargs
(
inputs
,
kwargs
,
[
device_id
],
dim
=
self
.
dim
)
def
scatter
(
self
,
inputs
,
kwargs
,
device_ids
):
def
scatter
(
self
,
inputs
:
ScatterInputs
,
kwargs
:
ScatterInputs
,
device_ids
:
List
[
int
])
->
Tuple
[
tuple
,
tuple
]:
return
scatter_kwargs
(
inputs
,
kwargs
,
device_ids
,
dim
=
self
.
dim
)
return
scatter_kwargs
(
inputs
,
kwargs
,
device_ids
,
dim
=
self
.
dim
)
def
train_step
(
self
,
*
inputs
,
**
kwargs
):
def
train_step
(
self
,
*
inputs
,
**
kwargs
):
...
@@ -44,8 +48,15 @@ class MMDistributedDataParallel(DistributedDataParallel):
...
@@ -44,8 +48,15 @@ class MMDistributedDataParallel(DistributedDataParallel):
'Reducer buckets have been rebuilt in this iteration.'
,
'Reducer buckets have been rebuilt in this iteration.'
,
logger
=
'mmcv'
)
logger
=
'mmcv'
)
if
getattr
(
self
,
'require_forward_param_sync'
,
True
):
if
(
'parrots'
not
in
TORCH_VERSION
and
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.11.0a0'
)):
if
self
.
_check_sync_bufs_pre_fwd
():
self
.
_sync_buffers
()
else
:
if
(
getattr
(
self
,
'require_forward_param_sync'
,
False
)
and
self
.
require_forward_param_sync
):
self
.
_sync_params
()
self
.
_sync_params
()
if
self
.
device_ids
:
if
self
.
device_ids
:
inputs
,
kwargs
=
self
.
scatter
(
inputs
,
kwargs
,
self
.
device_ids
)
inputs
,
kwargs
=
self
.
scatter
(
inputs
,
kwargs
,
self
.
device_ids
)
if
len
(
self
.
device_ids
)
==
1
:
if
len
(
self
.
device_ids
)
==
1
:
...
@@ -57,8 +68,14 @@ class MMDistributedDataParallel(DistributedDataParallel):
...
@@ -57,8 +68,14 @@ class MMDistributedDataParallel(DistributedDataParallel):
else
:
else
:
output
=
self
.
module
.
train_step
(
*
inputs
,
**
kwargs
)
output
=
self
.
module
.
train_step
(
*
inputs
,
**
kwargs
)
if
torch
.
is_grad_enabled
()
and
getattr
(
if
(
'parrots'
not
in
TORCH_VERSION
self
,
'require_backward_grad_sync'
,
True
):
and
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.11.0a0'
)):
if
self
.
_check_sync_bufs_post_fwd
():
self
.
_sync_buffers
()
if
(
torch
.
is_grad_enabled
()
and
getattr
(
self
,
'require_backward_grad_sync'
,
False
)
and
self
.
require_backward_grad_sync
):
if
self
.
find_unused_parameters
:
if
self
.
find_unused_parameters
:
self
.
reducer
.
prepare_for_backward
(
list
(
_find_tensors
(
output
)))
self
.
reducer
.
prepare_for_backward
(
list
(
_find_tensors
(
output
)))
else
:
else
:
...
@@ -86,8 +103,15 @@ class MMDistributedDataParallel(DistributedDataParallel):
...
@@ -86,8 +103,15 @@ class MMDistributedDataParallel(DistributedDataParallel):
'Reducer buckets have been rebuilt in this iteration.'
,
'Reducer buckets have been rebuilt in this iteration.'
,
logger
=
'mmcv'
)
logger
=
'mmcv'
)
if
getattr
(
self
,
'require_forward_param_sync'
,
True
):
if
(
'parrots'
not
in
TORCH_VERSION
and
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.11.0a0'
)):
if
self
.
_check_sync_bufs_pre_fwd
():
self
.
_sync_buffers
()
else
:
if
(
getattr
(
self
,
'require_forward_param_sync'
,
False
)
and
self
.
require_forward_param_sync
):
self
.
_sync_params
()
self
.
_sync_params
()
if
self
.
device_ids
:
if
self
.
device_ids
:
inputs
,
kwargs
=
self
.
scatter
(
inputs
,
kwargs
,
self
.
device_ids
)
inputs
,
kwargs
=
self
.
scatter
(
inputs
,
kwargs
,
self
.
device_ids
)
if
len
(
self
.
device_ids
)
==
1
:
if
len
(
self
.
device_ids
)
==
1
:
...
@@ -99,8 +123,14 @@ class MMDistributedDataParallel(DistributedDataParallel):
...
@@ -99,8 +123,14 @@ class MMDistributedDataParallel(DistributedDataParallel):
else
:
else
:
output
=
self
.
module
.
val_step
(
*
inputs
,
**
kwargs
)
output
=
self
.
module
.
val_step
(
*
inputs
,
**
kwargs
)
if
torch
.
is_grad_enabled
()
and
getattr
(
if
(
'parrots'
not
in
TORCH_VERSION
self
,
'require_backward_grad_sync'
,
True
):
and
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.11.0a0'
)):
if
self
.
_check_sync_bufs_post_fwd
():
self
.
_sync_buffers
()
if
(
torch
.
is_grad_enabled
()
and
getattr
(
self
,
'require_backward_grad_sync'
,
False
)
and
self
.
require_backward_grad_sync
):
if
self
.
find_unused_parameters
:
if
self
.
find_unused_parameters
:
self
.
reducer
.
prepare_for_backward
(
list
(
_find_tensors
(
output
)))
self
.
reducer
.
prepare_for_backward
(
list
(
_find_tensors
(
output
)))
else
:
else
:
...
@@ -110,3 +140,28 @@ class MMDistributedDataParallel(DistributedDataParallel):
...
@@ -110,3 +140,28 @@ class MMDistributedDataParallel(DistributedDataParallel):
and
digit_version
(
TORCH_VERSION
)
>
digit_version
(
'1.2'
)):
and
digit_version
(
TORCH_VERSION
)
>
digit_version
(
'1.2'
)):
self
.
require_forward_param_sync
=
False
self
.
require_forward_param_sync
=
False
return
output
return
output
def
_run_ddp_forward
(
self
,
*
inputs
,
**
kwargs
)
->
Any
:
"""Processes inputs and runs ``self.module.forward``.
Pytorch 1.12.0 performs ``self.module.forward`` in ``_run_ddp_forward``
and deprecates using ``DistributedDataParallel.to_kwargs`` to
process inputs, which leads to inputs cannot be processed by
:meth:`MMDistributedDataParallel.to_kwargs` anymore. Therefore,
``MMDistributedDataParallel`` overrides this method to call
:meth:`to_kwargs` explicitly.
See more information in `<https://github.com/open-mmlab/mmsegmentation/issues/1742>`_. # noqa: E501
Returns:
Any: Forward result of :attr:`module`.
"""
module_to_run
=
self
.
_replicated_tensor_module
if
\
self
.
_use_replicated_tensor_module
else
self
.
module
if
self
.
device_ids
:
inputs
,
kwargs
=
self
.
to_kwargs
(
# type: ignore
inputs
,
kwargs
,
self
.
device_ids
[
0
])
return
module_to_run
(
*
inputs
[
0
],
**
kwargs
[
0
])
# type: ignore
else
:
return
module_to_run
(
*
inputs
,
**
kwargs
)
mmcv/parallel/distributed_deprecated.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
List
,
Sequence
,
Tuple
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -7,18 +9,18 @@ from torch._utils import (_flatten_dense_tensors, _take_tensors,
...
@@ -7,18 +9,18 @@ from torch._utils import (_flatten_dense_tensors, _take_tensors,
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
from
.registry
import
MODULE_WRAPPERS
from
.registry
import
MODULE_WRAPPERS
from
.scatter_gather
import
scatter_kwargs
from
.scatter_gather
import
ScatterInputs
,
scatter_kwargs
@
MODULE_WRAPPERS
.
register_module
()
@
MODULE_WRAPPERS
.
register_module
()
class
MMDistributedDataParallel
(
nn
.
Module
):
class
MMDistributedDataParallel
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
module
,
module
:
nn
.
Module
,
dim
=
0
,
dim
:
int
=
0
,
broadcast_buffers
=
True
,
broadcast_buffers
:
bool
=
True
,
bucket_cap_mb
=
25
):
bucket_cap_mb
:
int
=
25
):
super
(
MMDistributedDataParallel
,
self
).
__init__
()
super
().
__init__
()
self
.
module
=
module
self
.
module
=
module
self
.
dim
=
dim
self
.
dim
=
dim
self
.
broadcast_buffers
=
broadcast_buffers
self
.
broadcast_buffers
=
broadcast_buffers
...
@@ -26,7 +28,8 @@ class MMDistributedDataParallel(nn.Module):
...
@@ -26,7 +28,8 @@ class MMDistributedDataParallel(nn.Module):
self
.
broadcast_bucket_size
=
bucket_cap_mb
*
1024
*
1024
self
.
broadcast_bucket_size
=
bucket_cap_mb
*
1024
*
1024
self
.
_sync_params
()
self
.
_sync_params
()
def
_dist_broadcast_coalesced
(
self
,
tensors
,
buffer_size
):
def
_dist_broadcast_coalesced
(
self
,
tensors
:
Sequence
[
torch
.
Tensor
],
buffer_size
:
int
)
->
None
:
for
tensors
in
_take_tensors
(
tensors
,
buffer_size
):
for
tensors
in
_take_tensors
(
tensors
,
buffer_size
):
flat_tensors
=
_flatten_dense_tensors
(
tensors
)
flat_tensors
=
_flatten_dense_tensors
(
tensors
)
dist
.
broadcast
(
flat_tensors
,
0
)
dist
.
broadcast
(
flat_tensors
,
0
)
...
@@ -34,7 +37,7 @@ class MMDistributedDataParallel(nn.Module):
...
@@ -34,7 +37,7 @@ class MMDistributedDataParallel(nn.Module):
tensors
,
_unflatten_dense_tensors
(
flat_tensors
,
tensors
)):
tensors
,
_unflatten_dense_tensors
(
flat_tensors
,
tensors
)):
tensor
.
copy_
(
synced
)
tensor
.
copy_
(
synced
)
def
_sync_params
(
self
):
def
_sync_params
(
self
)
->
None
:
module_states
=
list
(
self
.
module
.
state_dict
().
values
())
module_states
=
list
(
self
.
module
.
state_dict
().
values
())
if
len
(
module_states
)
>
0
:
if
len
(
module_states
)
>
0
:
self
.
_dist_broadcast_coalesced
(
module_states
,
self
.
_dist_broadcast_coalesced
(
module_states
,
...
@@ -49,7 +52,8 @@ class MMDistributedDataParallel(nn.Module):
...
@@ -49,7 +52,8 @@ class MMDistributedDataParallel(nn.Module):
self
.
_dist_broadcast_coalesced
(
buffers
,
self
.
_dist_broadcast_coalesced
(
buffers
,
self
.
broadcast_bucket_size
)
self
.
broadcast_bucket_size
)
def
scatter
(
self
,
inputs
,
kwargs
,
device_ids
):
def
scatter
(
self
,
inputs
:
ScatterInputs
,
kwargs
:
ScatterInputs
,
device_ids
:
List
[
int
])
->
Tuple
[
tuple
,
tuple
]:
return
scatter_kwargs
(
inputs
,
kwargs
,
device_ids
,
dim
=
self
.
dim
)
return
scatter_kwargs
(
inputs
,
kwargs
,
device_ids
,
dim
=
self
.
dim
)
def
forward
(
self
,
*
inputs
,
**
kwargs
):
def
forward
(
self
,
*
inputs
,
**
kwargs
):
...
...
mmcv/parallel/scatter_gather.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
typing
import
List
,
Tuple
,
Union
from
torch
import
Tensor
from
torch.nn.parallel._functions
import
Scatter
as
OrigScatter
from
torch.nn.parallel._functions
import
Scatter
as
OrigScatter
from
._functions
import
Scatter
from
._functions
import
Scatter
from
.data_container
import
DataContainer
from
.data_container
import
DataContainer
ScatterInputs
=
Union
[
Tensor
,
DataContainer
,
tuple
,
list
,
dict
]
def
scatter
(
inputs
,
target_gpus
,
dim
=
0
):
def
scatter
(
inputs
:
ScatterInputs
,
target_gpus
:
List
[
int
],
dim
:
int
=
0
)
->
list
:
"""Scatter inputs to target gpus.
"""Scatter inputs to target gpus.
The only difference from original :func:`scatter` is to add support for
The only difference from original :func:`scatter` is to add support for
...
@@ -14,7 +20,7 @@ def scatter(inputs, target_gpus, dim=0):
...
@@ -14,7 +20,7 @@ def scatter(inputs, target_gpus, dim=0):
"""
"""
def
scatter_map
(
obj
):
def
scatter_map
(
obj
):
if
isinstance
(
obj
,
torch
.
Tensor
):
if
isinstance
(
obj
,
Tensor
):
if
target_gpus
!=
[
-
1
]:
if
target_gpus
!=
[
-
1
]:
return
OrigScatter
.
apply
(
target_gpus
,
None
,
dim
,
obj
)
return
OrigScatter
.
apply
(
target_gpus
,
None
,
dim
,
obj
)
else
:
else
:
...
@@ -33,7 +39,7 @@ def scatter(inputs, target_gpus, dim=0):
...
@@ -33,7 +39,7 @@ def scatter(inputs, target_gpus, dim=0):
if
isinstance
(
obj
,
dict
)
and
len
(
obj
)
>
0
:
if
isinstance
(
obj
,
dict
)
and
len
(
obj
)
>
0
:
out
=
list
(
map
(
type
(
obj
),
zip
(
*
map
(
scatter_map
,
obj
.
items
()))))
out
=
list
(
map
(
type
(
obj
),
zip
(
*
map
(
scatter_map
,
obj
.
items
()))))
return
out
return
out
return
[
obj
for
targets
in
target_gpus
]
return
[
obj
for
_
in
target_gpus
]
# After scatter_map is called, a scatter_map cell will exist. This cell
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# has a reference to the actual function scatter_map, which has references
...
@@ -43,17 +49,22 @@ def scatter(inputs, target_gpus, dim=0):
...
@@ -43,17 +49,22 @@ def scatter(inputs, target_gpus, dim=0):
try
:
try
:
return
scatter_map
(
inputs
)
return
scatter_map
(
inputs
)
finally
:
finally
:
scatter_map
=
None
scatter_map
=
None
# type: ignore
def
scatter_kwargs
(
inputs
,
kwargs
,
target_gpus
,
dim
=
0
):
def
scatter_kwargs
(
inputs
:
ScatterInputs
,
kwargs
:
ScatterInputs
,
target_gpus
:
List
[
int
],
dim
:
int
=
0
)
->
Tuple
[
tuple
,
tuple
]:
"""Scatter with support for kwargs dictionary."""
"""Scatter with support for kwargs dictionary."""
inputs
=
scatter
(
inputs
,
target_gpus
,
dim
)
if
inputs
else
[]
inputs
=
scatter
(
inputs
,
target_gpus
,
dim
)
if
inputs
else
[]
kwargs
=
scatter
(
kwargs
,
target_gpus
,
dim
)
if
kwargs
else
[]
kwargs
=
scatter
(
kwargs
,
target_gpus
,
dim
)
if
kwargs
else
[]
if
len
(
inputs
)
<
len
(
kwargs
):
if
len
(
inputs
)
<
len
(
kwargs
):
inputs
.
extend
([()
for
_
in
range
(
len
(
kwargs
)
-
len
(
inputs
))])
length
=
len
(
kwargs
)
-
len
(
inputs
)
inputs
.
extend
([()
for
_
in
range
(
length
)])
# type: ignore
elif
len
(
kwargs
)
<
len
(
inputs
):
elif
len
(
kwargs
)
<
len
(
inputs
):
kwargs
.
extend
([{}
for
_
in
range
(
len
(
inputs
)
-
len
(
kwargs
))])
length
=
len
(
inputs
)
-
len
(
kwargs
)
kwargs
.
extend
([{}
for
_
in
range
(
length
)])
# type: ignore
inputs
=
tuple
(
inputs
)
inputs
=
tuple
(
inputs
)
kwargs
=
tuple
(
kwargs
)
kwargs
=
tuple
(
kwargs
)
return
inputs
,
kwargs
return
inputs
,
kwargs
mmcv/parallel/utils.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
torch
import
nn
from
.registry
import
MODULE_WRAPPERS
from
.registry
import
MODULE_WRAPPERS
def
is_module_wrapper
(
module
)
:
def
is_module_wrapper
(
module
:
nn
.
Module
)
->
bool
:
"""Check if a module is a module wrapper.
"""Check if a module is a module wrapper.
The following 3 modules in MMCV (and their subclasses) are regarded as
The following 3 modules in MMCV (and their subclasses) are regarded as
module wrappers: DataParallel, DistributedDataParallel,
module wrappers: DataParallel, DistributedDataParallel,
MMDistributedDataParallel (the deprecated version). You may add you own
MMDistributedDataParallel (the deprecated version). You may add you own
module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS.
module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS or
its children registries.
Args:
Args:
module (nn.Module): The module to be checked.
module (nn.Module): The module to be checked.
...
@@ -16,5 +19,14 @@ def is_module_wrapper(module):
...
@@ -16,5 +19,14 @@ def is_module_wrapper(module):
Returns:
Returns:
bool: True if the input module is a module wrapper.
bool: True if the input module is a module wrapper.
"""
"""
module_wrappers
=
tuple
(
MODULE_WRAPPERS
.
module_dict
.
values
())
return
isinstance
(
module
,
module_wrappers
)
def
is_module_in_wrapper
(
module
,
module_wrapper
):
module_wrappers
=
tuple
(
module_wrapper
.
module_dict
.
values
())
if
isinstance
(
module
,
module_wrappers
):
return
True
for
child
in
module_wrapper
.
children
.
values
():
if
is_module_in_wrapper
(
module
,
child
):
return
True
return
False
return
is_module_in_wrapper
(
module
,
MODULE_WRAPPERS
)
mmcv/runner/__init__.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
.base_module
import
BaseModule
,
ModuleList
,
Sequential
from
.base_module
import
BaseModule
,
ModuleDict
,
ModuleList
,
Sequential
from
.base_runner
import
BaseRunner
from
.base_runner
import
BaseRunner
from
.builder
import
RUNNERS
,
build_runner
from
.builder
import
RUNNERS
,
build_runner
from
.checkpoint
import
(
CheckpointLoader
,
_load_checkpoint
,
from
.checkpoint
import
(
CheckpointLoader
,
_load_checkpoint
,
...
@@ -10,14 +10,29 @@ from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info,
...
@@ -10,14 +10,29 @@ from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info,
init_dist
,
master_only
)
init_dist
,
master_only
)
from
.epoch_based_runner
import
EpochBasedRunner
,
Runner
from
.epoch_based_runner
import
EpochBasedRunner
,
Runner
from
.fp16_utils
import
LossScaler
,
auto_fp16
,
force_fp32
,
wrap_fp16_model
from
.fp16_utils
import
LossScaler
,
auto_fp16
,
force_fp32
,
wrap_fp16_model
from
.hooks
import
(
HOOKS
,
CheckpointHook
,
ClosureHook
,
DistEvalHook
,
from
.hooks
import
(
HOOKS
,
CheckpointHook
,
ClearMLLoggerHook
,
ClosureHook
,
DistSamplerSeedHook
,
DvcliveLoggerHook
,
EMAHook
,
EvalHook
,
DistEvalHook
,
DistSamplerSeedHook
,
DvcliveLoggerHook
,
Fp16OptimizerHook
,
GradientCumulativeFp16OptimizerHook
,
EMAHook
,
EvalHook
,
Fp16OptimizerHook
,
GradientCumulativeFp16OptimizerHook
,
GradientCumulativeOptimizerHook
,
Hook
,
IterTimerHook
,
GradientCumulativeOptimizerHook
,
Hook
,
IterTimerHook
,
LoggerHook
,
LrUpdat
erHook
,
Mlflow
LoggerHook
,
LoggerHook
,
MlflowLogg
erHook
,
Neptune
LoggerHook
,
NeptuneLoggerHook
,
OptimizerHook
,
PaviLoggerHook
,
OptimizerHook
,
PaviLoggerHook
,
SegmindLoggerHook
,
SyncBuffersHook
,
TensorboardLoggerHook
,
TextLoggerHook
,
SyncBuffersHook
,
TensorboardLoggerHook
,
TextLoggerHook
,
WandbLoggerHook
)
WandbLoggerHook
)
from
.hooks.lr_updater
import
StepLrUpdaterHook
# noqa
from
.hooks.lr_updater
import
(
CosineAnnealingLrUpdaterHook
,
CosineRestartLrUpdaterHook
,
CyclicLrUpdaterHook
,
ExpLrUpdaterHook
,
FixedLrUpdaterHook
,
FlatCosineAnnealingLrUpdaterHook
,
InvLrUpdaterHook
,
LinearAnnealingLrUpdaterHook
,
LrUpdaterHook
,
OneCycleLrUpdaterHook
,
PolyLrUpdaterHook
)
from
.hooks.momentum_updater
import
(
CosineAnnealingMomentumUpdaterHook
,
CyclicMomentumUpdaterHook
,
LinearAnnealingMomentumUpdaterHook
,
MomentumUpdaterHook
,
OneCycleMomentumUpdaterHook
,
StepMomentumUpdaterHook
)
from
.iter_based_runner
import
IterBasedRunner
,
IterLoader
from
.iter_based_runner
import
IterBasedRunner
,
IterLoader
from
.log_buffer
import
LogBuffer
from
.log_buffer
import
LogBuffer
from
.optimizer
import
(
OPTIMIZER_BUILDERS
,
OPTIMIZERS
,
from
.optimizer
import
(
OPTIMIZER_BUILDERS
,
OPTIMIZERS
,
...
@@ -26,9 +41,18 @@ from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
...
@@ -26,9 +41,18 @@ from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
from
.priority
import
Priority
,
get_priority
from
.priority
import
Priority
,
get_priority
from
.utils
import
get_host_info
,
get_time_str
,
obj_from_dict
,
set_random_seed
from
.utils
import
get_host_info
,
get_time_str
,
obj_from_dict
,
set_random_seed
# initialize ipu to registor ipu runner to RUNNERS
from
mmcv.device
import
ipu
# isort:skip # noqa
__all__
=
[
__all__
=
[
'BaseRunner'
,
'Runner'
,
'EpochBasedRunner'
,
'IterBasedRunner'
,
'LogBuffer'
,
'BaseRunner'
,
'Runner'
,
'EpochBasedRunner'
,
'IterBasedRunner'
,
'LogBuffer'
,
'HOOKS'
,
'Hook'
,
'CheckpointHook'
,
'ClosureHook'
,
'LrUpdaterHook'
,
'HOOKS'
,
'Hook'
,
'CheckpointHook'
,
'ClosureHook'
,
'LrUpdaterHook'
,
'FixedLrUpdaterHook'
,
'StepLrUpdaterHook'
,
'ExpLrUpdaterHook'
,
'PolyLrUpdaterHook'
,
'InvLrUpdaterHook'
,
'CosineAnnealingLrUpdaterHook'
,
'FlatCosineAnnealingLrUpdaterHook'
,
'CosineRestartLrUpdaterHook'
,
'CyclicLrUpdaterHook'
,
'OneCycleLrUpdaterHook'
,
'MomentumUpdaterHook'
,
'StepMomentumUpdaterHook'
,
'CosineAnnealingMomentumUpdaterHook'
,
'CyclicMomentumUpdaterHook'
,
'OneCycleMomentumUpdaterHook'
,
'OptimizerHook'
,
'IterTimerHook'
,
'DistSamplerSeedHook'
,
'LoggerHook'
,
'OptimizerHook'
,
'IterTimerHook'
,
'DistSamplerSeedHook'
,
'LoggerHook'
,
'PaviLoggerHook'
,
'TextLoggerHook'
,
'TensorboardLoggerHook'
,
'PaviLoggerHook'
,
'TextLoggerHook'
,
'TensorboardLoggerHook'
,
'NeptuneLoggerHook'
,
'WandbLoggerHook'
,
'MlflowLoggerHook'
,
'NeptuneLoggerHook'
,
'WandbLoggerHook'
,
'MlflowLoggerHook'
,
...
@@ -42,6 +66,8 @@ __all__ = [
...
@@ -42,6 +66,8 @@ __all__ = [
'SyncBuffersHook'
,
'EMAHook'
,
'build_runner'
,
'RUNNERS'
,
'allreduce_grads'
,
'SyncBuffersHook'
,
'EMAHook'
,
'build_runner'
,
'RUNNERS'
,
'allreduce_grads'
,
'allreduce_params'
,
'LossScaler'
,
'CheckpointLoader'
,
'BaseModule'
,
'allreduce_params'
,
'LossScaler'
,
'CheckpointLoader'
,
'BaseModule'
,
'_load_checkpoint_with_prefix'
,
'EvalHook'
,
'DistEvalHook'
,
'Sequential'
,
'_load_checkpoint_with_prefix'
,
'EvalHook'
,
'DistEvalHook'
,
'Sequential'
,
'ModuleList'
,
'GradientCumulativeOptimizerHook'
,
'ModuleDict'
,
'ModuleList'
,
'GradientCumulativeOptimizerHook'
,
'GradientCumulativeFp16OptimizerHook'
,
'DefaultRunnerConstructor'
'GradientCumulativeFp16OptimizerHook'
,
'DefaultRunnerConstructor'
,
'SegmindLoggerHook'
,
'LinearAnnealingMomentumUpdaterHook'
,
'LinearAnnealingLrUpdaterHook'
,
'ClearMLLoggerHook'
]
]
mmcv/runner/base_module.py
View file @
fdeee889
...
@@ -4,6 +4,7 @@ import warnings
...
@@ -4,6 +4,7 @@ import warnings
from
abc
import
ABCMeta
from
abc
import
ABCMeta
from
collections
import
defaultdict
from
collections
import
defaultdict
from
logging
import
FileHandler
from
logging
import
FileHandler
from
typing
import
Iterable
,
Optional
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -19,24 +20,23 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
...
@@ -19,24 +20,23 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
``torch.nn.Module``, ``BaseModule`` mainly adds three attributes.
``torch.nn.Module``, ``BaseModule`` mainly adds three attributes.
- ``init_cfg``: the config to control the initialization.
- ``init_cfg``: the config to control the initialization.
- ``init_weights``: The function of parameter
- ``init_weights``: The function of parameter initialization and recording
initialization and recording initialization
initialization information.
information.
- ``_params_init_info``: Used to track the parameter initialization
- ``_params_init_info``: Used to track the parameter
information. This attribute only exists during executing the
initialization information. This attribute only
``init_weights``.
exists during executing the ``init_weights``.
Args:
Args:
init_cfg (dict, optional): Initialization config dict.
init_cfg (dict, optional): Initialization config dict.
"""
"""
def
__init__
(
self
,
init_cfg
=
None
):
def
__init__
(
self
,
init_cfg
:
Optional
[
dict
]
=
None
):
"""Initialize BaseModule, inherited from `torch.nn.Module`"""
"""Initialize BaseModule, inherited from `torch.nn.Module`"""
# NOTE init_cfg can be defined in different levels, but init_cfg
# NOTE init_cfg can be defined in different levels, but init_cfg
# in low levels has a higher priority.
# in low levels has a higher priority.
super
(
BaseModule
,
self
).
__init__
()
super
().
__init__
()
# define default value of init_cfg instead of hard code
# define default value of init_cfg instead of hard code
# in init_weights() function
# in init_weights() function
self
.
_is_init
=
False
self
.
_is_init
=
False
...
@@ -50,10 +50,10 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
...
@@ -50,10 +50,10 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
# self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
# self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
@
property
@
property
def
is_init
(
self
):
def
is_init
(
self
)
->
bool
:
return
self
.
_is_init
return
self
.
_is_init
def
init_weights
(
self
):
def
init_weights
(
self
)
->
None
:
"""Initialize the weights."""
"""Initialize the weights."""
is_top_level_module
=
False
is_top_level_module
=
False
...
@@ -68,7 +68,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
...
@@ -68,7 +68,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
# which indicates whether the parameter has been modified.
# which indicates whether the parameter has been modified.
# this attribute would be deleted after all parameters
# this attribute would be deleted after all parameters
# is initialized.
# is initialized.
self
.
_params_init_info
=
defaultdict
(
dict
)
self
.
_params_init_info
:
defaultdict
=
defaultdict
(
dict
)
is_top_level_module
=
True
is_top_level_module
=
True
# Initialize the `_params_init_info`,
# Initialize the `_params_init_info`,
...
@@ -134,7 +134,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
...
@@ -134,7 +134,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
del
sub_module
.
_params_init_info
del
sub_module
.
_params_init_info
@
master_only
@
master_only
def
_dump_init_info
(
self
,
logger_name
)
:
def
_dump_init_info
(
self
,
logger_name
:
str
)
->
None
:
"""Dump the initialization information to a file named
"""Dump the initialization information to a file named
`initialization.log.json` in workdir.
`initialization.log.json` in workdir.
...
@@ -177,7 +177,7 @@ class Sequential(BaseModule, nn.Sequential):
...
@@ -177,7 +177,7 @@ class Sequential(BaseModule, nn.Sequential):
init_cfg (dict, optional): Initialization config dict.
init_cfg (dict, optional): Initialization config dict.
"""
"""
def
__init__
(
self
,
*
args
,
init_cfg
=
None
):
def
__init__
(
self
,
*
args
,
init_cfg
:
Optional
[
dict
]
=
None
):
BaseModule
.
__init__
(
self
,
init_cfg
)
BaseModule
.
__init__
(
self
,
init_cfg
)
nn
.
Sequential
.
__init__
(
self
,
*
args
)
nn
.
Sequential
.
__init__
(
self
,
*
args
)
...
@@ -190,6 +190,24 @@ class ModuleList(BaseModule, nn.ModuleList):
...
@@ -190,6 +190,24 @@ class ModuleList(BaseModule, nn.ModuleList):
init_cfg (dict, optional): Initialization config dict.
init_cfg (dict, optional): Initialization config dict.
"""
"""
def
__init__
(
self
,
modules
=
None
,
init_cfg
=
None
):
def
__init__
(
self
,
modules
:
Optional
[
Iterable
]
=
None
,
init_cfg
:
Optional
[
dict
]
=
None
):
BaseModule
.
__init__
(
self
,
init_cfg
)
BaseModule
.
__init__
(
self
,
init_cfg
)
nn
.
ModuleList
.
__init__
(
self
,
modules
)
nn
.
ModuleList
.
__init__
(
self
,
modules
)
class
ModuleDict
(
BaseModule
,
nn
.
ModuleDict
):
"""ModuleDict in openmmlab.
Args:
modules (dict, optional): a mapping (dictionary) of (string: module)
or an iterable of key-value pairs of type (string, module).
init_cfg (dict, optional): Initialization config dict.
"""
def
__init__
(
self
,
modules
:
Optional
[
dict
]
=
None
,
init_cfg
:
Optional
[
dict
]
=
None
):
BaseModule
.
__init__
(
self
,
init_cfg
)
nn
.
ModuleDict
.
__init__
(
self
,
modules
)
mmcv/runner/base_runner.py
View file @
fdeee889
...
@@ -4,9 +4,13 @@ import logging
...
@@ -4,9 +4,13 @@ import logging
import
os.path
as
osp
import
os.path
as
osp
import
warnings
import
warnings
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
from
collections
import
OrderedDict
from
typing
import
(
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
no_type_check
)
import
torch
import
torch
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
torch.utils.data
import
DataLoader
import
mmcv
import
mmcv
from
..parallel
import
is_module_wrapper
from
..parallel
import
is_module_wrapper
...
@@ -49,20 +53,22 @@ class BaseRunner(metaclass=ABCMeta):
...
@@ -49,20 +53,22 @@ class BaseRunner(metaclass=ABCMeta):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
model
,
model
:
torch
.
nn
.
Module
,
batch_processor
=
None
,
batch_processor
:
Optional
[
Callable
]
=
None
,
optimizer
=
None
,
optimizer
:
Union
[
Dict
,
torch
.
optim
.
Optimizer
,
None
]
=
None
,
work_dir
=
None
,
work_dir
:
Optional
[
str
]
=
None
,
logger
=
None
,
logger
:
Optional
[
logging
.
Logger
]
=
None
,
meta
=
None
,
meta
:
Optional
[
Dict
]
=
None
,
max_iters
=
None
,
max_iters
:
Optional
[
int
]
=
None
,
max_epochs
=
None
)
:
max_epochs
:
Optional
[
int
]
=
None
)
->
None
:
if
batch_processor
is
not
None
:
if
batch_processor
is
not
None
:
if
not
callable
(
batch_processor
):
if
not
callable
(
batch_processor
):
raise
TypeError
(
'batch_processor must be callable, '
raise
TypeError
(
'batch_processor must be callable, '
f
'but got
{
type
(
batch_processor
)
}
'
)
f
'but got
{
type
(
batch_processor
)
}
'
)
warnings
.
warn
(
'batch_processor is deprecated, please implement '
warnings
.
warn
(
'train_step() and val_step() in the model instead.'
)
'batch_processor is deprecated, please implement '
'train_step() and val_step() in the model instead.'
,
DeprecationWarning
)
# raise an error is `batch_processor` is not None and
# raise an error is `batch_processor` is not None and
# `model.train_step()` exists.
# `model.train_step()` exists.
if
is_module_wrapper
(
model
):
if
is_module_wrapper
(
model
):
...
@@ -104,8 +110,8 @@ class BaseRunner(metaclass=ABCMeta):
...
@@ -104,8 +110,8 @@ class BaseRunner(metaclass=ABCMeta):
self
.
logger
=
logger
self
.
logger
=
logger
self
.
meta
=
meta
self
.
meta
=
meta
# create work_dir
# create work_dir
if
mmcv
.
is_str
(
work_dir
):
if
isinstance
(
work_dir
,
str
):
self
.
work_dir
=
osp
.
abspath
(
work_dir
)
self
.
work_dir
:
Optional
[
str
]
=
osp
.
abspath
(
work_dir
)
mmcv
.
mkdir_or_exist
(
self
.
work_dir
)
mmcv
.
mkdir_or_exist
(
self
.
work_dir
)
elif
work_dir
is
None
:
elif
work_dir
is
None
:
self
.
work_dir
=
None
self
.
work_dir
=
None
...
@@ -120,8 +126,8 @@ class BaseRunner(metaclass=ABCMeta):
...
@@ -120,8 +126,8 @@ class BaseRunner(metaclass=ABCMeta):
self
.
_rank
,
self
.
_world_size
=
get_dist_info
()
self
.
_rank
,
self
.
_world_size
=
get_dist_info
()
self
.
timestamp
=
get_time_str
()
self
.
timestamp
=
get_time_str
()
self
.
mode
=
None
self
.
mode
:
Optional
[
str
]
=
None
self
.
_hooks
=
[]
self
.
_hooks
:
List
[
Hook
]
=
[]
self
.
_epoch
=
0
self
.
_epoch
=
0
self
.
_iter
=
0
self
.
_iter
=
0
self
.
_inner_iter
=
0
self
.
_inner_iter
=
0
...
@@ -136,38 +142,38 @@ class BaseRunner(metaclass=ABCMeta):
...
@@ -136,38 +142,38 @@ class BaseRunner(metaclass=ABCMeta):
self
.
log_buffer
=
LogBuffer
()
self
.
log_buffer
=
LogBuffer
()
@
property
@
property
def
model_name
(
self
):
def
model_name
(
self
)
->
str
:
"""str: Name of the model, usually the module class name."""
"""str: Name of the model, usually the module class name."""
return
self
.
_model_name
return
self
.
_model_name
@
property
@
property
def
rank
(
self
):
def
rank
(
self
)
->
int
:
"""int: Rank of current process. (distributed training)"""
"""int: Rank of current process. (distributed training)"""
return
self
.
_rank
return
self
.
_rank
@
property
@
property
def
world_size
(
self
):
def
world_size
(
self
)
->
int
:
"""int: Number of processes participating in the job.
"""int: Number of processes participating in the job.
(distributed training)"""
(distributed training)"""
return
self
.
_world_size
return
self
.
_world_size
@
property
@
property
def
hooks
(
self
):
def
hooks
(
self
)
->
List
[
Hook
]
:
"""list[:obj:`Hook`]: A list of registered hooks."""
"""list[:obj:`Hook`]: A list of registered hooks."""
return
self
.
_hooks
return
self
.
_hooks
@
property
@
property
def
epoch
(
self
):
def
epoch
(
self
)
->
int
:
"""int: Current epoch."""
"""int: Current epoch."""
return
self
.
_epoch
return
self
.
_epoch
@
property
@
property
def
iter
(
self
):
def
iter
(
self
)
->
int
:
"""int: Current iteration."""
"""int: Current iteration."""
return
self
.
_iter
return
self
.
_iter
@
property
@
property
def
inner_iter
(
self
):
def
inner_iter
(
self
)
->
int
:
"""int: Iteration in an epoch."""
"""int: Iteration in an epoch."""
return
self
.
_inner_iter
return
self
.
_inner_iter
...
@@ -190,26 +196,28 @@ class BaseRunner(metaclass=ABCMeta):
...
@@ -190,26 +196,28 @@ class BaseRunner(metaclass=ABCMeta):
pass
pass
@
abstractmethod
@
abstractmethod
def
run
(
self
,
data_loaders
,
workflow
,
**
kwargs
):
def
run
(
self
,
data_loaders
:
List
[
DataLoader
],
workflow
:
List
[
Tuple
[
str
,
int
]],
**
kwargs
)
->
Any
:
pass
pass
@
abstractmethod
@
abstractmethod
def
save_checkpoint
(
self
,
def
save_checkpoint
(
self
,
out_dir
,
out_dir
:
str
,
filename_tmpl
,
filename_tmpl
:
str
,
save_optimizer
=
True
,
save_optimizer
:
bool
=
True
,
meta
=
None
,
meta
:
Optional
[
Dict
]
=
None
,
create_symlink
=
True
)
:
create_symlink
:
bool
=
True
)
->
None
:
pass
pass
def
current_lr
(
self
):
def
current_lr
(
self
)
->
Union
[
List
[
float
],
Dict
[
str
,
List
[
float
]]]
:
"""Get current learning rates.
"""Get current learning rates.
Returns:
Returns:
list[float] | dict[str, list[float]]: Current learning rates of all
list[float] | dict[str, list[float]]: Current learning rates of all
param groups. If the runner has a dict of optimizers, this
param groups. If the runner has a dict of optimizers, this
method
method
will return a dict.
will return a dict.
"""
"""
lr
:
Union
[
List
[
float
],
Dict
[
str
,
List
[
float
]]]
if
isinstance
(
self
.
optimizer
,
torch
.
optim
.
Optimizer
):
if
isinstance
(
self
.
optimizer
,
torch
.
optim
.
Optimizer
):
lr
=
[
group
[
'lr'
]
for
group
in
self
.
optimizer
.
param_groups
]
lr
=
[
group
[
'lr'
]
for
group
in
self
.
optimizer
.
param_groups
]
elif
isinstance
(
self
.
optimizer
,
dict
):
elif
isinstance
(
self
.
optimizer
,
dict
):
...
@@ -221,13 +229,13 @@ class BaseRunner(metaclass=ABCMeta):
...
@@ -221,13 +229,13 @@ class BaseRunner(metaclass=ABCMeta):
'lr is not applicable because optimizer does not exist.'
)
'lr is not applicable because optimizer does not exist.'
)
return
lr
return
lr
def
current_momentum
(
self
):
def
current_momentum
(
self
)
->
Union
[
List
[
float
],
Dict
[
str
,
List
[
float
]]]
:
"""Get current momentums.
"""Get current momentums.
Returns:
Returns:
list[float] | dict[str, list[float]]: Current momentums of all
list[float] | dict[str, list[float]]: Current momentums of all
param groups. If the runner has a dict of optimizers, this
param groups. If the runner has a dict of optimizers, this
method
method
will return a dict.
will return a dict.
"""
"""
def
_get_momentum
(
optimizer
):
def
_get_momentum
(
optimizer
):
...
@@ -252,7 +260,9 @@ class BaseRunner(metaclass=ABCMeta):
...
@@ -252,7 +260,9 @@ class BaseRunner(metaclass=ABCMeta):
momentums
[
name
]
=
_get_momentum
(
optim
)
momentums
[
name
]
=
_get_momentum
(
optim
)
return
momentums
return
momentums
def
register_hook
(
self
,
hook
,
priority
=
'NORMAL'
):
def
register_hook
(
self
,
hook
:
Hook
,
priority
:
Union
[
int
,
str
,
Priority
]
=
'NORMAL'
)
->
None
:
"""Register a hook into the hook list.
"""Register a hook into the hook list.
The hook will be inserted into a priority queue, with the specified
The hook will be inserted into a priority queue, with the specified
...
@@ -269,25 +279,25 @@ class BaseRunner(metaclass=ABCMeta):
...
@@ -269,25 +279,25 @@ class BaseRunner(metaclass=ABCMeta):
if
hasattr
(
hook
,
'priority'
):
if
hasattr
(
hook
,
'priority'
):
raise
ValueError
(
'"priority" is a reserved attribute for hooks'
)
raise
ValueError
(
'"priority" is a reserved attribute for hooks'
)
priority
=
get_priority
(
priority
)
priority
=
get_priority
(
priority
)
hook
.
priority
=
priority
hook
.
priority
=
priority
# type: ignore
# insert the hook to a sorted list
# insert the hook to a sorted list
inserted
=
False
inserted
=
False
for
i
in
range
(
len
(
self
.
_hooks
)
-
1
,
-
1
,
-
1
):
for
i
in
range
(
len
(
self
.
_hooks
)
-
1
,
-
1
,
-
1
):
if
priority
>=
self
.
_hooks
[
i
].
priority
:
if
priority
>=
self
.
_hooks
[
i
].
priority
:
# type: ignore
self
.
_hooks
.
insert
(
i
+
1
,
hook
)
self
.
_hooks
.
insert
(
i
+
1
,
hook
)
inserted
=
True
inserted
=
True
break
break
if
not
inserted
:
if
not
inserted
:
self
.
_hooks
.
insert
(
0
,
hook
)
self
.
_hooks
.
insert
(
0
,
hook
)
def
register_hook_from_cfg
(
self
,
hook_cfg
)
:
def
register_hook_from_cfg
(
self
,
hook_cfg
:
Dict
)
->
None
:
"""Register a hook from its cfg.
"""Register a hook from its cfg.
Args:
Args:
hook_cfg (dict): Hook config. It should have at least keys 'type'
hook_cfg (dict): Hook config. It should have at least keys 'type'
and 'priority' indicating its type and priority.
and 'priority' indicating its type and priority.
Note
s
:
Note:
The specific hook class to register should not use 'type' and
The specific hook class to register should not use 'type' and
'priority' arguments during initialization.
'priority' arguments during initialization.
"""
"""
...
@@ -296,7 +306,7 @@ class BaseRunner(metaclass=ABCMeta):
...
@@ -296,7 +306,7 @@ class BaseRunner(metaclass=ABCMeta):
hook
=
mmcv
.
build_from_cfg
(
hook_cfg
,
HOOKS
)
hook
=
mmcv
.
build_from_cfg
(
hook_cfg
,
HOOKS
)
self
.
register_hook
(
hook
,
priority
=
priority
)
self
.
register_hook
(
hook
,
priority
=
priority
)
def
call_hook
(
self
,
fn_name
)
:
def
call_hook
(
self
,
fn_name
:
str
)
->
None
:
"""Call all hooks.
"""Call all hooks.
Args:
Args:
...
@@ -306,14 +316,14 @@ class BaseRunner(metaclass=ABCMeta):
...
@@ -306,14 +316,14 @@ class BaseRunner(metaclass=ABCMeta):
for
hook
in
self
.
_hooks
:
for
hook
in
self
.
_hooks
:
getattr
(
hook
,
fn_name
)(
self
)
getattr
(
hook
,
fn_name
)(
self
)
def
get_hook_info
(
self
):
def
get_hook_info
(
self
)
->
str
:
# Get hooks info in each stage
# Get hooks info in each stage
stage_hook_map
=
{
stage
:
[]
for
stage
in
Hook
.
stages
}
stage_hook_map
:
Dict
[
str
,
list
]
=
{
stage
:
[]
for
stage
in
Hook
.
stages
}
for
hook
in
self
.
hooks
:
for
hook
in
self
.
hooks
:
try
:
try
:
priority
=
Priority
(
hook
.
priority
).
name
priority
=
Priority
(
hook
.
priority
).
name
# type: ignore
except
ValueError
:
except
ValueError
:
priority
=
hook
.
priority
priority
=
hook
.
priority
# type: ignore
classname
=
hook
.
__class__
.
__name__
classname
=
hook
.
__class__
.
__name__
hook_info
=
f
'(
{
priority
:
<
12
}
)
{
classname
:
<
35
}
'
hook_info
=
f
'(
{
priority
:
<
12
}
)
{
classname
:
<
35
}
'
for
trigger_stage
in
hook
.
get_triggered_stages
():
for
trigger_stage
in
hook
.
get_triggered_stages
():
...
@@ -329,11 +339,13 @@ class BaseRunner(metaclass=ABCMeta):
...
@@ -329,11 +339,13 @@ class BaseRunner(metaclass=ABCMeta):
stage_hook_infos
.
append
(
info
)
stage_hook_infos
.
append
(
info
)
return
'
\n
'
.
join
(
stage_hook_infos
)
return
'
\n
'
.
join
(
stage_hook_infos
)
def
load_checkpoint
(
self
,
def
load_checkpoint
(
filename
,
self
,
map_location
=
'cpu'
,
filename
:
str
,
strict
=
False
,
map_location
:
Union
[
str
,
Callable
]
=
'cpu'
,
revise_keys
=
[(
r
'^module.'
,
''
)]):
strict
:
bool
=
False
,
revise_keys
:
List
=
[(
r
'^module.'
,
''
)],
)
->
Union
[
Dict
,
OrderedDict
]:
return
load_checkpoint
(
return
load_checkpoint
(
self
.
model
,
self
.
model
,
filename
,
filename
,
...
@@ -342,10 +354,11 @@ class BaseRunner(metaclass=ABCMeta):
...
@@ -342,10 +354,11 @@ class BaseRunner(metaclass=ABCMeta):
self
.
logger
,
self
.
logger
,
revise_keys
=
revise_keys
)
revise_keys
=
revise_keys
)
@
no_type_check
def
resume
(
self
,
def
resume
(
self
,
checkpoint
,
checkpoint
:
str
,
resume_optimizer
=
True
,
resume_optimizer
:
bool
=
True
,
map_location
=
'default'
):
map_location
:
Union
[
str
,
Callable
]
=
'default'
)
->
None
:
if
map_location
==
'default'
:
if
map_location
==
'default'
:
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device_id
=
torch
.
cuda
.
current_device
()
device_id
=
torch
.
cuda
.
current_device
()
...
@@ -396,7 +409,7 @@ class BaseRunner(metaclass=ABCMeta):
...
@@ -396,7 +409,7 @@ class BaseRunner(metaclass=ABCMeta):
self
.
logger
.
info
(
'resumed epoch %d, iter %d'
,
self
.
epoch
,
self
.
iter
)
self
.
logger
.
info
(
'resumed epoch %d, iter %d'
,
self
.
epoch
,
self
.
iter
)
def
register_lr_hook
(
self
,
lr_config
)
:
def
register_lr_hook
(
self
,
lr_config
:
Union
[
Dict
,
Hook
,
None
])
->
None
:
if
lr_config
is
None
:
if
lr_config
is
None
:
return
return
elif
isinstance
(
lr_config
,
dict
):
elif
isinstance
(
lr_config
,
dict
):
...
@@ -417,7 +430,8 @@ class BaseRunner(metaclass=ABCMeta):
...
@@ -417,7 +430,8 @@ class BaseRunner(metaclass=ABCMeta):
hook
=
lr_config
hook
=
lr_config
self
.
register_hook
(
hook
,
priority
=
'VERY_HIGH'
)
self
.
register_hook
(
hook
,
priority
=
'VERY_HIGH'
)
def
register_momentum_hook
(
self
,
momentum_config
):
def
register_momentum_hook
(
self
,
momentum_config
:
Union
[
Dict
,
Hook
,
None
])
->
None
:
if
momentum_config
is
None
:
if
momentum_config
is
None
:
return
return
if
isinstance
(
momentum_config
,
dict
):
if
isinstance
(
momentum_config
,
dict
):
...
@@ -438,7 +452,8 @@ class BaseRunner(metaclass=ABCMeta):
...
@@ -438,7 +452,8 @@ class BaseRunner(metaclass=ABCMeta):
hook
=
momentum_config
hook
=
momentum_config
self
.
register_hook
(
hook
,
priority
=
'HIGH'
)
self
.
register_hook
(
hook
,
priority
=
'HIGH'
)
def
register_optimizer_hook
(
self
,
optimizer_config
):
def
register_optimizer_hook
(
self
,
optimizer_config
:
Union
[
Dict
,
Hook
,
None
])
->
None
:
if
optimizer_config
is
None
:
if
optimizer_config
is
None
:
return
return
if
isinstance
(
optimizer_config
,
dict
):
if
isinstance
(
optimizer_config
,
dict
):
...
@@ -448,7 +463,8 @@ class BaseRunner(metaclass=ABCMeta):
...
@@ -448,7 +463,8 @@ class BaseRunner(metaclass=ABCMeta):
hook
=
optimizer_config
hook
=
optimizer_config
self
.
register_hook
(
hook
,
priority
=
'ABOVE_NORMAL'
)
self
.
register_hook
(
hook
,
priority
=
'ABOVE_NORMAL'
)
def
register_checkpoint_hook
(
self
,
checkpoint_config
):
def
register_checkpoint_hook
(
self
,
checkpoint_config
:
Union
[
Dict
,
Hook
,
None
])
->
None
:
if
checkpoint_config
is
None
:
if
checkpoint_config
is
None
:
return
return
if
isinstance
(
checkpoint_config
,
dict
):
if
isinstance
(
checkpoint_config
,
dict
):
...
@@ -458,7 +474,7 @@ class BaseRunner(metaclass=ABCMeta):
...
@@ -458,7 +474,7 @@ class BaseRunner(metaclass=ABCMeta):
hook
=
checkpoint_config
hook
=
checkpoint_config
self
.
register_hook
(
hook
,
priority
=
'NORMAL'
)
self
.
register_hook
(
hook
,
priority
=
'NORMAL'
)
def
register_logger_hooks
(
self
,
log_config
)
:
def
register_logger_hooks
(
self
,
log_config
:
Optional
[
Dict
])
->
None
:
if
log_config
is
None
:
if
log_config
is
None
:
return
return
log_interval
=
log_config
[
'interval'
]
log_interval
=
log_config
[
'interval'
]
...
@@ -467,7 +483,10 @@ class BaseRunner(metaclass=ABCMeta):
...
@@ -467,7 +483,10 @@ class BaseRunner(metaclass=ABCMeta):
info
,
HOOKS
,
default_args
=
dict
(
interval
=
log_interval
))
info
,
HOOKS
,
default_args
=
dict
(
interval
=
log_interval
))
self
.
register_hook
(
logger_hook
,
priority
=
'VERY_LOW'
)
self
.
register_hook
(
logger_hook
,
priority
=
'VERY_LOW'
)
def
register_timer_hook
(
self
,
timer_config
):
def
register_timer_hook
(
self
,
timer_config
:
Union
[
Dict
,
Hook
,
None
],
)
->
None
:
if
timer_config
is
None
:
if
timer_config
is
None
:
return
return
if
isinstance
(
timer_config
,
dict
):
if
isinstance
(
timer_config
,
dict
):
...
@@ -477,7 +496,8 @@ class BaseRunner(metaclass=ABCMeta):
...
@@ -477,7 +496,8 @@ class BaseRunner(metaclass=ABCMeta):
hook
=
timer_config
hook
=
timer_config
self
.
register_hook
(
hook
,
priority
=
'LOW'
)
self
.
register_hook
(
hook
,
priority
=
'LOW'
)
def
register_custom_hooks
(
self
,
custom_config
):
def
register_custom_hooks
(
self
,
custom_config
:
Union
[
List
,
Dict
,
Hook
,
None
])
->
None
:
if
custom_config
is
None
:
if
custom_config
is
None
:
return
return
...
@@ -490,7 +510,10 @@ class BaseRunner(metaclass=ABCMeta):
...
@@ -490,7 +510,10 @@ class BaseRunner(metaclass=ABCMeta):
else
:
else
:
self
.
register_hook
(
item
,
priority
=
'NORMAL'
)
self
.
register_hook
(
item
,
priority
=
'NORMAL'
)
def
register_profiler_hook
(
self
,
profiler_config
):
def
register_profiler_hook
(
self
,
profiler_config
:
Union
[
Dict
,
Hook
,
None
],
)
->
None
:
if
profiler_config
is
None
:
if
profiler_config
is
None
:
return
return
if
isinstance
(
profiler_config
,
dict
):
if
isinstance
(
profiler_config
,
dict
):
...
@@ -500,14 +523,15 @@ class BaseRunner(metaclass=ABCMeta):
...
@@ -500,14 +523,15 @@ class BaseRunner(metaclass=ABCMeta):
hook
=
profiler_config
hook
=
profiler_config
self
.
register_hook
(
hook
)
self
.
register_hook
(
hook
)
def
register_training_hooks
(
self
,
def
register_training_hooks
(
lr_config
,
self
,
optimizer_config
=
None
,
lr_config
:
Union
[
Dict
,
Hook
,
None
],
checkpoint_config
=
None
,
optimizer_config
:
Union
[
Dict
,
Hook
,
None
]
=
None
,
log_config
=
None
,
checkpoint_config
:
Union
[
Dict
,
Hook
,
None
]
=
None
,
momentum_config
=
None
,
log_config
:
Optional
[
Dict
]
=
None
,
timer_config
=
dict
(
type
=
'IterTimerHook'
),
momentum_config
:
Union
[
Dict
,
Hook
,
None
]
=
None
,
custom_hooks_config
=
None
):
timer_config
:
Union
[
Dict
,
Hook
]
=
dict
(
type
=
'IterTimerHook'
),
custom_hooks_config
:
Union
[
List
,
Dict
,
Hook
,
None
]
=
None
)
->
None
:
"""Register default and custom hooks for training.
"""Register default and custom hooks for training.
Default and custom hooks include:
Default and custom hooks include:
...
...
mmcv/runner/builder.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
copy
import
copy
from
typing
import
Optional
from
..utils
import
Registry
from
..utils
import
Registry
...
@@ -7,11 +8,11 @@ RUNNERS = Registry('runner')
...
@@ -7,11 +8,11 @@ RUNNERS = Registry('runner')
RUNNER_BUILDERS
=
Registry
(
'runner builder'
)
RUNNER_BUILDERS
=
Registry
(
'runner builder'
)
def
build_runner_constructor
(
cfg
):
def
build_runner_constructor
(
cfg
:
dict
):
return
RUNNER_BUILDERS
.
build
(
cfg
)
return
RUNNER_BUILDERS
.
build
(
cfg
)
def
build_runner
(
cfg
,
default_args
=
None
):
def
build_runner
(
cfg
:
dict
,
default_args
:
Optional
[
dict
]
=
None
):
runner_cfg
=
copy
.
deepcopy
(
cfg
)
runner_cfg
=
copy
.
deepcopy
(
cfg
)
constructor_type
=
runner_cfg
.
pop
(
'constructor'
,
constructor_type
=
runner_cfg
.
pop
(
'constructor'
,
'DefaultRunnerConstructor'
)
'DefaultRunnerConstructor'
)
...
...
mmcv/runner/checkpoint.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
io
import
io
import
logging
import
os
import
os
import
os.path
as
osp
import
os.path
as
osp
import
pkgutil
import
pkgutil
...
@@ -9,8 +10,10 @@ import warnings
...
@@ -9,8 +10,10 @@ import warnings
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
importlib
import
import_module
from
importlib
import
import_module
from
tempfile
import
TemporaryDirectory
from
tempfile
import
TemporaryDirectory
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torchvision
import
torchvision
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
...
@@ -18,7 +21,7 @@ import mmcv
...
@@ -18,7 +21,7 @@ import mmcv
from
..fileio
import
FileClient
from
..fileio
import
FileClient
from
..fileio
import
load
as
load_file
from
..fileio
import
load
as
load_file
from
..parallel
import
is_module_wrapper
from
..parallel
import
is_module_wrapper
from
..utils
import
load_url
,
mkdir_or_exist
from
..utils
import
digit_version
,
load_url
,
mkdir_or_exist
from
.dist_utils
import
get_dist_info
from
.dist_utils
import
get_dist_info
ENV_MMCV_HOME
=
'MMCV_HOME'
ENV_MMCV_HOME
=
'MMCV_HOME'
...
@@ -26,7 +29,7 @@ ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
...
@@ -26,7 +29,7 @@ ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR
=
'~/.cache'
DEFAULT_CACHE_DIR
=
'~/.cache'
def
_get_mmcv_home
():
def
_get_mmcv_home
()
->
str
:
mmcv_home
=
os
.
path
.
expanduser
(
mmcv_home
=
os
.
path
.
expanduser
(
os
.
getenv
(
os
.
getenv
(
ENV_MMCV_HOME
,
ENV_MMCV_HOME
,
...
@@ -37,7 +40,10 @@ def _get_mmcv_home():
...
@@ -37,7 +40,10 @@ def _get_mmcv_home():
return
mmcv_home
return
mmcv_home
def
load_state_dict
(
module
,
state_dict
,
strict
=
False
,
logger
=
None
):
def
load_state_dict
(
module
:
nn
.
Module
,
state_dict
:
Union
[
dict
,
OrderedDict
],
strict
:
bool
=
False
,
logger
:
Optional
[
logging
.
Logger
]
=
None
)
->
None
:
"""Load state_dict to a module.
"""Load state_dict to a module.
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
...
@@ -46,21 +52,21 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
...
@@ -46,21 +52,21 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
Args:
Args:
module (Module): Module that receives the state_dict.
module (Module): Module that receives the state_dict.
state_dict (OrderedDict): Weights.
state_dict (
dict or
OrderedDict): Weights.
strict (bool): whether to strictly enforce that the keys
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
logger (:obj:`logging.Logger`, optional): Logger to log the error
logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used.
message. If not specified, print function will be used.
"""
"""
unexpected_keys
=
[]
unexpected_keys
:
List
[
str
]
=
[]
all_missing_keys
=
[]
all_missing_keys
:
List
[
str
]
=
[]
err_msg
=
[]
err_msg
:
List
[
str
]
=
[]
metadata
=
getattr
(
state_dict
,
'_metadata'
,
None
)
metadata
=
getattr
(
state_dict
,
'_metadata'
,
None
)
state_dict
=
state_dict
.
copy
()
state_dict
=
state_dict
.
copy
()
# type: ignore
if
metadata
is
not
None
:
if
metadata
is
not
None
:
state_dict
.
_metadata
=
metadata
state_dict
.
_metadata
=
metadata
# type: ignore
# use _load_from_state_dict to enable checkpoint version control
# use _load_from_state_dict to enable checkpoint version control
def
load
(
module
,
prefix
=
''
):
def
load
(
module
,
prefix
=
''
):
...
@@ -78,7 +84,8 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
...
@@ -78,7 +84,8 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
load
(
child
,
prefix
+
name
+
'.'
)
load
(
child
,
prefix
+
name
+
'.'
)
load
(
module
)
load
(
module
)
load
=
None
# break load->load reference cycle
# break load->load reference cycle
load
=
None
# type: ignore
# ignore "num_batches_tracked" of BN layers
# ignore "num_batches_tracked" of BN layers
missing_keys
=
[
missing_keys
=
[
...
@@ -96,7 +103,7 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
...
@@ -96,7 +103,7 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
if
len
(
err_msg
)
>
0
and
rank
==
0
:
if
len
(
err_msg
)
>
0
and
rank
==
0
:
err_msg
.
insert
(
err_msg
.
insert
(
0
,
'The model and loaded state dict do not match exactly
\n
'
)
0
,
'The model and loaded state dict do not match exactly
\n
'
)
err_msg
=
'
\n
'
.
join
(
err_msg
)
err_msg
=
'
\n
'
.
join
(
err_msg
)
# type: ignore
if
strict
:
if
strict
:
raise
RuntimeError
(
err_msg
)
raise
RuntimeError
(
err_msg
)
elif
logger
is
not
None
:
elif
logger
is
not
None
:
...
@@ -106,14 +113,48 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
...
@@ -106,14 +113,48 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
def
get_torchvision_models
():
def
get_torchvision_models
():
if
digit_version
(
torchvision
.
__version__
)
<
digit_version
(
'0.13.0a0'
):
model_urls
=
dict
()
model_urls
=
dict
()
for
_
,
name
,
ispkg
in
pkgutil
.
walk_packages
(
torchvision
.
models
.
__path__
):
# When the version of torchvision is lower than 0.13, the model url is
# not declared in `torchvision.model.__init__.py`, so we need to
# iterate through `torchvision.models.__path__` to get the url for each
# model.
for
_
,
name
,
ispkg
in
pkgutil
.
walk_packages
(
torchvision
.
models
.
__path__
):
if
ispkg
:
if
ispkg
:
continue
continue
_zoo
=
import_module
(
f
'torchvision.models.
{
name
}
'
)
_zoo
=
import_module
(
f
'torchvision.models.
{
name
}
'
)
if
hasattr
(
_zoo
,
'model_urls'
):
if
hasattr
(
_zoo
,
'model_urls'
):
_urls
=
getattr
(
_zoo
,
'model_urls'
)
_urls
=
getattr
(
_zoo
,
'model_urls'
)
model_urls
.
update
(
_urls
)
model_urls
.
update
(
_urls
)
else
:
# Since torchvision bumps to v0.13, the weight loading logic,
# model keys and model urls have been changed. Here the URLs of old
# version is loaded to avoid breaking back compatibility. If the
# torchvision version>=0.13.0, new URLs will be added. Users can get
# the resnet50 checkpoint by setting 'resnet50.imagent1k_v1',
# 'resnet50' or 'ResNet50_Weights.IMAGENET1K_V1' in the config.
json_path
=
osp
.
join
(
mmcv
.
__path__
[
0
],
'model_zoo/torchvision_0.12.json'
)
model_urls
=
mmcv
.
load
(
json_path
)
for
cls_name
,
cls
in
torchvision
.
models
.
__dict__
.
items
():
# The name of torchvision model weights classes ends with
# `_Weights` such as `ResNet18_Weights`. However, some model weight
# classes, such as `MNASNet0_75_Weights` does not have any urls in
# torchvision 0.13.0 and cannot be iterated. Here we simply check
# `DEFAULT` attribute to ensure the class is not empty.
if
(
not
cls_name
.
endswith
(
'_Weights'
)
or
not
hasattr
(
cls
,
'DEFAULT'
)):
continue
# Since `cls.DEFAULT` can not be accessed by iterating cls, we set
# default urls explicitly.
cls_key
=
cls_name
.
replace
(
'_Weights'
,
''
).
lower
()
model_urls
[
f
'
{
cls_key
}
.default'
]
=
cls
.
DEFAULT
.
url
for
weight_enum
in
cls
:
cls_key
=
cls_name
.
replace
(
'_Weights'
,
''
).
lower
()
cls_key
=
f
'
{
cls_key
}
.
{
weight_enum
.
name
.
lower
()
}
'
model_urls
[
cls_key
]
=
weight_enum
.
url
return
model_urls
return
model_urls
...
@@ -147,7 +188,7 @@ def get_deprecated_model_names():
...
@@ -147,7 +188,7 @@ def get_deprecated_model_names():
return
deprecate_urls
return
deprecate_urls
def
_process_mmcls_checkpoint
(
checkpoint
)
:
def
_process_mmcls_checkpoint
(
checkpoint
:
Dict
)
->
Dict
:
if
'state_dict'
in
checkpoint
:
if
'state_dict'
in
checkpoint
:
state_dict
=
checkpoint
[
'state_dict'
]
state_dict
=
checkpoint
[
'state_dict'
]
else
:
else
:
...
@@ -166,10 +207,13 @@ def _process_mmcls_checkpoint(checkpoint):
...
@@ -166,10 +207,13 @@ def _process_mmcls_checkpoint(checkpoint):
class
CheckpointLoader
:
class
CheckpointLoader
:
"""A general checkpoint loader to manage all schemes."""
"""A general checkpoint loader to manage all schemes."""
_schemes
=
{}
_schemes
:
dict
=
{}
@
classmethod
@
classmethod
def
_register_scheme
(
cls
,
prefixes
,
loader
,
force
=
False
):
def
_register_scheme
(
cls
,
prefixes
:
Union
[
str
,
List
,
Tuple
],
loader
:
Callable
,
force
:
bool
=
False
)
->
None
:
if
isinstance
(
prefixes
,
str
):
if
isinstance
(
prefixes
,
str
):
prefixes
=
[
prefixes
]
prefixes
=
[
prefixes
]
else
:
else
:
...
@@ -186,13 +230,16 @@ class CheckpointLoader:
...
@@ -186,13 +230,16 @@ class CheckpointLoader:
sorted
(
cls
.
_schemes
.
items
(),
key
=
lambda
t
:
t
[
0
],
reverse
=
True
))
sorted
(
cls
.
_schemes
.
items
(),
key
=
lambda
t
:
t
[
0
],
reverse
=
True
))
@
classmethod
@
classmethod
def
register_scheme
(
cls
,
prefixes
,
loader
=
None
,
force
=
False
):
def
register_scheme
(
cls
,
prefixes
:
Union
[
str
,
List
[
str
],
Tuple
[
str
,
...]],
loader
:
Optional
[
Callable
]
=
None
,
force
:
bool
=
False
)
->
Callable
:
"""Register a loader to CheckpointLoader.
"""Register a loader to CheckpointLoader.
This method can be used as a normal class method or a decorator.
This method can be used as a normal class method or a decorator.
Args:
Args:
prefixes (str or
list[str] or tupl
e[str]):
prefixes (str or
Sequenc
e[str]):
The prefix of the registered loader.
The prefix of the registered loader.
loader (function, optional): The loader function to be registered.
loader (function, optional): The loader function to be registered.
When this method is used as a decorator, loader is None.
When this method is used as a decorator, loader is None.
...
@@ -203,7 +250,7 @@ class CheckpointLoader:
...
@@ -203,7 +250,7 @@ class CheckpointLoader:
if
loader
is
not
None
:
if
loader
is
not
None
:
cls
.
_register_scheme
(
prefixes
,
loader
,
force
=
force
)
cls
.
_register_scheme
(
prefixes
,
loader
,
force
=
force
)
return
return
# type: ignore
def
_register
(
loader_cls
):
def
_register
(
loader_cls
):
cls
.
_register_scheme
(
prefixes
,
loader_cls
,
force
=
force
)
cls
.
_register_scheme
(
prefixes
,
loader_cls
,
force
=
force
)
...
@@ -212,7 +259,7 @@ class CheckpointLoader:
...
@@ -212,7 +259,7 @@ class CheckpointLoader:
return
_register
return
_register
@
classmethod
@
classmethod
def
_get_checkpoint_loader
(
cls
,
path
):
def
_get_checkpoint_loader
(
cls
,
path
:
str
):
"""Finds a loader that supports the given path. Falls back to the local
"""Finds a loader that supports the given path. Falls back to the local
loader if no other loader is found.
loader if no other loader is found.
...
@@ -220,15 +267,22 @@ class CheckpointLoader:
...
@@ -220,15 +267,22 @@ class CheckpointLoader:
path (str): checkpoint path
path (str): checkpoint path
Returns:
Returns:
loader (function)
: checkpoint loader
callable
: checkpoint loader
"""
"""
for
p
in
cls
.
_schemes
:
for
p
in
cls
.
_schemes
:
if
path
.
startswith
(
p
):
# use regular match to handle some cases that where the prefix of
# loader has a prefix. For example, both 's3://path' and
# 'open-mmlab:s3://path' should return `load_from_ceph`
if
re
.
match
(
p
,
path
)
is
not
None
:
return
cls
.
_schemes
[
p
]
return
cls
.
_schemes
[
p
]
@
classmethod
@
classmethod
def
load_checkpoint
(
cls
,
filename
,
map_location
=
None
,
logger
=
None
):
def
load_checkpoint
(
cls
,
filename
:
str
,
map_location
:
Union
[
str
,
Callable
,
None
]
=
None
,
logger
:
Optional
[
logging
.
Logger
]
=
None
)
->
Union
[
dict
,
OrderedDict
]:
"""load checkpoint through URL scheme path.
"""load checkpoint through URL scheme path.
Args:
Args:
...
@@ -243,14 +297,17 @@ class CheckpointLoader:
...
@@ -243,14 +297,17 @@ class CheckpointLoader:
"""
"""
checkpoint_loader
=
cls
.
_get_checkpoint_loader
(
filename
)
checkpoint_loader
=
cls
.
_get_checkpoint_loader
(
filename
)
class_name
=
checkpoint_loader
.
__name__
class_name
=
checkpoint_loader
.
__name__
# type: ignore
mmcv
.
print_log
(
mmcv
.
print_log
(
f
'load checkpoint from
{
class_name
[
10
:]
}
path:
{
filename
}
'
,
logger
)
f
'load checkpoint from
{
class_name
[
10
:]
}
path:
{
filename
}
'
,
logger
)
return
checkpoint_loader
(
filename
,
map_location
)
return
checkpoint_loader
(
filename
,
map_location
)
# type: ignore
@
CheckpointLoader
.
register_scheme
(
prefixes
=
''
)
@
CheckpointLoader
.
register_scheme
(
prefixes
=
''
)
def
load_from_local
(
filename
,
map_location
):
def
load_from_local
(
filename
:
str
,
map_location
:
Union
[
str
,
Callable
,
None
]
=
None
,
)
->
Union
[
dict
,
OrderedDict
]:
"""load checkpoint by local file path.
"""load checkpoint by local file path.
Args:
Args:
...
@@ -260,15 +317,18 @@ def load_from_local(filename, map_location):
...
@@ -260,15 +317,18 @@ def load_from_local(filename, map_location):
Returns:
Returns:
dict or OrderedDict: The loaded checkpoint.
dict or OrderedDict: The loaded checkpoint.
"""
"""
filename
=
osp
.
expanduser
(
filename
)
if
not
osp
.
isfile
(
filename
):
if
not
osp
.
isfile
(
filename
):
raise
IO
Error
(
f
'
{
filename
}
is
not
a checkpoint file
'
)
raise
FileNotFound
Error
(
f
'
{
filename
}
can
not
be found.
'
)
checkpoint
=
torch
.
load
(
filename
,
map_location
=
map_location
)
checkpoint
=
torch
.
load
(
filename
,
map_location
=
map_location
)
return
checkpoint
return
checkpoint
@
CheckpointLoader
.
register_scheme
(
prefixes
=
(
'http://'
,
'https://'
))
@
CheckpointLoader
.
register_scheme
(
prefixes
=
(
'http://'
,
'https://'
))
def
load_from_http
(
filename
,
map_location
=
None
,
model_dir
=
None
):
def
load_from_http
(
filename
:
str
,
map_location
:
Union
[
str
,
Callable
,
None
]
=
None
,
model_dir
:
Optional
[
str
]
=
None
)
->
Union
[
dict
,
OrderedDict
]:
"""load checkpoint through HTTP or HTTPS scheme path. In distributed
"""load checkpoint through HTTP or HTTPS scheme path. In distributed
setting, this function only download checkpoint at local rank 0.
setting, this function only download checkpoint at local rank 0.
...
@@ -276,7 +336,7 @@ def load_from_http(filename, map_location=None, model_dir=None):
...
@@ -276,7 +336,7 @@ def load_from_http(filename, map_location=None, model_dir=None):
filename (str): checkpoint file path with modelzoo or
filename (str): checkpoint file path with modelzoo or
torchvision prefix
torchvision prefix
map_location (str, optional): Same as :func:`torch.load`.
map_location (str, optional): Same as :func:`torch.load`.
model_dir (str
ing
, optional): directory in which to save the object,
model_dir (str, optional): directory in which to save the object,
Default: None
Default: None
Returns:
Returns:
...
@@ -295,7 +355,10 @@ def load_from_http(filename, map_location=None, model_dir=None):
...
@@ -295,7 +355,10 @@ def load_from_http(filename, map_location=None, model_dir=None):
@
CheckpointLoader
.
register_scheme
(
prefixes
=
'pavi://'
)
@
CheckpointLoader
.
register_scheme
(
prefixes
=
'pavi://'
)
def
load_from_pavi
(
filename
,
map_location
=
None
):
def
load_from_pavi
(
filename
:
str
,
map_location
:
Union
[
str
,
Callable
,
None
]
=
None
,
)
->
Union
[
dict
,
OrderedDict
]:
"""load checkpoint through the file path prefixed with pavi. In distributed
"""load checkpoint through the file path prefixed with pavi. In distributed
setting, this function download ckpt at all ranks to different temporary
setting, this function download ckpt at all ranks to different temporary
directories.
directories.
...
@@ -326,16 +389,23 @@ def load_from_pavi(filename, map_location=None):
...
@@ -326,16 +389,23 @@ def load_from_pavi(filename, map_location=None):
return
checkpoint
return
checkpoint
@
CheckpointLoader
.
register_scheme
(
prefixes
=
's3://'
)
@
CheckpointLoader
.
register_scheme
(
prefixes
=
r
'(\S+\:)?s3://'
)
def
load_from_ceph
(
filename
,
map_location
=
None
,
backend
=
'petrel'
):
def
load_from_ceph
(
filename
:
str
,
map_location
:
Union
[
str
,
Callable
,
None
]
=
None
,
backend
:
str
=
'petrel'
)
->
Union
[
dict
,
OrderedDict
]:
"""load checkpoint through the file path prefixed with s3. In distributed
"""load checkpoint through the file path prefixed with s3. In distributed
setting, this function download ckpt at all ranks to different temporary
setting, this function download ckpt at all ranks to different temporary
directories.
directories.
Note:
Since v1.4.1, the registered scheme prefixes have been enhanced to
support bucket names in the path prefix, e.g. 's3://xx.xx/xx.path',
'bucket1:s3://xx.xx/xx.path'.
Args:
Args:
filename (str): checkpoint file path with s3 prefix
filename (str): checkpoint file path with s3 prefix
map_location (str, optional): Same as :func:`torch.load`.
map_location (str, optional): Same as :func:`torch.load`.
backend (str
, optional
): The storage backend type. Options are 'ceph',
backend (str): The storage backend type. Options are 'ceph',
'petrel'. Default: 'petrel'.
'petrel'. Default: 'petrel'.
.. warning::
.. warning::
...
@@ -351,7 +421,8 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
...
@@ -351,7 +421,8 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
if
backend
==
'ceph'
:
if
backend
==
'ceph'
:
warnings
.
warn
(
warnings
.
warn
(
'CephBackend will be deprecated, please use PetrelBackend instead'
)
'CephBackend will be deprecated, please use PetrelBackend instead'
,
DeprecationWarning
)
# CephClient and PetrelBackend have the same prefix 's3://' and the latter
# CephClient and PetrelBackend have the same prefix 's3://' and the latter
# will be chosen as default. If PetrelBackend can not be instantiated
# will be chosen as default. If PetrelBackend can not be instantiated
...
@@ -368,7 +439,10 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
...
@@ -368,7 +439,10 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
@
CheckpointLoader
.
register_scheme
(
prefixes
=
(
'modelzoo://'
,
'torchvision://'
))
@
CheckpointLoader
.
register_scheme
(
prefixes
=
(
'modelzoo://'
,
'torchvision://'
))
def
load_from_torchvision
(
filename
,
map_location
=
None
):
def
load_from_torchvision
(
filename
:
str
,
map_location
:
Union
[
str
,
Callable
,
None
]
=
None
,
)
->
Union
[
dict
,
OrderedDict
]:
"""load checkpoint through the file path prefixed with modelzoo or
"""load checkpoint through the file path prefixed with modelzoo or
torchvision.
torchvision.
...
@@ -382,16 +456,25 @@ def load_from_torchvision(filename, map_location=None):
...
@@ -382,16 +456,25 @@ def load_from_torchvision(filename, map_location=None):
"""
"""
model_urls
=
get_torchvision_models
()
model_urls
=
get_torchvision_models
()
if
filename
.
startswith
(
'modelzoo://'
):
if
filename
.
startswith
(
'modelzoo://'
):
warnings
.
warn
(
'The URL scheme of "modelzoo://" is deprecated, please '
warnings
.
warn
(
'use "torchvision://" instead'
)
'The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead'
,
DeprecationWarning
)
model_name
=
filename
[
11
:]
model_name
=
filename
[
11
:]
else
:
else
:
model_name
=
filename
[
14
:]
model_name
=
filename
[
14
:]
# Support getting model urls in the same way as torchvision
# `ResNet50_Weights.IMAGENET1K_V1` will be mapped to
# resnet50.imagenet1k_v1.
model_name
=
model_name
.
lower
().
replace
(
'_weights'
,
''
)
return
load_from_http
(
model_urls
[
model_name
],
map_location
=
map_location
)
return
load_from_http
(
model_urls
[
model_name
],
map_location
=
map_location
)
@
CheckpointLoader
.
register_scheme
(
prefixes
=
(
'open-mmlab://'
,
'openmmlab://'
))
@
CheckpointLoader
.
register_scheme
(
prefixes
=
(
'open-mmlab://'
,
'openmmlab://'
))
def
load_from_openmmlab
(
filename
,
map_location
=
None
):
def
load_from_openmmlab
(
filename
:
str
,
map_location
:
Union
[
str
,
Callable
,
None
]
=
None
,
)
->
Union
[
dict
,
OrderedDict
]:
"""load checkpoint through the file path prefixed with open-mmlab or
"""load checkpoint through the file path prefixed with open-mmlab or
openmmlab.
openmmlab.
...
@@ -415,8 +498,10 @@ def load_from_openmmlab(filename, map_location=None):
...
@@ -415,8 +498,10 @@ def load_from_openmmlab(filename, map_location=None):
deprecated_urls
=
get_deprecated_model_names
()
deprecated_urls
=
get_deprecated_model_names
()
if
model_name
in
deprecated_urls
:
if
model_name
in
deprecated_urls
:
warnings
.
warn
(
f
'
{
prefix_str
}{
model_name
}
is deprecated in favor '
warnings
.
warn
(
f
'of
{
prefix_str
}{
deprecated_urls
[
model_name
]
}
'
)
f
'
{
prefix_str
}{
model_name
}
is deprecated in favor '
f
'of
{
prefix_str
}{
deprecated_urls
[
model_name
]
}
'
,
DeprecationWarning
)
model_name
=
deprecated_urls
[
model_name
]
model_name
=
deprecated_urls
[
model_name
]
model_url
=
model_urls
[
model_name
]
model_url
=
model_urls
[
model_name
]
# check if is url
# check if is url
...
@@ -425,13 +510,16 @@ def load_from_openmmlab(filename, map_location=None):
...
@@ -425,13 +510,16 @@ def load_from_openmmlab(filename, map_location=None):
else
:
else
:
filename
=
osp
.
join
(
_get_mmcv_home
(),
model_url
)
filename
=
osp
.
join
(
_get_mmcv_home
(),
model_url
)
if
not
osp
.
isfile
(
filename
):
if
not
osp
.
isfile
(
filename
):
raise
IO
Error
(
f
'
{
filename
}
is
not
a checkpoint file
'
)
raise
FileNotFound
Error
(
f
'
{
filename
}
can
not
be found.
'
)
checkpoint
=
torch
.
load
(
filename
,
map_location
=
map_location
)
checkpoint
=
torch
.
load
(
filename
,
map_location
=
map_location
)
return
checkpoint
return
checkpoint
@
CheckpointLoader
.
register_scheme
(
prefixes
=
'mmcls://'
)
@
CheckpointLoader
.
register_scheme
(
prefixes
=
'mmcls://'
)
def
load_from_mmcls
(
filename
,
map_location
=
None
):
def
load_from_mmcls
(
filename
:
str
,
map_location
:
Union
[
str
,
Callable
,
None
]
=
None
,
)
->
Union
[
dict
,
OrderedDict
]:
"""load checkpoint through the file path prefixed with mmcls.
"""load checkpoint through the file path prefixed with mmcls.
Args:
Args:
...
@@ -450,7 +538,10 @@ def load_from_mmcls(filename, map_location=None):
...
@@ -450,7 +538,10 @@ def load_from_mmcls(filename, map_location=None):
return
checkpoint
return
checkpoint
def
_load_checkpoint
(
filename
,
map_location
=
None
,
logger
=
None
):
def
_load_checkpoint
(
filename
:
str
,
map_location
:
Union
[
str
,
Callable
,
None
]
=
None
,
logger
:
Optional
[
logging
.
Logger
]
=
None
)
->
Union
[
dict
,
OrderedDict
]:
"""Load checkpoint from somewhere (modelzoo, file, url).
"""Load checkpoint from somewhere (modelzoo, file, url).
Args:
Args:
...
@@ -470,7 +561,11 @@ def _load_checkpoint(filename, map_location=None, logger=None):
...
@@ -470,7 +561,11 @@ def _load_checkpoint(filename, map_location=None, logger=None):
return
CheckpointLoader
.
load_checkpoint
(
filename
,
map_location
,
logger
)
return
CheckpointLoader
.
load_checkpoint
(
filename
,
map_location
,
logger
)
def
_load_checkpoint_with_prefix
(
prefix
,
filename
,
map_location
=
None
):
def
_load_checkpoint_with_prefix
(
prefix
:
str
,
filename
:
str
,
map_location
:
Union
[
str
,
Callable
,
None
]
=
None
,
)
->
Union
[
dict
,
OrderedDict
]:
"""Load partial pretrained model with specific prefix.
"""Load partial pretrained model with specific prefix.
Args:
Args:
...
@@ -503,12 +598,13 @@ def _load_checkpoint_with_prefix(prefix, filename, map_location=None):
...
@@ -503,12 +598,13 @@ def _load_checkpoint_with_prefix(prefix, filename, map_location=None):
return
state_dict
return
state_dict
def
load_checkpoint
(
model
,
def
load_checkpoint
(
filename
,
model
:
torch
.
nn
.
Module
,
map_location
=
None
,
filename
:
str
,
strict
=
False
,
map_location
:
Union
[
str
,
Callable
,
None
]
=
None
,
logger
=
None
,
strict
:
bool
=
False
,
revise_keys
=
[(
r
'^module\.'
,
''
)]):
logger
:
Optional
[
logging
.
Logger
]
=
None
,
revise_keys
:
list
=
[(
r
'^module\.'
,
''
)])
->
Union
[
dict
,
OrderedDict
]:
"""Load checkpoint from a file or URI.
"""Load checkpoint from a file or URI.
Args:
Args:
...
@@ -553,7 +649,7 @@ def load_checkpoint(model,
...
@@ -553,7 +649,7 @@ def load_checkpoint(model,
return
checkpoint
return
checkpoint
def
weights_to_cpu
(
state_dict
)
:
def
weights_to_cpu
(
state_dict
:
OrderedDict
)
->
OrderedDict
:
"""Copy a model state_dict to cpu.
"""Copy a model state_dict to cpu.
Args:
Args:
...
@@ -566,11 +662,13 @@ def weights_to_cpu(state_dict):
...
@@ -566,11 +662,13 @@ def weights_to_cpu(state_dict):
for
key
,
val
in
state_dict
.
items
():
for
key
,
val
in
state_dict
.
items
():
state_dict_cpu
[
key
]
=
val
.
cpu
()
state_dict_cpu
[
key
]
=
val
.
cpu
()
# Keep metadata in state_dict
# Keep metadata in state_dict
state_dict_cpu
.
_metadata
=
getattr
(
state_dict
,
'_metadata'
,
OrderedDict
())
state_dict_cpu
.
_metadata
=
getattr
(
# type: ignore
state_dict
,
'_metadata'
,
OrderedDict
())
return
state_dict_cpu
return
state_dict_cpu
def
_save_to_state_dict
(
module
,
destination
,
prefix
,
keep_vars
):
def
_save_to_state_dict
(
module
:
torch
.
nn
.
Module
,
destination
:
dict
,
prefix
:
str
,
keep_vars
:
bool
)
->
None
:
"""Saves module state to `destination` dictionary.
"""Saves module state to `destination` dictionary.
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
...
@@ -590,7 +688,10 @@ def _save_to_state_dict(module, destination, prefix, keep_vars):
...
@@ -590,7 +688,10 @@ def _save_to_state_dict(module, destination, prefix, keep_vars):
destination
[
prefix
+
name
]
=
buf
if
keep_vars
else
buf
.
detach
()
destination
[
prefix
+
name
]
=
buf
if
keep_vars
else
buf
.
detach
()
def
get_state_dict
(
module
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
def
get_state_dict
(
module
:
torch
.
nn
.
Module
,
destination
:
Optional
[
OrderedDict
]
=
None
,
prefix
:
str
=
''
,
keep_vars
:
bool
=
False
)
->
OrderedDict
:
"""Returns a dictionary containing a whole state of the module.
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
Both parameters and persistent buffers (e.g. running averages) are
...
@@ -619,10 +720,10 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
...
@@ -619,10 +720,10 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
# below is the same as torch.nn.Module.state_dict()
# below is the same as torch.nn.Module.state_dict()
if
destination
is
None
:
if
destination
is
None
:
destination
=
OrderedDict
()
destination
=
OrderedDict
()
destination
.
_metadata
=
OrderedDict
()
destination
.
_metadata
=
OrderedDict
()
# type: ignore
destination
.
_metadata
[
prefix
[:
-
1
]]
=
local_metadata
=
dict
(
destination
.
_metadata
[
prefix
[:
-
1
]]
=
local_metadata
=
dict
(
# type: ignore
version
=
module
.
_version
)
version
=
module
.
_version
)
_save_to_state_dict
(
module
,
destination
,
prefix
,
keep_vars
)
_save_to_state_dict
(
module
,
destination
,
prefix
,
keep_vars
)
# type: ignore
for
name
,
child
in
module
.
_modules
.
items
():
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
if
child
is
not
None
:
get_state_dict
(
get_state_dict
(
...
@@ -631,14 +732,14 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
...
@@ -631,14 +732,14 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
hook_result
=
hook
(
module
,
destination
,
prefix
,
local_metadata
)
hook_result
=
hook
(
module
,
destination
,
prefix
,
local_metadata
)
if
hook_result
is
not
None
:
if
hook_result
is
not
None
:
destination
=
hook_result
destination
=
hook_result
return
destination
return
destination
# type: ignore
def
save_checkpoint
(
model
,
def
save_checkpoint
(
model
:
torch
.
nn
.
Module
,
filename
,
filename
:
str
,
optimizer
=
None
,
optimizer
:
Optional
[
Optimizer
]
=
None
,
meta
=
None
,
meta
:
Optional
[
dict
]
=
None
,
file_client_args
=
None
)
:
file_client_args
:
Optional
[
dict
]
=
None
)
->
None
:
"""Save checkpoint to file.
"""Save checkpoint to file.
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
...
@@ -669,7 +770,7 @@ def save_checkpoint(model,
...
@@ -669,7 +770,7 @@ def save_checkpoint(model,
checkpoint
=
{
checkpoint
=
{
'meta'
:
meta
,
'meta'
:
meta
,
'state_dict'
:
weights_to_cpu
(
get_state_dict
(
model
))
'state_dict'
:
weights_to_cpu
(
get_state_dict
(
model
))
# type: ignore
}
}
# save optimizer state dict in the checkpoint
# save optimizer state dict in the checkpoint
if
isinstance
(
optimizer
,
Optimizer
):
if
isinstance
(
optimizer
,
Optimizer
):
...
@@ -685,8 +786,7 @@ def save_checkpoint(model,
...
@@ -685,8 +786,7 @@ def save_checkpoint(model,
'file_client_args should be "None" if filename starts with'
'file_client_args should be "None" if filename starts with'
f
'"pavi://", but got
{
file_client_args
}
'
)
f
'"pavi://", but got
{
file_client_args
}
'
)
try
:
try
:
from
pavi
import
modelcloud
from
pavi
import
exception
,
modelcloud
from
pavi
import
exception
except
ImportError
:
except
ImportError
:
raise
ImportError
(
raise
ImportError
(
'Please install pavi to load checkpoint from modelcloud.'
)
'Please install pavi to load checkpoint from modelcloud.'
)
...
...
mmcv/runner/default_constructor.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Optional
from
.builder
import
RUNNER_BUILDERS
,
RUNNERS
from
.builder
import
RUNNER_BUILDERS
,
RUNNERS
...
@@ -33,7 +36,7 @@ class DefaultRunnerConstructor:
...
@@ -33,7 +36,7 @@ class DefaultRunnerConstructor:
>>> runner = build_runner(runner_cfg)
>>> runner = build_runner(runner_cfg)
"""
"""
def
__init__
(
self
,
runner_cfg
,
default_args
=
None
):
def
__init__
(
self
,
runner_cfg
:
dict
,
default_args
:
Optional
[
dict
]
=
None
):
if
not
isinstance
(
runner_cfg
,
dict
):
if
not
isinstance
(
runner_cfg
,
dict
):
raise
TypeError
(
'runner_cfg should be a dict'
,
raise
TypeError
(
'runner_cfg should be a dict'
,
f
'but got
{
type
(
runner_cfg
)
}
'
)
f
'but got
{
type
(
runner_cfg
)
}
'
)
...
...
mmcv/runner/dist_utils.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import
functools
import
functools
import
os
import
os
import
socket
import
subprocess
import
subprocess
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Callable
,
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
...
@@ -10,8 +13,28 @@ from torch import distributed as dist
...
@@ -10,8 +13,28 @@ from torch import distributed as dist
from
torch._utils
import
(
_flatten_dense_tensors
,
_take_tensors
,
from
torch._utils
import
(
_flatten_dense_tensors
,
_take_tensors
,
_unflatten_dense_tensors
)
_unflatten_dense_tensors
)
from
mmcv.utils
import
IS_MLU_AVAILABLE
def
init_dist
(
launcher
,
backend
=
'nccl'
,
**
kwargs
):
def
_find_free_port
()
->
str
:
# Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
# Binding to port 0 will cause the OS to find an available port for us
sock
.
bind
((
''
,
0
))
port
=
sock
.
getsockname
()[
1
]
sock
.
close
()
# NOTE: there is still a chance the port could be taken by other processes.
return
port
def
_is_free_port
(
port
:
int
)
->
bool
:
ips
=
socket
.
gethostbyname_ex
(
socket
.
gethostname
())[
-
1
]
ips
.
append
(
'localhost'
)
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
return
all
(
s
.
connect_ex
((
ip
,
port
))
!=
0
for
ip
in
ips
)
def
init_dist
(
launcher
:
str
,
backend
:
str
=
'nccl'
,
**
kwargs
)
->
None
:
if
mp
.
get_start_method
(
allow_none
=
True
)
is
None
:
if
mp
.
get_start_method
(
allow_none
=
True
)
is
None
:
mp
.
set_start_method
(
'spawn'
)
mp
.
set_start_method
(
'spawn'
)
if
launcher
==
'pytorch'
:
if
launcher
==
'pytorch'
:
...
@@ -24,23 +47,37 @@ def init_dist(launcher, backend='nccl', **kwargs):
...
@@ -24,23 +47,37 @@ def init_dist(launcher, backend='nccl', **kwargs):
raise
ValueError
(
f
'Invalid launcher type:
{
launcher
}
'
)
raise
ValueError
(
f
'Invalid launcher type:
{
launcher
}
'
)
def
_init_dist_pytorch
(
backend
,
**
kwargs
):
def
_init_dist_pytorch
(
backend
:
str
,
**
kwargs
)
->
None
:
# TODO: use local_rank instead of rank % num_gpus
# TODO: use local_rank instead of rank % num_gpus
rank
=
int
(
os
.
environ
[
'RANK'
])
rank
=
int
(
os
.
environ
[
'RANK'
])
if
IS_MLU_AVAILABLE
:
import
torch_mlu
# noqa: F401
torch
.
mlu
.
set_device
(
rank
)
dist
.
init_process_group
(
backend
=
'cncl'
,
rank
=
rank
,
world_size
=
int
(
os
.
environ
[
'WORLD_SIZE'
]),
**
kwargs
)
else
:
num_gpus
=
torch
.
cuda
.
device_count
()
num_gpus
=
torch
.
cuda
.
device_count
()
torch
.
cuda
.
set_device
(
rank
%
num_gpus
)
torch
.
cuda
.
set_device
(
rank
%
num_gpus
)
dist
.
init_process_group
(
backend
=
backend
,
**
kwargs
)
dist
.
init_process_group
(
backend
=
backend
,
**
kwargs
)
def
_init_dist_mpi
(
backend
,
**
kwargs
):
def
_init_dist_mpi
(
backend
:
str
,
**
kwargs
)
->
None
:
# TODO: use local_rank instead of rank % num_gpus
local_rank
=
int
(
os
.
environ
[
'OMPI_COMM_WORLD_LOCAL_RANK'
])
rank
=
int
(
os
.
environ
[
'OMPI_COMM_WORLD_RANK'
])
torch
.
cuda
.
set_device
(
local_rank
)
num_gpus
=
torch
.
cuda
.
device_count
()
if
'MASTER_PORT'
not
in
os
.
environ
:
torch
.
cuda
.
set_device
(
rank
%
num_gpus
)
# 29500 is torch.distributed default port
os
.
environ
[
'MASTER_PORT'
]
=
'29500'
if
'MASTER_ADDR'
not
in
os
.
environ
:
raise
KeyError
(
'The environment variable MASTER_ADDR is not set'
)
os
.
environ
[
'WORLD_SIZE'
]
=
os
.
environ
[
'OMPI_COMM_WORLD_SIZE'
]
os
.
environ
[
'RANK'
]
=
os
.
environ
[
'OMPI_COMM_WORLD_RANK'
]
dist
.
init_process_group
(
backend
=
backend
,
**
kwargs
)
dist
.
init_process_group
(
backend
=
backend
,
**
kwargs
)
def
_init_dist_slurm
(
backend
,
port
=
None
)
:
def
_init_dist_slurm
(
backend
:
str
,
port
:
Optional
[
int
]
=
None
)
->
None
:
"""Initialize slurm distributed training environment.
"""Initialize slurm distributed training environment.
If argument ``port`` is not specified, then the master port will be system
If argument ``port`` is not specified, then the master port will be system
...
@@ -64,8 +101,12 @@ def _init_dist_slurm(backend, port=None):
...
@@ -64,8 +101,12 @@ def _init_dist_slurm(backend, port=None):
elif
'MASTER_PORT'
in
os
.
environ
:
elif
'MASTER_PORT'
in
os
.
environ
:
pass
# use MASTER_PORT in the environment variable
pass
# use MASTER_PORT in the environment variable
else
:
else
:
# 29500 is torch.distributed default port
# if torch.distributed default port(29500) is available
# then use it, else find a free port
if
_is_free_port
(
29500
):
os
.
environ
[
'MASTER_PORT'
]
=
'29500'
os
.
environ
[
'MASTER_PORT'
]
=
'29500'
else
:
os
.
environ
[
'MASTER_PORT'
]
=
str
(
_find_free_port
())
# use MASTER_ADDR in the environment variable if it already exists
# use MASTER_ADDR in the environment variable if it already exists
if
'MASTER_ADDR'
not
in
os
.
environ
:
if
'MASTER_ADDR'
not
in
os
.
environ
:
os
.
environ
[
'MASTER_ADDR'
]
=
addr
os
.
environ
[
'MASTER_ADDR'
]
=
addr
...
@@ -75,7 +116,7 @@ def _init_dist_slurm(backend, port=None):
...
@@ -75,7 +116,7 @@ def _init_dist_slurm(backend, port=None):
dist
.
init_process_group
(
backend
=
backend
)
dist
.
init_process_group
(
backend
=
backend
)
def
get_dist_info
():
def
get_dist_info
()
->
Tuple
[
int
,
int
]
:
if
dist
.
is_available
()
and
dist
.
is_initialized
():
if
dist
.
is_available
()
and
dist
.
is_initialized
():
rank
=
dist
.
get_rank
()
rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
world_size
=
dist
.
get_world_size
()
...
@@ -85,7 +126,7 @@ def get_dist_info():
...
@@ -85,7 +126,7 @@ def get_dist_info():
return
rank
,
world_size
return
rank
,
world_size
def
master_only
(
func
)
:
def
master_only
(
func
:
Callable
)
->
Callable
:
@
functools
.
wraps
(
func
)
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
def
wrapper
(
*
args
,
**
kwargs
):
...
@@ -96,12 +137,14 @@ def master_only(func):
...
@@ -96,12 +137,14 @@ def master_only(func):
return
wrapper
return
wrapper
def
allreduce_params
(
params
,
coalesce
=
True
,
bucket_size_mb
=-
1
):
def
allreduce_params
(
params
:
List
[
torch
.
nn
.
Parameter
],
coalesce
:
bool
=
True
,
bucket_size_mb
:
int
=
-
1
)
->
None
:
"""Allreduce parameters.
"""Allreduce parameters.
Args:
Args:
params (list[torch.Parameter
s
]): List of parameters or buffers
of a
params (list[torch.
nn.
Parameter]): List of parameters or buffers
model.
of a
model.
coalesce (bool, optional): Whether allreduce parameters as a whole.
coalesce (bool, optional): Whether allreduce parameters as a whole.
Defaults to True.
Defaults to True.
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
...
@@ -118,11 +161,13 @@ def allreduce_params(params, coalesce=True, bucket_size_mb=-1):
...
@@ -118,11 +161,13 @@ def allreduce_params(params, coalesce=True, bucket_size_mb=-1):
dist
.
all_reduce
(
tensor
.
div_
(
world_size
))
dist
.
all_reduce
(
tensor
.
div_
(
world_size
))
def
allreduce_grads
(
params
,
coalesce
=
True
,
bucket_size_mb
=-
1
):
def
allreduce_grads
(
params
:
List
[
torch
.
nn
.
Parameter
],
coalesce
:
bool
=
True
,
bucket_size_mb
:
int
=
-
1
)
->
None
:
"""Allreduce gradients.
"""Allreduce gradients.
Args:
Args:
params (list[torch.Parameter
s
]): List of parameters of a model
params (list[torch.
nn.
Parameter]): List of parameters of a model
.
coalesce (bool, optional): Whether allreduce parameters as a whole.
coalesce (bool, optional): Whether allreduce parameters as a whole.
Defaults to True.
Defaults to True.
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
...
@@ -142,7 +187,9 @@ def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
...
@@ -142,7 +187,9 @@ def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
dist
.
all_reduce
(
tensor
.
div_
(
world_size
))
dist
.
all_reduce
(
tensor
.
div_
(
world_size
))
def
_allreduce_coalesced
(
tensors
,
world_size
,
bucket_size_mb
=-
1
):
def
_allreduce_coalesced
(
tensors
:
torch
.
Tensor
,
world_size
:
int
,
bucket_size_mb
:
int
=
-
1
)
->
None
:
if
bucket_size_mb
>
0
:
if
bucket_size_mb
>
0
:
bucket_size_bytes
=
bucket_size_mb
*
1024
*
1024
bucket_size_bytes
=
bucket_size_mb
*
1024
*
1024
buckets
=
_take_tensors
(
tensors
,
bucket_size_bytes
)
buckets
=
_take_tensors
(
tensors
,
bucket_size_bytes
)
...
...
mmcv/runner/epoch_based_runner.py
View file @
fdeee889
...
@@ -4,8 +4,10 @@ import platform
...
@@ -4,8 +4,10 @@ import platform
import
shutil
import
shutil
import
time
import
time
import
warnings
import
warnings
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch.utils.data
import
DataLoader
import
mmcv
import
mmcv
from
.base_runner
import
BaseRunner
from
.base_runner
import
BaseRunner
...
@@ -21,7 +23,7 @@ class EpochBasedRunner(BaseRunner):
...
@@ -21,7 +23,7 @@ class EpochBasedRunner(BaseRunner):
This runner train models epoch by epoch.
This runner train models epoch by epoch.
"""
"""
def
run_iter
(
self
,
data_batch
,
train_mode
,
**
kwargs
):
def
run_iter
(
self
,
data_batch
:
Any
,
train_mode
:
bool
,
**
kwargs
)
->
None
:
if
self
.
batch_processor
is
not
None
:
if
self
.
batch_processor
is
not
None
:
outputs
=
self
.
batch_processor
(
outputs
=
self
.
batch_processor
(
self
.
model
,
data_batch
,
train_mode
=
train_mode
,
**
kwargs
)
self
.
model
,
data_batch
,
train_mode
=
train_mode
,
**
kwargs
)
...
@@ -45,10 +47,12 @@ class EpochBasedRunner(BaseRunner):
...
@@ -45,10 +47,12 @@ class EpochBasedRunner(BaseRunner):
self
.
call_hook
(
'before_train_epoch'
)
self
.
call_hook
(
'before_train_epoch'
)
time
.
sleep
(
2
)
# Prevent possible deadlock during epoch transition
time
.
sleep
(
2
)
# Prevent possible deadlock during epoch transition
for
i
,
data_batch
in
enumerate
(
self
.
data_loader
):
for
i
,
data_batch
in
enumerate
(
self
.
data_loader
):
self
.
data_batch
=
data_batch
self
.
_inner_iter
=
i
self
.
_inner_iter
=
i
self
.
call_hook
(
'before_train_iter'
)
self
.
call_hook
(
'before_train_iter'
)
self
.
run_iter
(
data_batch
,
train_mode
=
True
,
**
kwargs
)
self
.
run_iter
(
data_batch
,
train_mode
=
True
,
**
kwargs
)
self
.
call_hook
(
'after_train_iter'
)
self
.
call_hook
(
'after_train_iter'
)
del
self
.
data_batch
self
.
_iter
+=
1
self
.
_iter
+=
1
self
.
call_hook
(
'after_train_epoch'
)
self
.
call_hook
(
'after_train_epoch'
)
...
@@ -62,14 +66,19 @@ class EpochBasedRunner(BaseRunner):
...
@@ -62,14 +66,19 @@ class EpochBasedRunner(BaseRunner):
self
.
call_hook
(
'before_val_epoch'
)
self
.
call_hook
(
'before_val_epoch'
)
time
.
sleep
(
2
)
# Prevent possible deadlock during epoch transition
time
.
sleep
(
2
)
# Prevent possible deadlock during epoch transition
for
i
,
data_batch
in
enumerate
(
self
.
data_loader
):
for
i
,
data_batch
in
enumerate
(
self
.
data_loader
):
self
.
data_batch
=
data_batch
self
.
_inner_iter
=
i
self
.
_inner_iter
=
i
self
.
call_hook
(
'before_val_iter'
)
self
.
call_hook
(
'before_val_iter'
)
self
.
run_iter
(
data_batch
,
train_mode
=
False
)
self
.
run_iter
(
data_batch
,
train_mode
=
False
)
self
.
call_hook
(
'after_val_iter'
)
self
.
call_hook
(
'after_val_iter'
)
del
self
.
data_batch
self
.
call_hook
(
'after_val_epoch'
)
self
.
call_hook
(
'after_val_epoch'
)
def
run
(
self
,
data_loaders
,
workflow
,
max_epochs
=
None
,
**
kwargs
):
def
run
(
self
,
data_loaders
:
List
[
DataLoader
],
workflow
:
List
[
Tuple
[
str
,
int
]],
max_epochs
:
Optional
[
int
]
=
None
,
**
kwargs
)
->
None
:
"""Start running.
"""Start running.
Args:
Args:
...
@@ -130,11 +139,11 @@ class EpochBasedRunner(BaseRunner):
...
@@ -130,11 +139,11 @@ class EpochBasedRunner(BaseRunner):
self
.
call_hook
(
'after_run'
)
self
.
call_hook
(
'after_run'
)
def
save_checkpoint
(
self
,
def
save_checkpoint
(
self
,
out_dir
,
out_dir
:
str
,
filename_tmpl
=
'epoch_{}.pth'
,
filename_tmpl
:
str
=
'epoch_{}.pth'
,
save_optimizer
=
True
,
save_optimizer
:
bool
=
True
,
meta
=
None
,
meta
:
Optional
[
Dict
]
=
None
,
create_symlink
=
True
)
:
create_symlink
:
bool
=
True
)
->
None
:
"""Save the checkpoint.
"""Save the checkpoint.
Args:
Args:
...
@@ -183,5 +192,6 @@ class Runner(EpochBasedRunner):
...
@@ -183,5 +192,6 @@ class Runner(EpochBasedRunner):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
warnings
.
warn
(
warnings
.
warn
(
'Runner was deprecated, please use EpochBasedRunner instead'
)
'Runner was deprecated, please use EpochBasedRunner instead'
,
DeprecationWarning
)
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
mmcv/runner/fp16_utils.py
View file @
fdeee889
...
@@ -3,10 +3,12 @@ import functools
...
@@ -3,10 +3,12 @@ import functools
import
warnings
import
warnings
from
collections
import
abc
from
collections
import
abc
from
inspect
import
getfullargspec
from
inspect
import
getfullargspec
from
typing
import
Callable
,
Iterable
,
List
,
Optional
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.nn.parameter
import
Parameter
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
from
.dist_utils
import
allreduce_grads
as
_allreduce_grads
from
.dist_utils
import
allreduce_grads
as
_allreduce_grads
...
@@ -21,9 +23,18 @@ except ImportError:
...
@@ -21,9 +23,18 @@ except ImportError:
pass
pass
def
cast_tensor_type
(
inputs
,
src_type
,
dst_
type
):
def
cast_tensor_type
(
inputs
,
src_type
:
torch
.
dtype
,
dst_type
:
torch
.
d
type
):
"""Recursively convert Tensor in inputs from src_type to dst_type.
"""Recursively convert Tensor in inputs from src_type to dst_type.
Note:
In v1.4.4 and later, ``cast_tersor_type`` will only convert the
torch.Tensor which is consistent with ``src_type`` to the ``dst_type``.
Before v1.4.4, it ignores the ``src_type`` argument, leading to some
potential problems. For example,
``cast_tensor_type(inputs, torch.float, torch.half)`` will convert all
tensors in inputs to ``torch.half`` including those originally in
``torch.Int`` or other types, which is not expected.
Args:
Args:
inputs: Inputs that to be casted.
inputs: Inputs that to be casted.
src_type (torch.dtype): Source type..
src_type (torch.dtype): Source type..
...
@@ -35,24 +46,30 @@ def cast_tensor_type(inputs, src_type, dst_type):
...
@@ -35,24 +46,30 @@ def cast_tensor_type(inputs, src_type, dst_type):
if
isinstance
(
inputs
,
nn
.
Module
):
if
isinstance
(
inputs
,
nn
.
Module
):
return
inputs
return
inputs
elif
isinstance
(
inputs
,
torch
.
Tensor
):
elif
isinstance
(
inputs
,
torch
.
Tensor
):
return
inputs
.
to
(
dst_type
)
# we need to ensure that the type of inputs to be casted are the same
# as the argument `src_type`.
return
inputs
.
to
(
dst_type
)
if
inputs
.
dtype
==
src_type
else
inputs
elif
isinstance
(
inputs
,
str
):
elif
isinstance
(
inputs
,
str
):
return
inputs
return
inputs
elif
isinstance
(
inputs
,
np
.
ndarray
):
elif
isinstance
(
inputs
,
np
.
ndarray
):
return
inputs
return
inputs
elif
isinstance
(
inputs
,
abc
.
Mapping
):
elif
isinstance
(
inputs
,
abc
.
Mapping
):
return
type
(
inputs
)({
return
type
(
inputs
)({
# type: ignore
k
:
cast_tensor_type
(
v
,
src_type
,
dst_type
)
k
:
cast_tensor_type
(
v
,
src_type
,
dst_type
)
for
k
,
v
in
inputs
.
items
()
for
k
,
v
in
inputs
.
items
()
})
})
elif
isinstance
(
inputs
,
abc
.
Iterable
):
elif
isinstance
(
inputs
,
abc
.
Iterable
):
return
type
(
inputs
)(
return
type
(
inputs
)(
# type: ignore
cast_tensor_type
(
item
,
src_type
,
dst_type
)
for
item
in
inputs
)
cast_tensor_type
(
item
,
src_type
,
dst_type
)
for
item
in
inputs
)
else
:
else
:
return
inputs
return
inputs
def
auto_fp16
(
apply_to
=
None
,
out_fp32
=
False
):
def
auto_fp16
(
apply_to
:
Optional
[
Iterable
]
=
None
,
out_fp32
:
bool
=
False
,
supported_types
:
tuple
=
(
nn
.
Module
,
),
)
->
Callable
:
"""Decorator to enable fp16 training automatically.
"""Decorator to enable fp16 training automatically.
This decorator is useful when you write custom modules and want to support
This decorator is useful when you write custom modules and want to support
...
@@ -65,7 +82,8 @@ def auto_fp16(apply_to=None, out_fp32=False):
...
@@ -65,7 +82,8 @@ def auto_fp16(apply_to=None, out_fp32=False):
apply_to (Iterable, optional): The argument names to be converted.
apply_to (Iterable, optional): The argument names to be converted.
`None` indicates all arguments.
`None` indicates all arguments.
out_fp32 (bool): Whether to convert the output back to fp32.
out_fp32 (bool): Whether to convert the output back to fp32.
supported_types (tuple): Classes can be decorated by ``auto_fp16``.
`New in version 1.5.0.`
Example:
Example:
>>> import torch.nn as nn
>>> import torch.nn as nn
...
@@ -85,15 +103,15 @@ def auto_fp16(apply_to=None, out_fp32=False):
...
@@ -85,15 +103,15 @@ def auto_fp16(apply_to=None, out_fp32=False):
>>> pass
>>> pass
"""
"""
def
auto_fp16_wrapper
(
old_func
)
:
def
auto_fp16_wrapper
(
old_func
:
Callable
)
->
Callable
:
@
functools
.
wraps
(
old_func
)
@
functools
.
wraps
(
old_func
)
def
new_func
(
*
args
,
**
kwargs
):
def
new_func
(
*
args
,
**
kwargs
)
->
Callable
:
# check if the module has set the attribute `fp16_enabled`, if not,
# check if the module has set the attribute `fp16_enabled`, if not,
# just fallback to the original method.
# just fallback to the original method.
if
not
isinstance
(
args
[
0
],
torch
.
nn
.
Module
):
if
not
isinstance
(
args
[
0
],
supported_types
):
raise
TypeError
(
'@auto_fp16 can only be used to decorate the '
raise
TypeError
(
'@auto_fp16 can only be used to decorate the '
'method of
nn.Module
'
)
f
'method of
those classes
{
supported_types
}
'
)
if
not
(
hasattr
(
args
[
0
],
'fp16_enabled'
)
and
args
[
0
].
fp16_enabled
):
if
not
(
hasattr
(
args
[
0
],
'fp16_enabled'
)
and
args
[
0
].
fp16_enabled
):
return
old_func
(
*
args
,
**
kwargs
)
return
old_func
(
*
args
,
**
kwargs
)
...
@@ -138,7 +156,8 @@ def auto_fp16(apply_to=None, out_fp32=False):
...
@@ -138,7 +156,8 @@ def auto_fp16(apply_to=None, out_fp32=False):
return
auto_fp16_wrapper
return
auto_fp16_wrapper
def
force_fp32
(
apply_to
=
None
,
out_fp16
=
False
):
def
force_fp32
(
apply_to
:
Optional
[
Iterable
]
=
None
,
out_fp16
:
bool
=
False
)
->
Callable
:
"""Decorator to convert input arguments to fp32 in force.
"""Decorator to convert input arguments to fp32 in force.
This decorator is useful when you write custom modules and want to support
This decorator is useful when you write custom modules and want to support
...
@@ -176,7 +195,7 @@ def force_fp32(apply_to=None, out_fp16=False):
...
@@ -176,7 +195,7 @@ def force_fp32(apply_to=None, out_fp16=False):
def
force_fp32_wrapper
(
old_func
):
def
force_fp32_wrapper
(
old_func
):
@
functools
.
wraps
(
old_func
)
@
functools
.
wraps
(
old_func
)
def
new_func
(
*
args
,
**
kwargs
):
def
new_func
(
*
args
,
**
kwargs
)
->
Callable
:
# check if the module has set the attribute `fp16_enabled`, if not,
# check if the module has set the attribute `fp16_enabled`, if not,
# just fallback to the original method.
# just fallback to the original method.
if
not
isinstance
(
args
[
0
],
torch
.
nn
.
Module
):
if
not
isinstance
(
args
[
0
],
torch
.
nn
.
Module
):
...
@@ -224,14 +243,17 @@ def force_fp32(apply_to=None, out_fp16=False):
...
@@ -224,14 +243,17 @@ def force_fp32(apply_to=None, out_fp16=False):
return
force_fp32_wrapper
return
force_fp32_wrapper
def
allreduce_grads
(
params
,
coalesce
=
True
,
bucket_size_mb
=-
1
):
def
allreduce_grads
(
params
:
List
[
Parameter
],
warnings
.
warning
(
coalesce
:
bool
=
True
,
bucket_size_mb
:
int
=
-
1
)
->
None
:
warnings
.
warn
(
'"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be '
'"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be '
'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads'
)
'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads'
,
DeprecationWarning
)
_allreduce_grads
(
params
,
coalesce
=
coalesce
,
bucket_size_mb
=
bucket_size_mb
)
_allreduce_grads
(
params
,
coalesce
=
coalesce
,
bucket_size_mb
=
bucket_size_mb
)
def
wrap_fp16_model
(
model
)
:
def
wrap_fp16_model
(
model
:
nn
.
Module
)
->
None
:
"""Wrap the FP32 model to FP16.
"""Wrap the FP32 model to FP16.
If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
...
@@ -260,7 +282,7 @@ def wrap_fp16_model(model):
...
@@ -260,7 +282,7 @@ def wrap_fp16_model(model):
m
.
fp16_enabled
=
True
m
.
fp16_enabled
=
True
def
patch_norm_fp32
(
module
)
:
def
patch_norm_fp32
(
module
:
nn
.
Module
)
->
nn
.
Module
:
"""Recursively convert normalization layers from FP16 to FP32.
"""Recursively convert normalization layers from FP16 to FP32.
Args:
Args:
...
@@ -280,7 +302,10 @@ def patch_norm_fp32(module):
...
@@ -280,7 +302,10 @@ def patch_norm_fp32(module):
return
module
return
module
def
patch_forward_method
(
func
,
src_type
,
dst_type
,
convert_output
=
True
):
def
patch_forward_method
(
func
:
Callable
,
src_type
:
torch
.
dtype
,
dst_type
:
torch
.
dtype
,
convert_output
:
bool
=
True
)
->
Callable
:
"""Patch the forward method of a module.
"""Patch the forward method of a module.
Args:
Args:
...
@@ -333,10 +358,10 @@ class LossScaler:
...
@@ -333,10 +358,10 @@ class LossScaler:
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
init_scale
=
2
**
32
,
init_scale
:
float
=
2
**
32
,
mode
=
'dynamic'
,
mode
:
str
=
'dynamic'
,
scale_factor
=
2.
,
scale_factor
:
float
=
2.
,
scale_window
=
1000
):
scale_window
:
int
=
1000
):
self
.
cur_scale
=
init_scale
self
.
cur_scale
=
init_scale
self
.
cur_iter
=
0
self
.
cur_iter
=
0
assert
mode
in
(
'dynamic'
,
assert
mode
in
(
'dynamic'
,
...
@@ -346,7 +371,7 @@ class LossScaler:
...
@@ -346,7 +371,7 @@ class LossScaler:
self
.
scale_factor
=
scale_factor
self
.
scale_factor
=
scale_factor
self
.
scale_window
=
scale_window
self
.
scale_window
=
scale_window
def
has_overflow
(
self
,
params
)
:
def
has_overflow
(
self
,
params
:
List
[
Parameter
])
->
bool
:
"""Check if params contain overflow."""
"""Check if params contain overflow."""
if
self
.
mode
!=
'dynamic'
:
if
self
.
mode
!=
'dynamic'
:
return
False
return
False
...
@@ -355,7 +380,7 @@ class LossScaler:
...
@@ -355,7 +380,7 @@ class LossScaler:
return
True
return
True
return
False
return
False
def
_has_inf_or_nan
(
x
)
:
def
_has_inf_or_nan
(
x
:
torch
.
Tensor
)
->
bool
:
"""Check if params contain NaN."""
"""Check if params contain NaN."""
try
:
try
:
cpu_sum
=
float
(
x
.
float
().
sum
())
cpu_sum
=
float
(
x
.
float
().
sum
())
...
@@ -369,7 +394,7 @@ class LossScaler:
...
@@ -369,7 +394,7 @@ class LossScaler:
return
True
return
True
return
False
return
False
def
update_scale
(
self
,
overflow
)
:
def
update_scale
(
self
,
overflow
:
bool
)
->
None
:
"""update the current loss scale value when overflow happens."""
"""update the current loss scale value when overflow happens."""
if
self
.
mode
!=
'dynamic'
:
if
self
.
mode
!=
'dynamic'
:
return
return
...
@@ -382,7 +407,7 @@ class LossScaler:
...
@@ -382,7 +407,7 @@ class LossScaler:
self
.
cur_scale
*=
self
.
scale_factor
self
.
cur_scale
*=
self
.
scale_factor
self
.
cur_iter
+=
1
self
.
cur_iter
+=
1
def
state_dict
(
self
):
def
state_dict
(
self
)
->
dict
:
"""Returns the state of the scaler as a :class:`dict`."""
"""Returns the state of the scaler as a :class:`dict`."""
return
dict
(
return
dict
(
cur_scale
=
self
.
cur_scale
,
cur_scale
=
self
.
cur_scale
,
...
@@ -392,7 +417,7 @@ class LossScaler:
...
@@ -392,7 +417,7 @@ class LossScaler:
scale_factor
=
self
.
scale_factor
,
scale_factor
=
self
.
scale_factor
,
scale_window
=
self
.
scale_window
)
scale_window
=
self
.
scale_window
)
def
load_state_dict
(
self
,
state_dict
)
:
def
load_state_dict
(
self
,
state_dict
:
dict
)
->
None
:
"""Loads the loss_scaler state dict.
"""Loads the loss_scaler state dict.
Args:
Args:
...
@@ -406,5 +431,5 @@ class LossScaler:
...
@@ -406,5 +431,5 @@ class LossScaler:
self
.
scale_window
=
state_dict
[
'scale_window'
]
self
.
scale_window
=
state_dict
[
'scale_window'
]
@
property
@
property
def
loss_scale
(
self
):
def
loss_scale
(
self
)
->
float
:
return
self
.
cur_scale
return
self
.
cur_scale
mmcv/runner/hooks/__init__.py
View file @
fdeee889
...
@@ -5,12 +5,24 @@ from .ema import EMAHook
...
@@ -5,12 +5,24 @@ from .ema import EMAHook
from
.evaluation
import
DistEvalHook
,
EvalHook
from
.evaluation
import
DistEvalHook
,
EvalHook
from
.hook
import
HOOKS
,
Hook
from
.hook
import
HOOKS
,
Hook
from
.iter_timer
import
IterTimerHook
from
.iter_timer
import
IterTimerHook
from
.logger
import
(
DvcliveLoggerHook
,
LoggerHook
,
MlflowLoggerHook
,
from
.logger
import
(
ClearMLLoggerHook
,
DvcliveLoggerHook
,
LoggerHook
,
NeptuneLoggerHook
,
PaviLoggerHook
,
TensorboardLoggerHook
,
MlflowLoggerHook
,
NeptuneLoggerHook
,
PaviLoggerHook
,
TextLoggerHook
,
WandbLoggerHook
)
SegmindLoggerHook
,
TensorboardLoggerHook
,
TextLoggerHook
,
from
.lr_updater
import
LrUpdaterHook
WandbLoggerHook
)
from
.lr_updater
import
(
CosineAnnealingLrUpdaterHook
,
CosineRestartLrUpdaterHook
,
CyclicLrUpdaterHook
,
ExpLrUpdaterHook
,
FixedLrUpdaterHook
,
FlatCosineAnnealingLrUpdaterHook
,
InvLrUpdaterHook
,
LinearAnnealingLrUpdaterHook
,
LrUpdaterHook
,
OneCycleLrUpdaterHook
,
PolyLrUpdaterHook
,
StepLrUpdaterHook
)
from
.memory
import
EmptyCacheHook
from
.memory
import
EmptyCacheHook
from
.momentum_updater
import
MomentumUpdaterHook
from
.momentum_updater
import
(
CosineAnnealingMomentumUpdaterHook
,
CyclicMomentumUpdaterHook
,
LinearAnnealingMomentumUpdaterHook
,
MomentumUpdaterHook
,
OneCycleMomentumUpdaterHook
,
StepMomentumUpdaterHook
)
from
.optimizer
import
(
Fp16OptimizerHook
,
GradientCumulativeFp16OptimizerHook
,
from
.optimizer
import
(
Fp16OptimizerHook
,
GradientCumulativeFp16OptimizerHook
,
GradientCumulativeOptimizerHook
,
OptimizerHook
)
GradientCumulativeOptimizerHook
,
OptimizerHook
)
from
.profiler
import
ProfilerHook
from
.profiler
import
ProfilerHook
...
@@ -19,11 +31,18 @@ from .sync_buffer import SyncBuffersHook
...
@@ -19,11 +31,18 @@ from .sync_buffer import SyncBuffersHook
__all__
=
[
__all__
=
[
'HOOKS'
,
'Hook'
,
'CheckpointHook'
,
'ClosureHook'
,
'LrUpdaterHook'
,
'HOOKS'
,
'Hook'
,
'CheckpointHook'
,
'ClosureHook'
,
'LrUpdaterHook'
,
'OptimizerHook'
,
'Fp16OptimizerHook'
,
'IterTimerHook'
,
'FixedLrUpdaterHook'
,
'StepLrUpdaterHook'
,
'ExpLrUpdaterHook'
,
'DistSamplerSeedHook'
,
'EmptyCacheHook'
,
'LoggerHook'
,
'MlflowLoggerHook'
,
'PolyLrUpdaterHook'
,
'InvLrUpdaterHook'
,
'CosineAnnealingLrUpdaterHook'
,
'PaviLoggerHook'
,
'TextLoggerHook'
,
'TensorboardLoggerHook'
,
'FlatCosineAnnealingLrUpdaterHook'
,
'CosineRestartLrUpdaterHook'
,
'NeptuneLoggerHook'
,
'WandbLoggerHook'
,
'DvcliveLoggerHook'
,
'CyclicLrUpdaterHook'
,
'OneCycleLrUpdaterHook'
,
'OptimizerHook'
,
'MomentumUpdaterHook'
,
'SyncBuffersHook'
,
'EMAHook'
,
'EvalHook'
,
'Fp16OptimizerHook'
,
'IterTimerHook'
,
'DistSamplerSeedHook'
,
'DistEvalHook'
,
'ProfilerHook'
,
'GradientCumulativeOptimizerHook'
,
'EmptyCacheHook'
,
'LoggerHook'
,
'MlflowLoggerHook'
,
'PaviLoggerHook'
,
'GradientCumulativeFp16OptimizerHook'
'TextLoggerHook'
,
'TensorboardLoggerHook'
,
'NeptuneLoggerHook'
,
'WandbLoggerHook'
,
'DvcliveLoggerHook'
,
'MomentumUpdaterHook'
,
'StepMomentumUpdaterHook'
,
'CosineAnnealingMomentumUpdaterHook'
,
'CyclicMomentumUpdaterHook'
,
'OneCycleMomentumUpdaterHook'
,
'SyncBuffersHook'
,
'EMAHook'
,
'EvalHook'
,
'DistEvalHook'
,
'ProfilerHook'
,
'GradientCumulativeOptimizerHook'
,
'GradientCumulativeFp16OptimizerHook'
,
'SegmindLoggerHook'
,
'LinearAnnealingLrUpdaterHook'
,
'LinearAnnealingMomentumUpdaterHook'
,
'ClearMLLoggerHook'
]
]
mmcv/runner/hooks/checkpoint.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
import
os.path
as
osp
import
warnings
import
warnings
from
typing
import
Optional
from
mmcv.fileio
import
FileClient
from
mmcv.fileio
import
FileClient
from
..dist_utils
import
allreduce_params
,
master_only
from
..dist_utils
import
allreduce_params
,
master_only
...
@@ -49,14 +50,14 @@ class CheckpointHook(Hook):
...
@@ -49,14 +50,14 @@ class CheckpointHook(Hook):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
interval
=
-
1
,
interval
:
int
=
-
1
,
by_epoch
=
True
,
by_epoch
:
bool
=
True
,
save_optimizer
=
True
,
save_optimizer
:
bool
=
True
,
out_dir
=
None
,
out_dir
:
Optional
[
str
]
=
None
,
max_keep_ckpts
=
-
1
,
max_keep_ckpts
:
int
=
-
1
,
save_last
=
True
,
save_last
:
bool
=
True
,
sync_buffer
=
False
,
sync_buffer
:
bool
=
False
,
file_client_args
=
None
,
file_client_args
:
Optional
[
dict
]
=
None
,
**
kwargs
):
**
kwargs
):
self
.
interval
=
interval
self
.
interval
=
interval
self
.
by_epoch
=
by_epoch
self
.
by_epoch
=
by_epoch
...
@@ -83,8 +84,8 @@ class CheckpointHook(Hook):
...
@@ -83,8 +84,8 @@ class CheckpointHook(Hook):
basename
=
osp
.
basename
(
runner
.
work_dir
.
rstrip
(
osp
.
sep
))
basename
=
osp
.
basename
(
runner
.
work_dir
.
rstrip
(
osp
.
sep
))
self
.
out_dir
=
self
.
file_client
.
join_path
(
self
.
out_dir
,
basename
)
self
.
out_dir
=
self
.
file_client
.
join_path
(
self
.
out_dir
,
basename
)
runner
.
logger
.
info
(
(
f
'Checkpoints will be saved to
{
self
.
out_dir
}
by '
runner
.
logger
.
info
(
f
'Checkpoints will be saved to
{
self
.
out_dir
}
by '
f
'
{
self
.
file_client
.
name
}
.'
)
)
f
'
{
self
.
file_client
.
name
}
.'
)
# disable the create_symlink option because some file backends do not
# disable the create_symlink option because some file backends do not
# allow to create a symlink
# allow to create a symlink
...
@@ -93,9 +94,9 @@ class CheckpointHook(Hook):
...
@@ -93,9 +94,9 @@ class CheckpointHook(Hook):
'create_symlink'
]
and
not
self
.
file_client
.
allow_symlink
:
'create_symlink'
]
and
not
self
.
file_client
.
allow_symlink
:
self
.
args
[
'create_symlink'
]
=
False
self
.
args
[
'create_symlink'
]
=
False
warnings
.
warn
(
warnings
.
warn
(
(
'create_symlink is set as True by the user but is changed'
'create_symlink is set as True by the user but is changed'
'to be False because creating symbolic link is not '
'to be False because creating symbolic link is not '
f
'allowed in
{
self
.
file_client
.
name
}
'
)
)
f
'allowed in
{
self
.
file_client
.
name
}
'
)
else
:
else
:
self
.
args
[
'create_symlink'
]
=
self
.
file_client
.
allow_symlink
self
.
args
[
'create_symlink'
]
=
self
.
file_client
.
allow_symlink
...
...
mmcv/runner/hooks/closure.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Callable
from
.hook
import
HOOKS
,
Hook
from
.hook
import
HOOKS
,
Hook
@
HOOKS
.
register_module
()
@
HOOKS
.
register_module
()
class
ClosureHook
(
Hook
):
class
ClosureHook
(
Hook
):
def
__init__
(
self
,
fn_name
,
fn
):
def
__init__
(
self
,
fn_name
:
str
,
fn
:
Callable
):
assert
hasattr
(
self
,
fn_name
)
assert
hasattr
(
self
,
fn_name
)
assert
callable
(
fn
)
assert
callable
(
fn
)
setattr
(
self
,
fn_name
,
fn
)
setattr
(
self
,
fn_name
,
fn
)
Prev
1
…
19
20
21
22
23
24
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment