Commit e98bf7e6 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Move fb_pathmgr registration out of train.py

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/903

Reviewed By: sujitoc

Differential Revision: D18327653

fbshipit-source-id: 739ddbaf54862acdf7b4f1bc3ad538bde5ae00fd
parent e9171ce1
......@@ -173,8 +173,6 @@ def get_parser(desc, default_task='translation'):
parser.add_argument('--tensorboard-logdir', metavar='DIR', default='',
help='path to save logs for tensorboard, should match --logdir '
'of running tensorboard (default: no tensorboard logging)')
parser.add_argument("--tbmf-wrapper", action="store_true",
help="[FB only] ")
parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed')
parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA')
......
......@@ -16,8 +16,6 @@ import sys
from fairseq import distributed_utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
g_tbmf_wrapper = None
def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm', no_progress_bar='none'):
if args.log_format is None:
......@@ -37,16 +35,13 @@ def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm',
else:
raise ValueError('Unknown log format: {}'.format(args.log_format))
if args.tbmf_wrapper and distributed_utils.is_master(args):
global g_tbmf_wrapper
if g_tbmf_wrapper is None:
if args.tensorboard_logdir and distributed_utils.is_master(args):
try:
# [FB only] custom wrapper for TensorBoard
import palaas # noqa
from fairseq.fb_tbmf_wrapper import fb_tbmf_wrapper
except Exception:
raise ImportError("fb_tbmf_wrapper package not found.")
g_tbmf_wrapper = fb_tbmf_wrapper
bar = g_tbmf_wrapper(bar, args, args.log_interval)
elif args.tensorboard_logdir and distributed_utils.is_master(args):
bar = fb_tbmf_wrapper(bar, args, args.log_interval)
except ImportError:
bar = tensorboard_log_wrapper(bar, args.tensorboard_logdir, args)
return bar
......
......@@ -173,7 +173,7 @@ class Trainer(object):
try:
from fairseq.fb_pathmgr import fb_pathmgr
bexists = fb_pathmgr.isfile(filename)
except Exception:
except (ModuleNotFoundError, ImportError):
bexists = os.path.exists(filename)
if bexists:
......
......@@ -19,21 +19,10 @@ from fairseq.data import iterators
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
fb_pathmgr_registerd = False
def main(args, init_distributed=False):
utils.import_user_module(args)
try:
from fairseq.fb_pathmgr import fb_pathmgr
global fb_pathmgr_registerd
if not fb_pathmgr_registerd:
fb_pathmgr.register()
fb_pathmgr_registerd = True
except (ModuleNotFoundError, ImportError):
pass
assert args.max_tokens is not None or args.max_sentences is not None, \
'Must specify batch size either with --max-tokens or --max-sentences'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment