Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
820f796f
Commit
820f796f
authored
Oct 25, 2017
by
Myle Ott
Browse files
Add `--curriculum` option
parent
3af8ec82
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
9 deletions
+15
-9
fairseq/data.py
fairseq/data.py
+8
-7
fairseq/options.py
fairseq/options.py
+2
-0
train.py
train.py
+5
-2
No files found.
fairseq/data.py
View file @
820f796f
...
...
@@ -94,7 +94,8 @@ class LanguageDatasets(object):
def
dataloader
(
self
,
split
,
batch_size
=
1
,
num_workers
=
0
,
max_tokens
=
None
,
seed
=
None
,
epoch
=
1
,
sample_without_replacement
=
0
,
max_positions
=
1024
,
skip_invalid_size_inputs_valid_test
=
False
):
skip_invalid_size_inputs_valid_test
=
False
,
sort_by_source_size
=
False
):
dataset
=
self
.
splits
[
split
]
if
split
.
startswith
(
'train'
):
with
numpy_seed
(
seed
):
...
...
@@ -102,7 +103,8 @@ class LanguageDatasets(object):
dataset
.
src
,
dataset
.
dst
,
max_tokens
=
max_tokens
,
epoch
=
epoch
,
sample
=
sample_without_replacement
,
max_positions
=
max_positions
)
max_positions
=
max_positions
,
sort_by_source_size
=
sort_by_source_size
)
elif
split
.
startswith
(
'valid'
):
batch_sampler
=
list
(
batches_by_size
(
dataset
.
src
,
batch_size
,
max_tokens
,
dst
=
dataset
.
dst
,
max_positions
=
max_positions
,
...
...
@@ -269,7 +271,8 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
yield
batch
def
shuffled_batches_by_size
(
src
,
dst
,
max_tokens
=
None
,
epoch
=
1
,
sample
=
0
,
max_positions
=
1024
):
def
shuffled_batches_by_size
(
src
,
dst
,
max_tokens
=
None
,
epoch
=
1
,
sample
=
0
,
max_positions
=
1024
,
sort_by_source_size
=
False
):
"""Returns batches of indices, bucketed by size and then shuffled. Batches
may contain sequences of different lengths."""
assert
isinstance
(
src
,
IndexedDataset
)
and
isinstance
(
dst
,
IndexedDataset
)
...
...
@@ -310,7 +313,8 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, max_p
"and will be ignored, sample ids={}"
.
format
(
len
(
ignored
),
ignored
))
batches
=
list
(
make_batches
())
np
.
random
.
shuffle
(
batches
)
if
not
sort_by_source_size
:
np
.
random
.
shuffle
(
batches
)
if
sample
:
offset
=
(
epoch
-
1
)
*
sample
...
...
@@ -327,9 +331,6 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, max_p
"batch length is not correct {}"
.
format
(
len
(
result
))
batches
=
result
else
:
for
_
in
range
(
epoch
-
1
):
np
.
random
.
shuffle
(
batches
)
return
batches
...
...
fairseq/options.py
View file @
820f796f
...
...
@@ -67,6 +67,8 @@ def add_optimization_args(parser):
help
=
'If bigger than 0, use that number of mini-batches for each epoch,'
' where each sample is drawn randomly without replacement from the'
' dataset'
)
group
.
add_argument
(
'--curriculum'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'sort batches by source length for first N epochs'
)
return
group
...
...
train.py
View file @
820f796f
...
...
@@ -120,11 +120,15 @@ def get_perplexity(loss):
def
train
(
args
,
epoch
,
batch_offset
,
trainer
,
dataset
,
num_gpus
):
"""Train the model for one epoch."""
torch
.
manual_seed
(
args
.
seed
+
epoch
)
trainer
.
set_seed
(
args
.
seed
+
epoch
)
itr
=
dataset
.
dataloader
(
args
.
train_subset
,
num_workers
=
args
.
workers
,
max_tokens
=
args
.
max_tokens
,
seed
=
args
.
seed
,
epoch
=
epoch
,
max_positions
=
args
.
max_positions
,
sample_without_replacement
=
args
.
sample_without_replacement
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
,
sort_by_source_size
=
(
epoch
<=
args
.
curriculum
))
loss_meter
=
AverageMeter
()
bsz_meter
=
AverageMeter
()
# sentences per batch
wpb_meter
=
AverageMeter
()
# words per batch
...
...
@@ -133,7 +137,6 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
desc
=
'| epoch {:03d}'
.
format
(
epoch
)
trainer
.
set_seed
(
args
.
seed
+
epoch
)
lr
=
trainer
.
get_lr
()
with
progress_bar
(
itr
,
desc
,
leave
=
False
)
as
t
:
for
i
,
sample
in
data
.
skip_group_enumerator
(
t
,
num_gpus
,
batch_offset
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment