Unverified Commit fe4e185a authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Merge internal changes

Changelog:
- `f472d141`: Support tied embeddings in LSTM encoder/decoder
- `89e19d42`: Don't print alignment by default (use `--print-alignment` to re-enable it)
- `d2e2a1d4`: Add Transformer-based language model
- `c2794070`: Add new Transformer configuration for IWSLT
- `2fbfda0d`: Misc changes for pytorch-translate
- Miscellaneous bug fixes
parents 7358296b 2fbfda0d
......@@ -32,7 +32,7 @@ def main(args):
task = tasks.setup_task(args)
# Load dataset splits
load_dataset_splits(args, task, ['train', 'valid'])
load_dataset_splits(task, ['train', 'valid'])
# Build model and criterion
model = task.build_model(args)
......@@ -263,16 +263,16 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
checkpoint_conds = collections.OrderedDict()
checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
end_of_epoch and not args.no_epoch_checkpoints and
epoch % args.save_interval == 0
end_of_epoch and not args.no_epoch_checkpoints and
epoch % args.save_interval == 0
)
checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = (
not end_of_epoch and args.save_interval_updates > 0 and
updates % args.save_interval_updates == 0
not end_of_epoch and args.save_interval_updates > 0 and
updates % args.save_interval_updates == 0
)
checkpoint_conds['checkpoint_best.pt'] = (
val_loss is not None and
(not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
val_loss is not None and
(not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
)
checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink
......@@ -316,17 +316,19 @@ def load_checkpoint(args, trainer, epoch_itr):
save_checkpoint.best = extra_state['best']
def load_dataset_splits(args, task, splits):
def load_dataset_splits(task, splits):
for split in splits:
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
try:
task.load_dataset(split_k)
print('| {} {} {} examples'.format(args.data, split_k, len(task.dataset(split_k))))
except FileNotFoundError as e:
if k > 0:
break
raise e
if split == 'train':
task.load_dataset(split, combine=True)
else:
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
try:
task.load_dataset(split_k, combine=False)
except FileNotFoundError as e:
if k > 0:
break
raise e
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment