Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
9438019f
Commit
9438019f
authored
Feb 24, 2018
by
Myle Ott
Committed by
Sergey Edunov
Feb 27, 2018
Browse files
Refactor incremental generation to be more explicit and less magical (#222)
parent
e7094b14
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
145 additions
and
219 deletions
+145
-219
fairseq/models/fairseq_incremental_decoder.py
fairseq/models/fairseq_incremental_decoder.py
+9
-88
fairseq/models/fconv.py
fairseq/models/fconv.py
+29
-29
fairseq/models/lstm.py
fairseq/models/lstm.py
+24
-37
fairseq/modules/learned_positional_embedding.py
fairseq/modules/learned_positional_embedding.py
+2
-6
fairseq/modules/linearized_convolution.py
fairseq/modules/linearized_convolution.py
+32
-35
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+12
-13
fairseq/utils.py
fairseq/utils.py
+31
-0
tests/utils.py
tests/utils.py
+6
-9
train.py
train.py
+0
-2
No files found.
fairseq/models/fairseq_incremental_decoder.py
View file @
9438019f
...
@@ -13,99 +13,20 @@ class FairseqIncrementalDecoder(FairseqDecoder):
...
@@ -13,99 +13,20 @@ class FairseqIncrementalDecoder(FairseqDecoder):
def
__init__
(
self
,
dictionary
):
def
__init__
(
self
,
dictionary
):
super
().
__init__
(
dictionary
)
super
().
__init__
(
dictionary
)
self
.
_is_incremental_eval
=
False
self
.
_incremental_state
=
{}
def
forward
(
self
,
prev_output_tokens
,
encoder_out
):
def
forward
(
self
,
prev_output_tokens
,
encoder_out
,
incremental_state
=
None
):
if
self
.
_is_incremental_eval
:
raise
NotImplementedError
raise
NotImplementedError
else
:
raise
NotImplementedError
def
incremental_inference
(
self
):
"""Context manager for incremental inference.
This provides an optimized forward pass for incremental inference
(i.e., it predicts one time step at a time). If the input order changes
between time steps, call reorder_incremental_state to update the
relevant buffers. To generate a fresh sequence, first call
clear_incremental_state.
Usage:
```
with model.decoder.incremental_inference():
for step in range(maxlen):
out, _ = model.decoder(tokens[:, :step], encoder_out)
probs = model.get_normalized_probs(out[:, -1, :], log_probs=False)
```
"""
class
IncrementalInference
(
object
):
def
__init__
(
self
,
decoder
):
self
.
decoder
=
decoder
def
__enter__
(
self
):
self
.
decoder
.
incremental_eval
(
True
)
def
__exit__
(
self
,
*
args
):
self
.
decoder
.
incremental_eval
(
False
)
return
IncrementalInference
(
self
)
def
incremental_eval
(
self
,
mode
=
True
):
"""Sets the decoder and all children in incremental evaluation mode."""
assert
self
.
_is_incremental_eval
!=
mode
,
\
'incremental_eval already set to mode {}'
.
format
(
mode
)
self
.
_is_incremental_eval
=
mode
if
mode
:
self
.
clear_incremental_state
()
def
apply_incremental_eval
(
module
):
if
module
!=
self
and
hasattr
(
module
,
'incremental_eval'
):
module
.
incremental_eval
(
mode
)
self
.
apply
(
apply_incremental_eval
)
def
get_incremental_state
(
self
,
key
):
"""Return cached state or None if not in incremental inference mode."""
if
self
.
_is_incremental_eval
and
key
in
self
.
_incremental_state
:
return
self
.
_incremental_state
[
key
]
return
None
def
set_incremental_state
(
self
,
key
,
value
):
"""Cache state needed for incremental inference mode."""
if
self
.
_is_incremental_eval
:
self
.
_incremental_state
[
key
]
=
value
return
value
def
clear_incremental_state
(
self
):
"""Clear all state used for incremental generation.
**For incremental inference only**
This should be called before generating a fresh sequence.
beam_size is required if using BeamableMM.
"""
if
self
.
_is_incremental_eval
:
del
self
.
_incremental_state
self
.
_incremental_state
=
{}
def
apply_clear_incremental_state
(
module
):
if
module
!=
self
and
hasattr
(
module
,
'clear_incremental_state'
):
module
.
clear_incremental_state
()
self
.
apply
(
apply_clear_incremental_state
)
def
reorder_incremental_state
(
self
,
new_order
):
"""Reorder buffered internal state (for incremental generation).
**For incremental inference only**
def
reorder_incremental_state
(
self
,
incremental_state
,
new_order
):
"""Reorder incremental state.
This should be called when the order of the input has changed from the
This should be called when the order of the input has changed from the
previous time step. A typical use case is beam search, where the input
previous time step. A typical use case is beam search, where the input
order changes between time steps based on the
choice
of beams.
order changes between time steps based on the
selection
of beams.
"""
"""
if
self
.
_is_incremental_eval
:
def
apply_reorder_incremental_state
(
module
):
def
apply_reorder_incremental_state
(
module
):
if
module
!=
self
and
hasattr
(
module
,
'reorder_incremental_state'
):
if
module
!=
self
and
hasattr
(
module
,
'reorder_incremental_state'
):
module
.
reorder_incremental_state
(
new_order
)
module
.
reorder_incremental_state
(
incremental_state
,
new_order
)
self
.
apply
(
apply_reorder_incremental_state
)
self
.
apply
(
apply_reorder_incremental_state
)
def
set_beam_size
(
self
,
beam_size
):
def
set_beam_size
(
self
,
beam_size
):
...
...
fairseq/models/fconv.py
View file @
9438019f
...
@@ -10,6 +10,7 @@ import torch
...
@@ -10,6 +10,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
fairseq
import
utils
from
fairseq.data
import
LanguagePairDataset
from
fairseq.data
import
LanguagePairDataset
from
fairseq.modules
import
BeamableMM
,
GradMultiply
,
LearnedPositionalEmbedding
,
LinearizedConvolution
from
fairseq.modules
import
BeamableMM
,
GradMultiply
,
LearnedPositionalEmbedding
,
LinearizedConvolution
...
@@ -229,19 +230,13 @@ class FConvDecoder(FairseqIncrementalDecoder):
...
@@ -229,19 +230,13 @@ class FConvDecoder(FairseqIncrementalDecoder):
else
:
else
:
self
.
fc3
=
Linear
(
out_embed_dim
,
num_embeddings
,
dropout
=
dropout
)
self
.
fc3
=
Linear
(
out_embed_dim
,
num_embeddings
,
dropout
=
dropout
)
def
forward
(
self
,
prev_output_tokens
,
encoder_out
):
def
forward
(
self
,
prev_output_tokens
,
encoder_out
,
incremental_state
=
None
):
# split and transpose encoder outputs
# split and transpose encoder outputs
encoder_a
,
encoder_b
=
self
.
_split_encoder_out
(
encoder_out
)
encoder_a
,
encoder_b
=
self
.
_split_encoder_out
(
encoder_out
,
incremental_state
)
# embed positions
# embed tokens and combine with positional embeddings
positions
=
self
.
embed_positions
(
prev_output_tokens
)
x
=
self
.
_embed_tokens
(
prev_output_tokens
,
incremental_state
)
x
+=
self
.
embed_positions
(
prev_output_tokens
,
incremental_state
)
if
self
.
_is_incremental_eval
:
# keep only the last token for incremental forward pass
prev_output_tokens
=
prev_output_tokens
[:,
-
1
:]
# embed tokens and positions
x
=
self
.
embed_tokens
(
prev_output_tokens
)
+
positions
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
target_embedding
=
x
target_embedding
=
x
...
@@ -249,7 +244,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
...
@@ -249,7 +244,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
x
=
self
.
fc1
(
x
)
x
=
self
.
fc1
(
x
)
# B x T x C -> T x B x C
# B x T x C -> T x B x C
x
=
self
.
_transpose_
unless_
incremental_
eval
(
x
)
x
=
self
.
_transpose_
if_training
(
x
,
incremental_
state
)
# temporal convolutions
# temporal convolutions
avg_attn_scores
=
None
avg_attn_scores
=
None
...
@@ -258,13 +253,14 @@ class FConvDecoder(FairseqIncrementalDecoder):
...
@@ -258,13 +253,14 @@ class FConvDecoder(FairseqIncrementalDecoder):
residual
=
x
if
proj
is
None
else
proj
(
x
)
residual
=
x
if
proj
is
None
else
proj
(
x
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
conv
(
x
)
x
=
conv
(
x
,
incremental_state
)
if
incremental_state
is
None
:
x
=
conv
.
remove_future_timesteps
(
x
)
x
=
conv
.
remove_future_timesteps
(
x
)
x
=
F
.
glu
(
x
,
dim
=
2
)
x
=
F
.
glu
(
x
,
dim
=
2
)
# attention
# attention
if
attention
is
not
None
:
if
attention
is
not
None
:
x
=
self
.
_transpose_
unless_
incremental_
eval
(
x
)
x
=
self
.
_transpose_
if_training
(
x
,
incremental_
state
)
x
,
attn_scores
=
attention
(
x
,
target_embedding
,
(
encoder_a
,
encoder_b
))
x
,
attn_scores
=
attention
(
x
,
target_embedding
,
(
encoder_a
,
encoder_b
))
attn_scores
=
attn_scores
/
num_attn_layers
attn_scores
=
attn_scores
/
num_attn_layers
...
@@ -273,13 +269,13 @@ class FConvDecoder(FairseqIncrementalDecoder):
...
@@ -273,13 +269,13 @@ class FConvDecoder(FairseqIncrementalDecoder):
else
:
else
:
avg_attn_scores
.
add_
(
attn_scores
)
avg_attn_scores
.
add_
(
attn_scores
)
x
=
self
.
_transpose_
unless_
incremental_
eval
(
x
)
x
=
self
.
_transpose_
if_training
(
x
,
incremental_
state
)
# residual
# residual
x
=
(
x
+
residual
)
*
math
.
sqrt
(
0.5
)
x
=
(
x
+
residual
)
*
math
.
sqrt
(
0.5
)
# T x B x C -> B x T x C
# T x B x C -> B x T x C
x
=
self
.
_transpose_
unless_
incremental_
eval
(
x
)
x
=
self
.
_transpose_
if_training
(
x
,
incremental_
state
)
# project back to size of vocabulary
# project back to size of vocabulary
x
=
self
.
fc2
(
x
)
x
=
self
.
fc2
(
x
)
...
@@ -288,10 +284,6 @@ class FConvDecoder(FairseqIncrementalDecoder):
...
@@ -288,10 +284,6 @@ class FConvDecoder(FairseqIncrementalDecoder):
return
x
,
avg_attn_scores
return
x
,
avg_attn_scores
def
reorder_incremental_state
(
self
,
new_order
):
"""Reorder buffered internal state (for incremental generation)."""
super
().
reorder_incremental_state
(
new_order
)
def
max_positions
(
self
):
def
max_positions
(
self
):
"""Maximum output length supported by the decoder."""
"""Maximum output length supported by the decoder."""
return
self
.
embed_positions
.
max_positions
()
return
self
.
embed_positions
.
max_positions
()
...
@@ -306,13 +298,19 @@ class FConvDecoder(FairseqIncrementalDecoder):
...
@@ -306,13 +298,19 @@ class FConvDecoder(FairseqIncrementalDecoder):
state_dict
[
'decoder.version'
]
=
torch
.
Tensor
([
1
])
state_dict
[
'decoder.version'
]
=
torch
.
Tensor
([
1
])
return
state_dict
return
state_dict
def
_split_encoder_out
(
self
,
encoder_out
):
def
_embed_tokens
(
self
,
tokens
,
incremental_state
):
if
incremental_state
is
not
None
:
# keep only the last token for incremental forward pass
tokens
=
tokens
[:,
-
1
:]
return
self
.
embed_tokens
(
tokens
)
def
_split_encoder_out
(
self
,
encoder_out
,
incremental_state
):
"""Split and transpose encoder outputs.
"""Split and transpose encoder outputs.
This is cached when doing incremental inference.
This is cached when doing incremental inference.
"""
"""
cached_result
=
self
.
get_incremental_state
(
'encoder_out'
)
cached_result
=
utils
.
get_incremental_state
(
self
,
incremental_state
,
'encoder_out'
)
if
cached_result
:
if
cached_result
is
not
None
:
return
cached_result
return
cached_result
# transpose only once to speed up attention layers
# transpose only once to speed up attention layers
...
@@ -320,12 +318,14 @@ class FConvDecoder(FairseqIncrementalDecoder):
...
@@ -320,12 +318,14 @@ class FConvDecoder(FairseqIncrementalDecoder):
encoder_a
=
encoder_a
.
transpose
(
1
,
2
).
contiguous
()
encoder_a
=
encoder_a
.
transpose
(
1
,
2
).
contiguous
()
result
=
(
encoder_a
,
encoder_b
)
result
=
(
encoder_a
,
encoder_b
)
return
self
.
set_incremental_state
(
'encoder_out'
,
result
)
if
incremental_state
is
not
None
:
utils
.
set_incremental_state
(
self
,
incremental_state
,
'encoder_out'
,
result
)
return
result
def
_transpose_unless_incremental_eval
(
self
,
x
):
def
_transpose_if_training
(
self
,
x
,
incremental_state
):
if
self
.
_is_incremental_eval
:
if
incremental_state
is
None
:
x
=
x
.
transpose
(
0
,
1
)
return
x
return
x
return
x
.
transpose
(
0
,
1
)
def
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
):
def
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
):
...
...
fairseq/models/lstm.py
View file @
9438019f
...
@@ -183,12 +183,9 @@ class LSTMDecoder(FairseqIncrementalDecoder):
...
@@ -183,12 +183,9 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self
.
additional_fc
=
Linear
(
embed_dim
,
out_embed_dim
)
self
.
additional_fc
=
Linear
(
embed_dim
,
out_embed_dim
)
self
.
fc_out
=
Linear
(
out_embed_dim
,
num_embeddings
,
dropout
=
dropout_out
)
self
.
fc_out
=
Linear
(
out_embed_dim
,
num_embeddings
,
dropout
=
dropout_out
)
def
forward
(
self
,
prev_output_tokens
,
encoder_out
):
def
forward
(
self
,
prev_output_tokens
,
encoder_out
,
incremental_state
=
None
):
if
self
.
_is_
incremental_
eval
:
if
incremental_
state
is
not
None
:
prev_output_tokens
=
prev_output_tokens
[:,
-
1
:]
prev_output_tokens
=
prev_output_tokens
[:,
-
1
:]
return
self
.
_forward
(
prev_output_tokens
,
encoder_out
)
def
_forward
(
self
,
prev_output_tokens
,
encoder_out
):
bsz
,
seqlen
=
prev_output_tokens
.
size
()
bsz
,
seqlen
=
prev_output_tokens
.
size
()
# get outputs from encoder
# get outputs from encoder
...
@@ -204,15 +201,15 @@ class LSTMDecoder(FairseqIncrementalDecoder):
...
@@ -204,15 +201,15 @@ class LSTMDecoder(FairseqIncrementalDecoder):
x
=
x
.
transpose
(
0
,
1
)
x
=
x
.
transpose
(
0
,
1
)
# initialize previous states (or get from cache during incremental generation)
# initialize previous states (or get from cache during incremental generation)
prev_hiddens
=
self
.
get_incremental_state
(
'prev_hiddens'
)
cached_state
=
utils
.
get_incremental_state
(
self
,
incremental_state
,
'cached_state'
)
if
not
prev_hiddens
:
if
cached_state
is
not
None
:
# first time step, initialize previous states
prev_hiddens
,
prev_cells
,
input_feed
=
cached_state
prev_hiddens
,
prev_cells
=
self
.
_init_prev_states
(
encoder_out
)
input_feed
=
Variable
(
x
.
data
.
new
(
bsz
,
embed_dim
).
zero_
())
else
:
else
:
# previous states are cached
_
,
encoder_hiddens
,
encoder_cells
=
encoder_out
prev_cells
=
self
.
get_incremental_state
(
'prev_cells'
)
num_layers
=
len
(
self
.
layers
)
input_feed
=
self
.
get_incremental_state
(
'input_feed'
)
prev_hiddens
=
[
encoder_hiddens
[
i
]
for
i
in
range
(
num_layers
)]
prev_cells
=
[
encoder_cells
[
i
]
for
i
in
range
(
num_layers
)]
input_feed
=
Variable
(
x
.
data
.
new
(
bsz
,
embed_dim
).
zero_
())
attn_scores
=
Variable
(
x
.
data
.
new
(
srclen
,
seqlen
,
bsz
).
zero_
())
attn_scores
=
Variable
(
x
.
data
.
new
(
srclen
,
seqlen
,
bsz
).
zero_
())
outs
=
[]
outs
=
[]
...
@@ -242,9 +239,8 @@ class LSTMDecoder(FairseqIncrementalDecoder):
...
@@ -242,9 +239,8 @@ class LSTMDecoder(FairseqIncrementalDecoder):
outs
.
append
(
out
)
outs
.
append
(
out
)
# cache previous states (no-op except during incremental generation)
# cache previous states (no-op except during incremental generation)
self
.
set_incremental_state
(
'prev_hiddens'
,
prev_hiddens
)
utils
.
set_incremental_state
(
self
.
set_incremental_state
(
'prev_cells'
,
prev_cells
)
self
,
incremental_state
,
'cached_state'
,
(
prev_hiddens
,
prev_cells
,
input_feed
))
self
.
set_incremental_state
(
'input_feed'
,
input_feed
)
# collect outputs across time steps
# collect outputs across time steps
x
=
torch
.
cat
(
outs
,
dim
=
0
).
view
(
seqlen
,
bsz
,
embed_dim
)
x
=
torch
.
cat
(
outs
,
dim
=
0
).
view
(
seqlen
,
bsz
,
embed_dim
)
...
@@ -263,34 +259,25 @@ class LSTMDecoder(FairseqIncrementalDecoder):
...
@@ -263,34 +259,25 @@ class LSTMDecoder(FairseqIncrementalDecoder):
return
x
,
attn_scores
return
x
,
attn_scores
def
reorder_incremental_state
(
self
,
new_order
):
def
reorder_incremental_state
(
self
,
incremental_state
,
new_order
):
"""Reorder buffered intern
al
state
(for
incremental
generation)."""
cached_state
=
utils
.
get_increment
al
_
state
(
self
,
incremental
_state
,
'cached_state'
)
super
().
reorder_incremental_state
(
new_order
)
if
cached_state
is
None
:
new_order
=
Variable
(
new_order
)
return
def
reorder_state
(
key
):
def
reorder_state
(
state
):
old
=
self
.
get_incremental_state
(
key
)
if
isinstance
(
state
,
list
):
if
isinstance
(
old
,
list
):
return
[
reorder_state
(
state_i
)
for
state_i
in
state
]
new
=
[
old_i
.
index_select
(
0
,
new_order
)
for
old_i
in
old
]
return
state
.
index_select
(
0
,
new_order
)
else
:
new
=
old
.
index_select
(
0
,
new_order
)
self
.
set_incremental_state
(
key
,
new
)
reorder_state
(
'prev_hiddens'
)
if
not
isinstance
(
new_order
,
Variable
):
reorder_state
(
'prev_cells'
)
new_order
=
Variable
(
new_order
)
reorder_state
(
'input_feed'
)
new_state
=
tuple
(
map
(
reorder_state
,
cached_state
))
utils
.
set_incremental_state
(
self
,
incremental_state
,
'cached_state'
,
new_state
)
def
max_positions
(
self
):
def
max_positions
(
self
):
"""Maximum output length supported by the decoder."""
"""Maximum output length supported by the decoder."""
return
int
(
1e5
)
# an arbitrary large number
return
int
(
1e5
)
# an arbitrary large number
def
_init_prev_states
(
self
,
encoder_out
):
_
,
encoder_hiddens
,
encoder_cells
=
encoder_out
num_layers
=
len
(
self
.
layers
)
prev_hiddens
=
[
encoder_hiddens
[
i
]
for
i
in
range
(
num_layers
)]
prev_cells
=
[
encoder_cells
[
i
]
for
i
in
range
(
num_layers
)]
return
prev_hiddens
,
prev_cells
def
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
):
def
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
):
m
=
nn
.
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
)
m
=
nn
.
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
)
...
...
fairseq/modules/learned_positional_embedding.py
View file @
9438019f
...
@@ -20,14 +20,10 @@ class LearnedPositionalEmbedding(nn.Embedding):
...
@@ -20,14 +20,10 @@ class LearnedPositionalEmbedding(nn.Embedding):
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
padding_idx
,
left_pad
):
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
padding_idx
,
left_pad
):
super
().
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
)
super
().
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
)
self
.
left_pad
=
left_pad
self
.
left_pad
=
left_pad
self
.
_is_incremental_eval
=
False
def
incremental_eval
(
self
,
mode
=
True
):
def
forward
(
self
,
input
,
incremental_state
=
None
):
self
.
_is_incremental_eval
=
mode
def
forward
(
self
,
input
):
"""Input is expected to be of size [bsz x seqlen]."""
"""Input is expected to be of size [bsz x seqlen]."""
if
self
.
_is_
incremental_
eval
:
if
incremental_
state
is
not
None
:
# positions is the same for every token when decoding a single step
# positions is the same for every token when decoding a single step
positions
=
Variable
(
positions
=
Variable
(
input
.
data
.
new
(
1
,
1
).
fill_
(
self
.
padding_idx
+
input
.
size
(
1
)))
input
.
data
.
new
(
1
,
1
).
fill_
(
self
.
padding_idx
+
input
.
size
(
1
)))
...
...
fairseq/modules/linearized_convolution.py
View file @
9438019f
...
@@ -22,35 +22,20 @@ class LinearizedConvolution(ConvTBC):
...
@@ -22,35 +22,20 @@ class LinearizedConvolution(ConvTBC):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
**
kwargs
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
**
kwargs
):
super
().
__init__
(
in_channels
,
out_channels
,
kernel_size
,
**
kwargs
)
super
().
__init__
(
in_channels
,
out_channels
,
kernel_size
,
**
kwargs
)
self
.
_is_incremental_eval
=
False
self
.
_linearized_weight
=
None
self
.
_linearized_weight
=
None
self
.
register_backward_hook
(
self
.
_clear_linearized_weight
)
self
.
register_backward_hook
(
self
.
_clear_linearized_weight
)
def
remove_future_timesteps
(
self
,
x
):
def
forward
(
self
,
input
,
incremental_state
=
None
):
"""Remove future time steps created by padding."""
"""
if
not
self
.
_is_incremental_eval
and
self
.
kernel_size
[
0
]
>
1
and
self
.
padding
[
0
]
>
0
:
Input: Time x Batch x Channel.
x
=
x
[:
-
self
.
padding
[
0
],
:,
:]
Args:
return
x
incremental_state: Used to buffer signal; if not None, then input is
expected to contain a single frame. If the input order changes
def
incremental_eval
(
self
,
mode
=
True
):
between time steps, call reorder_incremental_state.
self
.
_is_incremental_eval
=
mode
"""
if
mode
:
if
incremental_state
is
None
:
self
.
clear_incremental_state
()
def
forward
(
self
,
input
):
if
self
.
_is_incremental_eval
:
return
self
.
incremental_forward
(
input
)
else
:
return
super
().
forward
(
input
)
return
super
().
forward
(
input
)
def
incremental_forward
(
self
,
input
):
"""Forward convolution one time step at a time.
This function maintains an internal state to buffer signal and accepts
a single frame as input. If the input order changes between time steps,
call reorder_incremental_state. To apply to fresh inputs, call
clear_incremental_state.
"""
# reshape weight
# reshape weight
weight
=
self
.
_get_linearized_weight
()
weight
=
self
.
_get_linearized_weight
()
kw
=
self
.
kernel_size
[
0
]
kw
=
self
.
kernel_size
[
0
]
...
@@ -58,25 +43,37 @@ class LinearizedConvolution(ConvTBC):
...
@@ -58,25 +43,37 @@ class LinearizedConvolution(ConvTBC):
bsz
=
input
.
size
(
0
)
# input: bsz x len x dim
bsz
=
input
.
size
(
0
)
# input: bsz x len x dim
if
kw
>
1
:
if
kw
>
1
:
input
=
input
.
data
input
=
input
.
data
if
self
.
input_buffer
is
None
:
input_buffer
=
self
.
_get_input_buffer
(
incremental_state
)
self
.
input_buffer
=
input
.
new
(
bsz
,
kw
,
input
.
size
(
2
))
if
input_buffer
is
None
:
self
.
input_buffer
.
zero_
()
input_buffer
=
input
.
new
(
bsz
,
kw
,
input
.
size
(
2
)).
zero_
()
self
.
_set_input_buffer
(
incremental_state
,
input_buffer
)
else
:
else
:
# shift buffer
# shift buffer
self
.
input_buffer
[:,
:
-
1
,
:]
=
self
.
input_buffer
[:,
1
:,
:].
clone
()
input_buffer
[:,
:
-
1
,
:]
=
input_buffer
[:,
1
:,
:].
clone
()
# append next input
# append next input
self
.
input_buffer
[:,
-
1
,
:]
=
input
[:,
-
1
,
:]
input_buffer
[:,
-
1
,
:]
=
input
[:,
-
1
,
:]
input
=
utils
.
volatile_variable
(
self
.
input_buffer
)
input
=
utils
.
volatile_variable
(
input_buffer
)
with
utils
.
maybe_no_grad
():
with
utils
.
maybe_no_grad
():
output
=
F
.
linear
(
input
.
view
(
bsz
,
-
1
),
weight
,
self
.
bias
)
output
=
F
.
linear
(
input
.
view
(
bsz
,
-
1
),
weight
,
self
.
bias
)
return
output
.
view
(
bsz
,
1
,
-
1
)
return
output
.
view
(
bsz
,
1
,
-
1
)
def
clear_incremental_state
(
self
):
def
remove_future_timesteps
(
self
,
x
):
self
.
input_buffer
=
None
"""Remove future time steps created by padding."""
if
self
.
kernel_size
[
0
]
>
1
and
self
.
padding
[
0
]
>
0
:
x
=
x
[:
-
self
.
padding
[
0
],
:,
:]
return
x
def
reorder_incremental_state
(
self
,
incremental_state
,
new_order
):
input_buffer
=
self
.
_get_input_buffer
(
incremental_state
)
if
input_buffer
is
not
None
:
input_buffer
=
input_buffer
.
index_select
(
0
,
new_order
)
self
.
_set_input_buffer
(
incremental_state
,
input_buffer
)
def
_get_input_buffer
(
self
,
incremental_state
):
return
utils
.
get_incremental_state
(
self
,
incremental_state
,
'input_buffer'
)
def
reorder_incremental_state
(
self
,
new_order
):
def
_set_input_buffer
(
self
,
incremental_state
,
new_buffer
):
if
self
.
input_buffer
is
not
None
:
return
utils
.
set_incremental_state
(
self
,
incremental_state
,
'input_buffer'
,
new_buffer
)
self
.
input_buffer
=
self
.
input_buffer
.
index_select
(
0
,
new_order
)
def
_get_linearized_weight
(
self
):
def
_get_linearized_weight
(
self
):
if
self
.
_linearized_weight
is
None
:
if
self
.
_linearized_weight
is
None
:
...
...
fairseq/sequence_generator.py
View file @
9438019f
...
@@ -5,7 +5,6 @@
...
@@ -5,7 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
# can be found in the PATENTS file in the same directory.
from
contextlib
import
ExitStack
import
math
import
math
import
torch
import
torch
...
@@ -87,10 +86,6 @@ class SequenceGenerator(object):
...
@@ -87,10 +86,6 @@ class SequenceGenerator(object):
def
generate
(
self
,
src_tokens
,
src_lengths
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
):
def
generate
(
self
,
src_tokens
,
src_lengths
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
):
"""Generate a batch of translations."""
"""Generate a batch of translations."""
with
ExitStack
()
as
stack
:
for
model
in
self
.
models
:
if
isinstance
(
model
.
decoder
,
FairseqIncrementalDecoder
):
stack
.
enter_context
(
model
.
decoder
.
incremental_inference
())
with
utils
.
maybe_no_grad
():
with
utils
.
maybe_no_grad
():
return
self
.
_generate
(
src_tokens
,
src_lengths
,
beam_size
,
maxlen
,
prefix_tokens
)
return
self
.
_generate
(
src_tokens
,
src_lengths
,
beam_size
,
maxlen
,
prefix_tokens
)
...
@@ -103,11 +98,14 @@ class SequenceGenerator(object):
...
@@ -103,11 +98,14 @@ class SequenceGenerator(object):
beam_size
=
min
(
beam_size
,
self
.
vocab_size
-
1
)
beam_size
=
min
(
beam_size
,
self
.
vocab_size
-
1
)
encoder_outs
=
[]
encoder_outs
=
[]
incremental_states
=
{}
for
model
in
self
.
models
:
for
model
in
self
.
models
:
if
not
self
.
retain_dropout
:
if
not
self
.
retain_dropout
:
model
.
eval
()
model
.
eval
()
if
isinstance
(
model
.
decoder
,
FairseqIncrementalDecoder
):
if
isinstance
(
model
.
decoder
,
FairseqIncrementalDecoder
):
model
.
decoder
.
set_beam_size
(
beam_size
)
incremental_states
[
model
]
=
{}
else
:
incremental_states
[
model
]
=
None
# compute the encoder output for each beam
# compute the encoder output for each beam
encoder_out
=
model
.
encoder
(
encoder_out
=
model
.
encoder
(
...
@@ -245,9 +243,11 @@ class SequenceGenerator(object):
...
@@ -245,9 +243,11 @@ class SequenceGenerator(object):
if
reorder_state
is
not
None
:
if
reorder_state
is
not
None
:
for
model
in
self
.
models
:
for
model
in
self
.
models
:
if
isinstance
(
model
.
decoder
,
FairseqIncrementalDecoder
):
if
isinstance
(
model
.
decoder
,
FairseqIncrementalDecoder
):
model
.
decoder
.
reorder_incremental_state
(
reorder_state
)
model
.
decoder
.
reorder_incremental_state
(
incremental_states
[
model
],
reorder_state
)
probs
,
avg_attn_scores
=
self
.
_decode
(
tokens
[:,
:
step
+
1
],
encoder_outs
)
probs
,
avg_attn_scores
=
self
.
_decode
(
tokens
[:,
:
step
+
1
],
encoder_outs
,
incremental_states
)
if
step
==
0
:
if
step
==
0
:
# at the first step all hypotheses are equally likely, so use
# at the first step all hypotheses are equally likely, so use
# only the first beam
# only the first beam
...
@@ -287,7 +287,6 @@ class SequenceGenerator(object):
...
@@ -287,7 +287,6 @@ class SequenceGenerator(object):
)
)
torch
.
div
(
cand_indices
,
self
.
vocab_size
,
out
=
cand_beams
)
torch
.
div
(
cand_indices
,
self
.
vocab_size
,
out
=
cand_beams
)
cand_indices
.
fmod_
(
self
.
vocab_size
)
cand_indices
.
fmod_
(
self
.
vocab_size
)
else
:
else
:
# finalize all active hypotheses once we hit maxlen
# finalize all active hypotheses once we hit maxlen
# pick the hypothesis with the highest prob of EOS right now
# pick the hypothesis with the highest prob of EOS right now
...
@@ -403,7 +402,7 @@ class SequenceGenerator(object):
...
@@ -403,7 +402,7 @@ class SequenceGenerator(object):
return
finalized
return
finalized
def
_decode
(
self
,
tokens
,
encoder_outs
):
def
_decode
(
self
,
tokens
,
encoder_outs
,
incremental_states
):
# wrap in Variable
# wrap in Variable
tokens
=
utils
.
volatile_variable
(
tokens
)
tokens
=
utils
.
volatile_variable
(
tokens
)
...
@@ -411,7 +410,7 @@ class SequenceGenerator(object):
...
@@ -411,7 +410,7 @@ class SequenceGenerator(object):
avg_attn
=
None
avg_attn
=
None
for
model
,
encoder_out
in
zip
(
self
.
models
,
encoder_outs
):
for
model
,
encoder_out
in
zip
(
self
.
models
,
encoder_outs
):
with
utils
.
maybe_no_grad
():
with
utils
.
maybe_no_grad
():
decoder_out
,
attn
=
model
.
decoder
(
tokens
,
encoder_out
)
decoder_out
,
attn
=
model
.
decoder
(
tokens
,
encoder_out
,
incremental_states
[
model
]
)
probs
=
model
.
get_normalized_probs
(
decoder_out
[:,
-
1
,
:],
log_probs
=
False
).
data
probs
=
model
.
get_normalized_probs
(
decoder_out
[:,
-
1
,
:],
log_probs
=
False
).
data
if
avg_probs
is
None
:
if
avg_probs
is
None
:
avg_probs
=
probs
avg_probs
=
probs
...
...
fairseq/utils.py
View file @
9438019f
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
# can be found in the PATENTS file in the same directory.
from
collections
import
defaultdict
import
contextlib
import
contextlib
import
logging
import
logging
import
os
import
os
...
@@ -198,6 +199,36 @@ def make_variable(sample, volatile=False, cuda=False):
...
@@ -198,6 +199,36 @@ def make_variable(sample, volatile=False, cuda=False):
return
_make_variable
(
sample
)
return
_make_variable
(
sample
)
INCREMENTAL_STATE_INSTANCE_ID
=
defaultdict
(
lambda
:
0
)
def
_get_full_incremental_state_key
(
module_instance
,
key
):
module_name
=
module_instance
.
__class__
.
__name__
# assign a unique ID to each module instance, so that incremental state is
# not shared across module instances
if
not
hasattr
(
module_instance
,
'_fairseq_instance_id'
):
INCREMENTAL_STATE_INSTANCE_ID
[
module_name
]
+=
1
module_instance
.
_fairseq_instance_id
=
INCREMENTAL_STATE_INSTANCE_ID
[
module_name
]
return
'{}.{}.{}'
.
format
(
module_name
,
module_instance
.
_fairseq_instance_id
,
key
)
def
get_incremental_state
(
module
,
incremental_state
,
key
):
"""Helper for getting incremental state for an nn.Module."""
full_key
=
_get_full_incremental_state_key
(
module
,
key
)
if
incremental_state
is
None
or
full_key
not
in
incremental_state
:
return
None
return
incremental_state
[
full_key
]
def
set_incremental_state
(
module
,
incremental_state
,
key
,
value
):
"""Helper for setting incremental state for an nn.Module."""
if
incremental_state
is
not
None
:
full_key
=
_get_full_incremental_state_key
(
module
,
key
)
incremental_state
[
full_key
]
=
value
def
load_align_dict
(
replace_unk
):
def
load_align_dict
(
replace_unk
):
if
replace_unk
is
None
:
if
replace_unk
is
None
:
align_dict
=
None
align_dict
=
None
...
...
tests/utils.py
View file @
9438019f
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
import
torch
import
torch
from
torch.autograd
import
Variable
from
torch.autograd
import
Variable
from
fairseq
import
data
,
dictionary
from
fairseq
import
data
,
dictionary
,
utils
from
fairseq.models
import
(
from
fairseq.models
import
(
FairseqEncoder
,
FairseqEncoder
,
FairseqIncrementalDecoder
,
FairseqIncrementalDecoder
,
...
@@ -96,24 +96,21 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder):
...
@@ -96,24 +96,21 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder):
args
.
max_decoder_positions
=
getattr
(
args
,
'max_decoder_positions'
,
100
)
args
.
max_decoder_positions
=
getattr
(
args
,
'max_decoder_positions'
,
100
)
self
.
args
=
args
self
.
args
=
args
def
forward
(
self
,
prev_output_tokens
,
encoder_out
):
def
forward
(
self
,
prev_output_tokens
,
encoder_out
,
incremental_state
=
None
):
if
self
.
_is_
incremental_
eval
:
if
incremental_
state
is
not
None
:
prev_output_tokens
=
prev_output_tokens
[:,
-
1
:]
prev_output_tokens
=
prev_output_tokens
[:,
-
1
:]
return
self
.
_forward
(
prev_output_tokens
,
encoder_out
)
def
_forward
(
self
,
prev_output_tokens
,
encoder_out
):
bbsz
=
prev_output_tokens
.
size
(
0
)
bbsz
=
prev_output_tokens
.
size
(
0
)
vocab
=
len
(
self
.
dictionary
)
vocab
=
len
(
self
.
dictionary
)
src_len
=
encoder_out
.
size
(
1
)
src_len
=
encoder_out
.
size
(
1
)
tgt_len
=
prev_output_tokens
.
size
(
1
)
tgt_len
=
prev_output_tokens
.
size
(
1
)
# determine number of steps
# determine number of steps
if
self
.
_is_
incremental_
eval
:
if
incremental_
state
is
not
None
:
# cache step number
# cache step number
step
=
self
.
get_incremental_state
(
'step'
)
step
=
utils
.
get_incremental_state
(
self
,
incremental_state
,
'step'
)
if
step
is
None
:
if
step
is
None
:
step
=
0
step
=
0
self
.
set_incremental_state
(
'step'
,
step
+
1
)
utils
.
set_incremental_state
(
self
,
incremental_state
,
'step'
,
step
+
1
)
steps
=
[
step
]
steps
=
[
step
]
else
:
else
:
steps
=
list
(
range
(
tgt_len
))
steps
=
list
(
range
(
tgt_len
))
...
...
train.py
View file @
9438019f
...
@@ -6,8 +6,6 @@
...
@@ -6,8 +6,6 @@
# the root directory of this source tree. An additional grant of patent rights
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
# can be found in the PATENTS file in the same directory.
import
torch
from
fairseq
import
options
from
fairseq
import
options
from
distributed_train
import
main
as
distributed_main
from
distributed_train
import
main
as
distributed_main
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment