Unverified Commit fe4e185a authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Merge internal changes

Changelog:
- `f472d141`: Support tied embeddings in LSTM encoder/decoder
- `89e19d42`: Don't print alignment by default (use `--print-alignment` to re-enable it)
- `d2e2a1d4`: Add Transformer-based language model
- `c2794070`: Add new Transformer configuration for IWSLT
- `2fbfda0d`: Misc changes for pytorch-translate
- Miscellaneous bug fixes
parents 7358296b 2fbfda0d
...@@ -32,7 +32,7 @@ def main(args): ...@@ -32,7 +32,7 @@ def main(args):
task = tasks.setup_task(args) task = tasks.setup_task(args)
# Load dataset splits # Load dataset splits
load_dataset_splits(args, task, ['train', 'valid']) load_dataset_splits(task, ['train', 'valid'])
# Build model and criterion # Build model and criterion
model = task.build_model(args) model = task.build_model(args)
...@@ -263,16 +263,16 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): ...@@ -263,16 +263,16 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
checkpoint_conds = collections.OrderedDict() checkpoint_conds = collections.OrderedDict()
checkpoint_conds['checkpoint{}.pt'.format(epoch)] = ( checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
end_of_epoch and not args.no_epoch_checkpoints and end_of_epoch and not args.no_epoch_checkpoints and
epoch % args.save_interval == 0 epoch % args.save_interval == 0
) )
checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = ( checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = (
not end_of_epoch and args.save_interval_updates > 0 and not end_of_epoch and args.save_interval_updates > 0 and
updates % args.save_interval_updates == 0 updates % args.save_interval_updates == 0
) )
checkpoint_conds['checkpoint_best.pt'] = ( checkpoint_conds['checkpoint_best.pt'] = (
val_loss is not None and val_loss is not None and
(not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best) (not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
) )
checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink
...@@ -316,17 +316,19 @@ def load_checkpoint(args, trainer, epoch_itr): ...@@ -316,17 +316,19 @@ def load_checkpoint(args, trainer, epoch_itr):
save_checkpoint.best = extra_state['best'] save_checkpoint.best = extra_state['best']
def load_dataset_splits(args, task, splits): def load_dataset_splits(task, splits):
for split in splits: for split in splits:
for k in itertools.count(): if split == 'train':
split_k = split + (str(k) if k > 0 else '') task.load_dataset(split, combine=True)
try: else:
task.load_dataset(split_k) for k in itertools.count():
print('| {} {} {} examples'.format(args.data, split_k, len(task.dataset(split_k)))) split_k = split + (str(k) if k > 0 else '')
except FileNotFoundError as e: try:
if k > 0: task.load_dataset(split_k, combine=False)
break except FileNotFoundError as e:
raise e if k > 0:
break
raise e
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment