Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
2f781c5a
"dgl_sparse/vscode:/vscode.git/clone" did not exist on "4cf5f682a65295ec10c3419cecdcec4388a2235f"
Commit
2f781c5a
authored
Oct 25, 2017
by
Myle Ott
Browse files
Support different max_source_positions and max_target_positions
parent
5fe8ea46
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
49 additions
and
33 deletions
+49
-33
fairseq/data.py
fairseq/data.py
+24
-14
fairseq/models/fconv.py
fairseq/models/fconv.py
+2
-2
fairseq/options.py
fairseq/options.py
+4
-2
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+1
-1
train.py
train.py
+18
-14
No files found.
fairseq/data.py
View file @
2f781c5a
...
...
@@ -8,6 +8,7 @@
import
contextlib
import
itertools
import
numbers
import
numpy
as
np
import
os
import
torch
...
...
@@ -93,7 +94,7 @@ class LanguageDatasets(object):
def
dataloader
(
self
,
split
,
batch_size
=
1
,
num_workers
=
0
,
max_tokens
=
None
,
seed
=
None
,
epoch
=
1
,
sample_without_replacement
=
0
,
max_positions
=
1024
,
sample_without_replacement
=
0
,
max_positions
=
(
1024
,
1024
),
skip_invalid_size_inputs_valid_test
=
False
,
sort_by_source_size
=
False
):
dataset
=
self
.
splits
[
split
]
...
...
@@ -205,8 +206,20 @@ class LanguagePairDataset(object):
return
res
def
_valid_size
(
src_size
,
dst_size
,
max_positions
):
if
isinstance
(
max_positions
,
numbers
.
Number
):
max_src_positions
,
max_dst_positions
=
max_positions
,
max_positions
else
:
max_src_positions
,
max_dst_positions
=
max_positions
if
src_size
<
2
or
src_size
>
max_src_positions
:
return
False
if
dst_size
is
not
None
and
(
dst_size
<
2
or
dst_size
>
max_dst_positions
):
return
False
return
True
def
batches_by_size
(
src
,
batch_size
=
None
,
max_tokens
=
None
,
dst
=
None
,
max_positions
=
1024
,
ignore_invalid_inputs
=
False
):
max_positions
=
(
1024
,
1024
),
ignore_invalid_inputs
=
False
):
"""Returns batches of indices sorted by size. Sequences of different lengths
are not allowed in the same batch."""
assert
isinstance
(
src
,
IndexedDataset
)
...
...
@@ -234,15 +247,14 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
cur_max_size
=
0
ignored
=
[]
for
idx
in
indices
:
if
src
.
sizes
[
idx
]
<
2
or
\
(
Fals
e
if
dst
is
None
else
dst
.
sizes
[
idx
]
<
2
)
or
\
sizes
[
idx
]
>
max_positions
:
if
not
_valid_size
(
src
.
sizes
[
idx
]
,
Non
e
if
dst
is
None
else
dst
.
sizes
[
idx
]
,
max_positions
)
:
if
ignore_invalid_inputs
:
ignored
.
append
(
idx
)
continue
raise
Exception
(
"Unable to handle input id {} of "
"size {} / {}."
.
format
(
idx
,
src
.
sizes
[
idx
],
"none"
if
dst
is
None
else
dst
.
sizes
[
idx
]))
raise
Exception
(
"Unable to handle input id {} of size {} / {}."
.
format
(
idx
,
src
.
sizes
[
idx
],
"none"
if
dst
is
None
else
dst
.
sizes
[
idx
]))
if
yield_batch
(
idx
,
cur_max_size
*
(
len
(
batch
)
+
1
)):
yield
batch
...
...
@@ -253,14 +265,14 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
if
len
(
ignored
)
>
0
:
print
(
"Warning! {} samples are either too short or too long "
"and will be ignored, sample ids={}"
.
format
(
len
(
ignored
),
ignored
))
"and will be ignored,
first few
sample ids={}"
.
format
(
len
(
ignored
),
ignored
[:
10
]
))
if
len
(
batch
)
>
0
:
yield
batch
def
shuffled_batches_by_size
(
src
,
dst
,
max_tokens
=
None
,
epoch
=
1
,
sample
=
0
,
max_positions
=
1024
,
sort_by_source_size
=
False
):
max_positions
=
(
1024
,
1024
),
sort_by_source_size
=
False
):
"""Returns batches of indices, bucketed by size and then shuffled. Batches
may contain sequences of different lengths."""
assert
isinstance
(
src
,
IndexedDataset
)
and
isinstance
(
dst
,
IndexedDataset
)
...
...
@@ -278,9 +290,7 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0,
sample_len
=
0
ignored
=
[]
for
idx
in
indices
:
if
src
.
sizes
[
idx
]
<
2
or
dst
.
sizes
[
idx
]
<
2
or
\
src
.
sizes
[
idx
]
>
max_positions
or
\
dst
.
sizes
[
idx
]
>
max_positions
:
if
not
_valid_size
(
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
],
max_positions
):
ignored
.
append
(
idx
)
continue
sample_len
=
max
(
sample_len
,
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
])
...
...
@@ -296,7 +306,7 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0,
if
len
(
ignored
)
>
0
:
print
(
"Warning! {} samples are either too short or too long "
"and will be ignored, sample ids={}"
.
format
(
len
(
ignored
),
ignored
))
"and will be ignored,
first few
sample ids={}"
.
format
(
len
(
ignored
),
ignored
[:
10
]
))
batches
=
list
(
make_batches
())
if
not
sort_by_source_size
:
...
...
fairseq/models/fconv.py
View file @
2f781c5a
...
...
@@ -381,7 +381,7 @@ def build_model(args, src_dict, dst_dict):
embed_dim
=
args
.
encoder_embed_dim
,
convolutions
=
eval
(
args
.
encoder_layers
),
dropout
=
args
.
dropout
,
max_positions
=
args
.
max_positions
,
max_positions
=
args
.
max_
source_
positions
,
)
decoder
=
FConvDecoder
(
dst_dict
,
...
...
@@ -390,6 +390,6 @@ def build_model(args, src_dict, dst_dict):
out_embed_dim
=
args
.
decoder_out_embed_dim
,
attention
=
eval
(
args
.
decoder_attention
),
dropout
=
args
.
dropout
,
max_positions
=
args
.
max_positions
,
max_positions
=
args
.
max_
target_
positions
,
)
return
FConvModel
(
encoder
,
decoder
)
fairseq/options.py
View file @
2f781c5a
...
...
@@ -33,8 +33,10 @@ def add_dataset_args(parser):
help
=
'target language'
)
group
.
add_argument
(
'-j'
,
'--workers'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'number of data loading workers (default: 1)'
)
group
.
add_argument
(
'--max-positions'
,
default
=
1024
,
type
=
int
,
metavar
=
'N'
,
help
=
'max number of tokens in the sequence'
)
group
.
add_argument
(
'--max-source-positions'
,
default
=
1024
,
type
=
int
,
metavar
=
'N'
,
help
=
'max number of tokens in the source sequence'
)
group
.
add_argument
(
'--max-target-positions'
,
default
=
1024
,
type
=
int
,
metavar
=
'N'
,
help
=
'max number of tokens in the target sequence'
)
group
.
add_argument
(
'--skip-invalid-size-inputs-valid-test'
,
action
=
'store_true'
,
help
=
'Ignore too long or too short lines in valid and test set'
)
return
group
...
...
fairseq/sequence_generator.py
View file @
2f781c5a
...
...
@@ -40,7 +40,7 @@ class SequenceGenerator(object):
self
.
vocab_size
=
len
(
models
[
0
].
dst_dict
)
self
.
beam_size
=
beam_size
self
.
minlen
=
minlen
self
.
maxlen
=
min
(
maxlen
,
*
[
m
.
decoder
.
max
_positions
()
for
m
in
self
.
models
])
self
.
maxlen
=
min
(
maxlen
,
*
[
m
.
max_
decoder_positions
()
for
m
in
self
.
models
])
self
.
stop_early
=
stop_early
self
.
normalize_scores
=
normalize_scores
self
.
len_penalty
=
len_penalty
...
...
train.py
View file @
2f781c5a
...
...
@@ -66,6 +66,11 @@ def main():
criterion
=
utils
.
build_criterion
(
args
,
dataset
.
src_dict
,
dataset
.
dst_dict
)
print
(
'| model {}, criterion {}'
.
format
(
args
.
arch
,
criterion
.
__class__
.
__name__
))
# The max number of positions can be different for train and valid
# e.g., RNNs may support more positions at test time than seen in training
max_positions_train
=
(
args
.
max_source_positions
,
args
.
max_target_positions
)
max_positions_valid
=
(
model
.
max_encoder_positions
(),
model
.
max_decoder_positions
())
# Start multiprocessing
trainer
=
MultiprocessingTrainer
(
args
,
model
,
criterion
)
...
...
@@ -89,11 +94,11 @@ def main():
train_meter
.
start
()
while
lr
>
args
.
min_lr
and
epoch
<=
max_epoch
:
# train for one epoch
train
(
args
,
epoch
,
batch_offset
,
trainer
,
dataset
,
num_gpus
)
train
(
args
,
epoch
,
batch_offset
,
trainer
,
dataset
,
max_positions_train
,
num_gpus
)
# evaluate on validate set
for
k
,
subset
in
enumerate
(
args
.
valid_subset
.
split
(
','
)):
val_loss
=
validate
(
args
,
epoch
,
trainer
,
dataset
,
subset
,
num_gpus
)
val_loss
=
validate
(
args
,
epoch
,
trainer
,
dataset
,
max_positions_valid
,
subset
,
num_gpus
)
if
k
==
0
:
if
not
args
.
no_save
:
# save checkpoint
...
...
@@ -117,18 +122,18 @@ def get_perplexity(loss):
return
float
(
'inf'
)
def
train
(
args
,
epoch
,
batch_offset
,
trainer
,
dataset
,
num_gpus
):
def
train
(
args
,
epoch
,
batch_offset
,
trainer
,
dataset
,
max_positions
,
num_gpus
):
"""Train the model for one epoch."""
torch
.
manual_seed
(
args
.
seed
+
epoch
)
trainer
.
set_seed
(
args
.
seed
+
epoch
)
itr
=
dataset
.
dataloader
(
args
.
train_subset
,
num_workers
=
args
.
workers
,
max_to
ke
n
s
=
args
.
max_tokens
,
seed
=
args
.
seed
,
epoch
=
epoch
,
max_positions
=
args
.
max_positions
,
sample_without_replacement
=
args
.
sample_without_replacement
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
,
sort_by_source_size
=
(
epoch
<=
args
.
curriculum
))
itr
=
dataset
.
dataloader
(
args
.
train_subset
,
num_wor
ke
r
s
=
args
.
workers
,
max_tokens
=
args
.
max_tokens
,
seed
=
args
.
seed
,
epoch
=
epoch
,
max_positions
=
max_positions
,
sample_without_replacement
=
args
.
sample_without_replacement
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
,
sort_by_source_size
=
(
epoch
<=
args
.
curriculum
))
loss_meter
=
AverageMeter
()
bsz_meter
=
AverageMeter
()
# sentences per batch
wpb_meter
=
AverageMeter
()
# words per batch
...
...
@@ -207,13 +212,12 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
trainer
.
save_checkpoint
(
last_filename
,
extra_state
)
def
validate
(
args
,
epoch
,
trainer
,
dataset
,
subset
,
ngpus
):
def
validate
(
args
,
epoch
,
trainer
,
dataset
,
max_positions
,
subset
,
ngpus
):
"""Evaluate the model on the validation set and return the average loss."""
itr
=
dataset
.
dataloader
(
subset
,
batch_size
=
None
,
max_tokens
=
args
.
max_tokens
,
max_positions
=
args
.
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
itr
=
dataset
.
dataloader
(
subset
,
batch_size
=
None
,
max_tokens
=
args
.
max_tokens
,
max_positions
=
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
loss_meter
=
AverageMeter
()
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment