Commit 978c125a authored by alexeib's avatar alexeib Committed by Myle Ott
Browse files

fix restoring from middle of epoch; fix defaulting transformer dropout params

parent 386847ee
...@@ -10,4 +10,3 @@ from .token_block_dataset import TokenBlockDataset ...@@ -10,4 +10,3 @@ from .token_block_dataset import TokenBlockDataset
from .language_dataset import LanguageDatasets from .language_dataset import LanguageDatasets
from .language_pair_dataset import LanguagePairDataset from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset from .monolingual_dataset import MonolingualDataset
from .offset_dataset import OffsetDataset
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from torch.utils.data import Dataset
class OffsetDataset(Dataset):
""" Wraps an existing dataset, but starts iterating from a particular offset """
def __init__(self, dataset, offset):
"""
Args:
dataset: Dataset to wrap
offset: An integer. offset from which to start iterating
"""
super().__init__()
assert len(dataset) >= offset
self.dataset = dataset
self.offset = offset
def __getitem__(self, i):
return self.dataset[i + self.offset]
def __len__(self):
return len(self.dataset) - self.offset
...@@ -31,11 +31,11 @@ class TransformerModel(FairseqModel): ...@@ -31,11 +31,11 @@ class TransformerModel(FairseqModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
parser.add_argument('--dropout', default=0.1, type=float, metavar='D', parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability') help='dropout probability')
parser.add_argument('--attention-dropout', default=0., type=float, metavar='D', parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights') help='dropout probability for attention weights')
parser.add_argument('--relu-dropout', default=0., type=float, metavar='D', parser.add_argument('--relu-dropout', type=float, metavar='D',
help='dropout probability after ReLU in FFN') help='dropout probability after ReLU in FFN')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N', parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension') help='encoder embedding dimension')
...@@ -399,6 +399,9 @@ def base_architecture(args): ...@@ -399,6 +399,9 @@ def base_architecture(args):
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim) args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim)
args.decoder_layers = getattr(args, 'decoder_layers', 6) args.decoder_layers = getattr(args, 'decoder_layers', 6)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8) args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
args.attention_dropout = getattr(args, 'attention_dropout', 0.)
args.attention_dropout = getattr(args, 'relu_dropout', 0.)
args.attention_dropout = getattr(args, 'dropout', 0.1)
@register_model_architecture('transformer', 'transformer_iwslt_de_en') @register_model_architecture('transformer', 'transformer_iwslt_de_en')
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import unittest import unittest
import itertools
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import train import train
...@@ -19,10 +21,8 @@ def mock_trainer(epoch, num_updates): ...@@ -19,10 +21,8 @@ def mock_trainer(epoch, num_updates):
def mock_loader(length): def mock_loader(length):
ds = MagicMock()
ds.__len__.return_value = length
loader = MagicMock() loader = MagicMock()
loader.__next__.return_value = ds loader.__next__.return_value = list(range(length))
return loader return loader
...@@ -42,16 +42,14 @@ class TestLoadCheckpoint(unittest.TestCase): ...@@ -42,16 +42,14 @@ class TestLoadCheckpoint(unittest.TestCase):
loader = mock_loader(150) loader = mock_loader(150)
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader) epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 2) self.assertEqual(epoch, 2)
self.assertEqual(len(ds), 50) self.assertEqual(next(ds), 50)
self.assertNotIsInstance(ds, MagicMock)
def test_load_full_checkpoint(self): def test_load_full_checkpoint(self):
trainer = mock_trainer(2, 150) trainer = mock_trainer(2, 150)
loader = mock_loader(150) loader = mock_loader(150)
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader) epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 2) self.assertEqual(epoch, 2)
self.assertEqual(len(ds), 150) self.assertEqual(next(iter(ds)), 0)
self.assertIsInstance(ds, MagicMock)
def test_load_no_checkpoint(self): def test_load_no_checkpoint(self):
trainer = mock_trainer(0, 0) trainer = mock_trainer(0, 0)
...@@ -60,8 +58,7 @@ class TestLoadCheckpoint(unittest.TestCase): ...@@ -60,8 +58,7 @@ class TestLoadCheckpoint(unittest.TestCase):
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader) epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 1) self.assertEqual(epoch, 1)
self.assertEqual(len(ds), 150) self.assertEqual(next(iter(ds)), 0)
self.assertIsInstance(ds, MagicMock)
def tearDown(self): def tearDown(self):
patch.stopall() patch.stopall()
......
...@@ -11,8 +11,10 @@ import os ...@@ -11,8 +11,10 @@ import os
import math import math
import torch import torch
from itertools import islice
from fairseq import criterions, models, options, progress_bar from fairseq import criterions, models, options, progress_bar
from fairseq.data import data_utils, data_loaders, OffsetDataset from fairseq.data import data_utils, data_loaders
from fairseq.fp16_trainer import FP16Trainer from fairseq.fp16_trainer import FP16Trainer
from fairseq.trainer import Trainer from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter from fairseq.meters import AverageMeter, StopwatchMeter
...@@ -323,7 +325,12 @@ def load_checkpoint(args, trainer, train_dataloader): ...@@ -323,7 +325,12 @@ def load_checkpoint(args, trainer, train_dataloader):
updates += len(ds) updates += len(ds)
if ds is not None and updates > trainer_updates: if ds is not None and updates > trainer_updates:
ds = OffsetDataset(ds, updates - trainer_updates) completed_batches = len(ds) - (updates - trainer_updates)
assert completed_batches >= 0
ds = iter(ds)
# consume completed batches
next(islice(ds, completed_batches, completed_batches), None)
else: else:
ds = next(train_dataloader) ds = next(train_dataloader)
epoch += 1 epoch += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment