Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
b59815bc
Commit
b59815bc
authored
May 09, 2018
by
Angela Fan
Committed by
Myle Ott
Jun 15, 2018
Browse files
added multiscale gated self attention layer with multiple heads, and pretrained fusion models
parent
50931d69
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
896 additions
and
23 deletions
+896
-23
fairseq/models/__init__.py
fairseq/models/__init__.py
+1
-0
fairseq/models/composite_encoder.py
fairseq/models/composite_encoder.py
+35
-0
fairseq/models/fconv_self_att.py
fairseq/models/fconv_self_att.py
+502
-0
fairseq/modules/__init__.py
fairseq/modules/__init__.py
+4
-0
fairseq/modules/downsampled_multihead_attention.py
fairseq/modules/downsampled_multihead_attention.py
+272
-0
fairseq/modules/scalar_bias.py
fairseq/modules/scalar_bias.py
+33
-0
fairseq/options.py
fairseq/options.py
+8
-2
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+27
-13
fairseq/utils.py
fairseq/utils.py
+1
-1
generate.py
generate.py
+6
-3
interactive.py
interactive.py
+6
-3
train.py
train.py
+1
-1
No files found.
fairseq/models/__init__.py
View file @
b59815bc
...
@@ -12,6 +12,7 @@ from .fairseq_decoder import FairseqDecoder # noqa: F401
...
@@ -12,6 +12,7 @@ from .fairseq_decoder import FairseqDecoder # noqa: F401
from
.fairseq_encoder
import
FairseqEncoder
# noqa: F401
from
.fairseq_encoder
import
FairseqEncoder
# noqa: F401
from
.fairseq_incremental_decoder
import
FairseqIncrementalDecoder
# noqa: F401
from
.fairseq_incremental_decoder
import
FairseqIncrementalDecoder
# noqa: F401
from
.fairseq_model
import
BaseFairseqModel
,
FairseqModel
,
FairseqLanguageModel
# noqa: F401
from
.fairseq_model
import
BaseFairseqModel
,
FairseqModel
,
FairseqLanguageModel
# noqa: F401
from
.composite_encoder
import
CompositeEncoder
# noqa: F401
MODEL_REGISTRY
=
{}
MODEL_REGISTRY
=
{}
ARCH_MODEL_REGISTRY
=
{}
ARCH_MODEL_REGISTRY
=
{}
...
...
fairseq/models/composite_encoder.py
0 → 100644
View file @
b59815bc
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from
.
import
FairseqEncoder
class
CompositeEncoder
(
FairseqEncoder
):
"""
Encoder class that forwards on multiple encoders, for example for a fusion model or question-answering
Accepts a dictionary of encoder, the first encoder's dictionary is used for initialization
"""
def
__init__
(
self
,
encoders
):
super
().
__init__
(
next
(
iter
(
encoders
.
values
())).
dictionary
)
self
.
encoders
=
encoders
for
key
in
self
.
encoders
:
self
.
add_module
(
key
,
self
.
encoders
[
key
])
def
forward
(
self
,
src_tokens
,
src_lengths
):
encoder_out
=
{}
for
key
in
self
.
encoders
:
encoder_out
[
key
]
=
self
.
encoders
[
key
](
src_tokens
,
src_lengths
)
return
encoder_out
def
max_positions
(
self
):
return
min
([
self
.
encoders
[
key
].
max_positions
()
for
key
in
self
.
encoders
])
def
upgrade_state_dict
(
self
,
state_dict
):
for
key
in
self
.
encoders
:
self
.
encoders
[
key
].
upgrade_state_dict
(
state_dict
)
return
state_dict
fairseq/models/fconv_self_att.py
0 → 100644
View file @
b59815bc
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
fairseq.data
import
LanguagePairDataset
from
fairseq.data.consts
import
LEFT_PAD_SOURCE
,
LEFT_PAD_TARGET
from
fairseq.modules
import
GradMultiply
,
LearnedPositionalEmbedding
,
LinearizedConvolution
,
DownsampledMultiHeadAttention
from
fairseq
import
utils
from
.
import
FairseqEncoder
,
CompositeEncoder
,
FairseqDecoder
,
FairseqModel
,
register_model
,
register_model_architecture
@
register_model
(
'fconv_self_att'
)
class
FConvModelSelfAtt
(
FairseqModel
):
def
__init__
(
self
,
encoder
,
decoder
,
pretrained_encoder
=
None
):
super
().
__init__
(
encoder
,
decoder
)
self
.
encoder
.
num_attention_layers
=
sum
(
layer
is
not
None
for
layer
in
decoder
.
attention
)
self
.
pretrained_encoder
=
pretrained_encoder
if
self
.
pretrained_encoder
is
None
:
encoders
=
{
'encoder'
:
encoder
}
else
:
encoders
=
{
'encoder'
:
encoder
,
'pretrained'
:
self
.
pretrained_encoder
}
# for fusion model, CompositeEncoder contains both pretrained and training encoders
# these are forwarded and then combined in the decoder
self
.
encoder
=
CompositeEncoder
(
encoders
)
@
staticmethod
def
add_args
(
parser
):
"""Add model-specific arguments to the parser."""
parser
.
add_argument
(
'--dropout'
,
default
=
0.1
,
type
=
float
,
metavar
=
'D'
,
help
=
'dropout probability'
)
parser
.
add_argument
(
'--encoder-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'encoder embedding dimension'
)
parser
.
add_argument
(
'--encoder-layers'
,
type
=
str
,
metavar
=
'EXPR'
,
help
=
'encoder layers [(dim, kernel_size), ...]'
)
parser
.
add_argument
(
'--decoder-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'decoder embedding dimension'
)
parser
.
add_argument
(
'--decoder-layers'
,
type
=
str
,
metavar
=
'EXPR'
,
help
=
'decoder layers [(dim, kernel_size), ...]'
)
parser
.
add_argument
(
'--decoder-out-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'decoder output embedding dimension'
)
parser
.
add_argument
(
'--decoder-attention'
,
type
=
str
,
metavar
=
'EXPR'
,
help
=
'decoder attention [True, ...]'
)
parser
.
add_argument
(
'--self-attention'
,
default
=
'False'
,
type
=
str
,
metavar
=
'EXPR'
,
help
=
'decoder self-attention layers, ex: [True] + [False]*5'
)
parser
.
add_argument
(
'--multihead-attention-nheads'
,
default
=
1
,
type
=
int
,
help
=
'Number of heads to use in attention'
)
parser
.
add_argument
(
'--multihead-self-attention-nheads'
,
default
=
1
,
type
=
int
,
help
=
'Number of heads to use in self-attention'
)
parser
.
add_argument
(
'--encoder-attention'
,
type
=
str
,
metavar
=
'EXPR'
,
default
=
'False'
,
help
=
'encoder attention [True, ...]'
)
parser
.
add_argument
(
'--encoder-attention-nheads'
,
default
=
1
,
type
=
int
,
help
=
'Number of heads to use in encoder attention'
)
parser
.
add_argument
(
'--project-input'
,
type
=
str
,
metavar
=
'EXPR'
,
default
=
'False'
,
help
=
'Use projections in self-attention [True, ...]'
)
parser
.
add_argument
(
'--gated-attention'
,
type
=
str
,
metavar
=
'EXPR'
,
default
=
'False'
,
help
=
'Use GLU layers in self-attention projections [True, ...]'
)
parser
.
add_argument
(
'--downsample'
,
type
=
str
,
metavar
=
'EXPR'
,
default
=
'False'
,
help
=
'Use downsampling in self-attention [True, ...]'
)
parser
.
add_argument
(
'--pretrained-checkpoint'
,
metavar
=
'DIR'
,
default
=
''
,
help
=
'path to load checkpoint from pretrained model'
)
parser
.
add_argument
(
'--pretrained'
,
type
=
str
,
metavar
=
'EXPR'
,
default
=
'False'
,
help
=
'use pretrained model when training [True, ...]'
)
@
classmethod
def
build_model
(
cls
,
args
,
src_dict
,
dst_dict
):
trained_encoder
,
trained_decoder
=
None
,
None
pretrained
=
eval
(
args
.
pretrained
)
if
pretrained
:
print
(
"| Loading pretrained model"
)
state
=
torch
.
load
(
args
.
pretrained_checkpoint
)
trained_model
=
utils
.
load_ensemble_for_inference
(
# not actually for inference, but loads pretrained model parameters
filenames
=
[
args
.
pretrained_checkpoint
],
src_dict
=
src_dict
,
dst_dict
=
dst_dict
,
)[
0
][
0
]
trained_decoder
=
list
(
trained_model
.
children
())[
1
]
trained_encoder
=
list
(
trained_model
.
children
())[
0
]
# freeze pretrained model
for
param
in
trained_decoder
.
parameters
():
param
.
requires_grad
=
False
for
param
in
trained_encoder
.
parameters
():
param
.
requires_grad
=
False
"""Build a new model instance."""
encoder
=
FConvEncoder
(
src_dict
,
embed_dim
=
args
.
encoder_embed_dim
,
convolutions
=
eval
(
args
.
encoder_layers
),
dropout
=
args
.
dropout
,
max_positions
=
args
.
max_source_positions
,
attention
=
eval
(
args
.
encoder_attention
),
attention_nheads
=
args
.
encoder_attention_nheads
)
decoder
=
FConvDecoder
(
dst_dict
,
embed_dim
=
args
.
decoder_embed_dim
,
convolutions
=
eval
(
args
.
decoder_layers
),
out_embed_dim
=
args
.
decoder_out_embed_dim
,
attention
=
eval
(
args
.
decoder_attention
),
dropout
=
args
.
dropout
,
max_positions
=
args
.
max_target_positions
,
selfattention
=
eval
(
args
.
self_attention
),
attention_nheads
=
args
.
multihead_attention_nheads
,
selfattention_nheads
=
args
.
multihead_self_attention_nheads
,
project_input
=
eval
(
args
.
project_input
),
gated_attention
=
eval
(
args
.
gated_attention
),
downsample
=
eval
(
args
.
downsample
),
pretrained
=
pretrained
,
trained_decoder
=
trained_decoder
)
model
=
FConvModelSelfAtt
(
encoder
,
decoder
,
trained_encoder
)
return
model
@
property
def
pretrained
(
self
):
return
self
.
pretrained_encoder
is
not
None
class
FConvEncoder
(
FairseqEncoder
):
"""Convolutional encoder"""
def
__init__
(
self
,
dictionary
,
embed_dim
=
512
,
max_positions
=
1024
,
convolutions
=
((
512
,
3
),)
*
20
,
dropout
=
0.1
,
attention
=
False
,
attention_nheads
=
1
):
super
().
__init__
(
dictionary
)
self
.
dropout
=
dropout
self
.
num_attention_layers
=
None
num_embeddings
=
len
(
dictionary
)
self
.
padding_idx
=
dictionary
.
pad
()
self
.
embed_tokens
=
Embedding
(
num_embeddings
,
embed_dim
,
self
.
padding_idx
)
self
.
embed_positions
=
PositionalEmbedding
(
max_positions
,
embed_dim
,
self
.
padding_idx
,
left_pad
=
LEFT_PAD_SOURCE
,
)
def
expand_bool_array
(
val
):
if
isinstance
(
val
,
bool
):
# expand True into [True, True, ...] and do the same with False
return
[
val
]
*
len
(
convolutions
)
return
val
attention
=
expand_bool_array
(
attention
)
in_channels
=
convolutions
[
0
][
0
]
self
.
fc1
=
Linear
(
embed_dim
,
in_channels
,
dropout
=
dropout
)
self
.
projections
=
nn
.
ModuleList
()
self
.
convolutions
=
nn
.
ModuleList
()
self
.
attention
=
nn
.
ModuleList
()
self
.
attproj
=
nn
.
ModuleList
()
for
i
,
(
out_channels
,
kernel_size
)
in
enumerate
(
convolutions
):
self
.
projections
.
append
(
Linear
(
in_channels
,
out_channels
)
if
in_channels
!=
out_channels
else
None
)
self
.
convolutions
.
append
(
ConvTBC
(
in_channels
,
out_channels
*
2
,
kernel_size
,
dropout
=
dropout
))
self
.
attention
.
append
(
SelfAttention
(
out_channels
,
embed_dim
,
attention_nheads
)
if
attention
[
i
]
else
None
)
in_channels
=
out_channels
self
.
fc2
=
Linear
(
in_channels
,
embed_dim
)
def
forward
(
self
,
src_tokens
,
src_lengths
):
# embed tokens and positions
x
=
self
.
embed_tokens
(
src_tokens
)
+
self
.
embed_positions
(
src_tokens
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
input_embedding
=
x
.
transpose
(
0
,
1
)
# project to size of convolution
x
=
self
.
fc1
(
x
)
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
# temporal convolutions
for
proj
,
conv
,
attention
in
zip
(
self
.
projections
,
self
.
convolutions
,
self
.
attention
):
residual
=
x
if
proj
is
None
else
proj
(
x
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
padding_l
=
(
conv
.
kernel_size
[
0
]
-
1
)
//
2
padding_r
=
conv
.
kernel_size
[
0
]
//
2
x
=
F
.
pad
(
x
,
(
0
,
0
,
0
,
0
,
padding_l
,
padding_r
))
x
=
conv
(
x
)
x
=
F
.
glu
(
x
,
dim
=
2
)
if
attention
is
not
None
:
x
=
attention
(
x
)
x
=
(
x
+
residual
)
*
math
.
sqrt
(
0.5
)
# T x B x C -> B x T x C
x
=
x
.
transpose
(
1
,
0
)
# project back to size of embedding
x
=
self
.
fc2
(
x
)
# scale gradients (this only affects backward, not forward)
x
=
GradMultiply
.
apply
(
x
,
1.0
/
(
2.0
*
self
.
num_attention_layers
))
# add output to input embedding for attention
y
=
(
x
+
input_embedding
.
transpose
(
0
,
1
))
*
math
.
sqrt
(
0.5
)
return
{
'encoder_out'
:
(
x
,
y
),
}
def
max_positions
(
self
):
"""Maximum input length supported by the encoder."""
return
self
.
embed_positions
.
max_positions
()
class
FConvDecoder
(
FairseqDecoder
):
"""Convolutional decoder"""
def
__init__
(
self
,
dictionary
,
embed_dim
=
512
,
out_embed_dim
=
256
,
max_positions
=
1024
,
convolutions
=
((
512
,
3
),)
*
8
,
attention
=
True
,
dropout
=
0.1
,
selfattention
=
False
,
attention_nheads
=
1
,
selfattention_nheads
=
1
,
project_input
=
False
,
gated_attention
=
False
,
downsample
=
False
,
pretrained
=
False
,
trained_decoder
=
None
):
super
().
__init__
(
dictionary
)
self
.
register_buffer
(
'version'
,
torch
.
Tensor
([
2
]))
self
.
pretrained
=
pretrained
self
.
pretrained_decoder
=
trained_decoder
self
.
dropout
=
dropout
in_channels
=
convolutions
[
0
][
0
]
def
expand_bool_array
(
val
):
if
isinstance
(
val
,
bool
):
# expand True into [True, True, ...] and do the same with False
return
[
val
]
*
len
(
convolutions
)
return
val
attention
=
expand_bool_array
(
attention
)
selfattention
=
expand_bool_array
(
selfattention
)
if
not
isinstance
(
attention
,
list
)
or
len
(
attention
)
!=
len
(
convolutions
):
raise
ValueError
(
'Attention is expected to be a list of booleans of '
'length equal to the number of layers.'
)
num_embeddings
=
len
(
dictionary
)
padding_idx
=
dictionary
.
pad
()
self
.
embed_tokens
=
Embedding
(
num_embeddings
,
embed_dim
,
padding_idx
)
self
.
embed_positions
=
PositionalEmbedding
(
max_positions
,
embed_dim
,
padding_idx
,
left_pad
=
LEFT_PAD_TARGET
,
)
self
.
fc1
=
Linear
(
embed_dim
,
in_channels
,
dropout
=
dropout
)
self
.
projections
=
nn
.
ModuleList
()
self
.
convolutions
=
nn
.
ModuleList
()
self
.
attention
=
nn
.
ModuleList
()
self
.
selfattention
=
nn
.
ModuleList
()
self
.
attproj
=
nn
.
ModuleList
()
for
i
,
(
out_channels
,
kernel_size
)
in
enumerate
(
convolutions
):
pad
=
kernel_size
-
1
self
.
projections
.
append
(
Linear
(
in_channels
,
out_channels
)
if
in_channels
!=
out_channels
else
None
)
self
.
convolutions
.
append
(
LinearizedConv1d
(
in_channels
,
out_channels
*
2
,
kernel_size
,
padding
=
(
kernel_size
-
1
),
dropout
=
dropout
))
self
.
attention
.
append
(
DownsampledMultiHeadAttention
(
out_channels
,
embed_dim
,
attention_nheads
,
project_input
=
project_input
,
gated
=
False
,
downsample
=
False
)
if
attention
[
i
]
else
None
)
self
.
attproj
.
append
(
Linear
(
out_channels
,
embed_dim
,
dropout
=
dropout
)
if
attention
[
i
]
else
None
)
self
.
selfattention
.
append
(
SelfAttention
(
out_channels
,
embed_dim
,
selfattention_nheads
,
project_input
=
project_input
,
gated
=
gated_attention
,
downsample
=
downsample
)
if
selfattention
[
i
]
else
None
)
in_channels
=
out_channels
self
.
fc2
=
Linear
(
in_channels
,
out_embed_dim
)
self
.
fc3
=
Linear
(
out_embed_dim
,
num_embeddings
,
dropout
=
dropout
)
# model fusion
if
self
.
pretrained
:
# independent gates are learned from the concatenated input
self
.
gate1
=
nn
.
Sequential
(
Linear
(
out_embed_dim
*
2
,
out_embed_dim
),
nn
.
Sigmoid
())
self
.
gate2
=
nn
.
Sequential
(
Linear
(
out_embed_dim
*
2
,
out_embed_dim
),
nn
.
Sigmoid
())
# pretrained and trained models are joined
self
.
joining
=
nn
.
Sequential
(
Linear
(
out_embed_dim
*
2
,
out_embed_dim
*
2
),
nn
.
LayerNorm
(
out_embed_dim
*
2
),
nn
.
GLU
(),
Linear
(
out_embed_dim
,
out_embed_dim
*
2
),
nn
.
LayerNorm
(
out_embed_dim
*
2
),
nn
.
GLU
(),
Linear
(
out_embed_dim
,
out_embed_dim
),
nn
.
LayerNorm
(
out_embed_dim
))
# pretrained model contains an output layer that is nhid -> vocab size
# but the models are combined in their hidden state
# the hook stores the output of the pretrained model forward
self
.
pretrained_outputs
=
{}
def
save_output
():
def
hook
(
a
,
b
,
output
):
self
.
pretrained_outputs
[
"out"
]
=
output
return
hook
self
.
pretrained_decoder
.
fc2
.
register_forward_hook
(
save_output
())
def
forward
(
self
,
prev_output_tokens
,
encoder_out_dict
):
encoder_out
=
encoder_out_dict
[
'encoder'
][
'encoder_out'
]
trained_encoder_out
=
encoder_out_dict
[
'pretrained'
]
if
self
.
pretrained
else
None
encoder_a
,
encoder_b
=
self
.
_split_encoder_out
(
encoder_out
)
# embed positions
positions
=
self
.
embed_positions
(
prev_output_tokens
)
# embed tokens and positions
x
=
self
.
embed_tokens
(
prev_output_tokens
)
+
positions
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
target_embedding
=
x
.
transpose
(
0
,
1
)
# project to size of convolution
x
=
self
.
fc1
(
x
)
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
# temporal convolutions
avg_attn_scores
=
None
for
proj
,
conv
,
attention
,
selfattention
,
attproj
in
zip
(
self
.
projections
,
self
.
convolutions
,
self
.
attention
,
self
.
selfattention
,
self
.
attproj
):
residual
=
x
if
proj
is
None
else
proj
(
x
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
conv
(
x
)
x
=
F
.
glu
(
x
,
dim
=
2
)
# attention
if
attention
is
not
None
:
r
=
x
x
,
attn_scores
=
attention
(
attproj
(
x
)
+
target_embedding
,
encoder_a
,
encoder_b
)
x
=
x
+
r
if
avg_attn_scores
is
None
:
avg_attn_scores
=
attn_scores
else
:
avg_attn_scores
.
add_
(
attn_scores
)
if
selfattention
is
not
None
:
x
=
selfattention
(
x
)
x
=
(
x
+
residual
)
*
math
.
sqrt
(
0.5
)
# T x B x C -> B x T x C
x
=
x
.
transpose
(
0
,
1
)
# project back to size of vocabulary
x
=
self
.
fc2
(
x
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
if
not
self
.
pretrained
:
x
=
self
.
fc3
(
x
)
# fusion gating
if
self
.
pretrained
:
trained_x
,
_
=
self
.
pretrained_decoder
.
forward
(
prev_output_tokens
,
trained_encoder_out
)
y
=
torch
.
cat
([
x
,
self
.
pretrained_outputs
[
"out"
]],
dim
=-
1
)
gate1
=
self
.
gate1
(
y
)
gate2
=
self
.
gate2
(
y
)
gated_x1
=
gate1
*
x
gated_x2
=
gate2
*
self
.
pretrained_outputs
[
"out"
]
fusion
=
torch
.
cat
([
gated_x1
,
gated_x2
],
dim
=-
1
)
fusion
=
self
.
joining
(
fusion
)
fusion_output
=
self
.
fc3
(
fusion
)
return
fusion_output
,
avg_attn_scores
else
:
return
x
,
avg_attn_scores
def
reorder_incremental_state
(
self
,
incremental_state
,
new_order
):
"""Reorder buffered internal state (for incremental generation)."""
super
().
reorder_incremental_state
(
incremental_state
,
new_order
)
def
reorder_encoder_out
(
self
,
encoder_out_dict
,
new_order
):
encoder_out_dict
[
'encoder'
][
'encoder_out'
]
=
tuple
(
eo
.
index_select
(
0
,
new_order
)
for
eo
in
encoder_out_dict
[
'encoder'
][
'encoder_out'
])
if
'pretrained'
in
encoder_out_dict
:
encoder_out_dict
[
'pretrained'
][
'encoder'
][
'encoder_out'
]
=
tuple
(
eo
.
index_select
(
0
,
new_order
)
for
eo
in
encoder_out_dict
[
'pretrained'
][
'encoder'
][
'encoder_out'
])
return
encoder_out_dict
def
max_positions
(
self
):
"""Maximum output length supported by the decoder."""
return
self
.
embed_positions
.
max_positions
()
def
_split_encoder_out
(
self
,
encoder_out
):
"""Split and transpose encoder outputs.
"""
# transpose only once to speed up attention layers
encoder_a
,
encoder_b
=
encoder_out
encoder_a
=
encoder_a
.
transpose
(
0
,
1
).
contiguous
()
encoder_b
=
encoder_b
.
transpose
(
0
,
1
).
contiguous
()
result
=
(
encoder_a
,
encoder_b
)
return
result
class
SelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
out_channels
,
embed_dim
,
num_heads
,
project_input
=
False
,
gated
=
False
,
downsample
=
False
):
super
().
__init__
()
self
.
attention
=
DownsampledMultiHeadAttention
(
out_channels
,
embed_dim
,
num_heads
,
dropout
=
0
,
bias
=
True
,
project_input
=
project_input
,
gated
=
gated
,
downsample
=
downsample
)
self
.
in_proj_q
=
Linear
(
out_channels
,
embed_dim
)
self
.
in_proj_k
=
Linear
(
out_channels
,
embed_dim
)
self
.
in_proj_v
=
Linear
(
out_channels
,
embed_dim
)
self
.
ln
=
nn
.
LayerNorm
(
out_channels
)
def
forward
(
self
,
x
):
residual
=
x
query
=
self
.
in_proj_q
(
x
)
key
=
self
.
in_proj_k
(
x
)
value
=
self
.
in_proj_v
(
x
)
x
,
_
=
self
.
attention
(
query
,
key
,
value
,
mask_future_timesteps
=
True
,
use_scalar_bias
=
True
)
return
self
.
ln
(
x
+
residual
)
def
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
):
m
=
nn
.
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
)
m
.
weight
.
data
.
normal_
(
0
,
0.1
)
return
m
def
PositionalEmbedding
(
num_embeddings
,
embedding_dim
,
padding_idx
,
left_pad
):
m
=
LearnedPositionalEmbedding
(
num_embeddings
,
embedding_dim
,
padding_idx
,
left_pad
)
m
.
weight
.
data
.
normal_
(
0
,
0.1
)
return
m
def
Linear
(
in_features
,
out_features
,
dropout
=
0.
):
"""Weight-normalized Linear layer (input: N x T x C)"""
m
=
nn
.
Linear
(
in_features
,
out_features
)
m
.
weight
.
data
.
normal_
(
mean
=
0
,
std
=
math
.
sqrt
((
1
-
dropout
)
/
in_features
))
m
.
bias
.
data
.
zero_
()
return
m
def
LinearizedConv1d
(
in_channels
,
out_channels
,
kernel_size
,
dropout
=
0.
,
**
kwargs
):
"""Weight-normalized Conv1d layer optimized for decoding"""
m
=
LinearizedConvolution
(
in_channels
,
out_channels
,
kernel_size
,
**
kwargs
)
std
=
math
.
sqrt
((
4
*
(
1.0
-
dropout
))
/
(
m
.
kernel_size
[
0
]
*
in_channels
))
m
.
weight
.
data
.
normal_
(
mean
=
0
,
std
=
std
)
m
.
bias
.
data
.
zero_
()
return
m
def
ConvTBC
(
in_channels
,
out_channels
,
kernel_size
,
dropout
=
0
,
**
kwargs
):
"""Weight-normalized Conv1d layer"""
from
fairseq.modules
import
ConvTBC
m
=
ConvTBC
(
in_channels
,
out_channels
,
kernel_size
,
**
kwargs
)
std
=
math
.
sqrt
((
4
*
(
1.0
-
dropout
))
/
(
m
.
kernel_size
[
0
]
*
in_channels
))
m
.
weight
.
data
.
normal_
(
mean
=
0
,
std
=
std
)
m
.
bias
.
data
.
zero_
()
return
m
@
register_model_architecture
(
'fconv_self_att'
,
'fconv_self_att'
)
def
base_architecture
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
'encoder_embed_dim'
,
512
)
args
.
encoder_layers
=
getattr
(
args
,
'encoder_layers'
,
'[(512, 3)] * 3'
)
args
.
decoder_embed_dim
=
getattr
(
args
,
'decoder_embed_dim'
,
512
)
args
.
decoder_layers
=
getattr
(
args
,
'decoder_layers'
,
'[(512, 3)] * 8'
)
args
.
decoder_out_embed_dim
=
getattr
(
args
,
'decoder_out_embed_dim'
,
256
)
args
.
decoder_attention
=
getattr
(
args
,
'decoder_attention'
,
'True'
)
@
register_model_architecture
(
'fconv_self_att'
,
'fconv_self_att_wp'
)
def
fconv_self_att_wp
(
args
):
base_architecture
(
args
)
args
.
encoder_embed_dim
=
getattr
(
args
,
'encoder_embed_dim'
,
512
)
args
.
encoder_layers
=
getattr
(
args
,
'encoder_layers'
,
'[(128, 3)] * 2 + [(512,3)] * 1'
)
args
.
decoder_embed_dim
=
getattr
(
args
,
'decoder_embed_dim'
,
512
)
args
.
decoder_layers
=
getattr
(
args
,
'decoder_layers'
,
'[(512, 4)] * 4 + [(768, 4)] * 2 + [(1024, 4)] * 1 + [(2048,4)] * 1'
)
args
.
decoder_out_embed_dim
=
getattr
(
args
,
'decoder_out_embed_dim'
,
256
)
args
.
decoder_attention
=
getattr
(
args
,
'decoder_attention'
,
'True'
)
fairseq/modules/__init__.py
View file @
b59815bc
...
@@ -8,19 +8,23 @@
...
@@ -8,19 +8,23 @@
from
.adaptive_softmax
import
AdaptiveSoftmax
from
.adaptive_softmax
import
AdaptiveSoftmax
from
.beamable_mm
import
BeamableMM
from
.beamable_mm
import
BeamableMM
from
.conv_tbc
import
ConvTBC
from
.conv_tbc
import
ConvTBC
from
.downsampled_multihead_attention
import
DownsampledMultiHeadAttention
from
.grad_multiply
import
GradMultiply
from
.grad_multiply
import
GradMultiply
from
.learned_positional_embedding
import
LearnedPositionalEmbedding
from
.learned_positional_embedding
import
LearnedPositionalEmbedding
from
.linearized_convolution
import
LinearizedConvolution
from
.linearized_convolution
import
LinearizedConvolution
from
.multihead_attention
import
MultiheadAttention
from
.multihead_attention
import
MultiheadAttention
from
.scalar_bias
import
ScalarBias
from
.sinusoidal_positional_embedding
import
SinusoidalPositionalEmbedding
from
.sinusoidal_positional_embedding
import
SinusoidalPositionalEmbedding
__all__
=
[
__all__
=
[
'AdaptiveSoftmax'
,
'AdaptiveSoftmax'
,
'BeamableMM'
,
'BeamableMM'
,
'ConvTBC'
,
'ConvTBC'
,
'DownsampledMultiHeadAttention'
,
'GradMultiply'
,
'GradMultiply'
,
'LearnedPositionalEmbedding'
,
'LearnedPositionalEmbedding'
,
'LinearizedConvolution'
,
'LinearizedConvolution'
,
'MultiheadAttention'
,
'MultiheadAttention'
,
'ScalarBias'
,
'SinusoidalPositionalEmbedding'
,
'SinusoidalPositionalEmbedding'
,
]
]
fairseq/modules/downsampled_multihead_attention.py
0 → 100644
View file @
b59815bc
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.autograd
import
Variable
from
fairseq.modules.scalar_bias
import
scalar_bias
class
SingleHeadAttention
(
nn
.
Module
):
"""
Single-head attention that supports Gating and Downsampling
"""
def
__init__
(
self
,
out_channels
,
embed_dim
,
head_dim
,
head_index
,
dropout
=
0.
,
bias
=
True
,
project_input
=
True
,
gated
=
False
,
downsample
=
False
,
num_heads
=
1
):
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
dropout
=
dropout
self
.
head_index
=
head_index
self
.
head_dim
=
head_dim
self
.
project_input
=
project_input
self
.
gated
=
gated
self
.
downsample
=
downsample
self
.
num_heads
=
num_heads
self
.
projection
=
None
k_layers
=
[]
v_layers
=
[]
if
self
.
downsample
:
k_layers
.
append
(
Downsample
(
self
.
head_index
))
v_layers
.
append
(
Downsample
(
self
.
head_index
))
out_proj_size
=
self
.
head_dim
else
:
out_proj_size
=
self
.
head_dim
*
self
.
num_heads
if
self
.
gated
:
k_layers
.
append
(
GatedLinear
(
self
.
embed_dim
,
out_proj_size
,
bias
=
bias
))
self
.
in_proj_q
=
GatedLinear
(
self
.
embed_dim
,
out_proj_size
,
bias
=
bias
)
v_layers
.
append
(
GatedLinear
(
self
.
embed_dim
,
out_proj_size
,
bias
=
bias
))
else
:
k_layers
.
append
(
Linear
(
self
.
embed_dim
,
out_proj_size
,
bias
=
bias
))
self
.
in_proj_q
=
Linear
(
self
.
embed_dim
,
out_proj_size
,
bias
=
bias
)
v_layers
.
append
(
Linear
(
self
.
embed_dim
,
out_proj_size
,
bias
=
bias
))
self
.
in_proj_k
=
nn
.
Sequential
(
*
k_layers
)
self
.
in_proj_v
=
nn
.
Sequential
(
*
v_layers
)
if
self
.
downsample
:
self
.
out_proj
=
Linear
(
out_proj_size
,
self
.
head_dim
,
bias
=
bias
)
else
:
self
.
out_proj
=
Linear
(
out_proj_size
,
out_channels
,
bias
=
bias
)
self
.
scaling
=
self
.
head_dim
**-
0.5
def
forward
(
self
,
query
,
key
,
value
,
mask_future_timesteps
=
False
,
key_padding_mask
=
None
,
use_scalar_bias
=
False
):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
query, key and value. Future timesteps can be masked with the
`mask_future_timesteps` argument. Padding elements can be excluded from
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
src_len
,
bsz
,
out_channels
=
key
.
size
()
tgt_len
=
query
.
size
(
0
)
assert
list
(
query
.
size
())
==
[
tgt_len
,
bsz
,
out_channels
]
assert
key
.
size
()
==
value
.
size
()
if
key_padding_mask
is
not
None
:
assert
key_padding_mask
.
size
(
0
)
==
bsz
assert
key_padding_mask
.
size
(
1
)
==
src_len
if
self
.
downsample
:
size
=
bsz
else
:
size
=
bsz
*
self
.
num_heads
k
=
key
v
=
value
q
=
query
if
self
.
project_input
:
q
=
self
.
in_proj_q
(
q
)
k
=
self
.
in_proj_k
(
k
)
v
=
self
.
in_proj_v
(
v
)
src_len
=
k
.
size
()[
0
]
q
*=
self
.
scaling
if
not
self
.
downsample
:
q
=
q
.
view
(
tgt_len
,
size
,
self
.
head_dim
)
k
=
k
.
view
(
src_len
,
size
,
self
.
head_dim
)
v
=
v
.
view
(
src_len
,
size
,
self
.
head_dim
)
q
=
q
.
transpose
(
0
,
1
)
k
=
k
.
transpose
(
0
,
1
)
v
=
v
.
transpose
(
0
,
1
)
attn_weights
=
torch
.
bmm
(
q
,
k
.
transpose
(
1
,
2
))
if
mask_future_timesteps
:
assert
query
.
size
()
==
key
.
size
(),
\
'mask_future_timesteps only applies to self-attention'
attn_weights
*=
Variable
(
torch
.
tril
(
attn_weights
.
data
.
new
([
1
]).
expand
(
tgt_len
,
tgt_len
).
clone
(),
diagonal
=-
1
,
)[:,
::
self
.
head_index
+
1
if
self
.
downsample
else
1
].
unsqueeze
(
0
))
attn_weights
+=
Variable
(
torch
.
triu
(
attn_weights
.
data
.
new
([
-
math
.
inf
]).
expand
(
tgt_len
,
tgt_len
).
clone
(),
diagonal
=
0
)[:,
::
self
.
head_index
+
1
if
self
.
downsample
else
1
].
unsqueeze
(
0
))
tgt_size
=
tgt_len
if
use_scalar_bias
:
attn_weights
=
scalar_bias
(
attn_weights
,
2
)
v
=
scalar_bias
(
v
,
1
)
tgt_size
+=
1
if
key_padding_mask
is
not
None
:
# don't attend to padding symbols
if
key_padding_mask
.
max
()
>
0
:
if
self
.
downsample
:
attn_weights
=
attn_weights
.
view
(
bsz
,
1
,
tgt_len
,
src_len
)
else
:
attn_weights
=
attn_weights
.
view
(
size
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
attn_weights
.
masked_fill
(
key_padding_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
),
-
math
.
inf
,
)
attn_weights
=
attn_weights
.
view
(
size
,
tgt_len
,
src_len
)
attn_weights
=
F
.
softmax
(
attn_weights
,
dim
=-
1
)
attn_weights
=
F
.
dropout
(
attn_weights
,
p
=
self
.
dropout
,
training
=
self
.
training
)
attn
=
torch
.
bmm
(
attn_weights
,
v
)
if
self
.
downsample
:
attn
=
attn
.
transpose
(
0
,
1
).
contiguous
().
view
(
tgt_len
,
bsz
,
self
.
head_dim
)
else
:
attn
=
attn
.
transpose
(
0
,
1
).
contiguous
().
view
(
tgt_len
,
bsz
,
self
.
embed_dim
)
attn
=
self
.
out_proj
(
attn
)
return
attn
,
attn_weights
class
DownsampledMultiHeadAttention
(
nn
.
ModuleList
):
"""
Multi-headed attention with Gating and Downsampling
"""
def
__init__
(
self
,
out_channels
,
embed_dim
,
num_heads
,
dropout
=
0.
,
bias
=
True
,
project_input
=
True
,
gated
=
False
,
downsample
=
False
):
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
dropout
=
dropout
self
.
head_dim
=
embed_dim
//
num_heads
self
.
downsample
=
downsample
self
.
gated
=
gated
self
.
project_input
=
project_input
assert
self
.
head_dim
*
num_heads
==
embed_dim
if
self
.
downsample
:
attention_heads
=
[]
for
index
in
range
(
self
.
num_heads
):
attention_heads
.
append
(
SingleHeadAttention
(
out_channels
,
self
.
embed_dim
,
self
.
head_dim
,
index
,
self
.
dropout
,
bias
,
self
.
project_input
,
self
.
gated
,
self
.
downsample
,
self
.
num_heads
))
super
().
__init__
(
modules
=
attention_heads
)
self
.
out_proj
=
Linear
(
embed_dim
,
out_channels
,
bias
=
bias
)
else
:
# either we have a list of attention heads, or just one attention head
# if not being downsampled, we can do the heads with one linear layer instead of separate ones
super
().
__init__
()
self
.
attention_module
=
SingleHeadAttention
(
out_channels
,
self
.
embed_dim
,
self
.
head_dim
,
1
,
self
.
dropout
,
bias
,
self
.
project_input
,
self
.
gated
,
self
.
downsample
,
self
.
num_heads
)
def
forward
(
self
,
query
,
key
,
value
,
mask_future_timesteps
=
False
,
key_padding_mask
=
None
,
use_scalar_bias
=
False
):
src_len
,
bsz
,
embed_dim
=
key
.
size
()
tgt_len
=
query
.
size
(
0
)
assert
embed_dim
==
self
.
embed_dim
assert
list
(
query
.
size
())
==
[
tgt_len
,
bsz
,
embed_dim
]
assert
key
.
size
()
==
value
.
size
()
tgt_size
=
tgt_len
if
use_scalar_bias
:
tgt_size
+=
1
attn
=
[]
attn_weights
=
[]
if
self
.
downsample
:
for
attention_head_number
in
range
(
self
.
num_heads
):
# call the forward of each attention head
_attn
,
_attn_weight
=
self
[
attention_head_number
](
query
,
key
,
value
,
mask_future_timesteps
,
key_padding_mask
,
use_scalar_bias
)
attn
.
append
(
_attn
)
attn_weights
.
append
(
_attn_weight
)
full_attn
=
torch
.
cat
(
attn
,
dim
=
2
)
full_attn
=
self
.
out_proj
(
full_attn
)
return
full_attn
,
attn_weights
[
0
].
clone
()
else
:
_attn
,
_attn_weight
=
self
.
attention_module
(
query
,
key
,
value
,
mask_future_timesteps
,
key_padding_mask
,
use_scalar_bias
)
attn
.
append
(
_attn
)
attn_weights
.
append
(
_attn_weight
)
full_attn
=
torch
.
cat
(
attn
,
dim
=
2
)
full_attn_weights
=
torch
.
cat
(
attn_weights
)
full_attn_weights
=
full_attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_size
,
src_len
)
full_attn_weights
=
full_attn_weights
.
sum
(
dim
=
1
)
/
self
.
num_heads
return
full_attn
,
full_attn_weights
class
Downsample
(
nn
.
Module
):
"""
Selects every nth element, where n is the index
"""
def
__init__
(
self
,
index
):
super
().
__init__
()
self
.
index
=
index
def
forward
(
self
,
x
):
return
x
[::
self
.
index
+
1
]
def
Linear
(
in_features
,
out_features
,
dropout
=
0.
,
bias
=
True
):
"""Weight-normalized Linear layer (input: B x T x C)"""
m
=
nn
.
Linear
(
in_features
,
out_features
,
bias
=
bias
)
m
.
weight
.
data
.
normal_
(
mean
=
0
,
std
=
math
.
sqrt
((
1
-
dropout
)
/
in_features
))
m
.
bias
.
data
.
zero_
()
return
nn
.
utils
.
weight_norm
(
m
)
def
GatedLinear
(
in_features
,
out_features
,
dropout
=
0.
,
bias
=
True
):
"""Weight-normalized Linear layer (input: B x T x C) with interspersed GLU units"""
return
nn
.
Sequential
(
Linear
(
in_features
,
out_features
*
4
,
dropout
,
bias
),
nn
.
GLU
(),
Linear
(
out_features
*
2
,
out_features
*
2
,
dropout
,
bias
),
nn
.
GLU
(),
Linear
(
out_features
,
out_features
,
dropout
,
bias
)
)
fairseq/modules/scalar_bias.py
0 → 100644
View file @
b59815bc
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import
torch
class
ScalarBias
(
torch
.
autograd
.
Function
):
"""
Adds a vector of scalars, used in self-attention mechanism to allow
the model to optionally attend to this vector instead of the past
"""
@
staticmethod
def
forward
(
ctx
,
input
,
dim
,
bias_init
):
size
=
list
(
input
.
size
())
size
[
dim
]
+=
1
output
=
input
.
new
(
*
size
).
fill_
(
bias_init
)
output
.
narrow
(
dim
,
1
,
size
[
dim
]
-
1
).
copy_
(
input
)
ctx
.
dim
=
dim
return
output
@
staticmethod
def
backward
(
ctx
,
grad
):
return
grad
.
narrow
(
ctx
.
dim
,
1
,
grad
.
size
(
ctx
.
dim
)
-
1
),
None
,
None
def
scalar_bias
(
input
,
dim
,
bias_init
=
0
):
return
ScalarBias
.
apply
(
input
,
dim
,
bias_init
)
fairseq/options.py
View file @
b59815bc
...
@@ -232,8 +232,8 @@ def add_checkpoint_args(parser):
...
@@ -232,8 +232,8 @@ def add_checkpoint_args(parser):
def
add_common_eval_args
(
group
):
def
add_common_eval_args
(
group
):
group
.
add_argument
(
'--path'
,
metavar
=
'FILE'
,
action
=
'append'
,
group
.
add_argument
(
'--path'
,
metavar
=
'FILE'
,
help
=
'path(s) to model file(s)'
)
help
=
'path(s) to model file(s)
, comma separated
'
)
group
.
add_argument
(
'--remove-bpe'
,
nargs
=
'?'
,
const
=
'@@ '
,
default
=
None
,
group
.
add_argument
(
'--remove-bpe'
,
nargs
=
'?'
,
const
=
'@@ '
,
default
=
None
,
help
=
'remove BPE tokens before scoring'
)
help
=
'remove BPE tokens before scoring'
)
group
.
add_argument
(
'--cpu'
,
action
=
'store_true'
,
help
=
'generate on CPU'
)
group
.
add_argument
(
'--cpu'
,
action
=
'store_true'
,
help
=
'generate on CPU'
)
...
@@ -259,6 +259,8 @@ def add_generation_args(parser):
...
@@ -259,6 +259,8 @@ def add_generation_args(parser):
group
.
add_argument
(
'--max-len-b'
,
default
=
200
,
type
=
int
,
metavar
=
'N'
,
group
.
add_argument
(
'--max-len-b'
,
default
=
200
,
type
=
int
,
metavar
=
'N'
,
help
=
(
'generate sequences of maximum length ax + b, '
help
=
(
'generate sequences of maximum length ax + b, '
'where x is the source length'
))
'where x is the source length'
))
group
.
add_argument
(
'--min-len'
,
default
=
1
,
type
=
float
,
metavar
=
'N'
,
help
=
(
'minimum generation length'
))
group
.
add_argument
(
'--no-early-stop'
,
action
=
'store_true'
,
group
.
add_argument
(
'--no-early-stop'
,
action
=
'store_true'
,
help
=
(
'continue searching even after finalizing k=beam '
help
=
(
'continue searching even after finalizing k=beam '
'hypotheses; this is more correct, but increases '
'hypotheses; this is more correct, but increases '
...
@@ -279,6 +281,10 @@ def add_generation_args(parser):
...
@@ -279,6 +281,10 @@ def add_generation_args(parser):
help
=
'initialize generation by target prefix of given length'
)
help
=
'initialize generation by target prefix of given length'
)
group
.
add_argument
(
'--sampling'
,
action
=
'store_true'
,
group
.
add_argument
(
'--sampling'
,
action
=
'store_true'
,
help
=
'sample hypotheses instead of using beam search'
)
help
=
'sample hypotheses instead of using beam search'
)
group
.
add_argument
(
'--sampling-topk'
,
default
=-
1
,
type
=
int
,
metavar
=
'PS'
,
help
=
'sample from top K likely next words instead of all words'
)
group
.
add_argument
(
'--sampling-temperature'
,
default
=
1
,
type
=
float
,
metavar
=
'N'
,
help
=
'temperature for random sampling'
)
return
group
return
group
...
...
fairseq/sequence_generator.py
View file @
b59815bc
...
@@ -15,9 +15,9 @@ from fairseq.models import FairseqIncrementalDecoder
...
@@ -15,9 +15,9 @@ from fairseq.models import FairseqIncrementalDecoder
class
SequenceGenerator
(
object
):
class
SequenceGenerator
(
object
):
def
__init__
(
self
,
models
,
beam_size
=
1
,
minlen
=
1
,
maxlen
=
None
,
def
__init__
(
self
,
models
,
beam_size
=
1
,
minlen
=
1
,
maxlen
=
None
,
stop_early
=
True
,
normalize_scores
=
True
,
len_penalty
=
1
,
stop_early
=
True
,
normalize_scores
=
True
,
len_penalty
=
1
,
unk_penalty
=
0
,
retain_dropout
=
False
,
sampling
=
False
):
unk_penalty
=
0
,
retain_dropout
=
False
,
sampling
=
False
,
sampling_topk
=-
1
,
sampling_temperature
=
1
):
"""Generates translations of a given source sentence.
"""Generates translations of a given source sentence.
Args:
Args:
min/maxlen: The length of the generated output will be bounded by
min/maxlen: The length of the generated output will be bounded by
minlen and maxlen (not including the end-of-sentence marker).
minlen and maxlen (not including the end-of-sentence marker).
...
@@ -45,6 +45,8 @@ class SequenceGenerator(object):
...
@@ -45,6 +45,8 @@ class SequenceGenerator(object):
self
.
unk_penalty
=
unk_penalty
self
.
unk_penalty
=
unk_penalty
self
.
retain_dropout
=
retain_dropout
self
.
retain_dropout
=
retain_dropout
self
.
sampling
=
sampling
self
.
sampling
=
sampling
self
.
sampling_topk
=
sampling_topk
self
.
sampling_temperature
=
sampling_temperature
def
cuda
(
self
):
def
cuda
(
self
):
for
model
in
self
.
models
:
for
model
in
self
.
models
:
...
@@ -54,7 +56,6 @@ class SequenceGenerator(object):
...
@@ -54,7 +56,6 @@ class SequenceGenerator(object):
def
generate_batched_itr
(
self
,
data_itr
,
beam_size
=
None
,
maxlen_a
=
0.0
,
maxlen_b
=
None
,
def
generate_batched_itr
(
self
,
data_itr
,
beam_size
=
None
,
maxlen_a
=
0.0
,
maxlen_b
=
None
,
cuda
=
False
,
timer
=
None
,
prefix_size
=
0
):
cuda
=
False
,
timer
=
None
,
prefix_size
=
0
):
"""Iterate over a batched dataset and yield individual translations.
"""Iterate over a batched dataset and yield individual translations.
Args:
Args:
maxlen_a/b: generate sequences of maximum length ax + b,
maxlen_a/b: generate sequences of maximum length ax + b,
where x is the source sentence length.
where x is the source sentence length.
...
@@ -169,11 +170,9 @@ class SequenceGenerator(object):
...
@@ -169,11 +170,9 @@ class SequenceGenerator(object):
"""
"""
Finalize the given hypotheses at this step, while keeping the total
Finalize the given hypotheses at this step, while keeping the total
number of finalized hypotheses per sentence <= beam_size.
number of finalized hypotheses per sentence <= beam_size.
Note: the input must be in the desired finalization order, so that
Note: the input must be in the desired finalization order, so that
hypotheses that appear earlier in the input are preferred to those
hypotheses that appear earlier in the input are preferred to those
that appear later.
that appear later.
Args:
Args:
step: current time step
step: current time step
bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
...
@@ -221,7 +220,6 @@ class SequenceGenerator(object):
...
@@ -221,7 +220,6 @@ class SequenceGenerator(object):
# remove padding tokens from attn scores
# remove padding tokens from attn scores
nonpad_idxs
=
src_tokens
[
sent
].
ne
(
self
.
pad
)
nonpad_idxs
=
src_tokens
[
sent
].
ne
(
self
.
pad
)
hypo_attn
=
attn_clone
[
i
][
nonpad_idxs
]
hypo_attn
=
attn_clone
[
i
][
nonpad_idxs
]
_
,
alignment
=
hypo_attn
.
max
(
dim
=
0
)
_
,
alignment
=
hypo_attn
.
max
(
dim
=
0
)
return
{
return
{
...
@@ -303,15 +301,29 @@ class SequenceGenerator(object):
...
@@ -303,15 +301,29 @@ class SequenceGenerator(object):
cand_beams
.
resize_as_
(
cand_indices
).
fill_
(
0
)
cand_beams
.
resize_as_
(
cand_indices
).
fill_
(
0
)
elif
self
.
sampling
:
elif
self
.
sampling
:
assert
self
.
pad
==
1
,
'sampling assumes the first two symbols can be ignored'
assert
self
.
pad
==
1
,
'sampling assumes the first two symbols can be ignored'
exp_probs
=
probs
.
exp_
().
view
(
-
1
,
self
.
vocab_size
)
if
self
.
sampling_topk
>
0
:
values
,
indices
=
probs
[:,
2
:].
topk
(
self
.
sampling_topk
)
exp_probs
=
values
.
div_
(
self
.
sampling_temperature
).
exp
()
if
step
==
0
:
torch
.
multinomial
(
exp_probs
,
beam_size
,
replacement
=
True
,
out
=
cand_indices
)
else
:
torch
.
multinomial
(
exp_probs
,
1
,
replacement
=
True
,
out
=
cand_indices
)
torch
.
gather
(
exp_probs
,
dim
=
1
,
index
=
cand_indices
,
out
=
cand_scores
)
torch
.
gather
(
indices
,
dim
=
1
,
index
=
cand_indices
,
out
=
cand_indices
)
cand_indices
.
add_
(
2
)
else
:
exp_probs
=
probs
.
div_
(
self
.
sampling_temperature
).
exp_
().
view
(
-
1
,
self
.
vocab_size
)
if
step
==
0
:
if
step
==
0
:
# we exclude the first two vocab items, one of which is pad
# we exclude the first two vocab items, one of which is pad
torch
.
multinomial
(
exp_probs
[:,
2
:],
beam_size
,
replacement
=
True
,
out
=
cand_indices
)
torch
.
multinomial
(
exp_probs
[:,
2
:],
beam_size
,
replacement
=
True
,
out
=
cand_indices
)
cand_indices
.
add_
(
2
)
else
:
else
:
torch
.
multinomial
(
exp_probs
[:,
2
:],
1
,
replacement
=
True
,
out
=
cand_indices
)
torch
.
multinomial
(
exp_probs
[:,
2
:],
1
,
replacement
=
True
,
out
=
cand_indices
)
cand_indices
.
add_
(
2
)
cand_indices
.
add_
(
2
)
torch
.
gather
(
exp_probs
,
dim
=
1
,
index
=
cand_indices
,
out
=
cand_scores
)
torch
.
gather
(
exp_probs
,
dim
=
1
,
index
=
cand_indices
,
out
=
cand_scores
)
cand_scores
.
log_
()
cand_scores
.
log_
()
cand_indices
=
cand_indices
.
view
(
bsz
,
-
1
).
repeat
(
1
,
2
)
cand_indices
=
cand_indices
.
view
(
bsz
,
-
1
).
repeat
(
1
,
2
)
cand_scores
=
cand_scores
.
view
(
bsz
,
-
1
).
repeat
(
1
,
2
)
cand_scores
=
cand_scores
.
view
(
bsz
,
-
1
).
repeat
(
1
,
2
)
...
@@ -489,6 +501,7 @@ class SequenceGenerator(object):
...
@@ -489,6 +501,7 @@ class SequenceGenerator(object):
avg_probs
=
None
avg_probs
=
None
avg_attn
=
None
avg_attn
=
None
for
model
,
encoder_out
in
zip
(
self
.
models
,
encoder_outs
):
for
model
,
encoder_out
in
zip
(
self
.
models
,
encoder_outs
):
with
utils
.
maybe_no_grad
():
with
utils
.
maybe_no_grad
():
if
incremental_states
[
model
]
is
not
None
:
if
incremental_states
[
model
]
is
not
None
:
...
@@ -497,6 +510,7 @@ class SequenceGenerator(object):
...
@@ -497,6 +510,7 @@ class SequenceGenerator(object):
decoder_out
=
list
(
model
.
decoder
(
tokens
,
encoder_out
))
decoder_out
=
list
(
model
.
decoder
(
tokens
,
encoder_out
))
decoder_out
[
0
]
=
decoder_out
[
0
][:,
-
1
,
:]
decoder_out
[
0
]
=
decoder_out
[
0
][:,
-
1
,
:]
attn
=
decoder_out
[
1
]
attn
=
decoder_out
[
1
]
probs
=
model
.
get_normalized_probs
(
decoder_out
,
log_probs
=
False
).
data
probs
=
model
.
get_normalized_probs
(
decoder_out
,
log_probs
=
False
).
data
if
avg_probs
is
None
:
if
avg_probs
is
None
:
avg_probs
=
probs
avg_probs
=
probs
...
...
fairseq/utils.py
View file @
b59815bc
...
@@ -157,7 +157,7 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None,
...
@@ -157,7 +157,7 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None,
ensemble
=
[]
ensemble
=
[]
for
state
in
states
:
for
state
in
states
:
model
=
models
.
build_model
(
args
,
src_dict
,
dst_dict
)
model
=
models
.
build_model
(
args
,
src_dict
,
dst_dict
)
model
.
load_state_dict
(
state
[
'model'
])
model
.
load_state_dict
(
state
[
'model'
]
,
strict
=
True
)
ensemble
.
append
(
model
)
ensemble
.
append
(
model
)
return
ensemble
,
args
return
ensemble
,
args
...
...
generate.py
View file @
b59815bc
...
@@ -31,8 +31,9 @@ def main(args):
...
@@ -31,8 +31,9 @@ def main(args):
dataset
=
data_loaders
.
load_dataset
(
args
,
[
args
.
gen_subset
],
args
.
replace_unk
is
not
None
)
dataset
=
data_loaders
.
load_dataset
(
args
,
[
args
.
gen_subset
],
args
.
replace_unk
is
not
None
)
# Load ensemble
# Load ensemble
print
(
'| loading model(s) from {}'
.
format
(
', '
.
join
(
args
.
path
)))
print
(
'| loading model(s) from {}'
.
format
(
args
.
path
))
models
,
_
=
utils
.
load_ensemble_for_inference
(
args
.
path
,
dataset
.
src_dict
,
dataset
.
dst_dict
)
model_paths
=
args
.
path
.
split
(
','
)
models
,
_
=
utils
.
load_ensemble_for_inference
(
model_paths
,
dataset
.
src_dict
,
dataset
.
dst_dict
)
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
src
,
len
(
dataset
.
src_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
src
,
len
(
dataset
.
src_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
dst
,
len
(
dataset
.
dst_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
dataset
.
dst
,
len
(
dataset
.
dst_dict
)))
...
@@ -70,7 +71,9 @@ def main(args):
...
@@ -70,7 +71,9 @@ def main(args):
translator
=
SequenceGenerator
(
translator
=
SequenceGenerator
(
models
,
beam_size
=
args
.
beam
,
stop_early
=
(
not
args
.
no_early_stop
),
models
,
beam_size
=
args
.
beam
,
stop_early
=
(
not
args
.
no_early_stop
),
normalize_scores
=
(
not
args
.
unnormalized
),
len_penalty
=
args
.
lenpen
,
normalize_scores
=
(
not
args
.
unnormalized
),
len_penalty
=
args
.
lenpen
,
unk_penalty
=
args
.
unkpen
,
sampling
=
args
.
sampling
)
unk_penalty
=
args
.
unkpen
,
sampling
=
args
.
sampling
,
sampling_topk
=
args
.
sampling_topk
,
minlen
=
args
.
min_len
)
if
use_cuda
:
if
use_cuda
:
translator
.
cuda
()
translator
.
cuda
()
...
...
interactive.py
View file @
b59815bc
...
@@ -64,8 +64,9 @@ def main(args):
...
@@ -64,8 +64,9 @@ def main(args):
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
args
.
cpu
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
args
.
cpu
# Load ensemble
# Load ensemble
print
(
'| loading model(s) from {}'
.
format
(
', '
.
join
(
args
.
path
)))
print
(
'| loading model(s) from {}'
.
format
(
args
.
path
))
models
,
model_args
=
utils
.
load_ensemble_for_inference
(
args
.
path
,
data_dir
=
args
.
data
)
model_paths
=
args
.
path
.
split
(
','
)
models
,
model_args
=
utils
.
load_ensemble_for_inference
(
model_paths
,
data_dir
=
args
.
data
)
src_dict
,
dst_dict
=
models
[
0
].
src_dict
,
models
[
0
].
dst_dict
src_dict
,
dst_dict
=
models
[
0
].
src_dict
,
models
[
0
].
dst_dict
print
(
'| [{}] dictionary: {} types'
.
format
(
model_args
.
source_lang
,
len
(
src_dict
)))
print
(
'| [{}] dictionary: {} types'
.
format
(
model_args
.
source_lang
,
len
(
src_dict
)))
...
@@ -81,7 +82,9 @@ def main(args):
...
@@ -81,7 +82,9 @@ def main(args):
translator
=
SequenceGenerator
(
translator
=
SequenceGenerator
(
models
,
beam_size
=
args
.
beam
,
stop_early
=
(
not
args
.
no_early_stop
),
models
,
beam_size
=
args
.
beam
,
stop_early
=
(
not
args
.
no_early_stop
),
normalize_scores
=
(
not
args
.
unnormalized
),
len_penalty
=
args
.
lenpen
,
normalize_scores
=
(
not
args
.
unnormalized
),
len_penalty
=
args
.
lenpen
,
unk_penalty
=
args
.
unkpen
,
sampling
=
args
.
sampling
)
unk_penalty
=
args
.
unkpen
,
sampling
=
args
.
sampling
,
sampling_topk
=
args
.
sampling_topk
,
minlen
=
args
.
min_len
)
if
use_cuda
:
if
use_cuda
:
translator
.
cuda
()
translator
.
cuda
()
...
...
train.py
View file @
b59815bc
...
@@ -40,8 +40,8 @@ def main(args):
...
@@ -40,8 +40,8 @@ def main(args):
for
split
in
splits
:
for
split
in
splits
:
print
(
'| {} {} {} examples'
.
format
(
args
.
data
,
split
,
len
(
dataset
.
splits
[
split
])))
print
(
'| {} {} {} examples'
.
format
(
args
.
data
,
split
,
len
(
dataset
.
splits
[
split
])))
# Build model and criterion
model
=
models
.
build_model
(
args
,
dataset
.
src_dict
,
dataset
.
dst_dict
)
model
=
models
.
build_model
(
args
,
dataset
.
src_dict
,
dataset
.
dst_dict
)
criterion
=
criterions
.
build_criterion
(
args
,
dataset
.
src_dict
,
dataset
.
dst_dict
)
criterion
=
criterions
.
build_criterion
(
args
,
dataset
.
src_dict
,
dataset
.
dst_dict
)
print
(
'| model {}, criterion {}'
.
format
(
args
.
arch
,
criterion
.
__class__
.
__name__
))
print
(
'| model {}, criterion {}'
.
format
(
args
.
arch
,
criterion
.
__class__
.
__name__
))
print
(
'| num. model params: {}'
.
format
(
sum
(
p
.
data
.
numel
()
for
p
in
model
.
parameters
())))
print
(
'| num. model params: {}'
.
format
(
sum
(
p
.
data
.
numel
()
for
p
in
model
.
parameters
())))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment