Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
f442f896
Commit
f442f896
authored
Nov 04, 2017
by
Myle Ott
Browse files
Add --max-sentence option for batching based on # sentences
parent
2ef422f6
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
72 additions
and
72 deletions
+72
-72
fairseq/criterions/cross_entropy.py
fairseq/criterions/cross_entropy.py
+3
-4
fairseq/criterions/fairseq_criterion.py
fairseq/criterions/fairseq_criterion.py
+3
-1
fairseq/criterions/label_smoothed_cross_entropy.py
fairseq/criterions/label_smoothed_cross_entropy.py
+4
-5
fairseq/data.py
fairseq/data.py
+45
-52
fairseq/options.py
fairseq/options.py
+3
-0
fairseq/utils.py
fairseq/utils.py
+2
-3
generate.py
generate.py
+1
-1
train.py
train.py
+11
-6
No files found.
fairseq/criterions/cross_entropy.py
View file @
f442f896
...
...
@@ -14,9 +14,8 @@ from .fairseq_criterion import FairseqCriterion
class
CrossEntropyCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
padding_idx
):
super
().
__init__
()
self
.
padding_idx
=
padding_idx
def
__init__
(
self
,
args
,
dst_dict
):
super
().
__init__
(
args
,
dst_dict
)
def
forward
(
self
,
model
,
sample
):
"""Compute the loss for the given sample.
...
...
@@ -30,7 +29,7 @@ class CrossEntropyCriterion(FairseqCriterion):
input
=
net_output
.
view
(
-
1
,
net_output
.
size
(
-
1
))
target
=
sample
[
'target'
].
view
(
-
1
)
loss
=
F
.
cross_entropy
(
input
,
target
,
size_average
=
False
,
ignore_index
=
self
.
padding_idx
)
sample_size
=
sample
[
'ntokens'
]
sample_size
=
sample
[
'target'
].
size
(
0
)
if
self
.
args
.
sentence_avg
else
sample
[
'ntokens'
]
logging_output
=
{
'loss'
:
loss
.
data
[
0
],
'sample_size'
:
sample_size
,
...
...
fairseq/criterions/fairseq_criterion.py
View file @
f442f896
...
...
@@ -11,8 +11,10 @@ from torch.nn.modules.loss import _Loss
class
FairseqCriterion
(
_Loss
):
def
__init__
(
self
):
def
__init__
(
self
,
args
,
dst_dict
):
super
().
__init__
()
self
.
args
=
args
self
.
padding_idx
=
dst_dict
.
pad
()
def
forward
(
self
,
model
,
sample
):
"""Compute the loss for the given sample.
...
...
fairseq/criterions/label_smoothed_cross_entropy.py
View file @
f442f896
...
...
@@ -43,10 +43,9 @@ class LabelSmoothedCrossEntropy(torch.autograd.Function):
class
LabelSmoothedCrossEntropyCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
eps
,
padding_idx
=
None
,
weights
=
None
):
super
().
__init__
()
self
.
eps
=
eps
self
.
padding_idx
=
padding_idx
def
__init__
(
self
,
args
,
dst_dict
,
weights
=
None
):
super
().
__init__
(
args
,
dst_dict
)
self
.
eps
=
args
.
label_smoothing
self
.
weights
=
weights
def
forward
(
self
,
model
,
sample
):
...
...
@@ -61,7 +60,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
input
=
F
.
log_softmax
(
net_output
.
view
(
-
1
,
net_output
.
size
(
-
1
)))
target
=
sample
[
'target'
].
view
(
-
1
)
loss
=
LabelSmoothedCrossEntropy
.
apply
(
input
,
target
,
self
.
eps
,
self
.
padding_idx
,
self
.
weights
)
sample_size
=
sample
[
'ntokens'
]
sample_size
=
sample
[
'target'
].
size
(
0
)
if
self
.
args
.
sentence_avg
else
sample
[
'ntokens'
]
logging_output
=
{
'loss'
:
loss
.
data
[
0
],
'sample_size'
:
sample_size
,
...
...
fairseq/data.py
View file @
f442f896
...
...
@@ -97,27 +97,26 @@ class LanguageDatasets(object):
assert
self
.
src_dict
.
unk
()
==
self
.
dst_dict
.
unk
()
def
train_dataloader
(
self
,
split
,
num_workers
=
0
,
max_tokens
=
None
,
max_positions
=
(
1024
,
1024
),
seed
=
None
,
epoch
=
1
,
sample_without_replacement
=
0
,
max_sentences
=
None
,
max_positions
=
(
1024
,
1024
),
seed
=
None
,
epoch
=
1
,
sample_without_replacement
=
0
,
sort_by_source_size
=
False
):
dataset
=
self
.
splits
[
split
]
with
numpy_seed
(
seed
):
batch_sampler
=
shuffled_batches_by_size
(
dataset
.
src
,
dataset
.
dst
,
max_tokens
=
max_tokens
,
epoch
=
epoch
,
sample
=
sample_without_replacement
,
max_positions
=
max_positions
,
dataset
.
src
,
dataset
.
dst
,
max_tokens
=
max_tokens
,
max_sentences
=
max_sentences
,
epoch
=
epoch
,
sample
=
sample_without_replacement
,
max_positions
=
max_positions
,
sort_by_source_size
=
sort_by_source_size
)
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
num_workers
=
num_workers
,
collate_fn
=
dataset
.
collater
,
batch_sampler
=
batch_sampler
)
def
eval_dataloader
(
self
,
split
,
num_workers
=
0
,
batch_size
=
1
,
max_
token
s
=
None
,
max_positions
=
(
1024
,
1024
),
def
eval_dataloader
(
self
,
split
,
num_workers
=
0
,
max_tokens
=
None
,
max_
sentence
s
=
None
,
max_positions
=
(
1024
,
1024
),
skip_invalid_size_inputs_valid_test
=
False
):
dataset
=
self
.
splits
[
split
]
batch_sampler
=
list
(
batches_by_size
(
dataset
.
src
,
dataset
.
dst
,
batch_size
,
max_token
s
,
dataset
.
src
,
dataset
.
dst
,
max_tokens
,
max_sentence
s
,
max_positions
=
max_positions
,
ignore_invalid_inputs
=
skip_invalid_size_inputs_valid_test
))
return
torch
.
utils
.
data
.
DataLoader
(
...
...
@@ -220,29 +219,23 @@ def _valid_size(src_size, dst_size, max_positions):
return
True
def
batches_by_size
(
src
,
dst
,
batch_size
=
None
,
max_tokens
=
None
,
max_positions
=
(
1024
,
1024
),
ignore_invalid_inputs
=
False
):
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert
isinstance
(
src
,
IndexedDataset
)
and
isinstance
(
dst
,
IndexedDataset
)
if
max_tokens
is
None
:
max_tokens
=
float
(
'Inf'
)
indices
=
np
.
argsort
(
src
.
sizes
,
kind
=
'mergesort'
)
def
_make_batches
(
src
,
dst
,
indices
,
max_tokens
,
max_sentences
,
max_positions
,
ignore_invalid_inputs
=
False
,
allow_different_src_lens
=
False
):
batch
=
[]
def
yield_batch
(
next_idx
,
num_tokens
):
if
len
(
batch
)
==
0
:
return
False
if
len
(
batch
)
==
batch_size
:
if
len
(
batch
)
==
max_sentences
:
return
True
if
src
.
sizes
[
batch
[
0
]]
!=
src
.
sizes
[
next_idx
]
:
if
num_tokens
>
max_tokens
:
return
True
if
num_tokens
>=
max_tokens
:
if
not
allow_different_src_lens
and
\
(
src
.
sizes
[
batch
[
0
]]
!=
src
.
sizes
[
next_idx
]):
return
True
return
False
cur_max_size
=
0
sample_len
=
0
ignored
=
[]
for
idx
in
indices
:
if
not
_valid_size
(
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
],
max_positions
):
...
...
@@ -253,28 +246,48 @@ def batches_by_size(src, dst, batch_size=None, max_tokens=None,
"Unable to handle input id {} of size {} / {}."
.
format
(
idx
,
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
]))
if
yield_batch
(
idx
,
cur_max_size
*
(
len
(
batch
)
+
1
)):
sample_len
=
max
(
sample_len
,
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
])
num_tokens
=
(
len
(
batch
)
+
1
)
*
sample_len
if
yield_batch
(
idx
,
num_tokens
):
yield
batch
batch
=
[]
cur_max_size
=
0
sample_len
=
max
(
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
])
batch
.
append
(
idx
)
cur_max_size
=
max
(
cur_max_size
,
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
])
if
len
(
batch
)
>
0
:
yield
batch
if
len
(
ignored
)
>
0
:
print
(
"Warning! {} samples are either too short or too long "
"and will be ignored, first few sample ids={}"
.
format
(
len
(
ignored
),
ignored
[:
10
]))
if
len
(
batch
)
>
0
:
yield
batch
def
batches_by_size
(
src
,
dst
,
max_tokens
=
None
,
max_sentences
=
None
,
max_positions
=
(
1024
,
1024
),
ignore_invalid_inputs
=
False
):
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert
isinstance
(
src
,
IndexedDataset
)
and
isinstance
(
dst
,
IndexedDataset
)
if
max_tokens
is
None
:
max_tokens
=
float
(
'Inf'
)
if
max_sentences
is
None
:
max_sentences
=
float
(
'Inf'
)
indices
=
np
.
argsort
(
src
.
sizes
,
kind
=
'mergesort'
)
return
_make_batches
(
src
,
dst
,
indices
,
max_tokens
,
max_sentences
,
max_positions
,
ignore_invalid_inputs
,
allow_different_src_lens
=
False
)
def
shuffled_batches_by_size
(
src
,
dst
,
max_tokens
=
None
,
epoch
=
1
,
sample
=
0
,
max_positions
=
(
1024
,
1024
),
sort_by_source_size
=
False
):
def
shuffled_batches_by_size
(
src
,
dst
,
max_tokens
=
None
,
max_sentences
=
None
,
epoch
=
1
,
sample
=
0
,
max_positions
=
(
1024
,
1024
),
sort_by_source_size
=
False
):
"""Returns batches of indices, bucketed by size and then shuffled. Batches
may contain sequences of different lengths."""
assert
isinstance
(
src
,
IndexedDataset
)
and
isinstance
(
dst
,
IndexedDataset
)
if
max_tokens
is
None
:
max_tokens
=
float
(
'Inf'
)
if
max_sentences
is
None
:
max_sentences
=
float
(
'Inf'
)
indices
=
np
.
random
.
permutation
(
len
(
src
))
...
...
@@ -282,30 +295,10 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0,
indices
=
indices
[
np
.
argsort
(
dst
.
sizes
[
indices
],
kind
=
'mergesort'
)]
indices
=
indices
[
np
.
argsort
(
src
.
sizes
[
indices
],
kind
=
'mergesort'
)]
def
make_batches
():
batch
=
[]
sample_len
=
0
ignored
=
[]
for
idx
in
indices
:
if
not
_valid_size
(
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
],
max_positions
):
ignored
.
append
(
idx
)
continue
sample_len
=
max
(
sample_len
,
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
])
if
len
(
batch
)
>
0
and
(
len
(
batch
)
+
1
)
*
sample_len
>
max_tokens
:
yield
batch
batch
=
[]
sample_len
=
max
(
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
])
batch
.
append
(
idx
)
if
len
(
batch
)
>
0
:
yield
batch
if
len
(
ignored
)
>
0
:
print
(
"Warning! {} samples are either too short or too long "
"and will be ignored, first few sample ids={}"
.
format
(
len
(
ignored
),
ignored
[:
10
]))
batches
=
list
(
_make_batches
(
src
,
dst
,
indices
,
max_tokens
,
max_sentences
,
max_positions
,
ignore_invalid_inputs
=
True
,
allow_different_src_lens
=
True
))
batches
=
list
(
make_batches
())
if
not
sort_by_source_size
:
np
.
random
.
shuffle
(
batches
)
...
...
fairseq/options.py
View file @
f442f896
...
...
@@ -71,6 +71,9 @@ def add_optimization_args(parser):
' dataset'
)
group
.
add_argument
(
'--curriculum'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'sort batches by source length for first N epochs'
)
group
.
add_argument
(
'--sentence-avg'
,
action
=
'store_true'
,
help
=
'normalize gradients by the number of sentences in a batch'
' (default is to normalize by number of tokens)'
)
return
group
...
...
fairseq/utils.py
View file @
f442f896
...
...
@@ -30,11 +30,10 @@ def build_model(args, src_dict, dst_dict):
def
build_criterion
(
args
,
src_dict
,
dst_dict
):
padding_idx
=
dst_dict
.
pad
()
if
args
.
label_smoothing
>
0
:
return
criterions
.
LabelSmoothedCrossEntropyCriterion
(
args
.
label_smoothing
,
padding_idx
)
return
criterions
.
LabelSmoothedCrossEntropyCriterion
(
args
,
dst_dict
)
else
:
return
criterions
.
CrossEntropyCriterion
(
padding_idx
)
return
criterions
.
CrossEntropyCriterion
(
args
,
dst_dict
)
def
torch_persistent_save
(
*
args
,
**
kwargs
):
...
...
generate.py
View file @
f442f896
...
...
@@ -68,7 +68,7 @@ def main():
scorer
=
bleu
.
Scorer
(
dataset
.
dst_dict
.
pad
(),
dataset
.
dst_dict
.
eos
(),
dataset
.
dst_dict
.
unk
())
max_positions
=
min
(
model
.
max_encoder_positions
()
for
model
in
models
)
itr
=
dataset
.
eval_dataloader
(
args
.
gen_subset
,
batch_size
=
args
.
batch_size
,
max_positions
=
max_positions
,
args
.
gen_subset
,
max_sentences
=
args
.
batch_size
,
max_positions
=
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
num_sentences
=
0
with
progress_bar
(
itr
,
smoothing
=
0
,
leave
=
False
)
as
t
:
...
...
train.py
View file @
f442f896
...
...
@@ -23,6 +23,8 @@ def main():
dataset_args
=
options
.
add_dataset_args
(
parser
)
dataset_args
.
add_argument
(
'--max-tokens'
,
default
=
6000
,
type
=
int
,
metavar
=
'N'
,
help
=
'maximum number of tokens in a batch'
)
dataset_args
.
add_argument
(
'--max-sentences'
,
type
=
int
,
metavar
=
'N'
,
help
=
'maximum number of sentences in a batch'
)
dataset_args
.
add_argument
(
'--train-subset'
,
default
=
'train'
,
metavar
=
'SPLIT'
,
choices
=
[
'train'
,
'valid'
,
'test'
],
help
=
'data subset to use for training (train, valid, test)'
)
...
...
@@ -59,7 +61,8 @@ def main():
raise
NotImplementedError
(
'Training on CPU is not supported'
)
num_gpus
=
torch
.
cuda
.
device_count
()
print
(
'| using {} GPUs (with max tokens per GPU = {})'
.
format
(
num_gpus
,
args
.
max_tokens
))
print
(
'| using {} GPUs (with max tokens per GPU = {} and max sentences per GPU = {})'
.
format
(
num_gpus
,
args
.
max_tokens
,
args
.
max_sentences
))
# Build model and criterion
model
=
utils
.
build_model
(
args
,
dataset
.
src_dict
,
dataset
.
dst_dict
)
...
...
@@ -130,7 +133,8 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
trainer
.
set_seed
(
seed
)
itr
=
dataset
.
train_dataloader
(
args
.
train_subset
,
num_workers
=
args
.
workers
,
max_tokens
=
args
.
max_tokens
,
args
.
train_subset
,
num_workers
=
args
.
workers
,
max_tokens
=
args
.
max_tokens
,
max_sentences
=
args
.
max_sentences
,
max_positions
=
max_positions
,
seed
=
seed
,
epoch
=
epoch
,
sample_without_replacement
=
args
.
sample_without_replacement
,
sort_by_source_size
=
(
epoch
<=
args
.
curriculum
))
...
...
@@ -150,9 +154,9 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
del
loss_dict
[
'loss'
]
# don't include in extra_meters or extra_postfix
ntokens
=
sum
(
s
[
'ntokens'
]
for
s
in
sample
)
src_size
=
sum
(
s
[
'src_tokens'
].
size
(
0
)
for
s
in
sample
)
loss_meter
.
update
(
loss
,
ntokens
)
bsz_meter
.
update
(
src_size
)
nsentences
=
sum
(
s
[
'src_tokens'
].
size
(
0
)
for
s
in
sample
)
loss_meter
.
update
(
loss
,
nsentences
if
args
.
sentence_avg
else
ntokens
)
bsz_meter
.
update
(
nsentences
)
wpb_meter
.
update
(
ntokens
)
wps_meter
.
update
(
ntokens
)
clip_meter
.
update
(
1
if
loss_dict
[
'gnorm'
]
>
args
.
clip_norm
else
0
)
...
...
@@ -216,7 +220,8 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
"""Evaluate the model on the validation set and return the average loss."""
itr
=
dataset
.
eval_dataloader
(
subset
,
batch_size
=
None
,
max_tokens
=
args
.
max_tokens
,
max_positions
=
max_positions
,
subset
,
max_tokens
=
args
.
max_tokens
,
max_sentences
=
args
.
max_sentences
,
max_positions
=
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
loss_meter
=
AverageMeter
()
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment