Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
2f781c5a
Commit
2f781c5a
authored
Oct 25, 2017
by
Myle Ott
Browse files
Support different max_source_positions and max_target_positions
parent
5fe8ea46
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
49 additions
and
33 deletions
+49
-33
fairseq/data.py
fairseq/data.py
+24
-14
fairseq/models/fconv.py
fairseq/models/fconv.py
+2
-2
fairseq/options.py
fairseq/options.py
+4
-2
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+1
-1
train.py
train.py
+18
-14
No files found.
fairseq/data.py
View file @
2f781c5a
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
import
contextlib
import
contextlib
import
itertools
import
itertools
import
numbers
import
numpy
as
np
import
numpy
as
np
import
os
import
os
import
torch
import
torch
...
@@ -93,7 +94,7 @@ class LanguageDatasets(object):
...
@@ -93,7 +94,7 @@ class LanguageDatasets(object):
def
dataloader
(
self
,
split
,
batch_size
=
1
,
num_workers
=
0
,
def
dataloader
(
self
,
split
,
batch_size
=
1
,
num_workers
=
0
,
max_tokens
=
None
,
seed
=
None
,
epoch
=
1
,
max_tokens
=
None
,
seed
=
None
,
epoch
=
1
,
sample_without_replacement
=
0
,
max_positions
=
1024
,
sample_without_replacement
=
0
,
max_positions
=
(
1024
,
1024
),
skip_invalid_size_inputs_valid_test
=
False
,
skip_invalid_size_inputs_valid_test
=
False
,
sort_by_source_size
=
False
):
sort_by_source_size
=
False
):
dataset
=
self
.
splits
[
split
]
dataset
=
self
.
splits
[
split
]
...
@@ -205,8 +206,20 @@ class LanguagePairDataset(object):
...
@@ -205,8 +206,20 @@ class LanguagePairDataset(object):
return
res
return
res
def
_valid_size
(
src_size
,
dst_size
,
max_positions
):
if
isinstance
(
max_positions
,
numbers
.
Number
):
max_src_positions
,
max_dst_positions
=
max_positions
,
max_positions
else
:
max_src_positions
,
max_dst_positions
=
max_positions
if
src_size
<
2
or
src_size
>
max_src_positions
:
return
False
if
dst_size
is
not
None
and
(
dst_size
<
2
or
dst_size
>
max_dst_positions
):
return
False
return
True
def
batches_by_size
(
src
,
batch_size
=
None
,
max_tokens
=
None
,
dst
=
None
,
def
batches_by_size
(
src
,
batch_size
=
None
,
max_tokens
=
None
,
dst
=
None
,
max_positions
=
1024
,
ignore_invalid_inputs
=
False
):
max_positions
=
(
1024
,
1024
),
ignore_invalid_inputs
=
False
):
"""Returns batches of indices sorted by size. Sequences of different lengths
"""Returns batches of indices sorted by size. Sequences of different lengths
are not allowed in the same batch."""
are not allowed in the same batch."""
assert
isinstance
(
src
,
IndexedDataset
)
assert
isinstance
(
src
,
IndexedDataset
)
...
@@ -234,15 +247,14 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
...
@@ -234,15 +247,14 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
cur_max_size
=
0
cur_max_size
=
0
ignored
=
[]
ignored
=
[]
for
idx
in
indices
:
for
idx
in
indices
:
if
src
.
sizes
[
idx
]
<
2
or
\
if
not
_valid_size
(
src
.
sizes
[
idx
]
,
(
Fals
e
if
dst
is
None
else
dst
.
sizes
[
idx
]
<
2
)
or
\
Non
e
if
dst
is
None
else
dst
.
sizes
[
idx
]
,
sizes
[
idx
]
>
max_positions
:
max_positions
)
:
if
ignore_invalid_inputs
:
if
ignore_invalid_inputs
:
ignored
.
append
(
idx
)
ignored
.
append
(
idx
)
continue
continue
raise
Exception
(
"Unable to handle input id {} of "
raise
Exception
(
"Unable to handle input id {} of size {} / {}."
.
format
(
"size {} / {}."
.
format
(
idx
,
src
.
sizes
[
idx
],
idx
,
src
.
sizes
[
idx
],
"none"
if
dst
is
None
else
dst
.
sizes
[
idx
]))
"none"
if
dst
is
None
else
dst
.
sizes
[
idx
]))
if
yield_batch
(
idx
,
cur_max_size
*
(
len
(
batch
)
+
1
)):
if
yield_batch
(
idx
,
cur_max_size
*
(
len
(
batch
)
+
1
)):
yield
batch
yield
batch
...
@@ -253,14 +265,14 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
...
@@ -253,14 +265,14 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
if
len
(
ignored
)
>
0
:
if
len
(
ignored
)
>
0
:
print
(
"Warning! {} samples are either too short or too long "
print
(
"Warning! {} samples are either too short or too long "
"and will be ignored, sample ids={}"
.
format
(
len
(
ignored
),
ignored
))
"and will be ignored,
first few
sample ids={}"
.
format
(
len
(
ignored
),
ignored
[:
10
]
))
if
len
(
batch
)
>
0
:
if
len
(
batch
)
>
0
:
yield
batch
yield
batch
def
shuffled_batches_by_size
(
src
,
dst
,
max_tokens
=
None
,
epoch
=
1
,
sample
=
0
,
def
shuffled_batches_by_size
(
src
,
dst
,
max_tokens
=
None
,
epoch
=
1
,
sample
=
0
,
max_positions
=
1024
,
sort_by_source_size
=
False
):
max_positions
=
(
1024
,
1024
),
sort_by_source_size
=
False
):
"""Returns batches of indices, bucketed by size and then shuffled. Batches
"""Returns batches of indices, bucketed by size and then shuffled. Batches
may contain sequences of different lengths."""
may contain sequences of different lengths."""
assert
isinstance
(
src
,
IndexedDataset
)
and
isinstance
(
dst
,
IndexedDataset
)
assert
isinstance
(
src
,
IndexedDataset
)
and
isinstance
(
dst
,
IndexedDataset
)
...
@@ -278,9 +290,7 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0,
...
@@ -278,9 +290,7 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0,
sample_len
=
0
sample_len
=
0
ignored
=
[]
ignored
=
[]
for
idx
in
indices
:
for
idx
in
indices
:
if
src
.
sizes
[
idx
]
<
2
or
dst
.
sizes
[
idx
]
<
2
or
\
if
not
_valid_size
(
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
],
max_positions
):
src
.
sizes
[
idx
]
>
max_positions
or
\
dst
.
sizes
[
idx
]
>
max_positions
:
ignored
.
append
(
idx
)
ignored
.
append
(
idx
)
continue
continue
sample_len
=
max
(
sample_len
,
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
])
sample_len
=
max
(
sample_len
,
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
])
...
@@ -296,7 +306,7 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0,
...
@@ -296,7 +306,7 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0,
if
len
(
ignored
)
>
0
:
if
len
(
ignored
)
>
0
:
print
(
"Warning! {} samples are either too short or too long "
print
(
"Warning! {} samples are either too short or too long "
"and will be ignored, sample ids={}"
.
format
(
len
(
ignored
),
ignored
))
"and will be ignored,
first few
sample ids={}"
.
format
(
len
(
ignored
),
ignored
[:
10
]
))
batches
=
list
(
make_batches
())
batches
=
list
(
make_batches
())
if
not
sort_by_source_size
:
if
not
sort_by_source_size
:
...
...
fairseq/models/fconv.py
View file @
2f781c5a
...
@@ -381,7 +381,7 @@ def build_model(args, src_dict, dst_dict):
...
@@ -381,7 +381,7 @@ def build_model(args, src_dict, dst_dict):
embed_dim
=
args
.
encoder_embed_dim
,
embed_dim
=
args
.
encoder_embed_dim
,
convolutions
=
eval
(
args
.
encoder_layers
),
convolutions
=
eval
(
args
.
encoder_layers
),
dropout
=
args
.
dropout
,
dropout
=
args
.
dropout
,
max_positions
=
args
.
max_positions
,
max_positions
=
args
.
max_
source_
positions
,
)
)
decoder
=
FConvDecoder
(
decoder
=
FConvDecoder
(
dst_dict
,
dst_dict
,
...
@@ -390,6 +390,6 @@ def build_model(args, src_dict, dst_dict):
...
@@ -390,6 +390,6 @@ def build_model(args, src_dict, dst_dict):
out_embed_dim
=
args
.
decoder_out_embed_dim
,
out_embed_dim
=
args
.
decoder_out_embed_dim
,
attention
=
eval
(
args
.
decoder_attention
),
attention
=
eval
(
args
.
decoder_attention
),
dropout
=
args
.
dropout
,
dropout
=
args
.
dropout
,
max_positions
=
args
.
max_positions
,
max_positions
=
args
.
max_
target_
positions
,
)
)
return
FConvModel
(
encoder
,
decoder
)
return
FConvModel
(
encoder
,
decoder
)
fairseq/options.py
View file @
2f781c5a
...
@@ -33,8 +33,10 @@ def add_dataset_args(parser):
...
@@ -33,8 +33,10 @@ def add_dataset_args(parser):
help
=
'target language'
)
help
=
'target language'
)
group
.
add_argument
(
'-j'
,
'--workers'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
group
.
add_argument
(
'-j'
,
'--workers'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'number of data loading workers (default: 1)'
)
help
=
'number of data loading workers (default: 1)'
)
group
.
add_argument
(
'--max-positions'
,
default
=
1024
,
type
=
int
,
metavar
=
'N'
,
group
.
add_argument
(
'--max-source-positions'
,
default
=
1024
,
type
=
int
,
metavar
=
'N'
,
help
=
'max number of tokens in the sequence'
)
help
=
'max number of tokens in the source sequence'
)
group
.
add_argument
(
'--max-target-positions'
,
default
=
1024
,
type
=
int
,
metavar
=
'N'
,
help
=
'max number of tokens in the target sequence'
)
group
.
add_argument
(
'--skip-invalid-size-inputs-valid-test'
,
action
=
'store_true'
,
group
.
add_argument
(
'--skip-invalid-size-inputs-valid-test'
,
action
=
'store_true'
,
help
=
'Ignore too long or too short lines in valid and test set'
)
help
=
'Ignore too long or too short lines in valid and test set'
)
return
group
return
group
...
...
fairseq/sequence_generator.py
View file @
2f781c5a
...
@@ -40,7 +40,7 @@ class SequenceGenerator(object):
...
@@ -40,7 +40,7 @@ class SequenceGenerator(object):
self
.
vocab_size
=
len
(
models
[
0
].
dst_dict
)
self
.
vocab_size
=
len
(
models
[
0
].
dst_dict
)
self
.
beam_size
=
beam_size
self
.
beam_size
=
beam_size
self
.
minlen
=
minlen
self
.
minlen
=
minlen
self
.
maxlen
=
min
(
maxlen
,
*
[
m
.
decoder
.
max
_positions
()
for
m
in
self
.
models
])
self
.
maxlen
=
min
(
maxlen
,
*
[
m
.
max_
decoder_positions
()
for
m
in
self
.
models
])
self
.
stop_early
=
stop_early
self
.
stop_early
=
stop_early
self
.
normalize_scores
=
normalize_scores
self
.
normalize_scores
=
normalize_scores
self
.
len_penalty
=
len_penalty
self
.
len_penalty
=
len_penalty
...
...
train.py
View file @
2f781c5a
...
@@ -66,6 +66,11 @@ def main():
...
@@ -66,6 +66,11 @@ def main():
criterion
=
utils
.
build_criterion
(
args
,
dataset
.
src_dict
,
dataset
.
dst_dict
)
criterion
=
utils
.
build_criterion
(
args
,
dataset
.
src_dict
,
dataset
.
dst_dict
)
print
(
'| model {}, criterion {}'
.
format
(
args
.
arch
,
criterion
.
__class__
.
__name__
))
print
(
'| model {}, criterion {}'
.
format
(
args
.
arch
,
criterion
.
__class__
.
__name__
))
# The max number of positions can be different for train and valid
# e.g., RNNs may support more positions at test time than seen in training
max_positions_train
=
(
args
.
max_source_positions
,
args
.
max_target_positions
)
max_positions_valid
=
(
model
.
max_encoder_positions
(),
model
.
max_decoder_positions
())
# Start multiprocessing
# Start multiprocessing
trainer
=
MultiprocessingTrainer
(
args
,
model
,
criterion
)
trainer
=
MultiprocessingTrainer
(
args
,
model
,
criterion
)
...
@@ -89,11 +94,11 @@ def main():
...
@@ -89,11 +94,11 @@ def main():
train_meter
.
start
()
train_meter
.
start
()
while
lr
>
args
.
min_lr
and
epoch
<=
max_epoch
:
while
lr
>
args
.
min_lr
and
epoch
<=
max_epoch
:
# train for one epoch
# train for one epoch
train
(
args
,
epoch
,
batch_offset
,
trainer
,
dataset
,
num_gpus
)
train
(
args
,
epoch
,
batch_offset
,
trainer
,
dataset
,
max_positions_train
,
num_gpus
)
# evaluate on validate set
# evaluate on validate set
for
k
,
subset
in
enumerate
(
args
.
valid_subset
.
split
(
','
)):
for
k
,
subset
in
enumerate
(
args
.
valid_subset
.
split
(
','
)):
val_loss
=
validate
(
args
,
epoch
,
trainer
,
dataset
,
subset
,
num_gpus
)
val_loss
=
validate
(
args
,
epoch
,
trainer
,
dataset
,
max_positions_valid
,
subset
,
num_gpus
)
if
k
==
0
:
if
k
==
0
:
if
not
args
.
no_save
:
if
not
args
.
no_save
:
# save checkpoint
# save checkpoint
...
@@ -117,15 +122,15 @@ def get_perplexity(loss):
...
@@ -117,15 +122,15 @@ def get_perplexity(loss):
return
float
(
'inf'
)
return
float
(
'inf'
)
def
train
(
args
,
epoch
,
batch_offset
,
trainer
,
dataset
,
num_gpus
):
def
train
(
args
,
epoch
,
batch_offset
,
trainer
,
dataset
,
max_positions
,
num_gpus
):
"""Train the model for one epoch."""
"""Train the model for one epoch."""
torch
.
manual_seed
(
args
.
seed
+
epoch
)
torch
.
manual_seed
(
args
.
seed
+
epoch
)
trainer
.
set_seed
(
args
.
seed
+
epoch
)
trainer
.
set_seed
(
args
.
seed
+
epoch
)
itr
=
dataset
.
dataloader
(
args
.
train_subset
,
num_workers
=
args
.
workers
,
itr
=
dataset
.
dataloader
(
max_to
ke
n
s
=
args
.
max_tokens
,
seed
=
args
.
seed
,
epoch
=
epoch
,
args
.
train_subset
,
num_wor
ke
r
s
=
args
.
workers
,
max_tokens
=
args
.
max_tokens
,
max_positions
=
args
.
max_positions
,
seed
=
args
.
seed
,
epoch
=
epoch
,
max_positions
=
max_positions
,
sample_without_replacement
=
args
.
sample_without_replacement
,
sample_without_replacement
=
args
.
sample_without_replacement
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
,
sort_by_source_size
=
(
epoch
<=
args
.
curriculum
))
sort_by_source_size
=
(
epoch
<=
args
.
curriculum
))
...
@@ -207,12 +212,11 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
...
@@ -207,12 +212,11 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
trainer
.
save_checkpoint
(
last_filename
,
extra_state
)
trainer
.
save_checkpoint
(
last_filename
,
extra_state
)
def
validate
(
args
,
epoch
,
trainer
,
dataset
,
subset
,
ngpus
):
def
validate
(
args
,
epoch
,
trainer
,
dataset
,
max_positions
,
subset
,
ngpus
):
"""Evaluate the model on the validation set and return the average loss."""
"""Evaluate the model on the validation set and return the average loss."""
itr
=
dataset
.
dataloader
(
subset
,
batch_size
=
None
,
itr
=
dataset
.
dataloader
(
max_tokens
=
args
.
max_tokens
,
subset
,
batch_size
=
None
,
max_tokens
=
args
.
max_tokens
,
max_positions
=
max_positions
,
max_positions
=
args
.
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
loss_meter
=
AverageMeter
()
loss_meter
=
AverageMeter
()
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment