Commit 978c125a authored by alexeib's avatar alexeib Committed by Myle Ott
Browse files

fix restoring from middle of epoch; fix defaulting transformer dropout params

parent 386847ee
......@@ -10,4 +10,3 @@ from .token_block_dataset import TokenBlockDataset
from .language_dataset import LanguageDatasets
from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset
from .offset_dataset import OffsetDataset
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from torch.utils.data import Dataset
class OffsetDataset(Dataset):
""" Wraps an existing dataset, but starts iterating from a particular offset """
def __init__(self, dataset, offset):
"""
Args:
dataset: Dataset to wrap
offset: An integer. offset from which to start iterating
"""
super().__init__()
assert len(dataset) >= offset
self.dataset = dataset
self.offset = offset
def __getitem__(self, i):
return self.dataset[i + self.offset]
def __len__(self):
return len(self.dataset) - self.offset
......@@ -31,11 +31,11 @@ class TransformerModel(FairseqModel):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument('--dropout', default=0.1, type=float, metavar='D',
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', default=0., type=float, metavar='D',
parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--relu-dropout', default=0., type=float, metavar='D',
parser.add_argument('--relu-dropout', type=float, metavar='D',
help='dropout probability after ReLU in FFN')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
......@@ -399,6 +399,9 @@ def base_architecture(args):
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim)
args.decoder_layers = getattr(args, 'decoder_layers', 6)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
args.attention_dropout = getattr(args, 'attention_dropout', 0.)
args.attention_dropout = getattr(args, 'relu_dropout', 0.)
args.attention_dropout = getattr(args, 'dropout', 0.1)
@register_model_architecture('transformer', 'transformer_iwslt_de_en')
......
......@@ -6,6 +6,8 @@
# can be found in the PATENTS file in the same directory.
import unittest
import itertools
from unittest.mock import MagicMock, patch
import train
......@@ -19,10 +21,8 @@ def mock_trainer(epoch, num_updates):
def mock_loader(length):
ds = MagicMock()
ds.__len__.return_value = length
loader = MagicMock()
loader.__next__.return_value = ds
loader.__next__.return_value = list(range(length))
return loader
......@@ -42,16 +42,14 @@ class TestLoadCheckpoint(unittest.TestCase):
loader = mock_loader(150)
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 2)
self.assertEqual(len(ds), 50)
self.assertNotIsInstance(ds, MagicMock)
self.assertEqual(next(ds), 50)
def test_load_full_checkpoint(self):
trainer = mock_trainer(2, 150)
loader = mock_loader(150)
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 2)
self.assertEqual(len(ds), 150)
self.assertIsInstance(ds, MagicMock)
self.assertEqual(next(iter(ds)), 0)
def test_load_no_checkpoint(self):
trainer = mock_trainer(0, 0)
......@@ -60,8 +58,7 @@ class TestLoadCheckpoint(unittest.TestCase):
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 1)
self.assertEqual(len(ds), 150)
self.assertIsInstance(ds, MagicMock)
self.assertEqual(next(iter(ds)), 0)
def tearDown(self):
patch.stopall()
......
......@@ -11,8 +11,10 @@ import os
import math
import torch
from itertools import islice
from fairseq import criterions, models, options, progress_bar
from fairseq.data import data_utils, data_loaders, OffsetDataset
from fairseq.data import data_utils, data_loaders
from fairseq.fp16_trainer import FP16Trainer
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
......@@ -323,7 +325,12 @@ def load_checkpoint(args, trainer, train_dataloader):
updates += len(ds)
if ds is not None and updates > trainer_updates:
ds = OffsetDataset(ds, updates - trainer_updates)
completed_batches = len(ds) - (updates - trainer_updates)
assert completed_batches >= 0
ds = iter(ds)
# consume completed batches
next(islice(ds, completed_batches, completed_batches), None)
else:
ds = next(train_dataloader)
epoch += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment