Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
9438019f
Commit
9438019f
authored
Feb 24, 2018
by
Myle Ott
Committed by
Sergey Edunov
Feb 27, 2018
Browse files
Refactor incremental generation to be more explicit and less magical (#222)
parent
e7094b14
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
145 additions
and
219 deletions
+145
-219
fairseq/models/fairseq_incremental_decoder.py
fairseq/models/fairseq_incremental_decoder.py
+9
-88
fairseq/models/fconv.py
fairseq/models/fconv.py
+29
-29
fairseq/models/lstm.py
fairseq/models/lstm.py
+24
-37
fairseq/modules/learned_positional_embedding.py
fairseq/modules/learned_positional_embedding.py
+2
-6
fairseq/modules/linearized_convolution.py
fairseq/modules/linearized_convolution.py
+32
-35
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+12
-13
fairseq/utils.py
fairseq/utils.py
+31
-0
tests/utils.py
tests/utils.py
+6
-9
train.py
train.py
+0
-2
No files found.
fairseq/models/fairseq_incremental_decoder.py
View file @
9438019f
...
...
@@ -13,100 +13,21 @@ class FairseqIncrementalDecoder(FairseqDecoder):
def
__init__
(
self
,
dictionary
):
super
().
__init__
(
dictionary
)
self
.
_is_incremental_eval
=
False
self
.
_incremental_state
=
{}
def
forward
(
self
,
prev_output_tokens
,
encoder_out
):
if
self
.
_is_incremental_eval
:
raise
NotImplementedError
else
:
raise
NotImplementedError
def
forward
(
self
,
prev_output_tokens
,
encoder_out
,
incremental_state
=
None
):
raise
NotImplementedError
def
incremental_inference
(
self
):
"""Context manager for incremental inference.
This provides an optimized forward pass for incremental inference
(i.e., it predicts one time step at a time). If the input order changes
between time steps, call reorder_incremental_state to update the
relevant buffers. To generate a fresh sequence, first call
clear_incremental_state.
Usage:
```
with model.decoder.incremental_inference():
for step in range(maxlen):
out, _ = model.decoder(tokens[:, :step], encoder_out)
probs = model.get_normalized_probs(out[:, -1, :], log_probs=False)
```
"""
class
IncrementalInference
(
object
):
def
__init__
(
self
,
decoder
):
self
.
decoder
=
decoder
def
__enter__
(
self
):
self
.
decoder
.
incremental_eval
(
True
)
def
__exit__
(
self
,
*
args
):
self
.
decoder
.
incremental_eval
(
False
)
return
IncrementalInference
(
self
)
def
incremental_eval
(
self
,
mode
=
True
):
"""Sets the decoder and all children in incremental evaluation mode."""
assert
self
.
_is_incremental_eval
!=
mode
,
\
'incremental_eval already set to mode {}'
.
format
(
mode
)
self
.
_is_incremental_eval
=
mode
if
mode
:
self
.
clear_incremental_state
()
def
apply_incremental_eval
(
module
):
if
module
!=
self
and
hasattr
(
module
,
'incremental_eval'
):
module
.
incremental_eval
(
mode
)
self
.
apply
(
apply_incremental_eval
)
def
get_incremental_state
(
self
,
key
):
"""Return cached state or None if not in incremental inference mode."""
if
self
.
_is_incremental_eval
and
key
in
self
.
_incremental_state
:
return
self
.
_incremental_state
[
key
]
return
None
def
set_incremental_state
(
self
,
key
,
value
):
"""Cache state needed for incremental inference mode."""
if
self
.
_is_incremental_eval
:
self
.
_incremental_state
[
key
]
=
value
return
value
def
clear_incremental_state
(
self
):
"""Clear all state used for incremental generation.
**For incremental inference only**
This should be called before generating a fresh sequence.
beam_size is required if using BeamableMM.
"""
if
self
.
_is_incremental_eval
:
del
self
.
_incremental_state
self
.
_incremental_state
=
{}
def
apply_clear_incremental_state
(
module
):
if
module
!=
self
and
hasattr
(
module
,
'clear_incremental_state'
):
module
.
clear_incremental_state
()
self
.
apply
(
apply_clear_incremental_state
)
def
reorder_incremental_state
(
self
,
new_order
):
"""Reorder buffered internal state (for incremental generation).
**For incremental inference only**
def
reorder_incremental_state
(
self
,
incremental_state
,
new_order
):
"""Reorder incremental state.
This should be called when the order of the input has changed from the
previous time step. A typical use case is beam search, where the input
order changes between time steps based on the
choice
of beams.
order changes between time steps based on the
selection
of beams.
"""
if
self
.
_is_incremental_eval
:
def
apply_reorder_incremental_state
(
module
):
if
module
!=
self
and
hasattr
(
module
,
'reorder_incremental_state'
):
module
.
reorder_incremental_state
(
new_order
)
self
.
apply
(
apply_reorder_incremental_state
)
def
apply_reorder_incremental_state
(
module
):
if
module
!=
self
and
hasattr
(
module
,
'reorder_incremental_state'
):
module
.
reorder_incremental_state
(
incremental_state
,
new_order
)
self
.
apply
(
apply_reorder_incremental_state
)
def
set_beam_size
(
self
,
beam_size
):
"""Sets the beam size in the decoder and all children."""
...
...
fairseq/models/fconv.py
View file @
9438019f
...
...
@@ -10,6 +10,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
fairseq
import
utils
from
fairseq.data
import
LanguagePairDataset
from
fairseq.modules
import
BeamableMM
,
GradMultiply
,
LearnedPositionalEmbedding
,
LinearizedConvolution
...
...
@@ -229,19 +230,13 @@ class FConvDecoder(FairseqIncrementalDecoder):
else
:
self
.
fc3
=
Linear
(
out_embed_dim
,
num_embeddings
,
dropout
=
dropout
)
def
forward
(
self
,
prev_output_tokens
,
encoder_out
):
def
forward
(
self
,
prev_output_tokens
,
encoder_out
,
incremental_state
=
None
):
# split and transpose encoder outputs
encoder_a
,
encoder_b
=
self
.
_split_encoder_out
(
encoder_out
)
encoder_a
,
encoder_b
=
self
.
_split_encoder_out
(
encoder_out
,
incremental_state
)
# embed positions
positions
=
self
.
embed_positions
(
prev_output_tokens
)
if
self
.
_is_incremental_eval
:
# keep only the last token for incremental forward pass
prev_output_tokens
=
prev_output_tokens
[:,
-
1
:]
# embed tokens and positions
x
=
self
.
embed_tokens
(
prev_output_tokens
)
+
positions
# embed tokens and combine with positional embeddings
x
=
self
.
_embed_tokens
(
prev_output_tokens
,
incremental_state
)
x
+=
self
.
embed_positions
(
prev_output_tokens
,
incremental_state
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
target_embedding
=
x
...
...
@@ -249,7 +244,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
x
=
self
.
fc1
(
x
)
# B x T x C -> T x B x C
x
=
self
.
_transpose_
unless_
incremental_
eval
(
x
)
x
=
self
.
_transpose_
if_training
(
x
,
incremental_
state
)
# temporal convolutions
avg_attn_scores
=
None
...
...
@@ -258,13 +253,14 @@ class FConvDecoder(FairseqIncrementalDecoder):
residual
=
x
if
proj
is
None
else
proj
(
x
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
conv
(
x
)
x
=
conv
.
remove_future_timesteps
(
x
)
x
=
conv
(
x
,
incremental_state
)
if
incremental_state
is
None
:
x
=
conv
.
remove_future_timesteps
(
x
)
x
=
F
.
glu
(
x
,
dim
=
2
)
# attention
if
attention
is
not
None
:
x
=
self
.
_transpose_
unless_
incremental_
eval
(
x
)
x
=
self
.
_transpose_
if_training
(
x
,
incremental_
state
)
x
,
attn_scores
=
attention
(
x
,
target_embedding
,
(
encoder_a
,
encoder_b
))
attn_scores
=
attn_scores
/
num_attn_layers
...
...
@@ -273,13 +269,13 @@ class FConvDecoder(FairseqIncrementalDecoder):
else
:
avg_attn_scores
.
add_
(
attn_scores
)
x
=
self
.
_transpose_
unless_
incremental_
eval
(
x
)
x
=
self
.
_transpose_
if_training
(
x
,
incremental_
state
)
# residual
x
=
(
x
+
residual
)
*
math
.
sqrt
(
0.5
)
# T x B x C -> B x T x C
x
=
self
.
_transpose_
unless_
incremental_
eval
(
x
)
x
=
self
.
_transpose_
if_training
(
x
,
incremental_
state
)
# project back to size of vocabulary
x
=
self
.
fc2
(
x
)
...
...
@@ -288,10 +284,6 @@ class FConvDecoder(FairseqIncrementalDecoder):
return
x
,
avg_attn_scores
def
reorder_incremental_state
(
self
,
new_order
):
"""Reorder buffered internal state (for incremental generation)."""
super
().
reorder_incremental_state
(
new_order
)
def
max_positions
(
self
):
"""Maximum output length supported by the decoder."""
return
self
.
embed_positions
.
max_positions
()
...
...
@@ -306,13 +298,19 @@ class FConvDecoder(FairseqIncrementalDecoder):
state_dict
[
'decoder.version'
]
=
torch
.
Tensor
([
1
])
return
state_dict
def
_split_encoder_out
(
self
,
encoder_out
):
def
_embed_tokens
(
self
,
tokens
,
incremental_state
):
if
incremental_state
is
not
None
:
# keep only the last token for incremental forward pass
tokens
=
tokens
[:,
-
1
:]
return
self
.
embed_tokens
(
tokens
)
def
_split_encoder_out
(
self
,
encoder_out
,
incremental_state
):
"""Split and transpose encoder outputs.
This is cached when doing incremental inference.
"""
cached_result
=
self
.
get_incremental_state
(
'encoder_out'
)
if
cached_result
:
cached_result
=
utils
.
get_incremental_state
(
self
,
incremental_state
,
'encoder_out'
)
if
cached_result
is
not
None
:
return
cached_result
# transpose only once to speed up attention layers
...
...
@@ -320,12 +318,14 @@ class FConvDecoder(FairseqIncrementalDecoder):
encoder_a
=
encoder_a
.
transpose
(
1
,
2
).
contiguous
()
result
=
(
encoder_a
,
encoder_b
)
return
self
.
set_incremental_state
(
'encoder_out'
,
result
)
if
incremental_state
is
not
None
:
utils
.
set_incremental_state
(
self
,
incremental_state
,
'encoder_out'
,
result
)
return
result
def
_transpose_
unless_incremental_eval
(
self
,
x
):
if
self
.
_is_
incremental_
eval
:
return
x
return
x
.
transpose
(
0
,
1
)
def
_transpose_
if_training
(
self
,
x
,
incremental_state
):
if
incremental_
state
is
None
:
x
=
x
.
transpose
(
0
,
1
)
return
x
def
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
):
...
...
fairseq/models/lstm.py
View file @
9438019f
...
...
@@ -183,12 +183,9 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self
.
additional_fc
=
Linear
(
embed_dim
,
out_embed_dim
)
self
.
fc_out
=
Linear
(
out_embed_dim
,
num_embeddings
,
dropout
=
dropout_out
)
def
forward
(
self
,
prev_output_tokens
,
encoder_out
):
if
self
.
_is_
incremental_
eval
:
def
forward
(
self
,
prev_output_tokens
,
encoder_out
,
incremental_state
=
None
):
if
incremental_
state
is
not
None
:
prev_output_tokens
=
prev_output_tokens
[:,
-
1
:]
return
self
.
_forward
(
prev_output_tokens
,
encoder_out
)
def
_forward
(
self
,
prev_output_tokens
,
encoder_out
):
bsz
,
seqlen
=
prev_output_tokens
.
size
()
# get outputs from encoder
...
...
@@ -204,15 +201,15 @@ class LSTMDecoder(FairseqIncrementalDecoder):
x
=
x
.
transpose
(
0
,
1
)
# initialize previous states (or get from cache during incremental generation)
prev_hiddens
=
self
.
get_incremental_state
(
'prev_hiddens'
)
if
not
prev_hiddens
:
# first time step, initialize previous states
prev_hiddens
,
prev_cells
=
self
.
_init_prev_states
(
encoder_out
)
input_feed
=
Variable
(
x
.
data
.
new
(
bsz
,
embed_dim
).
zero_
())
cached_state
=
utils
.
get_incremental_state
(
self
,
incremental_state
,
'cached_state'
)
if
cached_state
is
not
None
:
prev_hiddens
,
prev_cells
,
input_feed
=
cached_state
else
:
# previous states are cached
prev_cells
=
self
.
get_incremental_state
(
'prev_cells'
)
input_feed
=
self
.
get_incremental_state
(
'input_feed'
)
_
,
encoder_hiddens
,
encoder_cells
=
encoder_out
num_layers
=
len
(
self
.
layers
)
prev_hiddens
=
[
encoder_hiddens
[
i
]
for
i
in
range
(
num_layers
)]
prev_cells
=
[
encoder_cells
[
i
]
for
i
in
range
(
num_layers
)]
input_feed
=
Variable
(
x
.
data
.
new
(
bsz
,
embed_dim
).
zero_
())
attn_scores
=
Variable
(
x
.
data
.
new
(
srclen
,
seqlen
,
bsz
).
zero_
())
outs
=
[]
...
...
@@ -242,9 +239,8 @@ class LSTMDecoder(FairseqIncrementalDecoder):
outs
.
append
(
out
)
# cache previous states (no-op except during incremental generation)
self
.
set_incremental_state
(
'prev_hiddens'
,
prev_hiddens
)
self
.
set_incremental_state
(
'prev_cells'
,
prev_cells
)
self
.
set_incremental_state
(
'input_feed'
,
input_feed
)
utils
.
set_incremental_state
(
self
,
incremental_state
,
'cached_state'
,
(
prev_hiddens
,
prev_cells
,
input_feed
))
# collect outputs across time steps
x
=
torch
.
cat
(
outs
,
dim
=
0
).
view
(
seqlen
,
bsz
,
embed_dim
)
...
...
@@ -263,34 +259,25 @@ class LSTMDecoder(FairseqIncrementalDecoder):
return
x
,
attn_scores
def
reorder_incremental_state
(
self
,
new_order
):
"""Reorder buffered intern
al
state
(for
incremental
generation)."""
super
().
reorder_incremental_state
(
new_order
)
new_order
=
Variable
(
new_order
)
def
reorder_incremental_state
(
self
,
incremental_state
,
new_order
):
cached_state
=
utils
.
get_increment
al
_
state
(
self
,
incremental
_state
,
'cached_state'
)
if
cached_state
is
None
:
return
def
reorder_state
(
key
):
old
=
self
.
get_incremental_state
(
key
)
if
isinstance
(
old
,
list
):
new
=
[
old_i
.
index_select
(
0
,
new_order
)
for
old_i
in
old
]
else
:
new
=
old
.
index_select
(
0
,
new_order
)
self
.
set_incremental_state
(
key
,
new
)
def
reorder_state
(
state
):
if
isinstance
(
state
,
list
):
return
[
reorder_state
(
state_i
)
for
state_i
in
state
]
return
state
.
index_select
(
0
,
new_order
)
reorder_state
(
'prev_hiddens'
)
reorder_state
(
'prev_cells'
)
reorder_state
(
'input_feed'
)
if
not
isinstance
(
new_order
,
Variable
):
new_order
=
Variable
(
new_order
)
new_state
=
tuple
(
map
(
reorder_state
,
cached_state
))
utils
.
set_incremental_state
(
self
,
incremental_state
,
'cached_state'
,
new_state
)
def
max_positions
(
self
):
"""Maximum output length supported by the decoder."""
return
int
(
1e5
)
# an arbitrary large number
def
_init_prev_states
(
self
,
encoder_out
):
_
,
encoder_hiddens
,
encoder_cells
=
encoder_out
num_layers
=
len
(
self
.
layers
)
prev_hiddens
=
[
encoder_hiddens
[
i
]
for
i
in
range
(
num_layers
)]
prev_cells
=
[
encoder_cells
[
i
]
for
i
in
range
(
num_layers
)]
return
prev_hiddens
,
prev_cells
def
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
):
m
=
nn
.
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
)
...
...
fairseq/modules/learned_positional_embedding.py
View file @
9438019f
...
...
@@ -20,14 +20,10 @@ class LearnedPositionalEmbedding(nn.Embedding):
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
padding_idx
,
left_pad
):
super
().
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
)
self
.
left_pad
=
left_pad
self
.
_is_incremental_eval
=
False
def
incremental_eval
(
self
,
mode
=
True
):
self
.
_is_incremental_eval
=
mode
def
forward
(
self
,
input
):
def
forward
(
self
,
input
,
incremental_state
=
None
):
"""Input is expected to be of size [bsz x seqlen]."""
if
self
.
_is_
incremental_
eval
:
if
incremental_
state
is
not
None
:
# positions is the same for every token when decoding a single step
positions
=
Variable
(
input
.
data
.
new
(
1
,
1
).
fill_
(
self
.
padding_idx
+
input
.
size
(
1
)))
...
...
fairseq/modules/linearized_convolution.py
View file @
9438019f
...
...
@@ -22,35 +22,20 @@ class LinearizedConvolution(ConvTBC):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
**
kwargs
):
super
().
__init__
(
in_channels
,
out_channels
,
kernel_size
,
**
kwargs
)
self
.
_is_incremental_eval
=
False
self
.
_linearized_weight
=
None
self
.
register_backward_hook
(
self
.
_clear_linearized_weight
)
def
remove_future_timesteps
(
self
,
x
):
"""Remove future time steps created by padding."""
if
not
self
.
_is_incremental_eval
and
self
.
kernel_size
[
0
]
>
1
and
self
.
padding
[
0
]
>
0
:
x
=
x
[:
-
self
.
padding
[
0
],
:,
:]
return
x
def
incremental_eval
(
self
,
mode
=
True
):
self
.
_is_incremental_eval
=
mode
if
mode
:
self
.
clear_incremental_state
()
def
forward
(
self
,
input
):
if
self
.
_is_incremental_eval
:
return
self
.
incremental_forward
(
input
)
else
:
def
forward
(
self
,
input
,
incremental_state
=
None
):
"""
Input: Time x Batch x Channel.
Args:
incremental_state: Used to buffer signal; if not None, then input is
expected to contain a single frame. If the input order changes
between time steps, call reorder_incremental_state.
"""
if
incremental_state
is
None
:
return
super
().
forward
(
input
)
def
incremental_forward
(
self
,
input
):
"""Forward convolution one time step at a time.
This function maintains an internal state to buffer signal and accepts
a single frame as input. If the input order changes between time steps,
call reorder_incremental_state. To apply to fresh inputs, call
clear_incremental_state.
"""
# reshape weight
weight
=
self
.
_get_linearized_weight
()
kw
=
self
.
kernel_size
[
0
]
...
...
@@ -58,25 +43,37 @@ class LinearizedConvolution(ConvTBC):
bsz
=
input
.
size
(
0
)
# input: bsz x len x dim
if
kw
>
1
:
input
=
input
.
data
if
self
.
input_buffer
is
None
:
self
.
input_buffer
=
input
.
new
(
bsz
,
kw
,
input
.
size
(
2
))
self
.
input_buffer
.
zero_
()
input_buffer
=
self
.
_get_input_buffer
(
incremental_state
)
if
input_buffer
is
None
:
input_buffer
=
input
.
new
(
bsz
,
kw
,
input
.
size
(
2
)).
zero_
()
self
.
_set_input_buffer
(
incremental_state
,
input_buffer
)
else
:
# shift buffer
self
.
input_buffer
[:,
:
-
1
,
:]
=
self
.
input_buffer
[:,
1
:,
:].
clone
()
input_buffer
[:,
:
-
1
,
:]
=
input_buffer
[:,
1
:,
:].
clone
()
# append next input
self
.
input_buffer
[:,
-
1
,
:]
=
input
[:,
-
1
,
:]
input
=
utils
.
volatile_variable
(
self
.
input_buffer
)
input_buffer
[:,
-
1
,
:]
=
input
[:,
-
1
,
:]
input
=
utils
.
volatile_variable
(
input_buffer
)
with
utils
.
maybe_no_grad
():
output
=
F
.
linear
(
input
.
view
(
bsz
,
-
1
),
weight
,
self
.
bias
)
return
output
.
view
(
bsz
,
1
,
-
1
)
def
clear_incremental_state
(
self
):
self
.
input_buffer
=
None
def
remove_future_timesteps
(
self
,
x
):
"""Remove future time steps created by padding."""
if
self
.
kernel_size
[
0
]
>
1
and
self
.
padding
[
0
]
>
0
:
x
=
x
[:
-
self
.
padding
[
0
],
:,
:]
return
x
def
reorder_incremental_state
(
self
,
incremental_state
,
new_order
):
input_buffer
=
self
.
_get_input_buffer
(
incremental_state
)
if
input_buffer
is
not
None
:
input_buffer
=
input_buffer
.
index_select
(
0
,
new_order
)
self
.
_set_input_buffer
(
incremental_state
,
input_buffer
)
def
_get_input_buffer
(
self
,
incremental_state
):
return
utils
.
get_incremental_state
(
self
,
incremental_state
,
'input_buffer'
)
def
reorder_incremental_state
(
self
,
new_order
):
if
self
.
input_buffer
is
not
None
:
self
.
input_buffer
=
self
.
input_buffer
.
index_select
(
0
,
new_order
)
def
_set_input_buffer
(
self
,
incremental_state
,
new_buffer
):
return
utils
.
set_incremental_state
(
self
,
incremental_state
,
'input_buffer'
,
new_buffer
)
def
_get_linearized_weight
(
self
):
if
self
.
_linearized_weight
is
None
:
...
...
fairseq/sequence_generator.py
View file @
9438019f
...
...
@@ -5,7 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from
contextlib
import
ExitStack
import
math
import
torch
...
...
@@ -87,12 +86,8 @@ class SequenceGenerator(object):
def
generate
(
self
,
src_tokens
,
src_lengths
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
):
"""Generate a batch of translations."""
with
ExitStack
()
as
stack
:
for
model
in
self
.
models
:
if
isinstance
(
model
.
decoder
,
FairseqIncrementalDecoder
):
stack
.
enter_context
(
model
.
decoder
.
incremental_inference
())
with
utils
.
maybe_no_grad
():
return
self
.
_generate
(
src_tokens
,
src_lengths
,
beam_size
,
maxlen
,
prefix_tokens
)
with
utils
.
maybe_no_grad
():
return
self
.
_generate
(
src_tokens
,
src_lengths
,
beam_size
,
maxlen
,
prefix_tokens
)
def
_generate
(
self
,
src_tokens
,
src_lengths
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
):
bsz
,
srclen
=
src_tokens
.
size
()
...
...
@@ -103,11 +98,14 @@ class SequenceGenerator(object):
beam_size
=
min
(
beam_size
,
self
.
vocab_size
-
1
)
encoder_outs
=
[]
incremental_states
=
{}
for
model
in
self
.
models
:
if
not
self
.
retain_dropout
:
model
.
eval
()
if
isinstance
(
model
.
decoder
,
FairseqIncrementalDecoder
):
model
.
decoder
.
set_beam_size
(
beam_size
)
incremental_states
[
model
]
=
{}
else
:
incremental_states
[
model
]
=
None
# compute the encoder output for each beam
encoder_out
=
model
.
encoder
(
...
...
@@ -245,9 +243,11 @@ class SequenceGenerator(object):
if
reorder_state
is
not
None
:
for
model
in
self
.
models
:
if
isinstance
(
model
.
decoder
,
FairseqIncrementalDecoder
):
model
.
decoder
.
reorder_incremental_state
(
reorder_state
)
model
.
decoder
.
reorder_incremental_state
(
incremental_states
[
model
],
reorder_state
)
probs
,
avg_attn_scores
=
self
.
_decode
(
tokens
[:,
:
step
+
1
],
encoder_outs
)
probs
,
avg_attn_scores
=
self
.
_decode
(
tokens
[:,
:
step
+
1
],
encoder_outs
,
incremental_states
)
if
step
==
0
:
# at the first step all hypotheses are equally likely, so use
# only the first beam
...
...
@@ -287,7 +287,6 @@ class SequenceGenerator(object):
)
torch
.
div
(
cand_indices
,
self
.
vocab_size
,
out
=
cand_beams
)
cand_indices
.
fmod_
(
self
.
vocab_size
)
else
:
# finalize all active hypotheses once we hit maxlen
# pick the hypothesis with the highest prob of EOS right now
...
...
@@ -403,7 +402,7 @@ class SequenceGenerator(object):
return
finalized
def
_decode
(
self
,
tokens
,
encoder_outs
):
def
_decode
(
self
,
tokens
,
encoder_outs
,
incremental_states
):
# wrap in Variable
tokens
=
utils
.
volatile_variable
(
tokens
)
...
...
@@ -411,7 +410,7 @@ class SequenceGenerator(object):
avg_attn
=
None
for
model
,
encoder_out
in
zip
(
self
.
models
,
encoder_outs
):
with
utils
.
maybe_no_grad
():
decoder_out
,
attn
=
model
.
decoder
(
tokens
,
encoder_out
)
decoder_out
,
attn
=
model
.
decoder
(
tokens
,
encoder_out
,
incremental_states
[
model
]
)
probs
=
model
.
get_normalized_probs
(
decoder_out
[:,
-
1
,
:],
log_probs
=
False
).
data
if
avg_probs
is
None
:
avg_probs
=
probs
...
...
fairseq/utils.py
View file @
9438019f
...
...
@@ -5,6 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from
collections
import
defaultdict
import
contextlib
import
logging
import
os
...
...
@@ -198,6 +199,36 @@ def make_variable(sample, volatile=False, cuda=False):
return
_make_variable
(
sample
)
INCREMENTAL_STATE_INSTANCE_ID
=
defaultdict
(
lambda
:
0
)
def
_get_full_incremental_state_key
(
module_instance
,
key
):
module_name
=
module_instance
.
__class__
.
__name__
# assign a unique ID to each module instance, so that incremental state is
# not shared across module instances
if
not
hasattr
(
module_instance
,
'_fairseq_instance_id'
):
INCREMENTAL_STATE_INSTANCE_ID
[
module_name
]
+=
1
module_instance
.
_fairseq_instance_id
=
INCREMENTAL_STATE_INSTANCE_ID
[
module_name
]
return
'{}.{}.{}'
.
format
(
module_name
,
module_instance
.
_fairseq_instance_id
,
key
)
def
get_incremental_state
(
module
,
incremental_state
,
key
):
"""Helper for getting incremental state for an nn.Module."""
full_key
=
_get_full_incremental_state_key
(
module
,
key
)
if
incremental_state
is
None
or
full_key
not
in
incremental_state
:
return
None
return
incremental_state
[
full_key
]
def
set_incremental_state
(
module
,
incremental_state
,
key
,
value
):
"""Helper for setting incremental state for an nn.Module."""
if
incremental_state
is
not
None
:
full_key
=
_get_full_incremental_state_key
(
module
,
key
)
incremental_state
[
full_key
]
=
value
def
load_align_dict
(
replace_unk
):
if
replace_unk
is
None
:
align_dict
=
None
...
...
tests/utils.py
View file @
9438019f
...
...
@@ -9,7 +9,7 @@
import
torch
from
torch.autograd
import
Variable
from
fairseq
import
data
,
dictionary
from
fairseq
import
data
,
dictionary
,
utils
from
fairseq.models
import
(
FairseqEncoder
,
FairseqIncrementalDecoder
,
...
...
@@ -96,24 +96,21 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder):
args
.
max_decoder_positions
=
getattr
(
args
,
'max_decoder_positions'
,
100
)
self
.
args
=
args
def
forward
(
self
,
prev_output_tokens
,
encoder_out
):
if
self
.
_is_
incremental_
eval
:
def
forward
(
self
,
prev_output_tokens
,
encoder_out
,
incremental_state
=
None
):
if
incremental_
state
is
not
None
:
prev_output_tokens
=
prev_output_tokens
[:,
-
1
:]
return
self
.
_forward
(
prev_output_tokens
,
encoder_out
)
def
_forward
(
self
,
prev_output_tokens
,
encoder_out
):
bbsz
=
prev_output_tokens
.
size
(
0
)
vocab
=
len
(
self
.
dictionary
)
src_len
=
encoder_out
.
size
(
1
)
tgt_len
=
prev_output_tokens
.
size
(
1
)
# determine number of steps
if
self
.
_is_
incremental_
eval
:
if
incremental_
state
is
not
None
:
# cache step number
step
=
self
.
get_incremental_state
(
'step'
)
step
=
utils
.
get_incremental_state
(
self
,
incremental_state
,
'step'
)
if
step
is
None
:
step
=
0
self
.
set_incremental_state
(
'step'
,
step
+
1
)
utils
.
set_incremental_state
(
self
,
incremental_state
,
'step'
,
step
+
1
)
steps
=
[
step
]
else
:
steps
=
list
(
range
(
tgt_len
))
...
...
train.py
View file @
9438019f
...
...
@@ -6,8 +6,6 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
torch
from
fairseq
import
options
from
distributed_train
import
main
as
distributed_main
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment