Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
fc5319b6
Commit
fc5319b6
authored
Sep 21, 2018
by
pangjm
Browse files
add build_hook & other minor modify
parent
ee54f9cf
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
22 additions
and
35 deletions
+22
-35
mmcv/image/transforms/geometry.py
mmcv/image/transforms/geometry.py
+1
-15
mmcv/torchpack/hooks/__init__.py
mmcv/torchpack/hooks/__init__.py
+2
-2
mmcv/torchpack/hooks/checkpoint_saver.py
mmcv/torchpack/hooks/checkpoint_saver.py
+1
-1
mmcv/torchpack/hooks/optimizer_stepper.py
mmcv/torchpack/hooks/optimizer_stepper.py
+1
-1
mmcv/torchpack/runner/runner.py
mmcv/torchpack/runner/runner.py
+17
-16
No files found.
mmcv/image/transforms/geometry.py
View file @
fc5319b6
...
...
@@ -3,7 +3,7 @@ from __future__ import division
import
cv2
import
numpy
as
np
__all__
=
[
'imflip'
,
'imrotate'
,
'imcrop'
,
'impad'
,
'impad_to_multiple'
,
'bbox_flip'
]
__all__
=
[
'imflip'
,
'imrotate'
,
'imcrop'
,
'impad'
,
'impad_to_multiple'
]
def
imflip
(
img
,
direction
=
'horizontal'
):
...
...
@@ -111,20 +111,6 @@ def bbox_scaling(bboxes, scale, clip_shape=None):
return
scaled_bboxes
def
bbox_flip
(
bboxes
,
img_shape
):
"""Flip bboxes horizontally
Args:
bboxes(ndarray): shape (..., 4*k)
img_shape(tuple): (height, width)
"""
assert
bboxes
.
shape
[
-
1
]
%
4
==
0
w
=
img_shape
[
1
]
flipped
=
bboxes
.
copy
()
flipped
[...,
0
::
4
]
=
w
-
bboxes
[...,
2
::
4
]
-
1
flipped
[...,
2
::
4
]
=
w
-
bboxes
[...,
0
::
4
]
-
1
return
flipped
def
imcrop
(
img
,
bboxes
,
scale_ratio
=
1.0
,
pad_fill
=
None
):
"""Crop image patches.
...
...
mmcv/torchpack/hooks/__init__.py
View file @
fc5319b6
from
.hook
import
Hook
from
.checkpoint_saver
import
Checkpoint
Saver
Hook
from
.checkpoint_saver
import
CheckpointHook
from
.closure
import
ClosureHook
from
.lr_updater
import
LrUpdaterHook
from
.optimizer_stepper
import
Optimizer
Stepper
Hook
from
.optimizer_stepper
import
OptimizerHook
from
.iter_timer
import
IterTimerHook
from
.logger
import
*
mmcv/torchpack/hooks/checkpoint_saver.py
View file @
fc5319b6
...
...
@@ -2,7 +2,7 @@ from .hook import Hook
from
..utils
import
master_only
class
Checkpoint
Saver
Hook
(
Hook
):
class
CheckpointHook
(
Hook
):
def
__init__
(
self
,
interval
=-
1
,
...
...
mmcv/torchpack/hooks/optimizer_stepper.py
View file @
fc5319b6
...
...
@@ -3,7 +3,7 @@ from torch.nn.utils import clip_grad
from
.hook
import
Hook
class
Optimizer
Stepper
Hook
(
Hook
):
class
OptimizerHook
(
Hook
):
def
__init__
(
self
,
grad_clip
=
False
,
max_norm
=
35
,
norm_type
=
2
):
self
.
grad_clip
=
grad_clip
...
...
mmcv/torchpack/runner/runner.py
View file @
fc5319b6
...
...
@@ -8,8 +8,8 @@ from torch.nn.parallel import DataParallel, DistributedDataParallel
from
.log_buffer
import
LogBuffer
from
..
import
hooks
from
..hooks
import
(
Hook
,
LrUpdaterHook
,
Checkpoint
Saver
Hook
,
IterTimerHook
,
Optimizer
Stepper
Hook
)
from
..hooks
import
(
Hook
,
LrUpdaterHook
,
CheckpointHook
,
IterTimerHook
,
OptimizerHook
)
from
..io
import
load_checkpoint
,
save_checkpoint
from
..utils
import
(
get_dist_info
,
get_host_info
,
get_time_str
,
add_file_handler
,
obj_from_dict
)
...
...
@@ -182,6 +182,16 @@ class Runner(object):
if
not
inserted
:
self
.
_hooks
.
insert
(
0
,
hook
)
def
build_hook
(
self
,
hook
,
args
):
assert
issubclass
(
hook
,
Hook
),
'"hook" must be a Hook object'
if
isinstance
(
args
,
dict
):
self
.
register_hook
(
hook
(
**
args
))
elif
isinstance
(
args
,
Hook
):
self
.
register_hook
(
args
)
else
:
raise
TypeError
(
'"args" must be either a Hook object'
' or dict, not {}'
.
format
(
type
(
args
)))
def
call_hook
(
self
,
fn_name
):
for
hook
in
self
.
_hooks
:
getattr
(
hook
,
fn_name
)(
self
)
...
...
@@ -329,7 +339,7 @@ class Runner(object):
def
register_training_hooks
(
self
,
lr_config
,
grad_clip
_config
=
None
,
optimizer
_config
=
None
,
checkpoint_config
=
None
,
log_config
=
None
):
"""Register default hooks for training.
...
...
@@ -341,22 +351,13 @@ class Runner(object):
- IterTimerHook
- LoggerHook
"""
if
grad_clip
_config
is
None
:
grad_clip
_config
=
{}
if
optimizer
_config
is
None
:
optimizer
_config
=
{}
if
checkpoint_config
is
None
:
checkpoint_config
=
{}
self
.
register_lr_hooks
(
lr_config
)
if
isinstance
(
grad_clip_config
,
Hook
):
self
.
register_hook
(
grad_clip_config
)
elif
isinstance
(
grad_clip_config
,
dict
):
self
.
register_hook
(
OptimizerStepperHook
(
**
grad_clip_config
))
else
:
raise
TypeError
(
"OptimizerStepperHook should be a Hook object or dict, not {}"
.
format
(
type
(
grad_clip_config
)))
self
.
register_hook
(
CheckpointSaverHook
(
**
checkpoint_config
))
self
.
build_hook
(
OptimizerHook
,
optimizer_config
)
self
.
build_hook
(
CheckpointHook
,
checkpoint_config
)
self
.
register_hook
(
IterTimerHook
())
if
log_config
is
not
None
:
self
.
register_logger_hooks
(
log_config
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment