Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
d9f46c54
Commit
d9f46c54
authored
Jan 26, 2018
by
Sergey Edunov
Browse files
Merge branch 'master' of github.com:facebookresearch/fairseq-py into prepare_wmt
parents
4185d3ed
ee36a6f3
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
125 additions
and
43 deletions
+125
-43
fairseq/options.py
fairseq/options.py
+4
-0
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+24
-16
fairseq/utils.py
fairseq/utils.py
+53
-18
generate.py
generate.py
+10
-0
requirements.txt
requirements.txt
+1
-1
setup.py
setup.py
+1
-1
train.py
train.py
+32
-7
No files found.
fairseq/options.py
View file @
d9f46c54
...
...
@@ -173,4 +173,8 @@ def add_model_args(parser):
help
=
'dropout probability'
)
group
.
add_argument
(
'--label-smoothing'
,
default
=
0
,
type
=
float
,
metavar
=
'D'
,
help
=
'epsilon for label smoothing, 0 means no label smoothing'
)
group
.
add_argument
(
'--share-input-output-embed'
,
action
=
'store_true'
,
help
=
"Share input and output embeddings, "
"requires --decoder-out-embed-dim and --decoder-embed-dim be equal "
)
return
group
fairseq/sequence_generator.py
View file @
d9f46c54
...
...
@@ -19,7 +19,7 @@ from fairseq.models import FairseqIncrementalDecoder
class
SequenceGenerator
(
object
):
def
__init__
(
self
,
models
,
beam_size
=
1
,
minlen
=
1
,
maxlen
=
200
,
stop_early
=
True
,
normalize_scores
=
True
,
len_penalty
=
1
,
unk_penalty
=
0
):
unk_penalty
=
0
,
retain_dropout
=
False
):
"""Generates translations of a given source sentence.
Args:
...
...
@@ -45,6 +45,7 @@ class SequenceGenerator(object):
self
.
normalize_scores
=
normalize_scores
self
.
len_penalty
=
len_penalty
self
.
unk_penalty
=
unk_penalty
self
.
retain_dropout
=
retain_dropout
def
cuda
(
self
):
for
model
in
self
.
models
:
...
...
@@ -65,19 +66,20 @@ class SequenceGenerator(object):
maxlen_b
=
self
.
maxlen
for
sample
in
data_itr
:
s
=
utils
.
prepare_samp
le
(
sample
,
volatile
=
True
,
cuda_device
=
cuda_device
)
s
=
utils
.
make_variab
le
(
sample
,
volatile
=
True
,
cuda_device
=
cuda_device
)
input
=
s
[
'net_input'
]
srclen
=
input
[
'src_tokens'
].
size
(
1
)
if
timer
is
not
None
:
timer
.
start
()
hypos
=
self
.
generate
(
input
[
'src_tokens'
],
beam_size
=
beam_size
,
maxlen
=
int
(
maxlen_a
*
srclen
+
maxlen_b
))
with
utils
.
maybe_no_grad
():
hypos
=
self
.
generate
(
input
[
'src_tokens'
],
beam_size
=
beam_size
,
maxlen
=
int
(
maxlen_a
*
srclen
+
maxlen_b
))
if
timer
is
not
None
:
timer
.
stop
(
s
[
'ntokens'
])
for
i
,
id
in
enumerate
(
s
[
'id'
]):
for
i
,
id
in
enumerate
(
s
[
'id'
]
.
data
):
src
=
input
[
'src_tokens'
].
data
[
i
,
:]
# remove padding from ref
ref
=
utils
.
r
strip_pad
(
s
[
'target'
].
data
[
i
,
:],
self
.
pad
)
ref
=
utils
.
strip_pad
(
s
[
'target'
].
data
[
i
,
:],
self
.
pad
)
yield
id
,
src
,
ref
,
hypos
[
i
]
def
generate
(
self
,
src_tokens
,
beam_size
=
None
,
maxlen
=
None
):
...
...
@@ -98,7 +100,8 @@ class SequenceGenerator(object):
encoder_outs
=
[]
for
model
in
self
.
models
:
model
.
eval
()
if
not
self
.
retain_dropout
:
model
.
eval
()
if
isinstance
(
model
.
decoder
,
FairseqIncrementalDecoder
):
model
.
decoder
.
set_beam_size
(
beam_size
)
...
...
@@ -269,7 +272,7 @@ class SequenceGenerator(object):
# and values < cand_size indicate candidate active hypos.
# After, the min values per row are the top candidate active hypos
active_mask
=
buffer
(
'active_mask'
)
torch
.
add
(
(
eos_mask
*
cand_size
)
.
type_as
(
cand_offsets
),
cand_offsets
[:
eos_mask
.
size
(
1
)],
torch
.
add
(
eos_mask
.
type_as
(
cand_offsets
)
*
cand_size
,
cand_offsets
[:
eos_mask
.
size
(
1
)],
out
=
active_mask
)
# get the top beam_size active hypotheses, which are just the hypos
...
...
@@ -320,22 +323,27 @@ class SequenceGenerator(object):
def
_decode
(
self
,
tokens
,
encoder_outs
):
# wrap in Variable
tokens
=
Variable
(
tokens
,
volatile
=
True
)
tokens
=
utils
.
volatile_variable
(
tokens
)
avg_probs
=
None
avg_attn
=
None
for
model
,
encoder_out
in
zip
(
self
.
models
,
encoder_outs
):
decoder_out
,
attn
=
model
.
decoder
(
tokens
,
encoder_out
)
probs
=
F
.
softmax
(
decoder_out
[:,
-
1
,
:]).
data
attn
=
attn
[:,
-
1
,
:]
.
data
if
avg_probs
is
None
or
avg_attn
is
None
:
with
utils
.
maybe_no_grad
():
decoder_out
,
attn
=
model
.
decoder
(
tokens
,
encoder_out
)
probs
=
model
.
get_normalized_probs
(
decoder_out
[:,
-
1
,
:],
log_probs
=
False
)
.
data
if
avg_probs
is
None
:
avg_probs
=
probs
avg_attn
=
attn
else
:
avg_probs
.
add_
(
probs
)
avg_attn
.
add_
(
attn
)
if
attn
is
not
None
:
attn
=
attn
[:,
-
1
,
:].
data
if
avg_attn
is
None
:
avg_attn
=
attn
else
:
avg_attn
.
add_
(
attn
)
avg_probs
.
div_
(
len
(
self
.
models
))
avg_probs
.
log_
()
avg_attn
.
div_
(
len
(
self
.
models
))
if
avg_attn
is
not
None
:
avg_attn
.
div_
(
len
(
self
.
models
))
return
avg_probs
,
avg_attn
fairseq/utils.py
View file @
d9f46c54
...
...
@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory.
#
import
contextlib
import
logging
import
os
import
torch
...
...
@@ -15,10 +16,11 @@ import sys
from
torch.autograd
import
Variable
from
torch.serialization
import
default_restore_location
from
fairseq
import
criterions
,
data
,
models
,
progress_bar
,
tokenizer
from
fairseq
import
criterions
,
progress_bar
,
tokenizer
def
parse_args_and_arch
(
parser
):
from
fairseq
import
models
args
=
parser
.
parse_args
()
args
.
model
=
models
.
arch_model_map
[
args
.
arch
]
args
=
getattr
(
models
,
args
.
model
).
parse_arch
(
args
)
...
...
@@ -26,6 +28,7 @@ def parse_args_and_arch(parser):
def
build_model
(
args
,
src_dict
,
dst_dict
):
from
fairseq
import
models
assert
hasattr
(
models
,
args
.
model
),
'Missing model type'
return
getattr
(
models
,
args
.
model
).
build_model
(
args
,
src_dict
,
dst_dict
)
...
...
@@ -143,6 +146,8 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_di
The source and target dictionaries can be given explicitly, or loaded from
the `data_dir` directory.
"""
from
fairseq
import
data
# load model architectures and weights
states
=
[]
for
filename
in
filenames
:
...
...
@@ -172,26 +177,48 @@ def _upgrade_args(args):
if
not
hasattr
(
args
,
'max_source_positions'
):
args
.
max_source_positions
=
args
.
max_positions
args
.
max_target_positions
=
args
.
max_positions
if
not
hasattr
(
args
,
'share_input_output_embed'
):
args
.
share_input_output_embed
=
False
return
args
def
prepare_sample
(
sample
,
volatile
=
False
,
cuda_device
=
None
):
def
maybe_no_grad
(
condition
=
True
):
if
hasattr
(
torch
,
'no_grad'
)
and
condition
:
return
torch
.
no_grad
()
# no-op context manager
return
contextlib
.
ExitStack
()
def
volatile_variable
(
*
args
,
**
kwargs
):
if
hasattr
(
torch
,
'no_grad'
):
# volatile has been deprecated, use the no_grad context manager instead
return
Variable
(
*
args
,
**
kwargs
)
else
:
return
Variable
(
*
args
,
**
kwargs
,
volatile
=
True
)
def
make_variable
(
sample
,
volatile
=
False
,
cuda_device
=
None
):
"""Wrap input tensors in Variable class."""
def
make_variable
(
tensor
):
if
cuda_device
is
not
None
and
torch
.
cuda
.
is_available
():
tensor
=
tensor
.
cuda
(
async
=
True
,
device
=
cuda_device
)
return
Variable
(
tensor
,
volatile
=
volatile
)
return
{
'id'
:
sample
[
'id'
],
'ntokens'
:
sample
[
'ntokens'
],
'target'
:
make_variable
(
sample
[
'target'
]),
'net_input'
:
{
key
:
make_variable
(
sample
[
key
])
for
key
in
[
'src_tokens'
,
'input_tokens'
]
},
}
def
_make_variable
(
maybe_tensor
):
if
torch
.
is_tensor
(
maybe_tensor
):
if
cuda_device
is
not
None
and
torch
.
cuda
.
is_available
():
maybe_tensor
=
maybe_tensor
.
cuda
(
async
=
True
,
device
=
cuda_device
)
if
volatile
:
return
volatile_variable
(
maybe_tensor
)
else
:
return
Variable
(
maybe_tensor
)
elif
isinstance
(
maybe_tensor
,
dict
):
return
{
key
:
_make_variable
(
value
)
for
key
,
value
in
maybe_tensor
.
items
()
}
elif
isinstance
(
maybe_tensor
,
list
):
return
[
_make_variable
(
x
)
for
x
in
maybe_tensor
]
else
:
return
maybe_tensor
return
_make_variable
(
sample
)
def
load_align_dict
(
replace_unk
):
...
...
@@ -236,11 +263,19 @@ def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, dst_dic
def
lstrip_pad
(
tensor
,
pad
):
return
tensor
[
tensor
.
eq
(
pad
).
sum
():]
return
tensor
[
tensor
.
eq
(
pad
).
long
().
sum
():]
def
rstrip_pad
(
tensor
,
pad
):
strip
=
tensor
.
eq
(
pad
).
sum
()
strip
=
tensor
.
eq
(
pad
).
long
().
sum
()
if
strip
>
0
:
return
tensor
[:
-
strip
]
return
tensor
def
strip_pad
(
tensor
,
pad
):
if
tensor
[
0
]
==
pad
:
tensor
=
lstrip_pad
(
tensor
,
pad
)
if
tensor
[
-
1
]
==
pad
:
tensor
=
rstrip_pad
(
tensor
,
pad
)
return
tensor
generate.py
View file @
d9f46c54
...
...
@@ -23,6 +23,10 @@ def main():
help
=
'batch size'
)
dataset_args
.
add_argument
(
'--gen-subset'
,
default
=
'test'
,
metavar
=
'SPLIT'
,
help
=
'data subset to generate (train, valid, test)'
)
dataset_args
.
add_argument
(
'--num-shards'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'shard generation over N shards'
)
dataset_args
.
add_argument
(
'--shard-id'
,
default
=
0
,
type
=
int
,
metavar
=
'ID'
,
help
=
'id of the shard to generate (id < num_shards)'
)
options
.
add_generation_args
(
parser
)
args
=
parser
.
parse_args
()
...
...
@@ -31,6 +35,8 @@ def main():
print
(
args
)
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
args
.
cpu
if
hasattr
(
torch
,
'set_grad_enabled'
):
torch
.
set_grad_enabled
(
False
)
# Load dataset
if
args
.
replace_unk
is
None
:
...
...
@@ -72,6 +78,10 @@ def main():
itr
=
dataset
.
eval_dataloader
(
args
.
gen_subset
,
max_sentences
=
args
.
batch_size
,
max_positions
=
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
if
args
.
num_shards
>
1
:
if
args
.
shard_id
<
0
or
args
.
shard_id
>=
args
.
num_shards
:
raise
ValueError
(
'--shard-id must be between 0 and num_shards'
)
itr
=
data
.
sharded_iterator
(
itr
,
args
.
num_shards
,
args
.
shard_id
)
num_sentences
=
0
with
utils
.
build_progress_bar
(
args
,
itr
)
as
t
:
wps_meter
=
TimeMeter
()
...
...
requirements.txt
View file @
d9f46c54
cffi
numpy
torch
torch
>=0.3.0
tqdm
setup.py
View file @
d9f46c54
...
...
@@ -54,7 +54,7 @@ class build_py_hook(build_py):
setup
(
name
=
'fairseq'
,
version
=
'0.
2
.0'
,
version
=
'0.
3
.0'
,
description
=
'Facebook AI Research Sequence-to-Sequence Toolkit'
,
long_description
=
readme
,
license
=
license
,
...
...
train.py
View file @
d9f46c54
...
...
@@ -30,6 +30,8 @@ def main():
dataset_args
.
add_argument
(
'--valid-subset'
,
default
=
'valid'
,
metavar
=
'SPLIT'
,
help
=
'comma separated list of data subsets '
' to use for validation (train, valid, valid1,test, test1)'
)
dataset_args
.
add_argument
(
'--max-sentences-valid'
,
type
=
int
,
metavar
=
'N'
,
help
=
'maximum number of sentences in a validation batch'
)
options
.
add_optimization_args
(
parser
)
options
.
add_checkpoint_args
(
parser
)
options
.
add_model_args
(
parser
)
...
...
@@ -39,6 +41,9 @@ def main():
if
args
.
no_progress_bar
and
args
.
log_format
is
None
:
args
.
log_format
=
'simple'
if
args
.
max_sentences_valid
is
None
:
args
.
max_sentences_valid
=
args
.
max_sentences
if
not
os
.
path
.
exists
(
args
.
save_dir
):
os
.
makedirs
(
args
.
save_dir
)
torch
.
manual_seed
(
args
.
seed
)
...
...
@@ -70,14 +75,15 @@ def main():
model
=
utils
.
build_model
(
args
,
dataset
.
src_dict
,
dataset
.
dst_dict
)
criterion
=
utils
.
build_criterion
(
args
,
dataset
.
src_dict
,
dataset
.
dst_dict
)
print
(
'| model {}, criterion {}'
.
format
(
args
.
arch
,
criterion
.
__class__
.
__name__
))
print
(
'| num. model params: {}'
.
format
(
sum
(
p
.
data
.
numel
()
for
p
in
model
.
parameters
())))
# The max number of positions can be different for train and valid
# e.g., RNNs may support more positions at test time than seen in training
max_positions_train
=
(
args
.
max_source_positions
,
args
.
max_target_positions
)
max_positions_valid
=
(
max_positions_train
=
(
min
(
args
.
max_source_positions
,
model
.
max_encoder_positions
()),
min
(
args
.
max_target_positions
,
model
.
max_decoder_positions
())
)
max_positions_valid
=
(
model
.
max_encoder_positions
(),
model
.
max_decoder_positions
())
# Start multiprocessing
trainer
=
MultiprocessingTrainer
(
args
,
model
,
criterion
)
...
...
@@ -144,6 +150,7 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
sample_without_replacement
=
args
.
sample_without_replacement
,
sort_by_source_size
=
(
epoch
<=
args
.
curriculum
))
loss_meter
=
AverageMeter
()
nll_loss_meter
=
AverageMeter
()
bsz_meter
=
AverageMeter
()
# sentences per batch
wpb_meter
=
AverageMeter
()
# words per batch
wps_meter
=
TimeMeter
()
# words per second
...
...
@@ -158,7 +165,12 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
del
loss_dict
[
'loss'
]
# don't include in extra_meters or extra_postfix
ntokens
=
sum
(
s
[
'ntokens'
]
for
s
in
sample
)
nsentences
=
sum
(
s
[
'src_tokens'
].
size
(
0
)
for
s
in
sample
)
if
'nll_loss'
in
loss_dict
:
nll_loss
=
loss_dict
[
'nll_loss'
]
nll_loss_meter
.
update
(
nll_loss
,
ntokens
)
nsentences
=
sum
(
s
[
'net_input'
][
'src_tokens'
].
size
(
0
)
for
s
in
sample
)
loss_meter
.
update
(
loss
,
nsentences
if
args
.
sentence_avg
else
ntokens
)
bsz_meter
.
update
(
nsentences
)
wpb_meter
.
update
(
ntokens
)
...
...
@@ -187,7 +199,9 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
t
.
print
(
collections
.
OrderedDict
([
(
'train loss'
,
round
(
loss_meter
.
avg
,
2
)),
(
'train ppl'
,
get_perplexity
(
loss_meter
.
avg
)),
(
'train ppl'
,
get_perplexity
(
nll_loss_meter
.
avg
if
nll_loss_meter
.
count
>
0
else
loss_meter
.
avg
)),
(
's/checkpoint'
,
round
(
wps_meter
.
elapsed_time
)),
(
'words/s'
,
round
(
wps_meter
.
avg
)),
(
'words/batch'
,
round
(
wpb_meter
.
avg
)),
...
...
@@ -217,6 +231,10 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
save_checkpoint
.
best
=
val_loss
best_filename
=
os
.
path
.
join
(
args
.
save_dir
,
'checkpoint_best.pt'
)
trainer
.
save_checkpoint
(
best_filename
,
extra_state
)
elif
not
args
.
no_epoch_checkpoints
:
epoch_filename
=
os
.
path
.
join
(
args
.
save_dir
,
'checkpoint{}_{}.pt'
.
format
(
epoch
,
batch_offset
))
trainer
.
save_checkpoint
(
epoch_filename
,
extra_state
)
last_filename
=
os
.
path
.
join
(
args
.
save_dir
,
'checkpoint_last.pt'
)
trainer
.
save_checkpoint
(
last_filename
,
extra_state
)
...
...
@@ -226,22 +244,27 @@ def validate(args, epoch, trainer, dataset, max_positions, subset):
"""Evaluate the model on the validation set and return the average loss."""
itr
=
dataset
.
eval_dataloader
(
subset
,
max_tokens
=
args
.
max_tokens
,
max_sentences
=
args
.
max_sentences
,
subset
,
max_tokens
=
args
.
max_tokens
,
max_sentences
=
args
.
max_sentences
_valid
,
max_positions
=
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
,
descending
=
True
,
# largest batch first to warm the caching allocator
)
loss_meter
=
AverageMeter
()
nll_loss_meter
=
AverageMeter
()
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
prefix
=
'valid on
\'
{}
\'
subset'
.
format
(
subset
)
with
utils
.
build_progress_bar
(
args
,
itr
,
epoch
,
prefix
)
as
t
:
for
_
,
sample
in
data
.
skip_group_enumerator
(
t
,
args
.
num_gpus
):
loss_dict
=
trainer
.
valid_step
(
sample
)
ntokens
=
sum
(
s
[
'ntokens'
]
for
s
in
sample
)
loss
=
loss_dict
[
'loss'
]
del
loss_dict
[
'loss'
]
# don't include in extra_meters or extra_postfix
ntokens
=
sum
(
s
[
'ntokens'
]
for
s
in
sample
)
if
'nll_loss'
in
loss_dict
:
nll_loss
=
loss_dict
[
'nll_loss'
]
nll_loss_meter
.
update
(
nll_loss
,
ntokens
)
loss_meter
.
update
(
loss
,
ntokens
)
extra_postfix
=
[]
...
...
@@ -255,7 +278,9 @@ def validate(args, epoch, trainer, dataset, max_positions, subset):
t
.
print
(
collections
.
OrderedDict
([
(
'valid loss'
,
round
(
loss_meter
.
avg
,
2
)),
(
'valid ppl'
,
get_perplexity
(
loss_meter
.
avg
)),
(
'valid ppl'
,
get_perplexity
(
nll_loss_meter
.
avg
if
nll_loss_meter
.
count
>
0
else
loss_meter
.
avg
)),
]
+
[
(
k
,
meter
.
avg
)
for
k
,
meter
in
extra_meters
.
items
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment