Commit 257a3b89 authored by Noe Casas's avatar Noe Casas Committed by Facebook Github Bot
Browse files

Add args and sys.argv to tensorboard (#673)

Summary:
Log fairseq's `args` and `sys.argv` in tensorboard to easily identify run hyperparameters from within tensorboard.

The idea was suggested in https://twitter.com/Thom_Wolf/status/1106300583835766786
Pull Request resolved: https://github.com/pytorch/fairseq/pull/673

Differential Revision: D15114159

Pulled By: myleott

fbshipit-source-id: d48133a7f629dffe984836712390c317916cf413
parent 8bf8399d
...@@ -41,7 +41,7 @@ def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm', ...@@ -41,7 +41,7 @@ def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm',
raise ValueError('Unknown log format: {}'.format(args.log_format)) raise ValueError('Unknown log format: {}'.format(args.log_format))
if args.tensorboard_logdir and distributed_utils.is_master(args): if args.tensorboard_logdir and distributed_utils.is_master(args):
bar = tensorboard_log_wrapper(bar, args.tensorboard_logdir) bar = tensorboard_log_wrapper(bar, args.tensorboard_logdir, args)
return bar return bar
...@@ -214,9 +214,10 @@ class tqdm_progress_bar(progress_bar): ...@@ -214,9 +214,10 @@ class tqdm_progress_bar(progress_bar):
class tensorboard_log_wrapper(progress_bar): class tensorboard_log_wrapper(progress_bar):
"""Log to tensorboard.""" """Log to tensorboard."""
def __init__(self, wrapped_bar, tensorboard_logdir): def __init__(self, wrapped_bar, tensorboard_logdir, args):
self.wrapped_bar = wrapped_bar self.wrapped_bar = wrapped_bar
self.tensorboard_logdir = tensorboard_logdir self.tensorboard_logdir = tensorboard_logdir
self.args = args
try: try:
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
...@@ -234,6 +235,8 @@ class tensorboard_log_wrapper(progress_bar): ...@@ -234,6 +235,8 @@ class tensorboard_log_wrapper(progress_bar):
self._writers[key] = self.SummaryWriter( self._writers[key] = self.SummaryWriter(
log_dir=os.path.join(self.tensorboard_logdir, key), log_dir=os.path.join(self.tensorboard_logdir, key),
) )
self._writers[key].add_text('args', str(vars(self.args)))
self._writers[key].add_text('sys.argv', " ".join(sys.argv))
return self._writers[key] return self._writers[key]
def __iter__(self): def __iter__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment