Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
080489e9
"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "971c3e45b96bc5aa5868c45cd40e4f3c3d90d126"
Commit
080489e9
authored
Oct 09, 2018
by
Kai Chen
Browse files
set default log dir for TensorboardLoggerHook
parent
99f53d2a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
3 deletions
+10
-3
examples/config_cifar10.py
examples/config_cifar10.py
+2
-2
mmcv/runner/hooks/logger/tensorboard.py
mmcv/runner/hooks/logger/tensorboard.py
+8
-1
No files found.
examples/config_cifar10.py
View file @
080489e9
...
@@ -14,7 +14,7 @@ lr_config = dict(policy='step', step=2)
...
@@ -14,7 +14,7 @@ lr_config = dict(policy='step', step=2)
# runtime settings
# runtime settings
work_dir
=
'./demo'
work_dir
=
'./demo'
gpus
=
range
(
2
)
gpus
=
range
(
2
)
dist_params
=
dict
(
backend
=
'
gloo'
)
# gloo is much slower than
nccl
dist_params
=
dict
(
backend
=
'nccl
'
)
data_workers
=
2
# data workers per gpu
data_workers
=
2
# data workers per gpu
checkpoint_config
=
dict
(
interval
=
1
)
# save checkpoint at every epoch
checkpoint_config
=
dict
(
interval
=
1
)
# save checkpoint at every epoch
workflow
=
[(
'train'
,
1
),
(
'val'
,
1
)]
workflow
=
[(
'train'
,
1
),
(
'val'
,
1
)]
...
@@ -28,5 +28,5 @@ log_config = dict(
...
@@ -28,5 +28,5 @@ log_config = dict(
interval
=
50
,
# log at every 50 iterations
interval
=
50
,
# log at every 50 iterations
hooks
=
[
hooks
=
[
dict
(
type
=
'TextLoggerHook'
),
dict
(
type
=
'TextLoggerHook'
),
# dict(type='TensorboardLoggerHook'
, log_dir=work_dir + '/log'
),
# dict(type='TensorboardLoggerHook'),
])
])
mmcv/runner/hooks/logger/tensorboard.py
View file @
080489e9
import
os.path
as
osp
from
.base
import
LoggerHook
from
.base
import
LoggerHook
from
...utils
import
master_only
from
...utils
import
master_only
class
TensorboardLoggerHook
(
LoggerHook
):
class
TensorboardLoggerHook
(
LoggerHook
):
def
__init__
(
self
,
log_dir
,
interval
=
10
,
ignore_last
=
True
,
def
__init__
(
self
,
log_dir
=
None
,
interval
=
10
,
ignore_last
=
True
,
reset_flag
=
True
):
reset_flag
=
True
):
super
(
TensorboardLoggerHook
,
self
).
__init__
(
interval
,
ignore_last
,
super
(
TensorboardLoggerHook
,
self
).
__init__
(
interval
,
ignore_last
,
reset_flag
)
reset_flag
)
...
@@ -18,6 +23,8 @@ class TensorboardLoggerHook(LoggerHook):
...
@@ -18,6 +23,8 @@ class TensorboardLoggerHook(LoggerHook):
raise
ImportError
(
'Please install tensorflow and tensorboardX '
raise
ImportError
(
'Please install tensorflow and tensorboardX '
'to use TensorboardLoggerHook.'
)
'to use TensorboardLoggerHook.'
)
else
:
else
:
if
self
.
log_dir
is
None
:
self
.
log_dir
=
osp
.
join
(
runner
.
work_dir
,
'tf_logs'
)
self
.
writer
=
SummaryWriter
(
self
.
log_dir
)
self
.
writer
=
SummaryWriter
(
self
.
log_dir
)
@
master_only
@
master_only
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment