Commit 693894b6 authored by Naman Goyal's avatar Naman Goyal Committed by Facebook Github Bot
Browse files

Merge small fixes from internal

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/374

Differential Revision: D13116074

Pulled By: myleott

fbshipit-source-id: 485724cc5a40e8360d21e4bf9c35821baa0ddc57
parent 0864a9c4
...@@ -55,7 +55,7 @@ def main(parsed_args): ...@@ -55,7 +55,7 @@ def main(parsed_args):
# Load ensemble # Load ensemble
print('| loading model(s) from {}'.format(parsed_args.path)) print('| loading model(s) from {}'.format(parsed_args.path))
models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task) models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides))
for arg in vars(parsed_args).keys(): for arg in vars(parsed_args).keys():
if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}: if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}:
......
...@@ -274,6 +274,8 @@ def add_common_eval_args(group): ...@@ -274,6 +274,8 @@ def add_common_eval_args(group):
group.add_argument('--cpu', action='store_true', help='generate on CPU') group.add_argument('--cpu', action='store_true', help='generate on CPU')
group.add_argument('--quiet', action='store_true', group.add_argument('--quiet', action='store_true',
help='only print final scores') help='only print final scores')
group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override model args at generation that were used during model training')
def add_eval_lm_args(parser): def add_eval_lm_args(parser):
...@@ -330,8 +332,6 @@ def add_generation_args(parser): ...@@ -330,8 +332,6 @@ def add_generation_args(parser):
help='strength of diversity penalty for Diverse Beam Search') help='strength of diversity penalty for Diverse Beam Search')
group.add_argument('--print-alignment', action='store_true', group.add_argument('--print-alignment', action='store_true',
help='if set, uses attention feedback to compute and print alignment to source tokens') help='if set, uses attention feedback to compute and print alignment to source tokens')
group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override model args at generation that were used during model training')
return group return group
......
...@@ -160,6 +160,10 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): ...@@ -160,6 +160,10 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
model.load_state_dict(state['model'], strict=True) model.load_state_dict(state['model'], strict=True)
ensemble.append(model) ensemble.append(model)
# some args (e.g., tokens_per_sample) might have been updated while building the model
if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides)
return ensemble, args return ensemble, args
......
...@@ -288,10 +288,11 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): ...@@ -288,10 +288,11 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
if val_loss is not None: if val_loss is not None:
save_checkpoint.best = min(val_loss, prev_best) save_checkpoint.best = min(val_loss, prev_best)
extra_state = { extra_state = {
'best': save_checkpoint.best,
'train_iterator': epoch_itr.state_dict(), 'train_iterator': epoch_itr.state_dict(),
'val_loss': val_loss, 'val_loss': val_loss,
} }
if hasattr(save_checkpoint, 'best'):
extra_state.update({'best': save_checkpoint.best})
checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond] checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
if len(checkpoints) > 0: if len(checkpoints) > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment