Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
507e67e4
"tests/python/vscode:/vscode.git/clone" did not exist on "d3bd4c6112bc347ff2cbe7ec67088fdcea01125f"
Commit
507e67e4
authored
Aug 29, 2018
by
Kai Chen
Browse files
make optimizer an optional argument for Runner
parent
7d872508
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
10 deletions
+17
-10
mmcv/torchpack/runner/runner.py
mmcv/torchpack/runner/runner.py
+17
-10
No files found.
mmcv/torchpack/runner/runner.py
View file @
507e67e4
...
@@ -20,13 +20,16 @@ class Runner(object):
...
@@ -20,13 +20,16 @@ class Runner(object):
def
__init__
(
self
,
def
__init__
(
self
,
model
,
model
,
optimizer
,
batch_processor
,
batch_processor
,
optimizer
=
None
,
work_dir
=
None
,
work_dir
=
None
,
log_level
=
logging
.
INFO
):
log_level
=
logging
.
INFO
):
assert
callable
(
batch_processor
)
assert
callable
(
batch_processor
)
self
.
model
=
model
self
.
model
=
model
self
.
optimizer
=
self
.
init_optimizer
(
optimizer
)
if
optimizer
is
not
None
:
self
.
optimizer
=
self
.
init_optimizer
(
optimizer
)
else
:
self
.
optimizer
=
None
self
.
batch_processor
=
batch_processor
self
.
batch_processor
=
batch_processor
# create work_dir
# create work_dir
...
@@ -152,6 +155,9 @@ class Runner(object):
...
@@ -152,6 +155,9 @@ class Runner(object):
Returns:
Returns:
list: Current learning rate of all param groups.
list: Current learning rate of all param groups.
"""
"""
if
self
.
optimizer
is
None
:
raise
RuntimeError
(
'lr is not applicable because optimizer does not exist.'
)
return
[
group
[
'lr'
]
for
group
in
self
.
optimizer
.
param_groups
]
return
[
group
[
'lr'
]
for
group
in
self
.
optimizer
.
param_groups
]
def
register_hook
(
self
,
hook
,
priority
=
50
):
def
register_hook
(
self
,
hook
,
priority
=
50
):
...
@@ -234,8 +240,9 @@ class Runner(object):
...
@@ -234,8 +240,9 @@ class Runner(object):
for
i
,
data_batch
in
enumerate
(
data_loader
):
for
i
,
data_batch
in
enumerate
(
data_loader
):
self
.
_inner_iter
=
i
self
.
_inner_iter
=
i
self
.
call_hook
(
'before_val_iter'
)
self
.
call_hook
(
'before_val_iter'
)
outputs
=
self
.
batch_processor
(
with
torch
.
no_grad
():
self
.
model
,
data_batch
,
train_mode
=
False
,
**
kwargs
)
outputs
=
self
.
batch_processor
(
self
.
model
,
data_batch
,
train_mode
=
False
,
**
kwargs
)
if
not
isinstance
(
outputs
,
dict
):
if
not
isinstance
(
outputs
,
dict
):
raise
TypeError
(
'batch_processor() must return a dict'
)
raise
TypeError
(
'batch_processor() must return a dict'
)
if
'log_vars'
in
outputs
:
if
'log_vars'
in
outputs
:
...
@@ -321,12 +328,12 @@ class Runner(object):
...
@@ -321,12 +328,12 @@ class Runner(object):
info
,
hooks
,
default_args
=
dict
(
interval
=
log_interval
))
info
,
hooks
,
default_args
=
dict
(
interval
=
log_interval
))
self
.
register_hook
(
logger_hook
,
priority
=
60
)
self
.
register_hook
(
logger_hook
,
priority
=
60
)
def
register_
default
_hooks
(
self
,
def
register_
training
_hooks
(
self
,
lr_config
,
lr_config
,
grad_clip_config
=
None
,
grad_clip_config
=
None
,
checkpoint_config
=
None
,
checkpoint_config
=
None
,
log_config
=
None
):
log_config
=
None
):
"""Register
several
default hooks.
"""Register default hooks
for training
.
Default hooks include:
Default hooks include:
- LrUpdaterHook
- LrUpdaterHook
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment