Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
966b7428
"docs/vscode:/vscode.git/clone" did not exist on "5de2b130d37301432ecbe0a51c31ef979b3d7a26"
Unverified
Commit
966b7428
authored
May 28, 2022
by
tripleMu
Committed by
GitHub
May 28, 2022
Browse files
Add type hints for mmcv/runner/optimizer (#2001)
parent
1577f407
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
9 deletions
+18
-9
mmcv/runner/optimizer/builder.py
mmcv/runner/optimizer/builder.py
+4
-3
mmcv/runner/optimizer/default_constructor.py
mmcv/runner/optimizer/default_constructor.py
+14
-6
No files found.
mmcv/runner/optimizer/builder.py
View file @
966b7428
# Copyright (c) OpenMMLab. All rights reserved.
import
copy
import
inspect
from
typing
import
Dict
,
List
import
torch
...
...
@@ -10,7 +11,7 @@ OPTIMIZERS = Registry('optimizer')
OPTIMIZER_BUILDERS
=
Registry
(
'optimizer builder'
)
def
register_torch_optimizers
():
def
register_torch_optimizers
()
->
List
:
torch_optimizers
=
[]
for
module_name
in
dir
(
torch
.
optim
):
if
module_name
.
startswith
(
'__'
):
...
...
@@ -26,11 +27,11 @@ def register_torch_optimizers():
TORCH_OPTIMIZERS
=
register_torch_optimizers
()
def
build_optimizer_constructor
(
cfg
):
def
build_optimizer_constructor
(
cfg
:
Dict
):
return
build_from_cfg
(
cfg
,
OPTIMIZER_BUILDERS
)
def
build_optimizer
(
model
,
cfg
):
def
build_optimizer
(
model
,
cfg
:
Dict
):
optimizer_cfg
=
copy
.
deepcopy
(
cfg
)
constructor_type
=
optimizer_cfg
.
pop
(
'constructor'
,
'DefaultOptimizerConstructor'
)
...
...
mmcv/runner/optimizer/default_constructor.py
View file @
966b7428
# Copyright (c) OpenMMLab. All rights reserved.
import
warnings
from
typing
import
Dict
,
List
,
Optional
,
Union
import
torch
import
torch.nn
as
nn
from
torch.nn
import
GroupNorm
,
LayerNorm
from
mmcv.utils
import
_BatchNorm
,
_InstanceNorm
,
build_from_cfg
,
is_list_of
...
...
@@ -93,7 +95,9 @@ class DefaultOptimizerConstructor:
>>> # model.cls_head is (0.01, 0.95).
"""
def
__init__
(
self
,
optimizer_cfg
,
paramwise_cfg
=
None
):
def
__init__
(
self
,
optimizer_cfg
:
Dict
,
paramwise_cfg
:
Optional
[
Dict
]
=
None
):
if
not
isinstance
(
optimizer_cfg
,
dict
):
raise
TypeError
(
'optimizer_cfg should be a dict'
,
f
'but got
{
type
(
optimizer_cfg
)
}
'
)
...
...
@@ -103,7 +107,7 @@ class DefaultOptimizerConstructor:
self
.
base_wd
=
optimizer_cfg
.
get
(
'weight_decay'
,
None
)
self
.
_validate_cfg
()
def
_validate_cfg
(
self
):
def
_validate_cfg
(
self
)
->
None
:
if
not
isinstance
(
self
.
paramwise_cfg
,
dict
):
raise
TypeError
(
'paramwise_cfg should be None or a dict, '
f
'but got
{
type
(
self
.
paramwise_cfg
)
}
'
)
...
...
@@ -126,7 +130,7 @@ class DefaultOptimizerConstructor:
if
self
.
base_wd
is
None
:
raise
ValueError
(
'base_wd should not be None'
)
def
_is_in
(
self
,
param_group
,
param_group_list
)
:
def
_is_in
(
self
,
param_group
:
Dict
,
param_group_list
:
List
)
->
bool
:
assert
is_list_of
(
param_group_list
,
dict
)
param
=
set
(
param_group
[
'params'
])
param_set
=
set
()
...
...
@@ -135,7 +139,11 @@ class DefaultOptimizerConstructor:
return
not
param
.
isdisjoint
(
param_set
)
def
add_params
(
self
,
params
,
module
,
prefix
=
''
,
is_dcn_module
=
None
):
def
add_params
(
self
,
params
:
List
[
Dict
],
module
:
nn
.
Module
,
prefix
:
str
=
''
,
is_dcn_module
:
Optional
[
Union
[
int
,
float
]]
=
None
)
->
None
:
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
...
...
@@ -232,7 +240,7 @@ class DefaultOptimizerConstructor:
prefix
=
child_prefix
,
is_dcn_module
=
is_dcn_module
)
def
__call__
(
self
,
model
):
def
__call__
(
self
,
model
:
nn
.
Module
):
if
hasattr
(
model
,
'module'
):
model
=
model
.
module
...
...
@@ -243,7 +251,7 @@ class DefaultOptimizerConstructor:
return
build_from_cfg
(
optimizer_cfg
,
OPTIMIZERS
)
# set param-wise lr and weight decay recursively
params
=
[]
params
:
List
[
Dict
]
=
[]
self
.
add_params
(
params
,
model
)
optimizer_cfg
[
'params'
]
=
params
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment