Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
fdeee889
Commit
fdeee889
authored
May 25, 2025
by
limm
Browse files
release v1.6.1 of mmcv
parent
df465820
Changes
457
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
634 additions
and
291 deletions
+634
-291
mmcv/parallel/_functions.py
mmcv/parallel/_functions.py
+15
-12
mmcv/parallel/collate.py
mmcv/parallel/collate.py
+1
-1
mmcv/parallel/data_container.py
mmcv/parallel/data_container.py
+18
-16
mmcv/parallel/data_parallel.py
mmcv/parallel/data_parallel.py
+7
-5
mmcv/parallel/distributed.py
mmcv/parallel/distributed.py
+66
-11
mmcv/parallel/distributed_deprecated.py
mmcv/parallel/distributed_deprecated.py
+13
-9
mmcv/parallel/scatter_gather.py
mmcv/parallel/scatter_gather.py
+19
-8
mmcv/parallel/utils.py
mmcv/parallel/utils.py
+16
-4
mmcv/runner/__init__.py
mmcv/runner/__init__.py
+34
-8
mmcv/runner/base_module.py
mmcv/runner/base_module.py
+33
-15
mmcv/runner/base_runner.py
mmcv/runner/base_runner.py
+91
-67
mmcv/runner/builder.py
mmcv/runner/builder.py
+3
-2
mmcv/runner/checkpoint.py
mmcv/runner/checkpoint.py
+174
-74
mmcv/runner/default_constructor.py
mmcv/runner/default_constructor.py
+4
-1
mmcv/runner/dist_utils.py
mmcv/runner/dist_utils.py
+68
-21
mmcv/runner/epoch_based_runner.py
mmcv/runner/epoch_based_runner.py
+19
-9
mmcv/runner/fp16_utils.py
mmcv/runner/fp16_utils.py
+53
-28
No files found.
Too many changes to show.
To preserve performance only
457 of 457+
files are displayed.
Plain diff
Email patch
mmcv/parallel/_functions.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
List
,
Optional
,
Union
import
torch
from
torch
import
Tensor
from
torch.nn.parallel._functions
import
_get_stream
def
scatter
(
input
,
devices
,
streams
=
None
):
def
scatter
(
input
:
Union
[
List
,
Tensor
],
devices
:
List
,
streams
:
Optional
[
List
]
=
None
)
->
Union
[
List
,
Tensor
]:
"""Scatters tensor across multiple GPUs."""
if
streams
is
None
:
streams
=
[
None
]
*
len
(
devices
)
...
...
@@ -15,30 +20,28 @@ def scatter(input, devices, streams=None):
[
streams
[
i
//
chunk_size
]])
for
i
in
range
(
len
(
input
))
]
return
outputs
elif
isinstance
(
input
,
torch
.
Tensor
):
elif
isinstance
(
input
,
Tensor
):
output
=
input
.
contiguous
()
# TODO: copy to a pinned buffer first (if copying from CPU)
stream
=
streams
[
0
]
if
output
.
numel
()
>
0
else
None
if
devices
!=
[
-
1
]:
with
torch
.
cuda
.
device
(
devices
[
0
]),
torch
.
cuda
.
stream
(
stream
):
output
=
output
.
cuda
(
devices
[
0
],
non_blocking
=
True
)
else
:
# unsqueeze the first dimension thus the tensor's shape is the
# same as those scattered with GPU.
output
=
output
.
unsqueeze
(
0
)
return
output
else
:
raise
Exception
(
f
'Unknown type
{
type
(
input
)
}
.'
)
def
synchronize_stream
(
output
,
devices
,
streams
):
def
synchronize_stream
(
output
:
Union
[
List
,
Tensor
],
devices
:
List
,
streams
:
List
)
->
None
:
if
isinstance
(
output
,
list
):
chunk_size
=
len
(
output
)
//
len
(
devices
)
for
i
in
range
(
len
(
devices
)):
for
j
in
range
(
chunk_size
):
synchronize_stream
(
output
[
i
*
chunk_size
+
j
],
[
devices
[
i
]],
[
streams
[
i
]])
elif
isinstance
(
output
,
torch
.
Tensor
):
elif
isinstance
(
output
,
Tensor
):
if
output
.
numel
()
!=
0
:
with
torch
.
cuda
.
device
(
devices
[
0
]):
main_stream
=
torch
.
cuda
.
current_stream
()
...
...
@@ -48,14 +51,14 @@ def synchronize_stream(output, devices, streams):
raise
Exception
(
f
'Unknown type
{
type
(
output
)
}
.'
)
def
get_input_device
(
input
)
:
def
get_input_device
(
input
:
Union
[
List
,
Tensor
])
->
int
:
if
isinstance
(
input
,
list
):
for
item
in
input
:
input_device
=
get_input_device
(
item
)
if
input_device
!=
-
1
:
return
input_device
return
-
1
elif
isinstance
(
input
,
torch
.
Tensor
):
elif
isinstance
(
input
,
Tensor
):
return
input
.
get_device
()
if
input
.
is_cuda
else
-
1
else
:
raise
Exception
(
f
'Unknown type
{
type
(
input
)
}
.'
)
...
...
@@ -64,7 +67,7 @@ def get_input_device(input):
class
Scatter
:
@
staticmethod
def
forward
(
target_gpus
,
input
)
:
def
forward
(
target_gpus
:
List
[
int
],
input
:
Union
[
List
,
Tensor
])
->
tuple
:
input_device
=
get_input_device
(
input
)
streams
=
None
if
input_device
==
-
1
and
target_gpus
!=
[
-
1
]:
...
...
@@ -76,4 +79,4 @@ class Scatter:
if
streams
is
not
None
:
synchronize_stream
(
outputs
,
target_gpus
,
streams
)
return
tuple
(
outputs
)
return
tuple
(
outputs
)
if
isinstance
(
outputs
,
list
)
else
(
outputs
,
)
mmcv/parallel/collate.py
View file @
fdeee889
...
...
@@ -8,7 +8,7 @@ from torch.utils.data.dataloader import default_collate
from
.data_container
import
DataContainer
def
collate
(
batch
,
samples_per_gpu
=
1
):
def
collate
(
batch
:
Sequence
,
samples_per_gpu
:
int
=
1
):
"""Puts each data field into a tensor/DataContainer with outer dimension
batch size.
...
...
mmcv/parallel/data_container.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
import
functools
from
typing
import
Callable
,
Type
,
Union
import
numpy
as
np
import
torch
def
assert_tensor_type
(
func
)
:
def
assert_tensor_type
(
func
:
Callable
)
->
Callable
:
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
...
...
@@ -35,11 +37,11 @@ class DataContainer:
"""
def
__init__
(
self
,
data
,
stack
=
False
,
padding_value
=
0
,
cpu_only
=
False
,
pad_dims
=
2
):
data
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
,
stack
:
bool
=
False
,
padding_value
:
int
=
0
,
cpu_only
:
bool
=
False
,
pad_dims
:
int
=
2
):
self
.
_data
=
data
self
.
_cpu_only
=
cpu_only
self
.
_stack
=
stack
...
...
@@ -47,43 +49,43 @@ class DataContainer:
assert
pad_dims
in
[
None
,
1
,
2
,
3
]
self
.
_pad_dims
=
pad_dims
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
return
f
'
{
self
.
__class__
.
__name__
}
(
{
repr
(
self
.
data
)
}
)'
def
__len__
(
self
):
def
__len__
(
self
)
->
int
:
return
len
(
self
.
_data
)
@
property
def
data
(
self
):
def
data
(
self
)
->
Union
[
torch
.
Tensor
,
np
.
ndarray
]
:
return
self
.
_data
@
property
def
datatype
(
self
):
def
datatype
(
self
)
->
Union
[
Type
,
str
]
:
if
isinstance
(
self
.
data
,
torch
.
Tensor
):
return
self
.
data
.
type
()
else
:
return
type
(
self
.
data
)
@
property
def
cpu_only
(
self
):
def
cpu_only
(
self
)
->
bool
:
return
self
.
_cpu_only
@
property
def
stack
(
self
):
def
stack
(
self
)
->
bool
:
return
self
.
_stack
@
property
def
padding_value
(
self
):
def
padding_value
(
self
)
->
int
:
return
self
.
_padding_value
@
property
def
pad_dims
(
self
):
def
pad_dims
(
self
)
->
int
:
return
self
.
_pad_dims
@
assert_tensor_type
def
size
(
self
,
*
args
,
**
kwargs
):
def
size
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Size
:
return
self
.
data
.
size
(
*
args
,
**
kwargs
)
@
assert_tensor_type
def
dim
(
self
):
def
dim
(
self
)
->
int
:
return
self
.
data
.
dim
()
mmcv/parallel/data_parallel.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
from
itertools
import
chain
from
typing
import
List
,
Tuple
from
torch.nn.parallel
import
DataParallel
from
.scatter_gather
import
scatter_kwargs
from
.scatter_gather
import
ScatterInputs
,
scatter_kwargs
class
MMDataParallel
(
DataParallel
):
...
...
@@ -13,7 +14,7 @@ class MMDataParallel(DataParallel):
- It supports a custom type :class:`DataContainer` which allows more
flexible control of input data during both GPU and CPU inference.
- It implement two more APIs ``train_step()`` and ``val_step()``.
- It implement
s
two more APIs ``train_step()`` and ``val_step()``.
.. warning::
MMDataParallel only supports single GPU training, if you need to
...
...
@@ -31,8 +32,8 @@ class MMDataParallel(DataParallel):
dim (int): Dimension used to scatter the data. Defaults to 0.
"""
def
__init__
(
self
,
*
args
,
dim
=
0
,
**
kwargs
):
super
(
MMDataParallel
,
self
).
__init__
(
*
args
,
dim
=
dim
,
**
kwargs
)
def
__init__
(
self
,
*
args
,
dim
:
int
=
0
,
**
kwargs
):
super
().
__init__
(
*
args
,
dim
=
dim
,
**
kwargs
)
self
.
dim
=
dim
def
forward
(
self
,
*
inputs
,
**
kwargs
):
...
...
@@ -49,7 +50,8 @@ class MMDataParallel(DataParallel):
else
:
return
super
().
forward
(
*
inputs
,
**
kwargs
)
def
scatter
(
self
,
inputs
,
kwargs
,
device_ids
):
def
scatter
(
self
,
inputs
:
ScatterInputs
,
kwargs
:
ScatterInputs
,
device_ids
:
List
[
int
])
->
Tuple
[
tuple
,
tuple
]:
return
scatter_kwargs
(
inputs
,
kwargs
,
device_ids
,
dim
=
self
.
dim
)
def
train_step
(
self
,
*
inputs
,
**
kwargs
):
...
...
mmcv/parallel/distributed.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Any
,
List
,
Tuple
import
torch
from
torch.nn.parallel.distributed
import
(
DistributedDataParallel
,
_find_tensors
)
from
mmcv
import
print_log
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
from
.scatter_gather
import
scatter_kwargs
from
.scatter_gather
import
ScatterInputs
,
scatter_kwargs
class
MMDistributedDataParallel
(
DistributedDataParallel
):
...
...
@@ -18,12 +20,14 @@ class MMDistributedDataParallel(DistributedDataParallel):
- It implement two APIs ``train_step()`` and ``val_step()``.
"""
def
to_kwargs
(
self
,
inputs
,
kwargs
,
device_id
):
def
to_kwargs
(
self
,
inputs
:
ScatterInputs
,
kwargs
:
ScatterInputs
,
device_id
:
int
)
->
Tuple
[
tuple
,
tuple
]:
# Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
# to move all tensors to device_id
return
scatter_kwargs
(
inputs
,
kwargs
,
[
device_id
],
dim
=
self
.
dim
)
def
scatter
(
self
,
inputs
,
kwargs
,
device_ids
):
def
scatter
(
self
,
inputs
:
ScatterInputs
,
kwargs
:
ScatterInputs
,
device_ids
:
List
[
int
])
->
Tuple
[
tuple
,
tuple
]:
return
scatter_kwargs
(
inputs
,
kwargs
,
device_ids
,
dim
=
self
.
dim
)
def
train_step
(
self
,
*
inputs
,
**
kwargs
):
...
...
@@ -44,8 +48,15 @@ class MMDistributedDataParallel(DistributedDataParallel):
'Reducer buckets have been rebuilt in this iteration.'
,
logger
=
'mmcv'
)
if
getattr
(
self
,
'require_forward_param_sync'
,
True
):
self
.
_sync_params
()
if
(
'parrots'
not
in
TORCH_VERSION
and
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.11.0a0'
)):
if
self
.
_check_sync_bufs_pre_fwd
():
self
.
_sync_buffers
()
else
:
if
(
getattr
(
self
,
'require_forward_param_sync'
,
False
)
and
self
.
require_forward_param_sync
):
self
.
_sync_params
()
if
self
.
device_ids
:
inputs
,
kwargs
=
self
.
scatter
(
inputs
,
kwargs
,
self
.
device_ids
)
if
len
(
self
.
device_ids
)
==
1
:
...
...
@@ -57,8 +68,14 @@ class MMDistributedDataParallel(DistributedDataParallel):
else
:
output
=
self
.
module
.
train_step
(
*
inputs
,
**
kwargs
)
if
torch
.
is_grad_enabled
()
and
getattr
(
self
,
'require_backward_grad_sync'
,
True
):
if
(
'parrots'
not
in
TORCH_VERSION
and
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.11.0a0'
)):
if
self
.
_check_sync_bufs_post_fwd
():
self
.
_sync_buffers
()
if
(
torch
.
is_grad_enabled
()
and
getattr
(
self
,
'require_backward_grad_sync'
,
False
)
and
self
.
require_backward_grad_sync
):
if
self
.
find_unused_parameters
:
self
.
reducer
.
prepare_for_backward
(
list
(
_find_tensors
(
output
)))
else
:
...
...
@@ -86,8 +103,15 @@ class MMDistributedDataParallel(DistributedDataParallel):
'Reducer buckets have been rebuilt in this iteration.'
,
logger
=
'mmcv'
)
if
getattr
(
self
,
'require_forward_param_sync'
,
True
):
self
.
_sync_params
()
if
(
'parrots'
not
in
TORCH_VERSION
and
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.11.0a0'
)):
if
self
.
_check_sync_bufs_pre_fwd
():
self
.
_sync_buffers
()
else
:
if
(
getattr
(
self
,
'require_forward_param_sync'
,
False
)
and
self
.
require_forward_param_sync
):
self
.
_sync_params
()
if
self
.
device_ids
:
inputs
,
kwargs
=
self
.
scatter
(
inputs
,
kwargs
,
self
.
device_ids
)
if
len
(
self
.
device_ids
)
==
1
:
...
...
@@ -99,8 +123,14 @@ class MMDistributedDataParallel(DistributedDataParallel):
else
:
output
=
self
.
module
.
val_step
(
*
inputs
,
**
kwargs
)
if
torch
.
is_grad_enabled
()
and
getattr
(
self
,
'require_backward_grad_sync'
,
True
):
if
(
'parrots'
not
in
TORCH_VERSION
and
digit_version
(
TORCH_VERSION
)
>=
digit_version
(
'1.11.0a0'
)):
if
self
.
_check_sync_bufs_post_fwd
():
self
.
_sync_buffers
()
if
(
torch
.
is_grad_enabled
()
and
getattr
(
self
,
'require_backward_grad_sync'
,
False
)
and
self
.
require_backward_grad_sync
):
if
self
.
find_unused_parameters
:
self
.
reducer
.
prepare_for_backward
(
list
(
_find_tensors
(
output
)))
else
:
...
...
@@ -110,3 +140,28 @@ class MMDistributedDataParallel(DistributedDataParallel):
and
digit_version
(
TORCH_VERSION
)
>
digit_version
(
'1.2'
)):
self
.
require_forward_param_sync
=
False
return
output
def
_run_ddp_forward
(
self
,
*
inputs
,
**
kwargs
)
->
Any
:
"""Processes inputs and runs ``self.module.forward``.
Pytorch 1.12.0 performs ``self.module.forward`` in ``_run_ddp_forward``
and deprecates using ``DistributedDataParallel.to_kwargs`` to
process inputs, which leads to inputs cannot be processed by
:meth:`MMDistributedDataParallel.to_kwargs` anymore. Therefore,
``MMDistributedDataParallel`` overrides this method to call
:meth:`to_kwargs` explicitly.
See more information in `<https://github.com/open-mmlab/mmsegmentation/issues/1742>`_. # noqa: E501
Returns:
Any: Forward result of :attr:`module`.
"""
module_to_run
=
self
.
_replicated_tensor_module
if
\
self
.
_use_replicated_tensor_module
else
self
.
module
if
self
.
device_ids
:
inputs
,
kwargs
=
self
.
to_kwargs
(
# type: ignore
inputs
,
kwargs
,
self
.
device_ids
[
0
])
return
module_to_run
(
*
inputs
[
0
],
**
kwargs
[
0
])
# type: ignore
else
:
return
module_to_run
(
*
inputs
,
**
kwargs
)
mmcv/parallel/distributed_deprecated.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
List
,
Sequence
,
Tuple
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
...
...
@@ -7,18 +9,18 @@ from torch._utils import (_flatten_dense_tensors, _take_tensors,
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
from
.registry
import
MODULE_WRAPPERS
from
.scatter_gather
import
scatter_kwargs
from
.scatter_gather
import
ScatterInputs
,
scatter_kwargs
@
MODULE_WRAPPERS
.
register_module
()
class
MMDistributedDataParallel
(
nn
.
Module
):
def
__init__
(
self
,
module
,
dim
=
0
,
broadcast_buffers
=
True
,
bucket_cap_mb
=
25
):
super
(
MMDistributedDataParallel
,
self
).
__init__
()
module
:
nn
.
Module
,
dim
:
int
=
0
,
broadcast_buffers
:
bool
=
True
,
bucket_cap_mb
:
int
=
25
):
super
().
__init__
()
self
.
module
=
module
self
.
dim
=
dim
self
.
broadcast_buffers
=
broadcast_buffers
...
...
@@ -26,7 +28,8 @@ class MMDistributedDataParallel(nn.Module):
self
.
broadcast_bucket_size
=
bucket_cap_mb
*
1024
*
1024
self
.
_sync_params
()
def
_dist_broadcast_coalesced
(
self
,
tensors
,
buffer_size
):
def
_dist_broadcast_coalesced
(
self
,
tensors
:
Sequence
[
torch
.
Tensor
],
buffer_size
:
int
)
->
None
:
for
tensors
in
_take_tensors
(
tensors
,
buffer_size
):
flat_tensors
=
_flatten_dense_tensors
(
tensors
)
dist
.
broadcast
(
flat_tensors
,
0
)
...
...
@@ -34,7 +37,7 @@ class MMDistributedDataParallel(nn.Module):
tensors
,
_unflatten_dense_tensors
(
flat_tensors
,
tensors
)):
tensor
.
copy_
(
synced
)
def
_sync_params
(
self
):
def
_sync_params
(
self
)
->
None
:
module_states
=
list
(
self
.
module
.
state_dict
().
values
())
if
len
(
module_states
)
>
0
:
self
.
_dist_broadcast_coalesced
(
module_states
,
...
...
@@ -49,7 +52,8 @@ class MMDistributedDataParallel(nn.Module):
self
.
_dist_broadcast_coalesced
(
buffers
,
self
.
broadcast_bucket_size
)
def
scatter
(
self
,
inputs
,
kwargs
,
device_ids
):
def
scatter
(
self
,
inputs
:
ScatterInputs
,
kwargs
:
ScatterInputs
,
device_ids
:
List
[
int
])
->
Tuple
[
tuple
,
tuple
]:
return
scatter_kwargs
(
inputs
,
kwargs
,
device_ids
,
dim
=
self
.
dim
)
def
forward
(
self
,
*
inputs
,
**
kwargs
):
...
...
mmcv/parallel/scatter_gather.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
typing
import
List
,
Tuple
,
Union
from
torch
import
Tensor
from
torch.nn.parallel._functions
import
Scatter
as
OrigScatter
from
._functions
import
Scatter
from
.data_container
import
DataContainer
ScatterInputs
=
Union
[
Tensor
,
DataContainer
,
tuple
,
list
,
dict
]
def
scatter
(
inputs
,
target_gpus
,
dim
=
0
):
def
scatter
(
inputs
:
ScatterInputs
,
target_gpus
:
List
[
int
],
dim
:
int
=
0
)
->
list
:
"""Scatter inputs to target gpus.
The only difference from original :func:`scatter` is to add support for
...
...
@@ -14,7 +20,7 @@ def scatter(inputs, target_gpus, dim=0):
"""
def
scatter_map
(
obj
):
if
isinstance
(
obj
,
torch
.
Tensor
):
if
isinstance
(
obj
,
Tensor
):
if
target_gpus
!=
[
-
1
]:
return
OrigScatter
.
apply
(
target_gpus
,
None
,
dim
,
obj
)
else
:
...
...
@@ -33,7 +39,7 @@ def scatter(inputs, target_gpus, dim=0):
if
isinstance
(
obj
,
dict
)
and
len
(
obj
)
>
0
:
out
=
list
(
map
(
type
(
obj
),
zip
(
*
map
(
scatter_map
,
obj
.
items
()))))
return
out
return
[
obj
for
targets
in
target_gpus
]
return
[
obj
for
_
in
target_gpus
]
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
...
...
@@ -43,17 +49,22 @@ def scatter(inputs, target_gpus, dim=0):
try
:
return
scatter_map
(
inputs
)
finally
:
scatter_map
=
None
scatter_map
=
None
# type: ignore
def
scatter_kwargs
(
inputs
,
kwargs
,
target_gpus
,
dim
=
0
):
def
scatter_kwargs
(
inputs
:
ScatterInputs
,
kwargs
:
ScatterInputs
,
target_gpus
:
List
[
int
],
dim
:
int
=
0
)
->
Tuple
[
tuple
,
tuple
]:
"""Scatter with support for kwargs dictionary."""
inputs
=
scatter
(
inputs
,
target_gpus
,
dim
)
if
inputs
else
[]
kwargs
=
scatter
(
kwargs
,
target_gpus
,
dim
)
if
kwargs
else
[]
if
len
(
inputs
)
<
len
(
kwargs
):
inputs
.
extend
([()
for
_
in
range
(
len
(
kwargs
)
-
len
(
inputs
))])
length
=
len
(
kwargs
)
-
len
(
inputs
)
inputs
.
extend
([()
for
_
in
range
(
length
)])
# type: ignore
elif
len
(
kwargs
)
<
len
(
inputs
):
kwargs
.
extend
([{}
for
_
in
range
(
len
(
inputs
)
-
len
(
kwargs
))])
length
=
len
(
inputs
)
-
len
(
kwargs
)
kwargs
.
extend
([{}
for
_
in
range
(
length
)])
# type: ignore
inputs
=
tuple
(
inputs
)
kwargs
=
tuple
(
kwargs
)
return
inputs
,
kwargs
mmcv/parallel/utils.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
from
torch
import
nn
from
.registry
import
MODULE_WRAPPERS
def
is_module_wrapper
(
module
)
:
def
is_module_wrapper
(
module
:
nn
.
Module
)
->
bool
:
"""Check if a module is a module wrapper.
The following 3 modules in MMCV (and their subclasses) are regarded as
module wrappers: DataParallel, DistributedDataParallel,
MMDistributedDataParallel (the deprecated version). You may add you own
module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS.
module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS or
its children registries.
Args:
module (nn.Module): The module to be checked.
...
...
@@ -16,5 +19,14 @@ def is_module_wrapper(module):
Returns:
bool: True if the input module is a module wrapper.
"""
module_wrappers
=
tuple
(
MODULE_WRAPPERS
.
module_dict
.
values
())
return
isinstance
(
module
,
module_wrappers
)
def
is_module_in_wrapper
(
module
,
module_wrapper
):
module_wrappers
=
tuple
(
module_wrapper
.
module_dict
.
values
())
if
isinstance
(
module
,
module_wrappers
):
return
True
for
child
in
module_wrapper
.
children
.
values
():
if
is_module_in_wrapper
(
module
,
child
):
return
True
return
False
return
is_module_in_wrapper
(
module
,
MODULE_WRAPPERS
)
mmcv/runner/__init__.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
from
.base_module
import
BaseModule
,
ModuleList
,
Sequential
from
.base_module
import
BaseModule
,
ModuleDict
,
ModuleList
,
Sequential
from
.base_runner
import
BaseRunner
from
.builder
import
RUNNERS
,
build_runner
from
.checkpoint
import
(
CheckpointLoader
,
_load_checkpoint
,
...
...
@@ -10,14 +10,29 @@ from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info,
init_dist
,
master_only
)
from
.epoch_based_runner
import
EpochBasedRunner
,
Runner
from
.fp16_utils
import
LossScaler
,
auto_fp16
,
force_fp32
,
wrap_fp16_model
from
.hooks
import
(
HOOKS
,
CheckpointHook
,
ClosureHook
,
DistEvalHook
,
DistSamplerSeedHook
,
DvcliveLoggerHook
,
EMAHook
,
EvalHook
,
Fp16OptimizerHook
,
GradientCumulativeFp16OptimizerHook
,
from
.hooks
import
(
HOOKS
,
CheckpointHook
,
ClearMLLoggerHook
,
ClosureHook
,
DistEvalHook
,
DistSamplerSeedHook
,
DvcliveLoggerHook
,
EMAHook
,
EvalHook
,
Fp16OptimizerHook
,
GradientCumulativeFp16OptimizerHook
,
GradientCumulativeOptimizerHook
,
Hook
,
IterTimerHook
,
LoggerHook
,
LrUpdat
erHook
,
Mlflow
LoggerHook
,
NeptuneLoggerHook
,
OptimizerHook
,
PaviLoggerHook
,
LoggerHook
,
MlflowLogg
erHook
,
Neptune
LoggerHook
,
OptimizerHook
,
PaviLoggerHook
,
SegmindLoggerHook
,
SyncBuffersHook
,
TensorboardLoggerHook
,
TextLoggerHook
,
WandbLoggerHook
)
from
.hooks.lr_updater
import
StepLrUpdaterHook
# noqa
from
.hooks.lr_updater
import
(
CosineAnnealingLrUpdaterHook
,
CosineRestartLrUpdaterHook
,
CyclicLrUpdaterHook
,
ExpLrUpdaterHook
,
FixedLrUpdaterHook
,
FlatCosineAnnealingLrUpdaterHook
,
InvLrUpdaterHook
,
LinearAnnealingLrUpdaterHook
,
LrUpdaterHook
,
OneCycleLrUpdaterHook
,
PolyLrUpdaterHook
)
from
.hooks.momentum_updater
import
(
CosineAnnealingMomentumUpdaterHook
,
CyclicMomentumUpdaterHook
,
LinearAnnealingMomentumUpdaterHook
,
MomentumUpdaterHook
,
OneCycleMomentumUpdaterHook
,
StepMomentumUpdaterHook
)
from
.iter_based_runner
import
IterBasedRunner
,
IterLoader
from
.log_buffer
import
LogBuffer
from
.optimizer
import
(
OPTIMIZER_BUILDERS
,
OPTIMIZERS
,
...
...
@@ -26,9 +41,18 @@ from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
from
.priority
import
Priority
,
get_priority
from
.utils
import
get_host_info
,
get_time_str
,
obj_from_dict
,
set_random_seed
# initialize ipu to registor ipu runner to RUNNERS
from
mmcv.device
import
ipu
# isort:skip # noqa
__all__
=
[
'BaseRunner'
,
'Runner'
,
'EpochBasedRunner'
,
'IterBasedRunner'
,
'LogBuffer'
,
'HOOKS'
,
'Hook'
,
'CheckpointHook'
,
'ClosureHook'
,
'LrUpdaterHook'
,
'FixedLrUpdaterHook'
,
'StepLrUpdaterHook'
,
'ExpLrUpdaterHook'
,
'PolyLrUpdaterHook'
,
'InvLrUpdaterHook'
,
'CosineAnnealingLrUpdaterHook'
,
'FlatCosineAnnealingLrUpdaterHook'
,
'CosineRestartLrUpdaterHook'
,
'CyclicLrUpdaterHook'
,
'OneCycleLrUpdaterHook'
,
'MomentumUpdaterHook'
,
'StepMomentumUpdaterHook'
,
'CosineAnnealingMomentumUpdaterHook'
,
'CyclicMomentumUpdaterHook'
,
'OneCycleMomentumUpdaterHook'
,
'OptimizerHook'
,
'IterTimerHook'
,
'DistSamplerSeedHook'
,
'LoggerHook'
,
'PaviLoggerHook'
,
'TextLoggerHook'
,
'TensorboardLoggerHook'
,
'NeptuneLoggerHook'
,
'WandbLoggerHook'
,
'MlflowLoggerHook'
,
...
...
@@ -42,6 +66,8 @@ __all__ = [
'SyncBuffersHook'
,
'EMAHook'
,
'build_runner'
,
'RUNNERS'
,
'allreduce_grads'
,
'allreduce_params'
,
'LossScaler'
,
'CheckpointLoader'
,
'BaseModule'
,
'_load_checkpoint_with_prefix'
,
'EvalHook'
,
'DistEvalHook'
,
'Sequential'
,
'ModuleList'
,
'GradientCumulativeOptimizerHook'
,
'GradientCumulativeFp16OptimizerHook'
,
'DefaultRunnerConstructor'
'ModuleDict'
,
'ModuleList'
,
'GradientCumulativeOptimizerHook'
,
'GradientCumulativeFp16OptimizerHook'
,
'DefaultRunnerConstructor'
,
'SegmindLoggerHook'
,
'LinearAnnealingMomentumUpdaterHook'
,
'LinearAnnealingLrUpdaterHook'
,
'ClearMLLoggerHook'
]
mmcv/runner/base_module.py
View file @
fdeee889
...
...
@@ -4,6 +4,7 @@ import warnings
from
abc
import
ABCMeta
from
collections
import
defaultdict
from
logging
import
FileHandler
from
typing
import
Iterable
,
Optional
import
torch.nn
as
nn
...
...
@@ -18,25 +19,24 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
functionality of parameter initialization. Compared with
``torch.nn.Module``, ``BaseModule`` mainly adds three attributes.
- ``init_cfg``: the config to control the initialization.
- ``init_weights``: The function of parameter
initialization and recording initialization
information.
- ``_params_init_info``: Used to track the parameter
initialization information. This attribute only
exists during executing the ``init_weights``.
- ``init_cfg``: the config to control the initialization.
- ``init_weights``: The function of parameter initialization and recording
initialization information.
- ``_params_init_info``: Used to track the parameter initialization
information. This attribute only exists during executing the
``init_weights``.
Args:
init_cfg (dict, optional): Initialization config dict.
"""
def
__init__
(
self
,
init_cfg
=
None
):
def
__init__
(
self
,
init_cfg
:
Optional
[
dict
]
=
None
):
"""Initialize BaseModule, inherited from `torch.nn.Module`"""
# NOTE init_cfg can be defined in different levels, but init_cfg
# in low levels has a higher priority.
super
(
BaseModule
,
self
).
__init__
()
super
().
__init__
()
# define default value of init_cfg instead of hard code
# in init_weights() function
self
.
_is_init
=
False
...
...
@@ -50,10 +50,10 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
# self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
@
property
def
is_init
(
self
):
def
is_init
(
self
)
->
bool
:
return
self
.
_is_init
def
init_weights
(
self
):
def
init_weights
(
self
)
->
None
:
"""Initialize the weights."""
is_top_level_module
=
False
...
...
@@ -68,7 +68,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
# which indicates whether the parameter has been modified.
# this attribute would be deleted after all parameters
# is initialized.
self
.
_params_init_info
=
defaultdict
(
dict
)
self
.
_params_init_info
:
defaultdict
=
defaultdict
(
dict
)
is_top_level_module
=
True
# Initialize the `_params_init_info`,
...
...
@@ -134,7 +134,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
del
sub_module
.
_params_init_info
@
master_only
def
_dump_init_info
(
self
,
logger_name
)
:
def
_dump_init_info
(
self
,
logger_name
:
str
)
->
None
:
"""Dump the initialization information to a file named
`initialization.log.json` in workdir.
...
...
@@ -177,7 +177,7 @@ class Sequential(BaseModule, nn.Sequential):
init_cfg (dict, optional): Initialization config dict.
"""
def
__init__
(
self
,
*
args
,
init_cfg
=
None
):
def
__init__
(
self
,
*
args
,
init_cfg
:
Optional
[
dict
]
=
None
):
BaseModule
.
__init__
(
self
,
init_cfg
)
nn
.
Sequential
.
__init__
(
self
,
*
args
)
...
...
@@ -190,6 +190,24 @@ class ModuleList(BaseModule, nn.ModuleList):
init_cfg (dict, optional): Initialization config dict.
"""
def
__init__
(
self
,
modules
=
None
,
init_cfg
=
None
):
def
__init__
(
self
,
modules
:
Optional
[
Iterable
]
=
None
,
init_cfg
:
Optional
[
dict
]
=
None
):
BaseModule
.
__init__
(
self
,
init_cfg
)
nn
.
ModuleList
.
__init__
(
self
,
modules
)
class
ModuleDict
(
BaseModule
,
nn
.
ModuleDict
):
"""ModuleDict in openmmlab.
Args:
modules (dict, optional): a mapping (dictionary) of (string: module)
or an iterable of key-value pairs of type (string, module).
init_cfg (dict, optional): Initialization config dict.
"""
def
__init__
(
self
,
modules
:
Optional
[
dict
]
=
None
,
init_cfg
:
Optional
[
dict
]
=
None
):
BaseModule
.
__init__
(
self
,
init_cfg
)
nn
.
ModuleDict
.
__init__
(
self
,
modules
)
mmcv/runner/base_runner.py
View file @
fdeee889
...
...
@@ -4,9 +4,13 @@ import logging
import
os.path
as
osp
import
warnings
from
abc
import
ABCMeta
,
abstractmethod
from
collections
import
OrderedDict
from
typing
import
(
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
no_type_check
)
import
torch
from
torch.optim
import
Optimizer
from
torch.utils.data
import
DataLoader
import
mmcv
from
..parallel
import
is_module_wrapper
...
...
@@ -49,20 +53,22 @@ class BaseRunner(metaclass=ABCMeta):
"""
def
__init__
(
self
,
model
,
batch_processor
=
None
,
optimizer
=
None
,
work_dir
=
None
,
logger
=
None
,
meta
=
None
,
max_iters
=
None
,
max_epochs
=
None
)
:
model
:
torch
.
nn
.
Module
,
batch_processor
:
Optional
[
Callable
]
=
None
,
optimizer
:
Union
[
Dict
,
torch
.
optim
.
Optimizer
,
None
]
=
None
,
work_dir
:
Optional
[
str
]
=
None
,
logger
:
Optional
[
logging
.
Logger
]
=
None
,
meta
:
Optional
[
Dict
]
=
None
,
max_iters
:
Optional
[
int
]
=
None
,
max_epochs
:
Optional
[
int
]
=
None
)
->
None
:
if
batch_processor
is
not
None
:
if
not
callable
(
batch_processor
):
raise
TypeError
(
'batch_processor must be callable, '
f
'but got
{
type
(
batch_processor
)
}
'
)
warnings
.
warn
(
'batch_processor is deprecated, please implement '
'train_step() and val_step() in the model instead.'
)
warnings
.
warn
(
'batch_processor is deprecated, please implement '
'train_step() and val_step() in the model instead.'
,
DeprecationWarning
)
# raise an error is `batch_processor` is not None and
# `model.train_step()` exists.
if
is_module_wrapper
(
model
):
...
...
@@ -104,8 +110,8 @@ class BaseRunner(metaclass=ABCMeta):
self
.
logger
=
logger
self
.
meta
=
meta
# create work_dir
if
mmcv
.
is_str
(
work_dir
):
self
.
work_dir
=
osp
.
abspath
(
work_dir
)
if
isinstance
(
work_dir
,
str
):
self
.
work_dir
:
Optional
[
str
]
=
osp
.
abspath
(
work_dir
)
mmcv
.
mkdir_or_exist
(
self
.
work_dir
)
elif
work_dir
is
None
:
self
.
work_dir
=
None
...
...
@@ -120,8 +126,8 @@ class BaseRunner(metaclass=ABCMeta):
self
.
_rank
,
self
.
_world_size
=
get_dist_info
()
self
.
timestamp
=
get_time_str
()
self
.
mode
=
None
self
.
_hooks
=
[]
self
.
mode
:
Optional
[
str
]
=
None
self
.
_hooks
:
List
[
Hook
]
=
[]
self
.
_epoch
=
0
self
.
_iter
=
0
self
.
_inner_iter
=
0
...
...
@@ -136,38 +142,38 @@ class BaseRunner(metaclass=ABCMeta):
self
.
log_buffer
=
LogBuffer
()
@
property
def
model_name
(
self
):
def
model_name
(
self
)
->
str
:
"""str: Name of the model, usually the module class name."""
return
self
.
_model_name
@
property
def
rank
(
self
):
def
rank
(
self
)
->
int
:
"""int: Rank of current process. (distributed training)"""
return
self
.
_rank
@
property
def
world_size
(
self
):
def
world_size
(
self
)
->
int
:
"""int: Number of processes participating in the job.
(distributed training)"""
return
self
.
_world_size
@
property
def
hooks
(
self
):
def
hooks
(
self
)
->
List
[
Hook
]
:
"""list[:obj:`Hook`]: A list of registered hooks."""
return
self
.
_hooks
@
property
def
epoch
(
self
):
def
epoch
(
self
)
->
int
:
"""int: Current epoch."""
return
self
.
_epoch
@
property
def
iter
(
self
):
def
iter
(
self
)
->
int
:
"""int: Current iteration."""
return
self
.
_iter
@
property
def
inner_iter
(
self
):
def
inner_iter
(
self
)
->
int
:
"""int: Iteration in an epoch."""
return
self
.
_inner_iter
...
...
@@ -190,26 +196,28 @@ class BaseRunner(metaclass=ABCMeta):
pass
@
abstractmethod
def
run
(
self
,
data_loaders
,
workflow
,
**
kwargs
):
def
run
(
self
,
data_loaders
:
List
[
DataLoader
],
workflow
:
List
[
Tuple
[
str
,
int
]],
**
kwargs
)
->
Any
:
pass
@
abstractmethod
def
save_checkpoint
(
self
,
out_dir
,
filename_tmpl
,
save_optimizer
=
True
,
meta
=
None
,
create_symlink
=
True
)
:
out_dir
:
str
,
filename_tmpl
:
str
,
save_optimizer
:
bool
=
True
,
meta
:
Optional
[
Dict
]
=
None
,
create_symlink
:
bool
=
True
)
->
None
:
pass
def
current_lr
(
self
):
def
current_lr
(
self
)
->
Union
[
List
[
float
],
Dict
[
str
,
List
[
float
]]]
:
"""Get current learning rates.
Returns:
list[float] | dict[str, list[float]]: Current learning rates of all
param groups. If the runner has a dict of optimizers, this
method
will return a dict.
param groups. If the runner has a dict of optimizers, this
method
will return a dict.
"""
lr
:
Union
[
List
[
float
],
Dict
[
str
,
List
[
float
]]]
if
isinstance
(
self
.
optimizer
,
torch
.
optim
.
Optimizer
):
lr
=
[
group
[
'lr'
]
for
group
in
self
.
optimizer
.
param_groups
]
elif
isinstance
(
self
.
optimizer
,
dict
):
...
...
@@ -221,13 +229,13 @@ class BaseRunner(metaclass=ABCMeta):
'lr is not applicable because optimizer does not exist.'
)
return
lr
def
current_momentum
(
self
):
def
current_momentum
(
self
)
->
Union
[
List
[
float
],
Dict
[
str
,
List
[
float
]]]
:
"""Get current momentums.
Returns:
list[float] | dict[str, list[float]]: Current momentums of all
param groups. If the runner has a dict of optimizers, this
method
will return a dict.
param groups. If the runner has a dict of optimizers, this
method
will return a dict.
"""
def
_get_momentum
(
optimizer
):
...
...
@@ -252,7 +260,9 @@ class BaseRunner(metaclass=ABCMeta):
momentums
[
name
]
=
_get_momentum
(
optim
)
return
momentums
def
register_hook
(
self
,
hook
,
priority
=
'NORMAL'
):
def
register_hook
(
self
,
hook
:
Hook
,
priority
:
Union
[
int
,
str
,
Priority
]
=
'NORMAL'
)
->
None
:
"""Register a hook into the hook list.
The hook will be inserted into a priority queue, with the specified
...
...
@@ -269,25 +279,25 @@ class BaseRunner(metaclass=ABCMeta):
if
hasattr
(
hook
,
'priority'
):
raise
ValueError
(
'"priority" is a reserved attribute for hooks'
)
priority
=
get_priority
(
priority
)
hook
.
priority
=
priority
hook
.
priority
=
priority
# type: ignore
# insert the hook to a sorted list
inserted
=
False
for
i
in
range
(
len
(
self
.
_hooks
)
-
1
,
-
1
,
-
1
):
if
priority
>=
self
.
_hooks
[
i
].
priority
:
if
priority
>=
self
.
_hooks
[
i
].
priority
:
# type: ignore
self
.
_hooks
.
insert
(
i
+
1
,
hook
)
inserted
=
True
break
if
not
inserted
:
self
.
_hooks
.
insert
(
0
,
hook
)
def
register_hook_from_cfg
(
self
,
hook_cfg
)
:
def
register_hook_from_cfg
(
self
,
hook_cfg
:
Dict
)
->
None
:
"""Register a hook from its cfg.
Args:
hook_cfg (dict): Hook config. It should have at least keys 'type'
and 'priority' indicating its type and priority.
Note
s
:
Note:
The specific hook class to register should not use 'type' and
'priority' arguments during initialization.
"""
...
...
@@ -296,7 +306,7 @@ class BaseRunner(metaclass=ABCMeta):
hook
=
mmcv
.
build_from_cfg
(
hook_cfg
,
HOOKS
)
self
.
register_hook
(
hook
,
priority
=
priority
)
def
call_hook
(
self
,
fn_name
)
:
def
call_hook
(
self
,
fn_name
:
str
)
->
None
:
"""Call all hooks.
Args:
...
...
@@ -306,14 +316,14 @@ class BaseRunner(metaclass=ABCMeta):
for
hook
in
self
.
_hooks
:
getattr
(
hook
,
fn_name
)(
self
)
def
get_hook_info
(
self
):
def
get_hook_info
(
self
)
->
str
:
# Get hooks info in each stage
stage_hook_map
=
{
stage
:
[]
for
stage
in
Hook
.
stages
}
stage_hook_map
:
Dict
[
str
,
list
]
=
{
stage
:
[]
for
stage
in
Hook
.
stages
}
for
hook
in
self
.
hooks
:
try
:
priority
=
Priority
(
hook
.
priority
).
name
priority
=
Priority
(
hook
.
priority
).
name
# type: ignore
except
ValueError
:
priority
=
hook
.
priority
priority
=
hook
.
priority
# type: ignore
classname
=
hook
.
__class__
.
__name__
hook_info
=
f
'(
{
priority
:
<
12
}
)
{
classname
:
<
35
}
'
for
trigger_stage
in
hook
.
get_triggered_stages
():
...
...
@@ -329,11 +339,13 @@ class BaseRunner(metaclass=ABCMeta):
stage_hook_infos
.
append
(
info
)
return
'
\n
'
.
join
(
stage_hook_infos
)
def
load_checkpoint
(
self
,
filename
,
map_location
=
'cpu'
,
strict
=
False
,
revise_keys
=
[(
r
'^module.'
,
''
)]):
def
load_checkpoint
(
self
,
filename
:
str
,
map_location
:
Union
[
str
,
Callable
]
=
'cpu'
,
strict
:
bool
=
False
,
revise_keys
:
List
=
[(
r
'^module.'
,
''
)],
)
->
Union
[
Dict
,
OrderedDict
]:
return
load_checkpoint
(
self
.
model
,
filename
,
...
...
@@ -342,10 +354,11 @@ class BaseRunner(metaclass=ABCMeta):
self
.
logger
,
revise_keys
=
revise_keys
)
@
no_type_check
def
resume
(
self
,
checkpoint
,
resume_optimizer
=
True
,
map_location
=
'default'
):
checkpoint
:
str
,
resume_optimizer
:
bool
=
True
,
map_location
:
Union
[
str
,
Callable
]
=
'default'
)
->
None
:
if
map_location
==
'default'
:
if
torch
.
cuda
.
is_available
():
device_id
=
torch
.
cuda
.
current_device
()
...
...
@@ -396,7 +409,7 @@ class BaseRunner(metaclass=ABCMeta):
self
.
logger
.
info
(
'resumed epoch %d, iter %d'
,
self
.
epoch
,
self
.
iter
)
def
register_lr_hook
(
self
,
lr_config
)
:
def
register_lr_hook
(
self
,
lr_config
:
Union
[
Dict
,
Hook
,
None
])
->
None
:
if
lr_config
is
None
:
return
elif
isinstance
(
lr_config
,
dict
):
...
...
@@ -417,7 +430,8 @@ class BaseRunner(metaclass=ABCMeta):
hook
=
lr_config
self
.
register_hook
(
hook
,
priority
=
'VERY_HIGH'
)
def
register_momentum_hook
(
self
,
momentum_config
):
def
register_momentum_hook
(
self
,
momentum_config
:
Union
[
Dict
,
Hook
,
None
])
->
None
:
if
momentum_config
is
None
:
return
if
isinstance
(
momentum_config
,
dict
):
...
...
@@ -438,7 +452,8 @@ class BaseRunner(metaclass=ABCMeta):
hook
=
momentum_config
self
.
register_hook
(
hook
,
priority
=
'HIGH'
)
def
register_optimizer_hook
(
self
,
optimizer_config
):
def
register_optimizer_hook
(
self
,
optimizer_config
:
Union
[
Dict
,
Hook
,
None
])
->
None
:
if
optimizer_config
is
None
:
return
if
isinstance
(
optimizer_config
,
dict
):
...
...
@@ -448,7 +463,8 @@ class BaseRunner(metaclass=ABCMeta):
hook
=
optimizer_config
self
.
register_hook
(
hook
,
priority
=
'ABOVE_NORMAL'
)
def
register_checkpoint_hook
(
self
,
checkpoint_config
):
def
register_checkpoint_hook
(
self
,
checkpoint_config
:
Union
[
Dict
,
Hook
,
None
])
->
None
:
if
checkpoint_config
is
None
:
return
if
isinstance
(
checkpoint_config
,
dict
):
...
...
@@ -458,7 +474,7 @@ class BaseRunner(metaclass=ABCMeta):
hook
=
checkpoint_config
self
.
register_hook
(
hook
,
priority
=
'NORMAL'
)
def
register_logger_hooks
(
self
,
log_config
)
:
def
register_logger_hooks
(
self
,
log_config
:
Optional
[
Dict
])
->
None
:
if
log_config
is
None
:
return
log_interval
=
log_config
[
'interval'
]
...
...
@@ -467,7 +483,10 @@ class BaseRunner(metaclass=ABCMeta):
info
,
HOOKS
,
default_args
=
dict
(
interval
=
log_interval
))
self
.
register_hook
(
logger_hook
,
priority
=
'VERY_LOW'
)
def
register_timer_hook
(
self
,
timer_config
):
def
register_timer_hook
(
self
,
timer_config
:
Union
[
Dict
,
Hook
,
None
],
)
->
None
:
if
timer_config
is
None
:
return
if
isinstance
(
timer_config
,
dict
):
...
...
@@ -477,7 +496,8 @@ class BaseRunner(metaclass=ABCMeta):
hook
=
timer_config
self
.
register_hook
(
hook
,
priority
=
'LOW'
)
def
register_custom_hooks
(
self
,
custom_config
):
def
register_custom_hooks
(
self
,
custom_config
:
Union
[
List
,
Dict
,
Hook
,
None
])
->
None
:
if
custom_config
is
None
:
return
...
...
@@ -490,7 +510,10 @@ class BaseRunner(metaclass=ABCMeta):
else
:
self
.
register_hook
(
item
,
priority
=
'NORMAL'
)
def
register_profiler_hook
(
self
,
profiler_config
):
def
register_profiler_hook
(
self
,
profiler_config
:
Union
[
Dict
,
Hook
,
None
],
)
->
None
:
if
profiler_config
is
None
:
return
if
isinstance
(
profiler_config
,
dict
):
...
...
@@ -500,14 +523,15 @@ class BaseRunner(metaclass=ABCMeta):
hook
=
profiler_config
self
.
register_hook
(
hook
)
def
register_training_hooks
(
self
,
lr_config
,
optimizer_config
=
None
,
checkpoint_config
=
None
,
log_config
=
None
,
momentum_config
=
None
,
timer_config
=
dict
(
type
=
'IterTimerHook'
),
custom_hooks_config
=
None
):
def
register_training_hooks
(
self
,
lr_config
:
Union
[
Dict
,
Hook
,
None
],
optimizer_config
:
Union
[
Dict
,
Hook
,
None
]
=
None
,
checkpoint_config
:
Union
[
Dict
,
Hook
,
None
]
=
None
,
log_config
:
Optional
[
Dict
]
=
None
,
momentum_config
:
Union
[
Dict
,
Hook
,
None
]
=
None
,
timer_config
:
Union
[
Dict
,
Hook
]
=
dict
(
type
=
'IterTimerHook'
),
custom_hooks_config
:
Union
[
List
,
Dict
,
Hook
,
None
]
=
None
)
->
None
:
"""Register default and custom hooks for training.
Default and custom hooks include:
...
...
mmcv/runner/builder.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
import
copy
from
typing
import
Optional
from
..utils
import
Registry
...
...
@@ -7,11 +8,11 @@ RUNNERS = Registry('runner')
RUNNER_BUILDERS
=
Registry
(
'runner builder'
)
def
build_runner_constructor
(
cfg
):
def
build_runner_constructor
(
cfg
:
dict
):
return
RUNNER_BUILDERS
.
build
(
cfg
)
def
build_runner
(
cfg
,
default_args
=
None
):
def
build_runner
(
cfg
:
dict
,
default_args
:
Optional
[
dict
]
=
None
):
runner_cfg
=
copy
.
deepcopy
(
cfg
)
constructor_type
=
runner_cfg
.
pop
(
'constructor'
,
'DefaultRunnerConstructor'
)
...
...
mmcv/runner/checkpoint.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
import
io
import
logging
import
os
import
os.path
as
osp
import
pkgutil
...
...
@@ -9,8 +10,10 @@ import warnings
from
collections
import
OrderedDict
from
importlib
import
import_module
from
tempfile
import
TemporaryDirectory
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
import
torchvision
from
torch.optim
import
Optimizer
...
...
@@ -18,7 +21,7 @@ import mmcv
from
..fileio
import
FileClient
from
..fileio
import
load
as
load_file
from
..parallel
import
is_module_wrapper
from
..utils
import
load_url
,
mkdir_or_exist
from
..utils
import
digit_version
,
load_url
,
mkdir_or_exist
from
.dist_utils
import
get_dist_info
ENV_MMCV_HOME
=
'MMCV_HOME'
...
...
@@ -26,7 +29,7 @@ ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR
=
'~/.cache'
def
_get_mmcv_home
():
def
_get_mmcv_home
()
->
str
:
mmcv_home
=
os
.
path
.
expanduser
(
os
.
getenv
(
ENV_MMCV_HOME
,
...
...
@@ -37,7 +40,10 @@ def _get_mmcv_home():
return
mmcv_home
def
load_state_dict
(
module
,
state_dict
,
strict
=
False
,
logger
=
None
):
def
load_state_dict
(
module
:
nn
.
Module
,
state_dict
:
Union
[
dict
,
OrderedDict
],
strict
:
bool
=
False
,
logger
:
Optional
[
logging
.
Logger
]
=
None
)
->
None
:
"""Load state_dict to a module.
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
...
...
@@ -46,21 +52,21 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
Args:
module (Module): Module that receives the state_dict.
state_dict (OrderedDict): Weights.
state_dict (
dict or
OrderedDict): Weights.
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used.
"""
unexpected_keys
=
[]
all_missing_keys
=
[]
err_msg
=
[]
unexpected_keys
:
List
[
str
]
=
[]
all_missing_keys
:
List
[
str
]
=
[]
err_msg
:
List
[
str
]
=
[]
metadata
=
getattr
(
state_dict
,
'_metadata'
,
None
)
state_dict
=
state_dict
.
copy
()
state_dict
=
state_dict
.
copy
()
# type: ignore
if
metadata
is
not
None
:
state_dict
.
_metadata
=
metadata
state_dict
.
_metadata
=
metadata
# type: ignore
# use _load_from_state_dict to enable checkpoint version control
def
load
(
module
,
prefix
=
''
):
...
...
@@ -78,7 +84,8 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
load
(
child
,
prefix
+
name
+
'.'
)
load
(
module
)
load
=
None
# break load->load reference cycle
# break load->load reference cycle
load
=
None
# type: ignore
# ignore "num_batches_tracked" of BN layers
missing_keys
=
[
...
...
@@ -96,7 +103,7 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
if
len
(
err_msg
)
>
0
and
rank
==
0
:
err_msg
.
insert
(
0
,
'The model and loaded state dict do not match exactly
\n
'
)
err_msg
=
'
\n
'
.
join
(
err_msg
)
err_msg
=
'
\n
'
.
join
(
err_msg
)
# type: ignore
if
strict
:
raise
RuntimeError
(
err_msg
)
elif
logger
is
not
None
:
...
...
@@ -106,14 +113,48 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
def
get_torchvision_models
():
model_urls
=
dict
()
for
_
,
name
,
ispkg
in
pkgutil
.
walk_packages
(
torchvision
.
models
.
__path__
):
if
ispkg
:
continue
_zoo
=
import_module
(
f
'torchvision.models.
{
name
}
'
)
if
hasattr
(
_zoo
,
'model_urls'
):
_urls
=
getattr
(
_zoo
,
'model_urls'
)
model_urls
.
update
(
_urls
)
if
digit_version
(
torchvision
.
__version__
)
<
digit_version
(
'0.13.0a0'
):
model_urls
=
dict
()
# When the version of torchvision is lower than 0.13, the model url is
# not declared in `torchvision.model.__init__.py`, so we need to
# iterate through `torchvision.models.__path__` to get the url for each
# model.
for
_
,
name
,
ispkg
in
pkgutil
.
walk_packages
(
torchvision
.
models
.
__path__
):
if
ispkg
:
continue
_zoo
=
import_module
(
f
'torchvision.models.
{
name
}
'
)
if
hasattr
(
_zoo
,
'model_urls'
):
_urls
=
getattr
(
_zoo
,
'model_urls'
)
model_urls
.
update
(
_urls
)
else
:
# Since torchvision bumps to v0.13, the weight loading logic,
# model keys and model urls have been changed. Here the URLs of old
# version is loaded to avoid breaking back compatibility. If the
# torchvision version>=0.13.0, new URLs will be added. Users can get
# the resnet50 checkpoint by setting 'resnet50.imagent1k_v1',
# 'resnet50' or 'ResNet50_Weights.IMAGENET1K_V1' in the config.
json_path
=
osp
.
join
(
mmcv
.
__path__
[
0
],
'model_zoo/torchvision_0.12.json'
)
model_urls
=
mmcv
.
load
(
json_path
)
for
cls_name
,
cls
in
torchvision
.
models
.
__dict__
.
items
():
# The name of torchvision model weights classes ends with
# `_Weights` such as `ResNet18_Weights`. However, some model weight
# classes, such as `MNASNet0_75_Weights` does not have any urls in
# torchvision 0.13.0 and cannot be iterated. Here we simply check
# `DEFAULT` attribute to ensure the class is not empty.
if
(
not
cls_name
.
endswith
(
'_Weights'
)
or
not
hasattr
(
cls
,
'DEFAULT'
)):
continue
# Since `cls.DEFAULT` can not be accessed by iterating cls, we set
# default urls explicitly.
cls_key
=
cls_name
.
replace
(
'_Weights'
,
''
).
lower
()
model_urls
[
f
'
{
cls_key
}
.default'
]
=
cls
.
DEFAULT
.
url
for
weight_enum
in
cls
:
cls_key
=
cls_name
.
replace
(
'_Weights'
,
''
).
lower
()
cls_key
=
f
'
{
cls_key
}
.
{
weight_enum
.
name
.
lower
()
}
'
model_urls
[
cls_key
]
=
weight_enum
.
url
return
model_urls
...
...
@@ -147,7 +188,7 @@ def get_deprecated_model_names():
return
deprecate_urls
def
_process_mmcls_checkpoint
(
checkpoint
)
:
def
_process_mmcls_checkpoint
(
checkpoint
:
Dict
)
->
Dict
:
if
'state_dict'
in
checkpoint
:
state_dict
=
checkpoint
[
'state_dict'
]
else
:
...
...
@@ -166,10 +207,13 @@ def _process_mmcls_checkpoint(checkpoint):
class
CheckpointLoader
:
"""A general checkpoint loader to manage all schemes."""
_schemes
=
{}
_schemes
:
dict
=
{}
@
classmethod
def
_register_scheme
(
cls
,
prefixes
,
loader
,
force
=
False
):
def
_register_scheme
(
cls
,
prefixes
:
Union
[
str
,
List
,
Tuple
],
loader
:
Callable
,
force
:
bool
=
False
)
->
None
:
if
isinstance
(
prefixes
,
str
):
prefixes
=
[
prefixes
]
else
:
...
...
@@ -186,13 +230,16 @@ class CheckpointLoader:
sorted
(
cls
.
_schemes
.
items
(),
key
=
lambda
t
:
t
[
0
],
reverse
=
True
))
@
classmethod
def
register_scheme
(
cls
,
prefixes
,
loader
=
None
,
force
=
False
):
def
register_scheme
(
cls
,
prefixes
:
Union
[
str
,
List
[
str
],
Tuple
[
str
,
...]],
loader
:
Optional
[
Callable
]
=
None
,
force
:
bool
=
False
)
->
Callable
:
"""Register a loader to CheckpointLoader.
This method can be used as a normal class method or a decorator.
Args:
prefixes (str or
list[str] or tupl
e[str]):
prefixes (str or
Sequenc
e[str]):
The prefix of the registered loader.
loader (function, optional): The loader function to be registered.
When this method is used as a decorator, loader is None.
...
...
@@ -203,7 +250,7 @@ class CheckpointLoader:
if
loader
is
not
None
:
cls
.
_register_scheme
(
prefixes
,
loader
,
force
=
force
)
return
return
# type: ignore
def
_register
(
loader_cls
):
cls
.
_register_scheme
(
prefixes
,
loader_cls
,
force
=
force
)
...
...
@@ -212,7 +259,7 @@ class CheckpointLoader:
return
_register
@
classmethod
def
_get_checkpoint_loader
(
cls
,
path
):
def
_get_checkpoint_loader
(
cls
,
path
:
str
):
"""Finds a loader that supports the given path. Falls back to the local
loader if no other loader is found.
...
...
@@ -220,15 +267,22 @@ class CheckpointLoader:
path (str): checkpoint path
Returns:
loader (function)
: checkpoint loader
callable
: checkpoint loader
"""
for
p
in
cls
.
_schemes
:
if
path
.
startswith
(
p
):
# use regular match to handle some cases that where the prefix of
# loader has a prefix. For example, both 's3://path' and
# 'open-mmlab:s3://path' should return `load_from_ceph`
if
re
.
match
(
p
,
path
)
is
not
None
:
return
cls
.
_schemes
[
p
]
@
classmethod
def
load_checkpoint
(
cls
,
filename
,
map_location
=
None
,
logger
=
None
):
def
load_checkpoint
(
cls
,
filename
:
str
,
map_location
:
Union
[
str
,
Callable
,
None
]
=
None
,
logger
:
Optional
[
logging
.
Logger
]
=
None
)
->
Union
[
dict
,
OrderedDict
]:
"""load checkpoint through URL scheme path.
Args:
...
...
@@ -243,14 +297,17 @@ class CheckpointLoader:
"""
checkpoint_loader
=
cls
.
_get_checkpoint_loader
(
filename
)
class_name
=
checkpoint_loader
.
__name__
class_name
=
checkpoint_loader
.
__name__
# type: ignore
mmcv
.
print_log
(
f
'load checkpoint from
{
class_name
[
10
:]
}
path:
{
filename
}
'
,
logger
)
return
checkpoint_loader
(
filename
,
map_location
)
return
checkpoint_loader
(
filename
,
map_location
)
# type: ignore
@
CheckpointLoader
.
register_scheme
(
prefixes
=
''
)
def
load_from_local
(
filename
,
map_location
):
def
load_from_local
(
filename
:
str
,
map_location
:
Union
[
str
,
Callable
,
None
]
=
None
,
)
->
Union
[
dict
,
OrderedDict
]:
"""load checkpoint by local file path.
Args:
...
...
@@ -260,15 +317,18 @@ def load_from_local(filename, map_location):
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
filename
=
osp
.
expanduser
(
filename
)
if
not
osp
.
isfile
(
filename
):
raise
IO
Error
(
f
'
{
filename
}
is
not
a checkpoint file
'
)
raise
FileNotFound
Error
(
f
'
{
filename
}
can
not
be found.
'
)
checkpoint
=
torch
.
load
(
filename
,
map_location
=
map_location
)
return
checkpoint
@
CheckpointLoader
.
register_scheme
(
prefixes
=
(
'http://'
,
'https://'
))
def
load_from_http
(
filename
,
map_location
=
None
,
model_dir
=
None
):
def
load_from_http
(
filename
:
str
,
map_location
:
Union
[
str
,
Callable
,
None
]
=
None
,
model_dir
:
Optional
[
str
]
=
None
)
->
Union
[
dict
,
OrderedDict
]:
"""load checkpoint through HTTP or HTTPS scheme path. In distributed
setting, this function only download checkpoint at local rank 0.
...
...
@@ -276,7 +336,7 @@ def load_from_http(filename, map_location=None, model_dir=None):
filename (str): checkpoint file path with modelzoo or
torchvision prefix
map_location (str, optional): Same as :func:`torch.load`.
model_dir (str
ing
, optional): directory in which to save the object,
model_dir (str, optional): directory in which to save the object,
Default: None
Returns:
...
...
@@ -295,7 +355,10 @@ def load_from_http(filename, map_location=None, model_dir=None):
@
CheckpointLoader
.
register_scheme
(
prefixes
=
'pavi://'
)
def
load_from_pavi
(
filename
,
map_location
=
None
):
def
load_from_pavi
(
filename
:
str
,
map_location
:
Union
[
str
,
Callable
,
None
]
=
None
,
)
->
Union
[
dict
,
OrderedDict
]:
"""load checkpoint through the file path prefixed with pavi. In distributed
setting, this function download ckpt at all ranks to different temporary
directories.
...
...
@@ -326,16 +389,23 @@ def load_from_pavi(filename, map_location=None):
return
checkpoint
@
CheckpointLoader
.
register_scheme
(
prefixes
=
's3://'
)
def
load_from_ceph
(
filename
,
map_location
=
None
,
backend
=
'petrel'
):
@
CheckpointLoader
.
register_scheme
(
prefixes
=
r
'(\S+\:)?s3://'
)
def
load_from_ceph
(
filename
:
str
,
map_location
:
Union
[
str
,
Callable
,
None
]
=
None
,
backend
:
str
=
'petrel'
)
->
Union
[
dict
,
OrderedDict
]:
"""load checkpoint through the file path prefixed with s3. In distributed
setting, this function download ckpt at all ranks to different temporary
directories.
Note:
Since v1.4.1, the registered scheme prefixes have been enhanced to
support bucket names in the path prefix, e.g. 's3://xx.xx/xx.path',
'bucket1:s3://xx.xx/xx.path'.
Args:
filename (str): checkpoint file path with s3 prefix
map_location (str, optional): Same as :func:`torch.load`.
backend (str
, optional
): The storage backend type. Options are 'ceph',
backend (str): The storage backend type. Options are 'ceph',
'petrel'. Default: 'petrel'.
.. warning::
...
...
@@ -351,7 +421,8 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
if
backend
==
'ceph'
:
warnings
.
warn
(
'CephBackend will be deprecated, please use PetrelBackend instead'
)
'CephBackend will be deprecated, please use PetrelBackend instead'
,
DeprecationWarning
)
# CephClient and PetrelBackend have the same prefix 's3://' and the latter
# will be chosen as default. If PetrelBackend can not be instantiated
...
...
@@ -368,7 +439,10 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
@
CheckpointLoader
.
register_scheme
(
prefixes
=
(
'modelzoo://'
,
'torchvision://'
))
def
load_from_torchvision
(
filename
,
map_location
=
None
):
def
load_from_torchvision
(
filename
:
str
,
map_location
:
Union
[
str
,
Callable
,
None
]
=
None
,
)
->
Union
[
dict
,
OrderedDict
]:
"""load checkpoint through the file path prefixed with modelzoo or
torchvision.
...
...
@@ -382,16 +456,25 @@ def load_from_torchvision(filename, map_location=None):
"""
model_urls
=
get_torchvision_models
()
if
filename
.
startswith
(
'modelzoo://'
):
warnings
.
warn
(
'The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead'
)
warnings
.
warn
(
'The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead'
,
DeprecationWarning
)
model_name
=
filename
[
11
:]
else
:
model_name
=
filename
[
14
:]
# Support getting model urls in the same way as torchvision
# `ResNet50_Weights.IMAGENET1K_V1` will be mapped to
# resnet50.imagenet1k_v1.
model_name
=
model_name
.
lower
().
replace
(
'_weights'
,
''
)
return
load_from_http
(
model_urls
[
model_name
],
map_location
=
map_location
)
@
CheckpointLoader
.
register_scheme
(
prefixes
=
(
'open-mmlab://'
,
'openmmlab://'
))
def
load_from_openmmlab
(
filename
,
map_location
=
None
):
def
load_from_openmmlab
(
filename
:
str
,
map_location
:
Union
[
str
,
Callable
,
None
]
=
None
,
)
->
Union
[
dict
,
OrderedDict
]:
"""load checkpoint through the file path prefixed with open-mmlab or
openmmlab.
...
...
@@ -415,8 +498,10 @@ def load_from_openmmlab(filename, map_location=None):
deprecated_urls
=
get_deprecated_model_names
()
if
model_name
in
deprecated_urls
:
warnings
.
warn
(
f
'
{
prefix_str
}{
model_name
}
is deprecated in favor '
f
'of
{
prefix_str
}{
deprecated_urls
[
model_name
]
}
'
)
warnings
.
warn
(
f
'
{
prefix_str
}{
model_name
}
is deprecated in favor '
f
'of
{
prefix_str
}{
deprecated_urls
[
model_name
]
}
'
,
DeprecationWarning
)
model_name
=
deprecated_urls
[
model_name
]
model_url
=
model_urls
[
model_name
]
# check if is url
...
...
@@ -425,13 +510,16 @@ def load_from_openmmlab(filename, map_location=None):
else
:
filename
=
osp
.
join
(
_get_mmcv_home
(),
model_url
)
if
not
osp
.
isfile
(
filename
):
raise
IO
Error
(
f
'
{
filename
}
is
not
a checkpoint file
'
)
raise
FileNotFound
Error
(
f
'
{
filename
}
can
not
be found.
'
)
checkpoint
=
torch
.
load
(
filename
,
map_location
=
map_location
)
return
checkpoint
@
CheckpointLoader
.
register_scheme
(
prefixes
=
'mmcls://'
)
def
load_from_mmcls
(
filename
,
map_location
=
None
):
def
load_from_mmcls
(
filename
:
str
,
map_location
:
Union
[
str
,
Callable
,
None
]
=
None
,
)
->
Union
[
dict
,
OrderedDict
]:
"""load checkpoint through the file path prefixed with mmcls.
Args:
...
...
@@ -450,7 +538,10 @@ def load_from_mmcls(filename, map_location=None):
return
checkpoint
def
_load_checkpoint
(
filename
,
map_location
=
None
,
logger
=
None
):
def
_load_checkpoint
(
filename
:
str
,
map_location
:
Union
[
str
,
Callable
,
None
]
=
None
,
logger
:
Optional
[
logging
.
Logger
]
=
None
)
->
Union
[
dict
,
OrderedDict
]:
"""Load checkpoint from somewhere (modelzoo, file, url).
Args:
...
...
@@ -470,7 +561,11 @@ def _load_checkpoint(filename, map_location=None, logger=None):
return
CheckpointLoader
.
load_checkpoint
(
filename
,
map_location
,
logger
)
def
_load_checkpoint_with_prefix
(
prefix
,
filename
,
map_location
=
None
):
def
_load_checkpoint_with_prefix
(
prefix
:
str
,
filename
:
str
,
map_location
:
Union
[
str
,
Callable
,
None
]
=
None
,
)
->
Union
[
dict
,
OrderedDict
]:
"""Load partial pretrained model with specific prefix.
Args:
...
...
@@ -503,12 +598,13 @@ def _load_checkpoint_with_prefix(prefix, filename, map_location=None):
return
state_dict
def
load_checkpoint
(
model
,
filename
,
map_location
=
None
,
strict
=
False
,
logger
=
None
,
revise_keys
=
[(
r
'^module\.'
,
''
)]):
def
load_checkpoint
(
model
:
torch
.
nn
.
Module
,
filename
:
str
,
map_location
:
Union
[
str
,
Callable
,
None
]
=
None
,
strict
:
bool
=
False
,
logger
:
Optional
[
logging
.
Logger
]
=
None
,
revise_keys
:
list
=
[(
r
'^module\.'
,
''
)])
->
Union
[
dict
,
OrderedDict
]:
"""Load checkpoint from a file or URI.
Args:
...
...
@@ -553,7 +649,7 @@ def load_checkpoint(model,
return
checkpoint
def
weights_to_cpu
(
state_dict
)
:
def
weights_to_cpu
(
state_dict
:
OrderedDict
)
->
OrderedDict
:
"""Copy a model state_dict to cpu.
Args:
...
...
@@ -566,11 +662,13 @@ def weights_to_cpu(state_dict):
for
key
,
val
in
state_dict
.
items
():
state_dict_cpu
[
key
]
=
val
.
cpu
()
# Keep metadata in state_dict
state_dict_cpu
.
_metadata
=
getattr
(
state_dict
,
'_metadata'
,
OrderedDict
())
state_dict_cpu
.
_metadata
=
getattr
(
# type: ignore
state_dict
,
'_metadata'
,
OrderedDict
())
return
state_dict_cpu
def
_save_to_state_dict
(
module
,
destination
,
prefix
,
keep_vars
):
def
_save_to_state_dict
(
module
:
torch
.
nn
.
Module
,
destination
:
dict
,
prefix
:
str
,
keep_vars
:
bool
)
->
None
:
"""Saves module state to `destination` dictionary.
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
...
...
@@ -590,7 +688,10 @@ def _save_to_state_dict(module, destination, prefix, keep_vars):
destination
[
prefix
+
name
]
=
buf
if
keep_vars
else
buf
.
detach
()
def
get_state_dict
(
module
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
def
get_state_dict
(
module
:
torch
.
nn
.
Module
,
destination
:
Optional
[
OrderedDict
]
=
None
,
prefix
:
str
=
''
,
keep_vars
:
bool
=
False
)
->
OrderedDict
:
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
...
...
@@ -619,10 +720,10 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
# below is the same as torch.nn.Module.state_dict()
if
destination
is
None
:
destination
=
OrderedDict
()
destination
.
_metadata
=
OrderedDict
()
destination
.
_metadata
[
prefix
[:
-
1
]]
=
local_metadata
=
dict
(
destination
.
_metadata
=
OrderedDict
()
# type: ignore
destination
.
_metadata
[
prefix
[:
-
1
]]
=
local_metadata
=
dict
(
# type: ignore
version
=
module
.
_version
)
_save_to_state_dict
(
module
,
destination
,
prefix
,
keep_vars
)
_save_to_state_dict
(
module
,
destination
,
prefix
,
keep_vars
)
# type: ignore
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
get_state_dict
(
...
...
@@ -631,14 +732,14 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
hook_result
=
hook
(
module
,
destination
,
prefix
,
local_metadata
)
if
hook_result
is
not
None
:
destination
=
hook_result
return
destination
return
destination
# type: ignore
def
save_checkpoint
(
model
,
filename
,
optimizer
=
None
,
meta
=
None
,
file_client_args
=
None
)
:
def
save_checkpoint
(
model
:
torch
.
nn
.
Module
,
filename
:
str
,
optimizer
:
Optional
[
Optimizer
]
=
None
,
meta
:
Optional
[
dict
]
=
None
,
file_client_args
:
Optional
[
dict
]
=
None
)
->
None
:
"""Save checkpoint to file.
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
...
...
@@ -669,7 +770,7 @@ def save_checkpoint(model,
checkpoint
=
{
'meta'
:
meta
,
'state_dict'
:
weights_to_cpu
(
get_state_dict
(
model
))
'state_dict'
:
weights_to_cpu
(
get_state_dict
(
model
))
# type: ignore
}
# save optimizer state dict in the checkpoint
if
isinstance
(
optimizer
,
Optimizer
):
...
...
@@ -685,8 +786,7 @@ def save_checkpoint(model,
'file_client_args should be "None" if filename starts with'
f
'"pavi://", but got
{
file_client_args
}
'
)
try
:
from
pavi
import
modelcloud
from
pavi
import
exception
from
pavi
import
exception
,
modelcloud
except
ImportError
:
raise
ImportError
(
'Please install pavi to load checkpoint from modelcloud.'
)
...
...
mmcv/runner/default_constructor.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Optional
from
.builder
import
RUNNER_BUILDERS
,
RUNNERS
...
...
@@ -33,7 +36,7 @@ class DefaultRunnerConstructor:
>>> runner = build_runner(runner_cfg)
"""
def
__init__
(
self
,
runner_cfg
,
default_args
=
None
):
def
__init__
(
self
,
runner_cfg
:
dict
,
default_args
:
Optional
[
dict
]
=
None
):
if
not
isinstance
(
runner_cfg
,
dict
):
raise
TypeError
(
'runner_cfg should be a dict'
,
f
'but got
{
type
(
runner_cfg
)
}
'
)
...
...
mmcv/runner/dist_utils.py
View file @
fdeee889
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import
functools
import
os
import
socket
import
subprocess
from
collections
import
OrderedDict
from
typing
import
Callable
,
List
,
Optional
,
Tuple
import
torch
import
torch.multiprocessing
as
mp
...
...
@@ -10,8 +13,28 @@ from torch import distributed as dist
from
torch._utils
import
(
_flatten_dense_tensors
,
_take_tensors
,
_unflatten_dense_tensors
)
from
mmcv.utils
import
IS_MLU_AVAILABLE
def
init_dist
(
launcher
,
backend
=
'nccl'
,
**
kwargs
):
def
_find_free_port
()
->
str
:
# Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
# Binding to port 0 will cause the OS to find an available port for us
sock
.
bind
((
''
,
0
))
port
=
sock
.
getsockname
()[
1
]
sock
.
close
()
# NOTE: there is still a chance the port could be taken by other processes.
return
port
def
_is_free_port
(
port
:
int
)
->
bool
:
ips
=
socket
.
gethostbyname_ex
(
socket
.
gethostname
())[
-
1
]
ips
.
append
(
'localhost'
)
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
return
all
(
s
.
connect_ex
((
ip
,
port
))
!=
0
for
ip
in
ips
)
def
init_dist
(
launcher
:
str
,
backend
:
str
=
'nccl'
,
**
kwargs
)
->
None
:
if
mp
.
get_start_method
(
allow_none
=
True
)
is
None
:
mp
.
set_start_method
(
'spawn'
)
if
launcher
==
'pytorch'
:
...
...
@@ -24,23 +47,37 @@ def init_dist(launcher, backend='nccl', **kwargs):
raise
ValueError
(
f
'Invalid launcher type:
{
launcher
}
'
)
def
_init_dist_pytorch
(
backend
,
**
kwargs
):
def
_init_dist_pytorch
(
backend
:
str
,
**
kwargs
)
->
None
:
# TODO: use local_rank instead of rank % num_gpus
rank
=
int
(
os
.
environ
[
'RANK'
])
num_gpus
=
torch
.
cuda
.
device_count
()
torch
.
cuda
.
set_device
(
rank
%
num_gpus
)
dist
.
init_process_group
(
backend
=
backend
,
**
kwargs
)
if
IS_MLU_AVAILABLE
:
import
torch_mlu
# noqa: F401
torch
.
mlu
.
set_device
(
rank
)
dist
.
init_process_group
(
backend
=
'cncl'
,
rank
=
rank
,
world_size
=
int
(
os
.
environ
[
'WORLD_SIZE'
]),
**
kwargs
)
else
:
num_gpus
=
torch
.
cuda
.
device_count
()
torch
.
cuda
.
set_device
(
rank
%
num_gpus
)
dist
.
init_process_group
(
backend
=
backend
,
**
kwargs
)
def
_init_dist_mpi
(
backend
,
**
kwargs
):
# TODO: use local_rank instead of rank % num_gpus
rank
=
int
(
os
.
environ
[
'OMPI_COMM_WORLD_RANK'
])
num_gpus
=
torch
.
cuda
.
device_count
()
torch
.
cuda
.
set_device
(
rank
%
num_gpus
)
def
_init_dist_mpi
(
backend
:
str
,
**
kwargs
)
->
None
:
local_rank
=
int
(
os
.
environ
[
'OMPI_COMM_WORLD_LOCAL_RANK'
])
torch
.
cuda
.
set_device
(
local_rank
)
if
'MASTER_PORT'
not
in
os
.
environ
:
# 29500 is torch.distributed default port
os
.
environ
[
'MASTER_PORT'
]
=
'29500'
if
'MASTER_ADDR'
not
in
os
.
environ
:
raise
KeyError
(
'The environment variable MASTER_ADDR is not set'
)
os
.
environ
[
'WORLD_SIZE'
]
=
os
.
environ
[
'OMPI_COMM_WORLD_SIZE'
]
os
.
environ
[
'RANK'
]
=
os
.
environ
[
'OMPI_COMM_WORLD_RANK'
]
dist
.
init_process_group
(
backend
=
backend
,
**
kwargs
)
def
_init_dist_slurm
(
backend
,
port
=
None
)
:
def
_init_dist_slurm
(
backend
:
str
,
port
:
Optional
[
int
]
=
None
)
->
None
:
"""Initialize slurm distributed training environment.
If argument ``port`` is not specified, then the master port will be system
...
...
@@ -64,8 +101,12 @@ def _init_dist_slurm(backend, port=None):
elif
'MASTER_PORT'
in
os
.
environ
:
pass
# use MASTER_PORT in the environment variable
else
:
# 29500 is torch.distributed default port
os
.
environ
[
'MASTER_PORT'
]
=
'29500'
# if torch.distributed default port(29500) is available
# then use it, else find a free port
if
_is_free_port
(
29500
):
os
.
environ
[
'MASTER_PORT'
]
=
'29500'
else
:
os
.
environ
[
'MASTER_PORT'
]
=
str
(
_find_free_port
())
# use MASTER_ADDR in the environment variable if it already exists
if
'MASTER_ADDR'
not
in
os
.
environ
:
os
.
environ
[
'MASTER_ADDR'
]
=
addr
...
...
@@ -75,7 +116,7 @@ def _init_dist_slurm(backend, port=None):
dist
.
init_process_group
(
backend
=
backend
)
def
get_dist_info
():
def
get_dist_info
()
->
Tuple
[
int
,
int
]
:
if
dist
.
is_available
()
and
dist
.
is_initialized
():
rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
...
...
@@ -85,7 +126,7 @@ def get_dist_info():
return
rank
,
world_size
def
master_only
(
func
)
:
def
master_only
(
func
:
Callable
)
->
Callable
:
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
...
...
@@ -96,12 +137,14 @@ def master_only(func):
return
wrapper
def
allreduce_params
(
params
,
coalesce
=
True
,
bucket_size_mb
=-
1
):
def
allreduce_params
(
params
:
List
[
torch
.
nn
.
Parameter
],
coalesce
:
bool
=
True
,
bucket_size_mb
:
int
=
-
1
)
->
None
:
"""Allreduce parameters.
Args:
params (list[torch.Parameter
s
]): List of parameters or buffers
of a
model.
params (list[torch.
nn.
Parameter]): List of parameters or buffers
of a
model.
coalesce (bool, optional): Whether allreduce parameters as a whole.
Defaults to True.
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
...
...
@@ -118,11 +161,13 @@ def allreduce_params(params, coalesce=True, bucket_size_mb=-1):
dist
.
all_reduce
(
tensor
.
div_
(
world_size
))
def
allreduce_grads
(
params
,
coalesce
=
True
,
bucket_size_mb
=-
1
):
def
allreduce_grads
(
params
:
List
[
torch
.
nn
.
Parameter
],
coalesce
:
bool
=
True
,
bucket_size_mb
:
int
=
-
1
)
->
None
:
"""Allreduce gradients.
Args:
params (list[torch.Parameter
s
]): List of parameters of a model
params (list[torch.
nn.
Parameter]): List of parameters of a model
.
coalesce (bool, optional): Whether allreduce parameters as a whole.
Defaults to True.
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
...
...
@@ -142,7 +187,9 @@ def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
dist
.
all_reduce
(
tensor
.
div_
(
world_size
))
def
_allreduce_coalesced
(
tensors
,
world_size
,
bucket_size_mb
=-
1
):
def
_allreduce_coalesced
(
tensors
:
torch
.
Tensor
,
world_size
:
int
,
bucket_size_mb
:
int
=
-
1
)
->
None
:
if
bucket_size_mb
>
0
:
bucket_size_bytes
=
bucket_size_mb
*
1024
*
1024
buckets
=
_take_tensors
(
tensors
,
bucket_size_bytes
)
...
...
mmcv/runner/epoch_based_runner.py
View file @
fdeee889
...
...
@@ -4,8 +4,10 @@ import platform
import
shutil
import
time
import
warnings
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
from
torch.utils.data
import
DataLoader
import
mmcv
from
.base_runner
import
BaseRunner
...
...
@@ -21,7 +23,7 @@ class EpochBasedRunner(BaseRunner):
This runner train models epoch by epoch.
"""
def
run_iter
(
self
,
data_batch
,
train_mode
,
**
kwargs
):
def
run_iter
(
self
,
data_batch
:
Any
,
train_mode
:
bool
,
**
kwargs
)
->
None
:
if
self
.
batch_processor
is
not
None
:
outputs
=
self
.
batch_processor
(
self
.
model
,
data_batch
,
train_mode
=
train_mode
,
**
kwargs
)
...
...
@@ -45,10 +47,12 @@ class EpochBasedRunner(BaseRunner):
self
.
call_hook
(
'before_train_epoch'
)
time
.
sleep
(
2
)
# Prevent possible deadlock during epoch transition
for
i
,
data_batch
in
enumerate
(
self
.
data_loader
):
self
.
data_batch
=
data_batch
self
.
_inner_iter
=
i
self
.
call_hook
(
'before_train_iter'
)
self
.
run_iter
(
data_batch
,
train_mode
=
True
,
**
kwargs
)
self
.
call_hook
(
'after_train_iter'
)
del
self
.
data_batch
self
.
_iter
+=
1
self
.
call_hook
(
'after_train_epoch'
)
...
...
@@ -62,14 +66,19 @@ class EpochBasedRunner(BaseRunner):
self
.
call_hook
(
'before_val_epoch'
)
time
.
sleep
(
2
)
# Prevent possible deadlock during epoch transition
for
i
,
data_batch
in
enumerate
(
self
.
data_loader
):
self
.
data_batch
=
data_batch
self
.
_inner_iter
=
i
self
.
call_hook
(
'before_val_iter'
)
self
.
run_iter
(
data_batch
,
train_mode
=
False
)
self
.
call_hook
(
'after_val_iter'
)
del
self
.
data_batch
self
.
call_hook
(
'after_val_epoch'
)
def
run
(
self
,
data_loaders
,
workflow
,
max_epochs
=
None
,
**
kwargs
):
def
run
(
self
,
data_loaders
:
List
[
DataLoader
],
workflow
:
List
[
Tuple
[
str
,
int
]],
max_epochs
:
Optional
[
int
]
=
None
,
**
kwargs
)
->
None
:
"""Start running.
Args:
...
...
@@ -130,11 +139,11 @@ class EpochBasedRunner(BaseRunner):
self
.
call_hook
(
'after_run'
)
def
save_checkpoint
(
self
,
out_dir
,
filename_tmpl
=
'epoch_{}.pth'
,
save_optimizer
=
True
,
meta
=
None
,
create_symlink
=
True
)
:
out_dir
:
str
,
filename_tmpl
:
str
=
'epoch_{}.pth'
,
save_optimizer
:
bool
=
True
,
meta
:
Optional
[
Dict
]
=
None
,
create_symlink
:
bool
=
True
)
->
None
:
"""Save the checkpoint.
Args:
...
...
@@ -183,5 +192,6 @@ class Runner(EpochBasedRunner):
def
__init__
(
self
,
*
args
,
**
kwargs
):
warnings
.
warn
(
'Runner was deprecated, please use EpochBasedRunner instead'
)
'Runner was deprecated, please use EpochBasedRunner instead'
,
DeprecationWarning
)
super
().
__init__
(
*
args
,
**
kwargs
)
mmcv/runner/fp16_utils.py
View file @
fdeee889
...
...
@@ -3,10 +3,12 @@ import functools
import
warnings
from
collections
import
abc
from
inspect
import
getfullargspec
from
typing
import
Callable
,
Iterable
,
List
,
Optional
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
torch.nn.parameter
import
Parameter
from
mmcv.utils
import
TORCH_VERSION
,
digit_version
from
.dist_utils
import
allreduce_grads
as
_allreduce_grads
...
...
@@ -21,9 +23,18 @@ except ImportError:
pass
def
cast_tensor_type
(
inputs
,
src_type
,
dst_
type
):
def
cast_tensor_type
(
inputs
,
src_type
:
torch
.
dtype
,
dst_type
:
torch
.
d
type
):
"""Recursively convert Tensor in inputs from src_type to dst_type.
Note:
In v1.4.4 and later, ``cast_tersor_type`` will only convert the
torch.Tensor which is consistent with ``src_type`` to the ``dst_type``.
Before v1.4.4, it ignores the ``src_type`` argument, leading to some
potential problems. For example,
``cast_tensor_type(inputs, torch.float, torch.half)`` will convert all
tensors in inputs to ``torch.half`` including those originally in
``torch.Int`` or other types, which is not expected.
Args:
inputs: Inputs that to be casted.
src_type (torch.dtype): Source type..
...
...
@@ -35,24 +46,30 @@ def cast_tensor_type(inputs, src_type, dst_type):
if
isinstance
(
inputs
,
nn
.
Module
):
return
inputs
elif
isinstance
(
inputs
,
torch
.
Tensor
):
return
inputs
.
to
(
dst_type
)
# we need to ensure that the type of inputs to be casted are the same
# as the argument `src_type`.
return
inputs
.
to
(
dst_type
)
if
inputs
.
dtype
==
src_type
else
inputs
elif
isinstance
(
inputs
,
str
):
return
inputs
elif
isinstance
(
inputs
,
np
.
ndarray
):
return
inputs
elif
isinstance
(
inputs
,
abc
.
Mapping
):
return
type
(
inputs
)({
return
type
(
inputs
)({
# type: ignore
k
:
cast_tensor_type
(
v
,
src_type
,
dst_type
)
for
k
,
v
in
inputs
.
items
()
})
elif
isinstance
(
inputs
,
abc
.
Iterable
):
return
type
(
inputs
)(
return
type
(
inputs
)(
# type: ignore
cast_tensor_type
(
item
,
src_type
,
dst_type
)
for
item
in
inputs
)
else
:
return
inputs
def
auto_fp16
(
apply_to
=
None
,
out_fp32
=
False
):
def
auto_fp16
(
apply_to
:
Optional
[
Iterable
]
=
None
,
out_fp32
:
bool
=
False
,
supported_types
:
tuple
=
(
nn
.
Module
,
),
)
->
Callable
:
"""Decorator to enable fp16 training automatically.
This decorator is useful when you write custom modules and want to support
...
...
@@ -65,7 +82,8 @@ def auto_fp16(apply_to=None, out_fp32=False):
apply_to (Iterable, optional): The argument names to be converted.
`None` indicates all arguments.
out_fp32 (bool): Whether to convert the output back to fp32.
supported_types (tuple): Classes can be decorated by ``auto_fp16``.
`New in version 1.5.0.`
Example:
>>> import torch.nn as nn
...
...
@@ -85,15 +103,15 @@ def auto_fp16(apply_to=None, out_fp32=False):
>>> pass
"""
def
auto_fp16_wrapper
(
old_func
)
:
def
auto_fp16_wrapper
(
old_func
:
Callable
)
->
Callable
:
@
functools
.
wraps
(
old_func
)
def
new_func
(
*
args
,
**
kwargs
):
def
new_func
(
*
args
,
**
kwargs
)
->
Callable
:
# check if the module has set the attribute `fp16_enabled`, if not,
# just fallback to the original method.
if
not
isinstance
(
args
[
0
],
torch
.
nn
.
Module
):
if
not
isinstance
(
args
[
0
],
supported_types
):
raise
TypeError
(
'@auto_fp16 can only be used to decorate the '
'method of
nn.Module
'
)
f
'method of
those classes
{
supported_types
}
'
)
if
not
(
hasattr
(
args
[
0
],
'fp16_enabled'
)
and
args
[
0
].
fp16_enabled
):
return
old_func
(
*
args
,
**
kwargs
)
...
...
@@ -138,7 +156,8 @@ def auto_fp16(apply_to=None, out_fp32=False):
return
auto_fp16_wrapper
def
force_fp32
(
apply_to
=
None
,
out_fp16
=
False
):
def
force_fp32
(
apply_to
:
Optional
[
Iterable
]
=
None
,
out_fp16
:
bool
=
False
)
->
Callable
:
"""Decorator to convert input arguments to fp32 in force.
This decorator is useful when you write custom modules and want to support
...
...
@@ -176,7 +195,7 @@ def force_fp32(apply_to=None, out_fp16=False):
def
force_fp32_wrapper
(
old_func
):
@
functools
.
wraps
(
old_func
)
def
new_func
(
*
args
,
**
kwargs
):
def
new_func
(
*
args
,
**
kwargs
)
->
Callable
:
# check if the module has set the attribute `fp16_enabled`, if not,
# just fallback to the original method.
if
not
isinstance
(
args
[
0
],
torch
.
nn
.
Module
):
...
...
@@ -224,14 +243,17 @@ def force_fp32(apply_to=None, out_fp16=False):
return
force_fp32_wrapper
def
allreduce_grads
(
params
,
coalesce
=
True
,
bucket_size_mb
=-
1
):
warnings
.
warning
(
def
allreduce_grads
(
params
:
List
[
Parameter
],
coalesce
:
bool
=
True
,
bucket_size_mb
:
int
=
-
1
)
->
None
:
warnings
.
warn
(
'"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be '
'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads'
)
'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads'
,
DeprecationWarning
)
_allreduce_grads
(
params
,
coalesce
=
coalesce
,
bucket_size_mb
=
bucket_size_mb
)
def
wrap_fp16_model
(
model
)
:
def
wrap_fp16_model
(
model
:
nn
.
Module
)
->
None
:
"""Wrap the FP32 model to FP16.
If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
...
...
@@ -260,7 +282,7 @@ def wrap_fp16_model(model):
m
.
fp16_enabled
=
True
def
patch_norm_fp32
(
module
)
:
def
patch_norm_fp32
(
module
:
nn
.
Module
)
->
nn
.
Module
:
"""Recursively convert normalization layers from FP16 to FP32.
Args:
...
...
@@ -280,7 +302,10 @@ def patch_norm_fp32(module):
return
module
def
patch_forward_method
(
func
,
src_type
,
dst_type
,
convert_output
=
True
):
def
patch_forward_method
(
func
:
Callable
,
src_type
:
torch
.
dtype
,
dst_type
:
torch
.
dtype
,
convert_output
:
bool
=
True
)
->
Callable
:
"""Patch the forward method of a module.
Args:
...
...
@@ -333,10 +358,10 @@ class LossScaler:
"""
def
__init__
(
self
,
init_scale
=
2
**
32
,
mode
=
'dynamic'
,
scale_factor
=
2.
,
scale_window
=
1000
):
init_scale
:
float
=
2
**
32
,
mode
:
str
=
'dynamic'
,
scale_factor
:
float
=
2.
,
scale_window
:
int
=
1000
):
self
.
cur_scale
=
init_scale
self
.
cur_iter
=
0
assert
mode
in
(
'dynamic'
,
...
...
@@ -346,7 +371,7 @@ class LossScaler:
self
.
scale_factor
=
scale_factor
self
.
scale_window
=
scale_window
def
has_overflow
(
self
,
params
)
:
def
has_overflow
(
self
,
params
:
List
[
Parameter
])
->
bool
:
"""Check if params contain overflow."""
if
self
.
mode
!=
'dynamic'
:
return
False
...
...
@@ -355,7 +380,7 @@ class LossScaler:
return
True
return
False
def
_has_inf_or_nan
(
x
)
:
def
_has_inf_or_nan
(
x
:
torch
.
Tensor
)
->
bool
:
"""Check if params contain NaN."""
try
:
cpu_sum
=
float
(
x
.
float
().
sum
())
...
...
@@ -369,7 +394,7 @@ class LossScaler:
return
True
return
False
def
update_scale
(
self
,
overflow
)
:
def
update_scale
(
self
,
overflow
:
bool
)
->
None
:
"""update the current loss scale value when overflow happens."""
if
self
.
mode
!=
'dynamic'
:
return
...
...
@@ -382,7 +407,7 @@ class LossScaler:
self
.
cur_scale
*=
self
.
scale_factor
self
.
cur_iter
+=
1
def
state_dict
(
self
):
def
state_dict
(
self
)
->
dict
:
"""Returns the state of the scaler as a :class:`dict`."""
return
dict
(
cur_scale
=
self
.
cur_scale
,
...
...
@@ -392,7 +417,7 @@ class LossScaler:
scale_factor
=
self
.
scale_factor
,
scale_window
=
self
.
scale_window
)
def
load_state_dict
(
self
,
state_dict
)
:
def
load_state_dict
(
self
,
state_dict
:
dict
)
->
None
:
"""Loads the loss_scaler state dict.
Args:
...
...
@@ -406,5 +431,5 @@ class LossScaler:
self
.
scale_window
=
state_dict
[
'scale_window'
]
@
property
def
loss_scale
(
self
):
def
loss_scale
(
self
)
->
float
:
return
self
.
cur_scale
Prev
1
…
19
20
21
22
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment