Commit 1b42c8c4 authored by Myle Ott's avatar Myle Ott
Browse files

Fallback to `--log-format=simple` for non-TTY terminals

parent e5b3c1f4
...@@ -18,7 +18,7 @@ def get_parser(desc): ...@@ -18,7 +18,7 @@ def get_parser(desc):
parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar') parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar')
parser.add_argument('--log-interval', type=int, default=1000, metavar='N', parser.add_argument('--log-interval', type=int, default=1000, metavar='N',
help='log progress every N updates (when progress bar is disabled)') help='log progress every N updates (when progress bar is disabled)')
parser.add_argument('--log-format', default='tqdm', help='log format to use', parser.add_argument('--log-format', default=None, help='log format to use',
choices=['json', 'none', 'simple', 'tqdm']) choices=['json', 'none', 'simple', 'tqdm'])
parser.add_argument('--seed', default=1, type=int, metavar='N', parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed') help='pseudo random number generator seed')
......
...@@ -10,6 +10,7 @@ import logging ...@@ -10,6 +10,7 @@ import logging
import os import os
import torch import torch
import traceback import traceback
import sys
from torch.autograd import Variable from torch.autograd import Variable
from torch.serialization import default_restore_location from torch.serialization import default_restore_location
...@@ -37,14 +38,19 @@ def build_criterion(args, src_dict, dst_dict): ...@@ -37,14 +38,19 @@ def build_criterion(args, src_dict, dst_dict):
def build_progress_bar(args, iterator, epoch=None, prefix=None): def build_progress_bar(args, iterator, epoch=None, prefix=None):
if args.log_format is None:
args.log_format = 'tqdm' if sys.stderr.isatty() else 'simple'
if args.log_format == 'json': if args.log_format == 'json':
bar = progress_bar.json_progress_bar(iterator, epoch, prefix, args.log_interval) bar = progress_bar.json_progress_bar(iterator, epoch, prefix, args.log_interval)
elif args.log_format == 'none': elif args.log_format == 'none':
bar = progress_bar.noop_progress_bar(iterator, epoch, prefix) bar = progress_bar.noop_progress_bar(iterator, epoch, prefix)
elif args.log_format == 'simple':
bar = progress_bar.simple_progress_bar(iterator, epoch, prefix, args.log_interval)
elif args.log_format == 'tqdm': elif args.log_format == 'tqdm':
bar = progress_bar.tqdm_progress_bar(iterator, epoch, prefix) bar = progress_bar.tqdm_progress_bar(iterator, epoch, prefix)
else: else:
bar = progress_bar.simple_progress_bar(iterator, epoch, prefix, args.log_interval) raise ValueError(f'Unknown log format: {args.log_format}')
return bar return bar
......
...@@ -26,7 +26,7 @@ def main(): ...@@ -26,7 +26,7 @@ def main():
options.add_generation_args(parser) options.add_generation_args(parser)
args = parser.parse_args() args = parser.parse_args()
if args.no_progress_bar: if args.no_progress_bar and args.log_format is None:
args.log_format = 'none' args.log_format = 'none'
print(args) print(args)
......
...@@ -36,7 +36,7 @@ def main(): ...@@ -36,7 +36,7 @@ def main():
args = utils.parse_args_and_arch(parser) args = utils.parse_args_and_arch(parser)
if args.no_progress_bar and args.log_format == 'tqdm': if args.no_progress_bar and args.log_format is None:
args.log_format = 'simple' args.log_format = 'simple'
if not os.path.exists(args.save_dir): if not os.path.exists(args.save_dir):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment