Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
820f796f
Commit
820f796f
authored
Oct 25, 2017
by
Myle Ott
Browse files
Add `--curriculum` option
parent
3af8ec82
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
9 deletions
+15
-9
fairseq/data.py
fairseq/data.py
+8
-7
fairseq/options.py
fairseq/options.py
+2
-0
train.py
train.py
+5
-2
No files found.
fairseq/data.py
View file @
820f796f
...
@@ -94,7 +94,8 @@ class LanguageDatasets(object):
...
@@ -94,7 +94,8 @@ class LanguageDatasets(object):
def
dataloader
(
self
,
split
,
batch_size
=
1
,
num_workers
=
0
,
def
dataloader
(
self
,
split
,
batch_size
=
1
,
num_workers
=
0
,
max_tokens
=
None
,
seed
=
None
,
epoch
=
1
,
max_tokens
=
None
,
seed
=
None
,
epoch
=
1
,
sample_without_replacement
=
0
,
max_positions
=
1024
,
sample_without_replacement
=
0
,
max_positions
=
1024
,
skip_invalid_size_inputs_valid_test
=
False
):
skip_invalid_size_inputs_valid_test
=
False
,
sort_by_source_size
=
False
):
dataset
=
self
.
splits
[
split
]
dataset
=
self
.
splits
[
split
]
if
split
.
startswith
(
'train'
):
if
split
.
startswith
(
'train'
):
with
numpy_seed
(
seed
):
with
numpy_seed
(
seed
):
...
@@ -102,7 +103,8 @@ class LanguageDatasets(object):
...
@@ -102,7 +103,8 @@ class LanguageDatasets(object):
dataset
.
src
,
dataset
.
dst
,
dataset
.
src
,
dataset
.
dst
,
max_tokens
=
max_tokens
,
epoch
=
epoch
,
max_tokens
=
max_tokens
,
epoch
=
epoch
,
sample
=
sample_without_replacement
,
sample
=
sample_without_replacement
,
max_positions
=
max_positions
)
max_positions
=
max_positions
,
sort_by_source_size
=
sort_by_source_size
)
elif
split
.
startswith
(
'valid'
):
elif
split
.
startswith
(
'valid'
):
batch_sampler
=
list
(
batches_by_size
(
dataset
.
src
,
batch_size
,
max_tokens
,
dst
=
dataset
.
dst
,
batch_sampler
=
list
(
batches_by_size
(
dataset
.
src
,
batch_size
,
max_tokens
,
dst
=
dataset
.
dst
,
max_positions
=
max_positions
,
max_positions
=
max_positions
,
...
@@ -269,7 +271,8 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
...
@@ -269,7 +271,8 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
yield
batch
yield
batch
def
shuffled_batches_by_size
(
src
,
dst
,
max_tokens
=
None
,
epoch
=
1
,
sample
=
0
,
max_positions
=
1024
):
def
shuffled_batches_by_size
(
src
,
dst
,
max_tokens
=
None
,
epoch
=
1
,
sample
=
0
,
max_positions
=
1024
,
sort_by_source_size
=
False
):
"""Returns batches of indices, bucketed by size and then shuffled. Batches
"""Returns batches of indices, bucketed by size and then shuffled. Batches
may contain sequences of different lengths."""
may contain sequences of different lengths."""
assert
isinstance
(
src
,
IndexedDataset
)
and
isinstance
(
dst
,
IndexedDataset
)
assert
isinstance
(
src
,
IndexedDataset
)
and
isinstance
(
dst
,
IndexedDataset
)
...
@@ -310,6 +313,7 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, max_p
...
@@ -310,6 +313,7 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, max_p
"and will be ignored, sample ids={}"
.
format
(
len
(
ignored
),
ignored
))
"and will be ignored, sample ids={}"
.
format
(
len
(
ignored
),
ignored
))
batches
=
list
(
make_batches
())
batches
=
list
(
make_batches
())
if
not
sort_by_source_size
:
np
.
random
.
shuffle
(
batches
)
np
.
random
.
shuffle
(
batches
)
if
sample
:
if
sample
:
...
@@ -327,9 +331,6 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, max_p
...
@@ -327,9 +331,6 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, max_p
"batch length is not correct {}"
.
format
(
len
(
result
))
"batch length is not correct {}"
.
format
(
len
(
result
))
batches
=
result
batches
=
result
else
:
for
_
in
range
(
epoch
-
1
):
np
.
random
.
shuffle
(
batches
)
return
batches
return
batches
...
...
fairseq/options.py
View file @
820f796f
...
@@ -67,6 +67,8 @@ def add_optimization_args(parser):
...
@@ -67,6 +67,8 @@ def add_optimization_args(parser):
help
=
'If bigger than 0, use that number of mini-batches for each epoch,'
help
=
'If bigger than 0, use that number of mini-batches for each epoch,'
' where each sample is drawn randomly without replacement from the'
' where each sample is drawn randomly without replacement from the'
' dataset'
)
' dataset'
)
group
.
add_argument
(
'--curriculum'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'sort batches by source length for first N epochs'
)
return
group
return
group
...
...
train.py
View file @
820f796f
...
@@ -120,11 +120,15 @@ def get_perplexity(loss):
...
@@ -120,11 +120,15 @@ def get_perplexity(loss):
def
train
(
args
,
epoch
,
batch_offset
,
trainer
,
dataset
,
num_gpus
):
def
train
(
args
,
epoch
,
batch_offset
,
trainer
,
dataset
,
num_gpus
):
"""Train the model for one epoch."""
"""Train the model for one epoch."""
torch
.
manual_seed
(
args
.
seed
+
epoch
)
trainer
.
set_seed
(
args
.
seed
+
epoch
)
itr
=
dataset
.
dataloader
(
args
.
train_subset
,
num_workers
=
args
.
workers
,
itr
=
dataset
.
dataloader
(
args
.
train_subset
,
num_workers
=
args
.
workers
,
max_tokens
=
args
.
max_tokens
,
seed
=
args
.
seed
,
epoch
=
epoch
,
max_tokens
=
args
.
max_tokens
,
seed
=
args
.
seed
,
epoch
=
epoch
,
max_positions
=
args
.
max_positions
,
max_positions
=
args
.
max_positions
,
sample_without_replacement
=
args
.
sample_without_replacement
,
sample_without_replacement
=
args
.
sample_without_replacement
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
,
sort_by_source_size
=
(
epoch
<=
args
.
curriculum
))
loss_meter
=
AverageMeter
()
loss_meter
=
AverageMeter
()
bsz_meter
=
AverageMeter
()
# sentences per batch
bsz_meter
=
AverageMeter
()
# sentences per batch
wpb_meter
=
AverageMeter
()
# words per batch
wpb_meter
=
AverageMeter
()
# words per batch
...
@@ -133,7 +137,6 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
...
@@ -133,7 +137,6 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
desc
=
'| epoch {:03d}'
.
format
(
epoch
)
desc
=
'| epoch {:03d}'
.
format
(
epoch
)
trainer
.
set_seed
(
args
.
seed
+
epoch
)
lr
=
trainer
.
get_lr
()
lr
=
trainer
.
get_lr
()
with
progress_bar
(
itr
,
desc
,
leave
=
False
)
as
t
:
with
progress_bar
(
itr
,
desc
,
leave
=
False
)
as
t
:
for
i
,
sample
in
data
.
skip_group_enumerator
(
t
,
num_gpus
,
batch_offset
):
for
i
,
sample
in
data
.
skip_group_enumerator
(
t
,
num_gpus
,
batch_offset
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment