Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
f442f896
Commit
f442f896
authored
Nov 04, 2017
by
Myle Ott
Browse files
Add --max-sentence option for batching based on # sentences
parent
2ef422f6
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
72 additions
and
72 deletions
+72
-72
fairseq/criterions/cross_entropy.py
fairseq/criterions/cross_entropy.py
+3
-4
fairseq/criterions/fairseq_criterion.py
fairseq/criterions/fairseq_criterion.py
+3
-1
fairseq/criterions/label_smoothed_cross_entropy.py
fairseq/criterions/label_smoothed_cross_entropy.py
+4
-5
fairseq/data.py
fairseq/data.py
+45
-52
fairseq/options.py
fairseq/options.py
+3
-0
fairseq/utils.py
fairseq/utils.py
+2
-3
generate.py
generate.py
+1
-1
train.py
train.py
+11
-6
No files found.
fairseq/criterions/cross_entropy.py
View file @
f442f896
...
@@ -14,9 +14,8 @@ from .fairseq_criterion import FairseqCriterion
...
@@ -14,9 +14,8 @@ from .fairseq_criterion import FairseqCriterion
class
CrossEntropyCriterion
(
FairseqCriterion
):
class
CrossEntropyCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
padding_idx
):
def
__init__
(
self
,
args
,
dst_dict
):
super
().
__init__
()
super
().
__init__
(
args
,
dst_dict
)
self
.
padding_idx
=
padding_idx
def
forward
(
self
,
model
,
sample
):
def
forward
(
self
,
model
,
sample
):
"""Compute the loss for the given sample.
"""Compute the loss for the given sample.
...
@@ -30,7 +29,7 @@ class CrossEntropyCriterion(FairseqCriterion):
...
@@ -30,7 +29,7 @@ class CrossEntropyCriterion(FairseqCriterion):
input
=
net_output
.
view
(
-
1
,
net_output
.
size
(
-
1
))
input
=
net_output
.
view
(
-
1
,
net_output
.
size
(
-
1
))
target
=
sample
[
'target'
].
view
(
-
1
)
target
=
sample
[
'target'
].
view
(
-
1
)
loss
=
F
.
cross_entropy
(
input
,
target
,
size_average
=
False
,
ignore_index
=
self
.
padding_idx
)
loss
=
F
.
cross_entropy
(
input
,
target
,
size_average
=
False
,
ignore_index
=
self
.
padding_idx
)
sample_size
=
sample
[
'ntokens'
]
sample_size
=
sample
[
'target'
].
size
(
0
)
if
self
.
args
.
sentence_avg
else
sample
[
'ntokens'
]
logging_output
=
{
logging_output
=
{
'loss'
:
loss
.
data
[
0
],
'loss'
:
loss
.
data
[
0
],
'sample_size'
:
sample_size
,
'sample_size'
:
sample_size
,
...
...
fairseq/criterions/fairseq_criterion.py
View file @
f442f896
...
@@ -11,8 +11,10 @@ from torch.nn.modules.loss import _Loss
...
@@ -11,8 +11,10 @@ from torch.nn.modules.loss import _Loss
class
FairseqCriterion
(
_Loss
):
class
FairseqCriterion
(
_Loss
):
def
__init__
(
self
):
def
__init__
(
self
,
args
,
dst_dict
):
super
().
__init__
()
super
().
__init__
()
self
.
args
=
args
self
.
padding_idx
=
dst_dict
.
pad
()
def
forward
(
self
,
model
,
sample
):
def
forward
(
self
,
model
,
sample
):
"""Compute the loss for the given sample.
"""Compute the loss for the given sample.
...
...
fairseq/criterions/label_smoothed_cross_entropy.py
View file @
f442f896
...
@@ -43,10 +43,9 @@ class LabelSmoothedCrossEntropy(torch.autograd.Function):
...
@@ -43,10 +43,9 @@ class LabelSmoothedCrossEntropy(torch.autograd.Function):
class
LabelSmoothedCrossEntropyCriterion
(
FairseqCriterion
):
class
LabelSmoothedCrossEntropyCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
eps
,
padding_idx
=
None
,
weights
=
None
):
def
__init__
(
self
,
args
,
dst_dict
,
weights
=
None
):
super
().
__init__
()
super
().
__init__
(
args
,
dst_dict
)
self
.
eps
=
eps
self
.
eps
=
args
.
label_smoothing
self
.
padding_idx
=
padding_idx
self
.
weights
=
weights
self
.
weights
=
weights
def
forward
(
self
,
model
,
sample
):
def
forward
(
self
,
model
,
sample
):
...
@@ -61,7 +60,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
...
@@ -61,7 +60,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
input
=
F
.
log_softmax
(
net_output
.
view
(
-
1
,
net_output
.
size
(
-
1
)))
input
=
F
.
log_softmax
(
net_output
.
view
(
-
1
,
net_output
.
size
(
-
1
)))
target
=
sample
[
'target'
].
view
(
-
1
)
target
=
sample
[
'target'
].
view
(
-
1
)
loss
=
LabelSmoothedCrossEntropy
.
apply
(
input
,
target
,
self
.
eps
,
self
.
padding_idx
,
self
.
weights
)
loss
=
LabelSmoothedCrossEntropy
.
apply
(
input
,
target
,
self
.
eps
,
self
.
padding_idx
,
self
.
weights
)
sample_size
=
sample
[
'ntokens'
]
sample_size
=
sample
[
'target'
].
size
(
0
)
if
self
.
args
.
sentence_avg
else
sample
[
'ntokens'
]
logging_output
=
{
logging_output
=
{
'loss'
:
loss
.
data
[
0
],
'loss'
:
loss
.
data
[
0
],
'sample_size'
:
sample_size
,
'sample_size'
:
sample_size
,
...
...
fairseq/data.py
View file @
f442f896
...
@@ -97,27 +97,26 @@ class LanguageDatasets(object):
...
@@ -97,27 +97,26 @@ class LanguageDatasets(object):
assert
self
.
src_dict
.
unk
()
==
self
.
dst_dict
.
unk
()
assert
self
.
src_dict
.
unk
()
==
self
.
dst_dict
.
unk
()
def
train_dataloader
(
self
,
split
,
num_workers
=
0
,
max_tokens
=
None
,
def
train_dataloader
(
self
,
split
,
num_workers
=
0
,
max_tokens
=
None
,
max_positions
=
(
1024
,
1024
),
seed
=
None
,
epoch
=
1
,
max_sentences
=
None
,
max_positions
=
(
1024
,
1024
),
sample_without_replacement
=
0
,
seed
=
None
,
epoch
=
1
,
sample_without_replacement
=
0
,
sort_by_source_size
=
False
):
sort_by_source_size
=
False
):
dataset
=
self
.
splits
[
split
]
dataset
=
self
.
splits
[
split
]
with
numpy_seed
(
seed
):
with
numpy_seed
(
seed
):
batch_sampler
=
shuffled_batches_by_size
(
batch_sampler
=
shuffled_batches_by_size
(
dataset
.
src
,
dataset
.
dst
,
dataset
.
src
,
dataset
.
dst
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
epoch
=
epoch
,
max_sentences
=
max_sentences
,
epoch
=
epoch
,
sample
=
sample_without_replacement
,
sample
=
sample_without_replacement
,
max_positions
=
max_positions
,
max_positions
=
max_positions
,
sort_by_source_size
=
sort_by_source_size
)
sort_by_source_size
=
sort_by_source_size
)
return
torch
.
utils
.
data
.
DataLoader
(
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
num_workers
=
num_workers
,
collate_fn
=
dataset
.
collater
,
dataset
,
num_workers
=
num_workers
,
collate_fn
=
dataset
.
collater
,
batch_sampler
=
batch_sampler
)
batch_sampler
=
batch_sampler
)
def
eval_dataloader
(
self
,
split
,
num_workers
=
0
,
batch_size
=
1
,
def
eval_dataloader
(
self
,
split
,
num_workers
=
0
,
max_tokens
=
None
,
max_
token
s
=
None
,
max_positions
=
(
1024
,
1024
),
max_
sentence
s
=
None
,
max_positions
=
(
1024
,
1024
),
skip_invalid_size_inputs_valid_test
=
False
):
skip_invalid_size_inputs_valid_test
=
False
):
dataset
=
self
.
splits
[
split
]
dataset
=
self
.
splits
[
split
]
batch_sampler
=
list
(
batches_by_size
(
batch_sampler
=
list
(
batches_by_size
(
dataset
.
src
,
dataset
.
dst
,
batch_size
,
max_token
s
,
dataset
.
src
,
dataset
.
dst
,
max_tokens
,
max_sentence
s
,
max_positions
=
max_positions
,
max_positions
=
max_positions
,
ignore_invalid_inputs
=
skip_invalid_size_inputs_valid_test
))
ignore_invalid_inputs
=
skip_invalid_size_inputs_valid_test
))
return
torch
.
utils
.
data
.
DataLoader
(
return
torch
.
utils
.
data
.
DataLoader
(
...
@@ -220,29 +219,23 @@ def _valid_size(src_size, dst_size, max_positions):
...
@@ -220,29 +219,23 @@ def _valid_size(src_size, dst_size, max_positions):
return
True
return
True
def
batches_by_size
(
src
,
dst
,
batch_size
=
None
,
max_tokens
=
None
,
def
_make_batches
(
src
,
dst
,
indices
,
max_tokens
,
max_sentences
,
max_positions
,
max_positions
=
(
1024
,
1024
),
ignore_invalid_inputs
=
False
):
ignore_invalid_inputs
=
False
,
allow_different_src_lens
=
False
):
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert
isinstance
(
src
,
IndexedDataset
)
and
isinstance
(
dst
,
IndexedDataset
)
if
max_tokens
is
None
:
max_tokens
=
float
(
'Inf'
)
indices
=
np
.
argsort
(
src
.
sizes
,
kind
=
'mergesort'
)
batch
=
[]
batch
=
[]
def
yield_batch
(
next_idx
,
num_tokens
):
def
yield_batch
(
next_idx
,
num_tokens
):
if
len
(
batch
)
==
0
:
if
len
(
batch
)
==
0
:
return
False
return
False
if
len
(
batch
)
==
batch_size
:
if
len
(
batch
)
==
max_sentences
:
return
True
return
True
if
src
.
sizes
[
batch
[
0
]]
!=
src
.
sizes
[
next_idx
]
:
if
num_tokens
>
max_tokens
:
return
True
return
True
if
num_tokens
>=
max_tokens
:
if
not
allow_different_src_lens
and
\
(
src
.
sizes
[
batch
[
0
]]
!=
src
.
sizes
[
next_idx
]):
return
True
return
True
return
False
return
False
cur_max_size
=
0
sample_len
=
0
ignored
=
[]
ignored
=
[]
for
idx
in
indices
:
for
idx
in
indices
:
if
not
_valid_size
(
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
],
max_positions
):
if
not
_valid_size
(
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
],
max_positions
):
...
@@ -253,28 +246,48 @@ def batches_by_size(src, dst, batch_size=None, max_tokens=None,
...
@@ -253,28 +246,48 @@ def batches_by_size(src, dst, batch_size=None, max_tokens=None,
"Unable to handle input id {} of size {} / {}."
.
format
(
"Unable to handle input id {} of size {} / {}."
.
format
(
idx
,
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
]))
idx
,
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
]))
if
yield_batch
(
idx
,
cur_max_size
*
(
len
(
batch
)
+
1
)):
sample_len
=
max
(
sample_len
,
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
])
num_tokens
=
(
len
(
batch
)
+
1
)
*
sample_len
if
yield_batch
(
idx
,
num_tokens
):
yield
batch
yield
batch
batch
=
[]
batch
=
[]
cur_max_size
=
0
sample_len
=
max
(
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
])
batch
.
append
(
idx
)
batch
.
append
(
idx
)
cur_max_size
=
max
(
cur_max_size
,
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
])
if
len
(
batch
)
>
0
:
yield
batch
if
len
(
ignored
)
>
0
:
if
len
(
ignored
)
>
0
:
print
(
"Warning! {} samples are either too short or too long "
print
(
"Warning! {} samples are either too short or too long "
"and will be ignored, first few sample ids={}"
.
format
(
len
(
ignored
),
ignored
[:
10
]))
"and will be ignored, first few sample ids={}"
.
format
(
len
(
ignored
),
ignored
[:
10
]))
if
len
(
batch
)
>
0
:
yield
batch
def
batches_by_size
(
src
,
dst
,
max_tokens
=
None
,
max_sentences
=
None
,
max_positions
=
(
1024
,
1024
),
ignore_invalid_inputs
=
False
):
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert
isinstance
(
src
,
IndexedDataset
)
and
isinstance
(
dst
,
IndexedDataset
)
if
max_tokens
is
None
:
max_tokens
=
float
(
'Inf'
)
if
max_sentences
is
None
:
max_sentences
=
float
(
'Inf'
)
indices
=
np
.
argsort
(
src
.
sizes
,
kind
=
'mergesort'
)
return
_make_batches
(
src
,
dst
,
indices
,
max_tokens
,
max_sentences
,
max_positions
,
ignore_invalid_inputs
,
allow_different_src_lens
=
False
)
def
shuffled_batches_by_size
(
src
,
dst
,
max_tokens
=
None
,
epoch
=
1
,
sample
=
0
,
max_positions
=
(
1024
,
1024
),
sort_by_source_size
=
False
):
def
shuffled_batches_by_size
(
src
,
dst
,
max_tokens
=
None
,
max_sentences
=
None
,
epoch
=
1
,
sample
=
0
,
max_positions
=
(
1024
,
1024
),
sort_by_source_size
=
False
):
"""Returns batches of indices, bucketed by size and then shuffled. Batches
"""Returns batches of indices, bucketed by size and then shuffled. Batches
may contain sequences of different lengths."""
may contain sequences of different lengths."""
assert
isinstance
(
src
,
IndexedDataset
)
and
isinstance
(
dst
,
IndexedDataset
)
assert
isinstance
(
src
,
IndexedDataset
)
and
isinstance
(
dst
,
IndexedDataset
)
if
max_tokens
is
None
:
if
max_tokens
is
None
:
max_tokens
=
float
(
'Inf'
)
max_tokens
=
float
(
'Inf'
)
if
max_sentences
is
None
:
max_sentences
=
float
(
'Inf'
)
indices
=
np
.
random
.
permutation
(
len
(
src
))
indices
=
np
.
random
.
permutation
(
len
(
src
))
...
@@ -282,30 +295,10 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0,
...
@@ -282,30 +295,10 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0,
indices
=
indices
[
np
.
argsort
(
dst
.
sizes
[
indices
],
kind
=
'mergesort'
)]
indices
=
indices
[
np
.
argsort
(
dst
.
sizes
[
indices
],
kind
=
'mergesort'
)]
indices
=
indices
[
np
.
argsort
(
src
.
sizes
[
indices
],
kind
=
'mergesort'
)]
indices
=
indices
[
np
.
argsort
(
src
.
sizes
[
indices
],
kind
=
'mergesort'
)]
def
make_batches
():
batches
=
list
(
_make_batches
(
batch
=
[]
src
,
dst
,
indices
,
max_tokens
,
max_sentences
,
max_positions
,
sample_len
=
0
ignore_invalid_inputs
=
True
,
allow_different_src_lens
=
True
))
ignored
=
[]
for
idx
in
indices
:
if
not
_valid_size
(
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
],
max_positions
):
ignored
.
append
(
idx
)
continue
sample_len
=
max
(
sample_len
,
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
])
if
len
(
batch
)
>
0
and
(
len
(
batch
)
+
1
)
*
sample_len
>
max_tokens
:
yield
batch
batch
=
[]
sample_len
=
max
(
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
])
batch
.
append
(
idx
)
if
len
(
batch
)
>
0
:
yield
batch
if
len
(
ignored
)
>
0
:
print
(
"Warning! {} samples are either too short or too long "
"and will be ignored, first few sample ids={}"
.
format
(
len
(
ignored
),
ignored
[:
10
]))
batches
=
list
(
make_batches
())
if
not
sort_by_source_size
:
if
not
sort_by_source_size
:
np
.
random
.
shuffle
(
batches
)
np
.
random
.
shuffle
(
batches
)
...
...
fairseq/options.py
View file @
f442f896
...
@@ -71,6 +71,9 @@ def add_optimization_args(parser):
...
@@ -71,6 +71,9 @@ def add_optimization_args(parser):
' dataset'
)
' dataset'
)
group
.
add_argument
(
'--curriculum'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
group
.
add_argument
(
'--curriculum'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'sort batches by source length for first N epochs'
)
help
=
'sort batches by source length for first N epochs'
)
group
.
add_argument
(
'--sentence-avg'
,
action
=
'store_true'
,
help
=
'normalize gradients by the number of sentences in a batch'
' (default is to normalize by number of tokens)'
)
return
group
return
group
...
...
fairseq/utils.py
View file @
f442f896
...
@@ -30,11 +30,10 @@ def build_model(args, src_dict, dst_dict):
...
@@ -30,11 +30,10 @@ def build_model(args, src_dict, dst_dict):
def
build_criterion
(
args
,
src_dict
,
dst_dict
):
def
build_criterion
(
args
,
src_dict
,
dst_dict
):
padding_idx
=
dst_dict
.
pad
()
if
args
.
label_smoothing
>
0
:
if
args
.
label_smoothing
>
0
:
return
criterions
.
LabelSmoothedCrossEntropyCriterion
(
args
.
label_smoothing
,
padding_idx
)
return
criterions
.
LabelSmoothedCrossEntropyCriterion
(
args
,
dst_dict
)
else
:
else
:
return
criterions
.
CrossEntropyCriterion
(
padding_idx
)
return
criterions
.
CrossEntropyCriterion
(
args
,
dst_dict
)
def
torch_persistent_save
(
*
args
,
**
kwargs
):
def
torch_persistent_save
(
*
args
,
**
kwargs
):
...
...
generate.py
View file @
f442f896
...
@@ -68,7 +68,7 @@ def main():
...
@@ -68,7 +68,7 @@ def main():
scorer
=
bleu
.
Scorer
(
dataset
.
dst_dict
.
pad
(),
dataset
.
dst_dict
.
eos
(),
dataset
.
dst_dict
.
unk
())
scorer
=
bleu
.
Scorer
(
dataset
.
dst_dict
.
pad
(),
dataset
.
dst_dict
.
eos
(),
dataset
.
dst_dict
.
unk
())
max_positions
=
min
(
model
.
max_encoder_positions
()
for
model
in
models
)
max_positions
=
min
(
model
.
max_encoder_positions
()
for
model
in
models
)
itr
=
dataset
.
eval_dataloader
(
itr
=
dataset
.
eval_dataloader
(
args
.
gen_subset
,
batch_size
=
args
.
batch_size
,
max_positions
=
max_positions
,
args
.
gen_subset
,
max_sentences
=
args
.
batch_size
,
max_positions
=
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
num_sentences
=
0
num_sentences
=
0
with
progress_bar
(
itr
,
smoothing
=
0
,
leave
=
False
)
as
t
:
with
progress_bar
(
itr
,
smoothing
=
0
,
leave
=
False
)
as
t
:
...
...
train.py
View file @
f442f896
...
@@ -23,6 +23,8 @@ def main():
...
@@ -23,6 +23,8 @@ def main():
dataset_args
=
options
.
add_dataset_args
(
parser
)
dataset_args
=
options
.
add_dataset_args
(
parser
)
dataset_args
.
add_argument
(
'--max-tokens'
,
default
=
6000
,
type
=
int
,
metavar
=
'N'
,
dataset_args
.
add_argument
(
'--max-tokens'
,
default
=
6000
,
type
=
int
,
metavar
=
'N'
,
help
=
'maximum number of tokens in a batch'
)
help
=
'maximum number of tokens in a batch'
)
dataset_args
.
add_argument
(
'--max-sentences'
,
type
=
int
,
metavar
=
'N'
,
help
=
'maximum number of sentences in a batch'
)
dataset_args
.
add_argument
(
'--train-subset'
,
default
=
'train'
,
metavar
=
'SPLIT'
,
dataset_args
.
add_argument
(
'--train-subset'
,
default
=
'train'
,
metavar
=
'SPLIT'
,
choices
=
[
'train'
,
'valid'
,
'test'
],
choices
=
[
'train'
,
'valid'
,
'test'
],
help
=
'data subset to use for training (train, valid, test)'
)
help
=
'data subset to use for training (train, valid, test)'
)
...
@@ -59,7 +61,8 @@ def main():
...
@@ -59,7 +61,8 @@ def main():
raise
NotImplementedError
(
'Training on CPU is not supported'
)
raise
NotImplementedError
(
'Training on CPU is not supported'
)
num_gpus
=
torch
.
cuda
.
device_count
()
num_gpus
=
torch
.
cuda
.
device_count
()
print
(
'| using {} GPUs (with max tokens per GPU = {})'
.
format
(
num_gpus
,
args
.
max_tokens
))
print
(
'| using {} GPUs (with max tokens per GPU = {} and max sentences per GPU = {})'
.
format
(
num_gpus
,
args
.
max_tokens
,
args
.
max_sentences
))
# Build model and criterion
# Build model and criterion
model
=
utils
.
build_model
(
args
,
dataset
.
src_dict
,
dataset
.
dst_dict
)
model
=
utils
.
build_model
(
args
,
dataset
.
src_dict
,
dataset
.
dst_dict
)
...
@@ -130,7 +133,8 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
...
@@ -130,7 +133,8 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
trainer
.
set_seed
(
seed
)
trainer
.
set_seed
(
seed
)
itr
=
dataset
.
train_dataloader
(
itr
=
dataset
.
train_dataloader
(
args
.
train_subset
,
num_workers
=
args
.
workers
,
max_tokens
=
args
.
max_tokens
,
args
.
train_subset
,
num_workers
=
args
.
workers
,
max_tokens
=
args
.
max_tokens
,
max_sentences
=
args
.
max_sentences
,
max_positions
=
max_positions
,
seed
=
seed
,
epoch
=
epoch
,
max_positions
=
max_positions
,
seed
=
seed
,
epoch
=
epoch
,
sample_without_replacement
=
args
.
sample_without_replacement
,
sample_without_replacement
=
args
.
sample_without_replacement
,
sort_by_source_size
=
(
epoch
<=
args
.
curriculum
))
sort_by_source_size
=
(
epoch
<=
args
.
curriculum
))
...
@@ -150,9 +154,9 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
...
@@ -150,9 +154,9 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
del
loss_dict
[
'loss'
]
# don't include in extra_meters or extra_postfix
del
loss_dict
[
'loss'
]
# don't include in extra_meters or extra_postfix
ntokens
=
sum
(
s
[
'ntokens'
]
for
s
in
sample
)
ntokens
=
sum
(
s
[
'ntokens'
]
for
s
in
sample
)
src_size
=
sum
(
s
[
'src_tokens'
].
size
(
0
)
for
s
in
sample
)
nsentences
=
sum
(
s
[
'src_tokens'
].
size
(
0
)
for
s
in
sample
)
loss_meter
.
update
(
loss
,
ntokens
)
loss_meter
.
update
(
loss
,
nsentences
if
args
.
sentence_avg
else
ntokens
)
bsz_meter
.
update
(
src_size
)
bsz_meter
.
update
(
nsentences
)
wpb_meter
.
update
(
ntokens
)
wpb_meter
.
update
(
ntokens
)
wps_meter
.
update
(
ntokens
)
wps_meter
.
update
(
ntokens
)
clip_meter
.
update
(
1
if
loss_dict
[
'gnorm'
]
>
args
.
clip_norm
else
0
)
clip_meter
.
update
(
1
if
loss_dict
[
'gnorm'
]
>
args
.
clip_norm
else
0
)
...
@@ -216,7 +220,8 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
...
@@ -216,7 +220,8 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
"""Evaluate the model on the validation set and return the average loss."""
"""Evaluate the model on the validation set and return the average loss."""
itr
=
dataset
.
eval_dataloader
(
itr
=
dataset
.
eval_dataloader
(
subset
,
batch_size
=
None
,
max_tokens
=
args
.
max_tokens
,
max_positions
=
max_positions
,
subset
,
max_tokens
=
args
.
max_tokens
,
max_sentences
=
args
.
max_sentences
,
max_positions
=
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
loss_meter
=
AverageMeter
()
loss_meter
=
AverageMeter
()
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment