Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
978c125a
Commit
978c125a
authored
May 27, 2018
by
alexeib
Committed by
Myle Ott
Jun 15, 2018
Browse files
fix restoring from middle of epoch; fix defaulting transformer dropout params
parent
386847ee
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
21 additions
and
47 deletions
+21
-47
fairseq/data/__init__.py
fairseq/data/__init__.py
+0
-1
fairseq/data/offset_dataset.py
fairseq/data/offset_dataset.py
+0
-32
fairseq/models/transformer.py
fairseq/models/transformer.py
+6
-3
tests/test_train.py
tests/test_train.py
+6
-9
train.py
train.py
+9
-2
No files found.
fairseq/data/__init__.py
View file @
978c125a
...
@@ -10,4 +10,3 @@ from .token_block_dataset import TokenBlockDataset
...
@@ -10,4 +10,3 @@ from .token_block_dataset import TokenBlockDataset
from
.language_dataset
import
LanguageDatasets
from
.language_dataset
import
LanguageDatasets
from
.language_pair_dataset
import
LanguagePairDataset
from
.language_pair_dataset
import
LanguagePairDataset
from
.monolingual_dataset
import
MonolingualDataset
from
.monolingual_dataset
import
MonolingualDataset
from
.offset_dataset
import
OffsetDataset
fairseq/data/offset_dataset.py
deleted
100644 → 0
View file @
386847ee
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from
torch.utils.data
import
Dataset
class
OffsetDataset
(
Dataset
):
""" Wraps an existing dataset, but starts iterating from a particular offset """
def
__init__
(
self
,
dataset
,
offset
):
"""
Args:
dataset: Dataset to wrap
offset: An integer. offset from which to start iterating
"""
super
().
__init__
()
assert
len
(
dataset
)
>=
offset
self
.
dataset
=
dataset
self
.
offset
=
offset
def
__getitem__
(
self
,
i
):
return
self
.
dataset
[
i
+
self
.
offset
]
def
__len__
(
self
):
return
len
(
self
.
dataset
)
-
self
.
offset
fairseq/models/transformer.py
View file @
978c125a
...
@@ -31,11 +31,11 @@ class TransformerModel(FairseqModel):
...
@@ -31,11 +31,11 @@ class TransformerModel(FairseqModel):
@
staticmethod
@
staticmethod
def
add_args
(
parser
):
def
add_args
(
parser
):
"""Add model-specific arguments to the parser."""
"""Add model-specific arguments to the parser."""
parser
.
add_argument
(
'--dropout'
,
default
=
0.1
,
type
=
float
,
metavar
=
'D'
,
parser
.
add_argument
(
'--dropout'
,
type
=
float
,
metavar
=
'D'
,
help
=
'dropout probability'
)
help
=
'dropout probability'
)
parser
.
add_argument
(
'--attention-dropout'
,
default
=
0.
,
type
=
float
,
metavar
=
'D'
,
parser
.
add_argument
(
'--attention-dropout'
,
type
=
float
,
metavar
=
'D'
,
help
=
'dropout probability for attention weights'
)
help
=
'dropout probability for attention weights'
)
parser
.
add_argument
(
'--relu-dropout'
,
default
=
0.
,
type
=
float
,
metavar
=
'D'
,
parser
.
add_argument
(
'--relu-dropout'
,
type
=
float
,
metavar
=
'D'
,
help
=
'dropout probability after ReLU in FFN'
)
help
=
'dropout probability after ReLU in FFN'
)
parser
.
add_argument
(
'--encoder-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
parser
.
add_argument
(
'--encoder-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'encoder embedding dimension'
)
help
=
'encoder embedding dimension'
)
...
@@ -399,6 +399,9 @@ def base_architecture(args):
...
@@ -399,6 +399,9 @@ def base_architecture(args):
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
'decoder_ffn_embed_dim'
,
args
.
encoder_ffn_embed_dim
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
'decoder_ffn_embed_dim'
,
args
.
encoder_ffn_embed_dim
)
args
.
decoder_layers
=
getattr
(
args
,
'decoder_layers'
,
6
)
args
.
decoder_layers
=
getattr
(
args
,
'decoder_layers'
,
6
)
args
.
decoder_attention_heads
=
getattr
(
args
,
'decoder_attention_heads'
,
8
)
args
.
decoder_attention_heads
=
getattr
(
args
,
'decoder_attention_heads'
,
8
)
args
.
attention_dropout
=
getattr
(
args
,
'attention_dropout'
,
0.
)
args
.
attention_dropout
=
getattr
(
args
,
'relu_dropout'
,
0.
)
args
.
attention_dropout
=
getattr
(
args
,
'dropout'
,
0.1
)
@
register_model_architecture
(
'transformer'
,
'transformer_iwslt_de_en'
)
@
register_model_architecture
(
'transformer'
,
'transformer_iwslt_de_en'
)
...
...
tests/test_train.py
View file @
978c125a
...
@@ -6,6 +6,8 @@
...
@@ -6,6 +6,8 @@
# can be found in the PATENTS file in the same directory.
# can be found in the PATENTS file in the same directory.
import
unittest
import
unittest
import
itertools
from
unittest.mock
import
MagicMock
,
patch
from
unittest.mock
import
MagicMock
,
patch
import
train
import
train
...
@@ -19,10 +21,8 @@ def mock_trainer(epoch, num_updates):
...
@@ -19,10 +21,8 @@ def mock_trainer(epoch, num_updates):
def
mock_loader
(
length
):
def
mock_loader
(
length
):
ds
=
MagicMock
()
ds
.
__len__
.
return_value
=
length
loader
=
MagicMock
()
loader
=
MagicMock
()
loader
.
__next__
.
return_value
=
ds
loader
.
__next__
.
return_value
=
list
(
range
(
length
))
return
loader
return
loader
...
@@ -42,16 +42,14 @@ class TestLoadCheckpoint(unittest.TestCase):
...
@@ -42,16 +42,14 @@ class TestLoadCheckpoint(unittest.TestCase):
loader
=
mock_loader
(
150
)
loader
=
mock_loader
(
150
)
epoch
,
ds
=
train
.
load_checkpoint
(
MagicMock
(),
trainer
,
loader
)
epoch
,
ds
=
train
.
load_checkpoint
(
MagicMock
(),
trainer
,
loader
)
self
.
assertEqual
(
epoch
,
2
)
self
.
assertEqual
(
epoch
,
2
)
self
.
assertEqual
(
len
(
ds
),
50
)
self
.
assertEqual
(
next
(
ds
),
50
)
self
.
assertNotIsInstance
(
ds
,
MagicMock
)
def
test_load_full_checkpoint
(
self
):
def
test_load_full_checkpoint
(
self
):
trainer
=
mock_trainer
(
2
,
150
)
trainer
=
mock_trainer
(
2
,
150
)
loader
=
mock_loader
(
150
)
loader
=
mock_loader
(
150
)
epoch
,
ds
=
train
.
load_checkpoint
(
MagicMock
(),
trainer
,
loader
)
epoch
,
ds
=
train
.
load_checkpoint
(
MagicMock
(),
trainer
,
loader
)
self
.
assertEqual
(
epoch
,
2
)
self
.
assertEqual
(
epoch
,
2
)
self
.
assertEqual
(
len
(
ds
),
150
)
self
.
assertEqual
(
next
(
iter
(
ds
)),
0
)
self
.
assertIsInstance
(
ds
,
MagicMock
)
def
test_load_no_checkpoint
(
self
):
def
test_load_no_checkpoint
(
self
):
trainer
=
mock_trainer
(
0
,
0
)
trainer
=
mock_trainer
(
0
,
0
)
...
@@ -60,8 +58,7 @@ class TestLoadCheckpoint(unittest.TestCase):
...
@@ -60,8 +58,7 @@ class TestLoadCheckpoint(unittest.TestCase):
epoch
,
ds
=
train
.
load_checkpoint
(
MagicMock
(),
trainer
,
loader
)
epoch
,
ds
=
train
.
load_checkpoint
(
MagicMock
(),
trainer
,
loader
)
self
.
assertEqual
(
epoch
,
1
)
self
.
assertEqual
(
epoch
,
1
)
self
.
assertEqual
(
len
(
ds
),
150
)
self
.
assertEqual
(
next
(
iter
(
ds
)),
0
)
self
.
assertIsInstance
(
ds
,
MagicMock
)
def
tearDown
(
self
):
def
tearDown
(
self
):
patch
.
stopall
()
patch
.
stopall
()
...
...
train.py
View file @
978c125a
...
@@ -11,8 +11,10 @@ import os
...
@@ -11,8 +11,10 @@ import os
import
math
import
math
import
torch
import
torch
from
itertools
import
islice
from
fairseq
import
criterions
,
models
,
options
,
progress_bar
from
fairseq
import
criterions
,
models
,
options
,
progress_bar
from
fairseq.data
import
data_utils
,
data_loaders
,
OffsetDataset
from
fairseq.data
import
data_utils
,
data_loaders
from
fairseq.fp16_trainer
import
FP16Trainer
from
fairseq.fp16_trainer
import
FP16Trainer
from
fairseq.trainer
import
Trainer
from
fairseq.trainer
import
Trainer
from
fairseq.meters
import
AverageMeter
,
StopwatchMeter
from
fairseq.meters
import
AverageMeter
,
StopwatchMeter
...
@@ -323,7 +325,12 @@ def load_checkpoint(args, trainer, train_dataloader):
...
@@ -323,7 +325,12 @@ def load_checkpoint(args, trainer, train_dataloader):
updates
+=
len
(
ds
)
updates
+=
len
(
ds
)
if
ds
is
not
None
and
updates
>
trainer_updates
:
if
ds
is
not
None
and
updates
>
trainer_updates
:
ds
=
OffsetDataset
(
ds
,
updates
-
trainer_updates
)
completed_batches
=
len
(
ds
)
-
(
updates
-
trainer_updates
)
assert
completed_batches
>=
0
ds
=
iter
(
ds
)
# consume completed batches
next
(
islice
(
ds
,
completed_batches
,
completed_batches
),
None
)
else
:
else
:
ds
=
next
(
train_dataloader
)
ds
=
next
(
train_dataloader
)
epoch
+=
1
epoch
+=
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment