Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
ffdc1d45
Commit
ffdc1d45
authored
Aug 28, 2018
by
Kai Chen
Browse files
add initial version of torchpack
parent
02ceae83
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
421 additions
and
0 deletions
+421
-0
mmcv/torchpack/runner/runner.py
mmcv/torchpack/runner/runner.py
+344
-0
mmcv/torchpack/utils.py
mmcv/torchpack/utils.py
+77
-0
No files found.
mmcv/torchpack/runner/runner.py
0 → 100644
View file @
ffdc1d45
import
logging
import
os.path
as
osp
import
time
import
mmcv
import
torch
from
torch.nn.parallel
import
DataParallel
,
DistributedDataParallel
from
.log_buffer
import
LogBuffer
from
..
import
hooks
from
..hooks
import
(
Hook
,
LrUpdaterHook
,
CheckpointSaverHook
,
IterTimerHook
,
OptimizerStepperHook
)
from
..io
import
load_checkpoint
,
save_checkpoint
from
..utils
import
(
get_dist_info
,
get_host_info
,
get_time_str
,
add_file_handler
,
obj_from_dict
)
class
Runner
(
object
):
"""A training helper for PyTorch."""
def
__init__
(
self
,
model
,
optimizer
,
batch_processor
,
work_dir
=
None
,
log_level
=
logging
.
INFO
):
assert
callable
(
batch_processor
)
self
.
model
=
model
self
.
optimizer
=
self
.
init_optimizer
(
optimizer
)
self
.
batch_processor
=
batch_processor
# create work_dir
if
mmcv
.
is_str
(
work_dir
):
self
.
work_dir
=
osp
.
abspath
(
work_dir
)
mmcv
.
mkdir_or_exist
(
self
.
work_dir
)
elif
work_dir
is
None
:
self
.
work_dir
=
None
else
:
raise
TypeError
(
'"work_dir" must be a str or None'
)
# get model name from the model class
if
isinstance
(
self
.
model
,
(
DataParallel
,
DistributedDataParallel
)):
self
.
_model_name
=
self
.
model
.
module
.
__class__
.
__name__
else
:
self
.
_model_name
=
self
.
model
.
__class__
.
__name__
self
.
_rank
,
self
.
_world_size
=
get_dist_info
()
self
.
logger
=
self
.
init_logger
(
work_dir
,
log_level
)
self
.
log_buffer
=
LogBuffer
()
self
.
mode
=
None
self
.
_hooks
=
[]
self
.
_epoch
=
0
self
.
_iter
=
0
self
.
_inner_iter
=
0
self
.
_max_epochs
=
0
self
.
_max_iters
=
0
@
property
def
model_name
(
self
):
"""str: Name of the model, usually the module class name."""
return
self
.
_model_name
@
property
def
rank
(
self
):
"""int: Rank of current process. (distributed training)"""
return
self
.
_rank
@
property
def
world_size
(
self
):
"""int: Number of processes participating in the job.
(distributed training)"""
return
self
.
_world_size
@
property
def
hooks
(
self
):
"""list[:obj:`Hook`]: A list of registered hooks."""
return
self
.
_hooks
@
property
def
epoch
(
self
):
"""int: Current epoch."""
return
self
.
_epoch
@
property
def
iter
(
self
):
"""int: Current iteration."""
return
self
.
_iter
@
property
def
inner_iter
(
self
):
"""int: Iteration in an epoch."""
return
self
.
_inner_iter
@
property
def
max_epochs
(
self
):
"""int: Maximum training epochs."""
return
self
.
_max_epochs
@
property
def
max_iters
(
self
):
"""int: Maximum training iterations."""
return
self
.
_max_iters
def
init_optimizer
(
self
,
optimizer
):
"""Init the optimizer.
Args:
optimizer (dict or :obj:`~torch.optim.Optimizer`): Either an
optimizer object or a dict used for constructing the optimizer.
An example of the dict: ``{'algorithm': 'SGD', 'lr': 0.02,
'momentum': 0.9, 'weight_decay': 0.0001}``.
Returns:
:obj:`~torch.optim.Optimizer`: An optimizer object.
"""
if
isinstance
(
optimizer
,
dict
):
optimizer
=
obj_from_dict
(
optimizer
,
torch
.
optim
,
dict
(
params
=
self
.
model
.
parameters
()))
elif
not
isinstance
(
optimizer
,
torch
.
optim
.
Optimizer
):
raise
TypeError
(
'optimizer must be either an Optimizer object or a dict, '
'but got {}'
.
format
(
type
(
optimizer
)))
return
optimizer
def
init_logger
(
self
,
log_dir
=
None
,
level
=
logging
.
INFO
):
"""Init the logger.
Args:
log_dir(str, optional): Log file directory. If not specified, no
log file will be used.
level (int or str): See the built-in python logging module.
Returns:
:obj:`~logging.Logger`: Python logger.
"""
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(message)s'
,
level
=
level
)
logger
=
logging
.
getLogger
(
__name__
)
if
log_dir
:
filename
=
'{}_{}.log'
.
format
(
get_time_str
(),
self
.
rank
)
log_file
=
osp
.
join
(
log_dir
,
filename
)
add_file_handler
(
logger
,
log_file
,
level
=
level
)
return
logger
def
current_lr
(
self
):
"""Get current learning rates.
Returns:
list: Current learning rate of all param groups.
"""
return
[
group
[
'lr'
]
for
group
in
self
.
optimizer
.
param_groups
]
def
register_hook
(
self
,
hook
,
priority
=
50
):
"""Register a hook into the hook list.
Args:
hook (:obj:`Hook`): The hook to be registered.
priority (int): Hook priority. Lower value means higher priority.
"""
assert
isinstance
(
hook
,
Hook
)
assert
isinstance
(
priority
,
int
)
and
priority
>=
0
and
priority
<=
100
if
hasattr
(
hook
,
'priority'
):
raise
ValueError
(
'"priority" is a reserved attribute for hooks'
)
hook
.
priority
=
priority
# insert the hook to a sorted list
inserted
=
False
for
i
in
range
(
len
(
self
.
_hooks
)
-
1
,
-
1
,
-
1
):
if
priority
>=
self
.
_hooks
[
i
].
priority
:
self
.
_hooks
.
insert
(
i
+
1
,
hook
)
inserted
=
True
break
if
not
inserted
:
self
.
_hooks
.
insert
(
0
,
hook
)
def
call_hook
(
self
,
fn_name
):
for
hook
in
self
.
_hooks
:
getattr
(
hook
,
fn_name
)(
self
)
def
load_checkpoint
(
self
,
filename
,
map_location
=
'cpu'
,
strict
=
False
):
self
.
logger
.
info
(
'load checkpoint from %s'
,
filename
)
return
load_checkpoint
(
self
.
model
,
filename
,
map_location
,
strict
,
self
.
logger
)
def
save_checkpoint
(
self
,
out_dir
,
filename_tmpl
=
'epoch_{}.pth'
,
save_optimizer
=
True
,
meta
=
None
):
if
meta
is
None
:
meta
=
dict
(
epoch
=
self
.
epoch
+
1
,
iter
=
self
.
iter
)
else
:
meta
.
update
(
epoch
=
self
.
epoch
+
1
,
iter
=
self
.
iter
)
filename
=
osp
.
join
(
out_dir
,
filename_tmpl
.
format
(
self
.
epoch
))
linkname
=
osp
.
join
(
out_dir
,
'latest.pth'
)
optimizer
=
self
.
optimizer
if
save_optimizer
else
None
save_checkpoint
(
self
.
model
,
filename
,
optimizer
=
optimizer
,
meta
=
meta
)
mmcv
.
symlink
(
filename
,
linkname
)
def
train
(
self
,
data_loader
,
**
kwargs
):
self
.
model
.
train
()
self
.
mode
=
'train'
self
.
data_loader
=
data_loader
self
.
_max_iters
=
self
.
_max_epochs
*
len
(
data_loader
)
self
.
call_hook
(
'before_train_epoch'
)
for
i
,
data_batch
in
enumerate
(
data_loader
):
self
.
_inner_iter
=
i
self
.
call_hook
(
'before_train_iter'
)
outputs
=
self
.
batch_processor
(
self
.
model
,
data_batch
,
train_mode
=
True
,
**
kwargs
)
if
not
isinstance
(
outputs
,
dict
):
raise
TypeError
(
'batch_processor() must return a dict'
)
if
'log_vars'
in
outputs
:
self
.
log_buffer
.
update
(
outputs
[
'log_vars'
],
outputs
[
'num_samples'
])
self
.
outputs
=
outputs
self
.
call_hook
(
'after_train_iter'
)
self
.
_iter
+=
1
self
.
call_hook
(
'after_train_epoch'
)
self
.
_epoch
+=
1
def
val
(
self
,
data_loader
,
**
kwargs
):
self
.
model
.
eval
()
self
.
mode
=
'val'
self
.
data_loader
=
data_loader
self
.
call_hook
(
'before_val_epoch'
)
for
i
,
data_batch
in
enumerate
(
data_loader
):
self
.
_inner_iter
=
i
self
.
call_hook
(
'before_val_iter'
)
outputs
=
self
.
batch_processor
(
self
.
model
,
data_batch
,
train_mode
=
False
,
**
kwargs
)
if
not
isinstance
(
outputs
,
dict
):
raise
TypeError
(
'batch_processor() must return a dict'
)
if
'log_vars'
in
outputs
:
self
.
log_buffer
.
update
(
outputs
[
'log_vars'
],
outputs
[
'num_samples'
])
self
.
outputs
=
outputs
self
.
call_hook
(
'after_val_iter'
)
self
.
call_hook
(
'after_val_epoch'
)
def
resume
(
self
,
checkpoint
,
resume_optimizer
=
True
,
map_location
=
'default'
):
if
map_location
==
'default'
:
device_id
=
torch
.
cuda
.
current_device
()
checkpoint
=
self
.
load_checkpoint
(
checkpoint
,
map_location
=
lambda
storage
,
loc
:
storage
.
cuda
(
device_id
))
else
:
checkpoint
=
self
.
load_checkpoint
(
checkpoint
,
map_location
=
map_location
)
self
.
_epoch
=
checkpoint
[
'meta'
][
'epoch'
]
self
.
_iter
=
checkpoint
[
'meta'
][
'iter'
]
if
'optimizer'
in
checkpoint
and
resume_optimizer
:
self
.
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
])
self
.
logger
.
info
(
'resumed epoch %d, iter %d'
,
self
.
epoch
,
self
.
iter
)
def
run
(
self
,
data_loaders
,
workflow
,
max_epochs
,
**
kwargs
):
assert
isinstance
(
data_loaders
,
list
)
assert
mmcv
.
is_list_of
(
workflow
,
tuple
)
assert
len
(
data_loaders
)
==
len
(
workflow
)
self
.
_max_epochs
=
max_epochs
work_dir
=
self
.
work_dir
if
self
.
work_dir
is
not
None
else
'NONE'
self
.
logger
.
info
(
'Start running, host: %s, work_dir: %s'
,
get_host_info
(),
work_dir
)
self
.
logger
.
info
(
'workflow: %s, max: %d epochs'
,
workflow
,
max_epochs
)
self
.
call_hook
(
'before_run'
)
while
self
.
epoch
<
max_epochs
:
for
i
,
flow
in
enumerate
(
workflow
):
mode
,
epochs
=
flow
if
isinstance
(
mode
,
str
):
# self.train()
if
not
hasattr
(
self
,
mode
):
raise
ValueError
(
'runner has no method named "{}" to run an epoch'
.
format
(
mode
))
epoch_runner
=
getattr
(
self
,
mode
)
elif
callable
(
mode
):
# custom train()
epoch_runner
=
mode
else
:
raise
TypeError
(
'mode in workflow must be a str or '
'callable function, not {}'
.
format
(
type
(
mode
)))
for
_
in
range
(
epochs
):
if
mode
==
'train'
and
self
.
epoch
>=
max_epochs
:
return
epoch_runner
(
data_loaders
[
i
],
**
kwargs
)
time
.
sleep
(
1
)
# wait for some hooks like loggers to finish
self
.
call_hook
(
'after_run'
)
def
register_lr_hooks
(
self
,
lr_config
):
if
isinstance
(
lr_config
,
LrUpdaterHook
):
self
.
register_hook
(
lr_config
)
elif
isinstance
(
lr_config
,
dict
):
assert
'policy'
in
lr_config
from
..hooks
import
lr_updater
hook_name
=
lr_config
[
'policy'
].
title
()
+
'LrUpdaterHook'
if
not
hasattr
(
lr_updater
,
hook_name
):
raise
ValueError
(
'"{}" does not exist'
.
format
(
hook_name
))
hook_cls
=
getattr
(
lr_updater
,
hook_name
)
self
.
register_hook
(
hook_cls
(
**
lr_config
))
else
:
raise
TypeError
(
'"lr_config" must be either a LrUpdaterHook object'
' or dict, not {}'
.
format
(
type
(
lr_config
)))
def
register_logger_hooks
(
self
,
log_config
):
log_interval
=
log_config
[
'interval'
]
for
info
in
log_config
[
'hooks'
]:
logger_hook
=
obj_from_dict
(
info
,
hooks
,
default_args
=
dict
(
interval
=
log_interval
))
self
.
register_hook
(
logger_hook
,
priority
=
60
)
def
register_default_hooks
(
self
,
lr_config
,
grad_clip_config
=
None
,
checkpoint_config
=
None
,
log_config
=
None
):
"""Register several default hooks.
Default hooks include:
- LrUpdaterHook
- OptimizerStepperHook
- CheckpointSaverHook
- IterTimerHook
- LoggerHook
"""
if
grad_clip_config
is
None
:
grad_clip_config
=
{}
if
checkpoint_config
is
None
:
checkpoint_config
=
{}
self
.
register_lr_hooks
(
lr_config
)
self
.
register_hook
(
OptimizerStepperHook
(
**
grad_clip_config
))
self
.
register_hook
(
CheckpointSaverHook
(
**
checkpoint_config
))
self
.
register_hook
(
IterTimerHook
())
if
log_config
is
not
None
:
self
.
register_logger_hooks
(
log_config
)
mmcv/torchpack/utils.py
0 → 100644
View file @
ffdc1d45
import
functools
import
logging
import
time
from
getpass
import
getuser
from
socket
import
gethostname
import
mmcv
import
torch.distributed
as
dist
def
get_host_info
():
return
'{}@{}'
.
format
(
getuser
(),
gethostname
())
def
get_dist_info
():
if
dist
.
_initialized
:
rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
else
:
rank
=
0
world_size
=
1
return
rank
,
world_size
def
master_only
(
func
):
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
rank
,
_
=
get_dist_info
()
if
rank
==
0
:
return
func
(
*
args
,
**
kwargs
)
return
wrapper
def
get_time_str
():
return
time
.
strftime
(
'%Y%m%d_%H%M%S'
,
time
.
localtime
())
def
add_file_handler
(
logger
,
filename
=
None
,
mode
=
'w'
,
level
=
logging
.
INFO
):
file_handler
=
logging
.
FileHandler
(
filename
,
mode
)
file_handler
.
setFormatter
(
logging
.
Formatter
(
'%(asctime)s - %(levelname)s - %(message)s'
))
logger
.
addHandler
(
file_handler
)
return
logger
def
obj_from_dict
(
info
,
module
,
default_args
=
None
):
"""Initialize an object from dict.
The dict must contain the key "type", which indicates the object type, it
can be either a string or type, such as "list" or ``list``. Remaining
fields are treated as the arguments for constructing the object.
Args:
info (dict): Object types and arguments.
module (:class:`module`): Module which may containing expected object
classes.
default_args (dict, optional): Default arguments for initializing the
object.
Returns:
"""
assert
isinstance
(
info
,
dict
)
and
'type'
in
info
assert
isinstance
(
default_args
,
dict
)
or
default_args
is
None
args
=
info
.
copy
()
obj_type
=
args
.
pop
(
'type'
)
if
mmcv
.
is_str
(
obj_type
):
obj_type
=
getattr
(
module
,
obj_type
)
elif
not
isinstance
(
obj_type
,
type
):
raise
TypeError
(
'type must be a str or valid type, but got {}'
.
format
(
type
(
obj_type
)))
if
default_args
is
not
None
:
for
name
,
value
in
default_args
.
items
():
args
.
setdefault
(
name
,
value
)
return
obj_type
(
**
args
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment