Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
6ec5022e
Commit
6ec5022e
authored
Jun 21, 2018
by
Myle Ott
Browse files
Move reorder_encoder_out to FairseqEncoder and fix non-incremental decoding
parent
e9967cd3
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
58 additions
and
45 deletions
+58
-45
fairseq/models/composite_encoder.py
fairseq/models/composite_encoder.py
+6
-0
fairseq/models/fairseq_encoder.py
fairseq/models/fairseq_encoder.py
+4
-0
fairseq/models/fairseq_incremental_decoder.py
fairseq/models/fairseq_incremental_decoder.py
+0
-3
fairseq/models/fconv.py
fairseq/models/fconv.py
+11
-6
fairseq/models/fconv_self_att.py
fairseq/models/fconv_self_att.py
+14
-19
fairseq/models/lstm.py
fairseq/models/lstm.py
+10
-10
fairseq/models/transformer.py
fairseq/models/transformer.py
+9
-6
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+1
-1
tests/utils.py
tests/utils.py
+3
-0
No files found.
fairseq/models/composite_encoder.py
View file @
6ec5022e
...
@@ -26,6 +26,12 @@ class CompositeEncoder(FairseqEncoder):
...
@@ -26,6 +26,12 @@ class CompositeEncoder(FairseqEncoder):
encoder_out
[
key
]
=
self
.
encoders
[
key
](
src_tokens
,
src_lengths
)
encoder_out
[
key
]
=
self
.
encoders
[
key
](
src_tokens
,
src_lengths
)
return
encoder_out
return
encoder_out
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
"""Reorder encoder output according to new_order."""
for
key
in
self
.
encoders
:
encoder_out
[
key
]
=
self
.
encoders
[
key
].
reorder_encoder_out
(
encoder_out
[
key
],
new_order
)
return
encoder_out
def
max_positions
(
self
):
def
max_positions
(
self
):
return
min
([
self
.
encoders
[
key
].
max_positions
()
for
key
in
self
.
encoders
])
return
min
([
self
.
encoders
[
key
].
max_positions
()
for
key
in
self
.
encoders
])
...
...
fairseq/models/fairseq_encoder.py
View file @
6ec5022e
...
@@ -18,6 +18,10 @@ class FairseqEncoder(nn.Module):
...
@@ -18,6 +18,10 @@ class FairseqEncoder(nn.Module):
def
forward
(
self
,
src_tokens
,
src_lengths
):
def
forward
(
self
,
src_tokens
,
src_lengths
):
raise
NotImplementedError
raise
NotImplementedError
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
"""Reorder encoder output according to new_order."""
raise
NotImplementedError
def
max_positions
(
self
):
def
max_positions
(
self
):
"""Maximum input length supported by the encoder."""
"""Maximum input length supported by the encoder."""
raise
NotImplementedError
raise
NotImplementedError
...
...
fairseq/models/fairseq_incremental_decoder.py
View file @
6ec5022e
...
@@ -32,9 +32,6 @@ class FairseqIncrementalDecoder(FairseqDecoder):
...
@@ -32,9 +32,6 @@ class FairseqIncrementalDecoder(FairseqDecoder):
)
)
self
.
apply
(
apply_reorder_incremental_state
)
self
.
apply
(
apply_reorder_incremental_state
)
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
return
encoder_out
def
set_beam_size
(
self
,
beam_size
):
def
set_beam_size
(
self
,
beam_size
):
"""Sets the beam size in the decoder and all children."""
"""Sets the beam size in the decoder and all children."""
if
getattr
(
self
,
'_beam_size'
,
-
1
)
!=
beam_size
:
if
getattr
(
self
,
'_beam_size'
,
-
1
)
!=
beam_size
:
...
...
fairseq/models/fconv.py
View file @
6ec5022e
...
@@ -268,6 +268,17 @@ class FConvEncoder(FairseqEncoder):
...
@@ -268,6 +268,17 @@ class FConvEncoder(FairseqEncoder):
'encoder_padding_mask'
:
encoder_padding_mask
,
# B x T
'encoder_padding_mask'
:
encoder_padding_mask
,
# B x T
}
}
def
reorder_encoder_out
(
self
,
encoder_out_dict
,
new_order
):
if
encoder_out_dict
[
'encoder_out'
]
is
not
None
:
encoder_out_dict
[
'encoder_out'
]
=
(
encoder_out_dict
[
'encoder_out'
][
0
].
index_select
(
0
,
new_order
),
encoder_out_dict
[
'encoder_out'
][
1
].
index_select
(
0
,
new_order
),
)
if
encoder_out_dict
[
'encoder_padding_mask'
]
is
not
None
:
encoder_out_dict
[
'encoder_padding_mask'
]
=
\
encoder_out_dict
[
'encoder_padding_mask'
].
index_select
(
0
,
new_order
)
return
encoder_out_dict
def
max_positions
(
self
):
def
max_positions
(
self
):
"""Maximum input length supported by the encoder."""
"""Maximum input length supported by the encoder."""
return
self
.
embed_positions
.
max_positions
()
return
self
.
embed_positions
.
max_positions
()
...
@@ -496,12 +507,6 @@ class FConvDecoder(FairseqIncrementalDecoder):
...
@@ -496,12 +507,6 @@ class FConvDecoder(FairseqIncrementalDecoder):
encoder_out
=
tuple
(
eo
.
index_select
(
0
,
new_order
)
for
eo
in
encoder_out
)
encoder_out
=
tuple
(
eo
.
index_select
(
0
,
new_order
)
for
eo
in
encoder_out
)
utils
.
set_incremental_state
(
self
,
incremental_state
,
'encoder_out'
,
encoder_out
)
utils
.
set_incremental_state
(
self
,
incremental_state
,
'encoder_out'
,
encoder_out
)
def
reorder_encoder_out
(
self
,
encoder_out_dict
,
new_order
):
if
encoder_out_dict
[
'encoder_padding_mask'
]
is
not
None
:
encoder_out_dict
[
'encoder_padding_mask'
]
=
\
encoder_out_dict
[
'encoder_padding_mask'
].
index_select
(
0
,
new_order
)
return
encoder_out_dict
def
max_positions
(
self
):
def
max_positions
(
self
):
"""Maximum output length supported by the decoder."""
"""Maximum output length supported by the decoder."""
return
self
.
embed_positions
.
max_positions
()
if
self
.
embed_positions
is
not
None
else
float
(
'inf'
)
return
self
.
embed_positions
.
max_positions
()
if
self
.
embed_positions
is
not
None
else
float
(
'inf'
)
...
...
fairseq/models/fconv_self_att.py
View file @
6ec5022e
...
@@ -226,6 +226,19 @@ class FConvEncoder(FairseqEncoder):
...
@@ -226,6 +226,19 @@ class FConvEncoder(FairseqEncoder):
'encoder_out'
:
(
x
,
y
),
'encoder_out'
:
(
x
,
y
),
}
}
def
reorder_encoder_out
(
self
,
encoder_out_dict
,
new_order
):
encoder_out_dict
[
'encoder_out'
]
=
tuple
(
eo
.
index_select
(
0
,
new_order
)
for
eo
in
encoder_out_dict
[
'encoder_out'
]
)
if
'pretrained'
in
encoder_out_dict
:
encoder_out_dict
[
'pretrained'
][
'encoder_out'
]
=
tuple
(
eo
.
index_select
(
0
,
new_order
)
for
eo
in
encoder_out_dict
[
'pretrained'
][
'encoder_out'
]
)
return
encoder_out_dict
def
max_positions
(
self
):
def
max_positions
(
self
):
"""Maximum input length supported by the encoder."""
"""Maximum input length supported by the encoder."""
return
self
.
embed_positions
.
max_positions
()
return
self
.
embed_positions
.
max_positions
()
...
@@ -409,30 +422,12 @@ class FConvDecoder(FairseqDecoder):
...
@@ -409,30 +422,12 @@ class FConvDecoder(FairseqDecoder):
else
:
else
:
return
x
,
avg_attn_scores
return
x
,
avg_attn_scores
def
reorder_incremental_state
(
self
,
incremental_state
,
new_order
):
"""Reorder buffered internal state (for incremental generation)."""
super
().
reorder_incremental_state
(
incremental_state
,
new_order
)
def
reorder_encoder_out
(
self
,
encoder_out_dict
,
new_order
):
encoder_out_dict
[
'encoder'
][
'encoder_out'
]
=
tuple
(
eo
.
index_select
(
0
,
new_order
)
for
eo
in
encoder_out_dict
[
'encoder'
][
'encoder_out'
]
)
if
'pretrained'
in
encoder_out_dict
:
encoder_out_dict
[
'pretrained'
][
'encoder'
][
'encoder_out'
]
=
tuple
(
eo
.
index_select
(
0
,
new_order
)
for
eo
in
encoder_out_dict
[
'pretrained'
][
'encoder'
][
'encoder_out'
]
)
return
encoder_out_dict
def
max_positions
(
self
):
def
max_positions
(
self
):
"""Maximum output length supported by the decoder."""
"""Maximum output length supported by the decoder."""
return
self
.
embed_positions
.
max_positions
()
return
self
.
embed_positions
.
max_positions
()
def
_split_encoder_out
(
self
,
encoder_out
):
def
_split_encoder_out
(
self
,
encoder_out
):
"""Split and transpose encoder outputs.
"""Split and transpose encoder outputs."""
"""
# transpose only once to speed up attention layers
# transpose only once to speed up attention layers
encoder_a
,
encoder_b
=
encoder_out
encoder_a
,
encoder_b
=
encoder_out
encoder_a
=
encoder_a
.
transpose
(
0
,
1
).
contiguous
()
encoder_a
=
encoder_a
.
transpose
(
0
,
1
).
contiguous
()
...
...
fairseq/models/lstm.py
View file @
6ec5022e
...
@@ -197,6 +197,16 @@ class LSTMEncoder(FairseqEncoder):
...
@@ -197,6 +197,16 @@ class LSTMEncoder(FairseqEncoder):
'encoder_padding_mask'
:
encoder_padding_mask
if
encoder_padding_mask
.
any
()
else
None
'encoder_padding_mask'
:
encoder_padding_mask
if
encoder_padding_mask
.
any
()
else
None
}
}
def
reorder_encoder_out
(
self
,
encoder_out_dict
,
new_order
):
encoder_out_dict
[
'encoder_out'
]
=
tuple
(
eo
.
index_select
(
1
,
new_order
)
for
eo
in
encoder_out_dict
[
'encoder_out'
]
)
if
encoder_out_dict
[
'encoder_padding_mask'
]
is
not
None
:
encoder_out_dict
[
'encoder_padding_mask'
]
=
\
encoder_out_dict
[
'encoder_padding_mask'
].
index_select
(
1
,
new_order
)
return
encoder_out_dict
def
max_positions
(
self
):
def
max_positions
(
self
):
"""Maximum input length supported by the encoder."""
"""Maximum input length supported by the encoder."""
return
int
(
1e5
)
# an arbitrary large number
return
int
(
1e5
)
# an arbitrary large number
...
@@ -366,16 +376,6 @@ class LSTMDecoder(FairseqIncrementalDecoder):
...
@@ -366,16 +376,6 @@ class LSTMDecoder(FairseqIncrementalDecoder):
new_state
=
tuple
(
map
(
reorder_state
,
cached_state
))
new_state
=
tuple
(
map
(
reorder_state
,
cached_state
))
utils
.
set_incremental_state
(
self
,
incremental_state
,
'cached_state'
,
new_state
)
utils
.
set_incremental_state
(
self
,
incremental_state
,
'cached_state'
,
new_state
)
def
reorder_encoder_out
(
self
,
encoder_out_dict
,
new_order
):
encoder_out_dict
[
'encoder_out'
]
=
tuple
(
eo
.
index_select
(
1
,
new_order
)
for
eo
in
encoder_out_dict
[
'encoder_out'
]
)
if
encoder_out_dict
[
'encoder_padding_mask'
]
is
not
None
:
encoder_out_dict
[
'encoder_padding_mask'
]
=
\
encoder_out_dict
[
'encoder_padding_mask'
].
index_select
(
1
,
new_order
)
return
encoder_out_dict
def
max_positions
(
self
):
def
max_positions
(
self
):
"""Maximum output length supported by the decoder."""
"""Maximum output length supported by the decoder."""
return
int
(
1e5
)
# an arbitrary large number
return
int
(
1e5
)
# an arbitrary large number
...
...
fairseq/models/transformer.py
View file @
6ec5022e
...
@@ -164,6 +164,15 @@ class TransformerEncoder(FairseqEncoder):
...
@@ -164,6 +164,15 @@ class TransformerEncoder(FairseqEncoder):
'encoder_padding_mask'
:
encoder_padding_mask
,
# B x T
'encoder_padding_mask'
:
encoder_padding_mask
,
# B x T
}
}
def
reorder_encoder_out
(
self
,
encoder_out_dict
,
new_order
):
if
encoder_out_dict
[
'encoder_out'
]
is
not
None
:
encoder_out_dict
[
'encoder_out'
]
=
\
encoder_out_dict
[
'encoder_out'
].
index_select
(
1
,
new_order
)
if
encoder_out_dict
[
'encoder_padding_mask'
]
is
not
None
:
encoder_out_dict
[
'encoder_padding_mask'
]
=
\
encoder_out_dict
[
'encoder_padding_mask'
].
index_select
(
0
,
new_order
)
return
encoder_out_dict
def
max_positions
(
self
):
def
max_positions
(
self
):
"""Maximum input length supported by the encoder."""
"""Maximum input length supported by the encoder."""
return
self
.
embed_positions
.
max_positions
()
return
self
.
embed_positions
.
max_positions
()
...
@@ -245,12 +254,6 @@ class TransformerDecoder(FairseqIncrementalDecoder):
...
@@ -245,12 +254,6 @@ class TransformerDecoder(FairseqIncrementalDecoder):
return
x
,
attn
return
x
,
attn
def
reorder_encoder_out
(
self
,
encoder_out_dict
,
new_order
):
if
encoder_out_dict
[
'encoder_padding_mask'
]
is
not
None
:
encoder_out_dict
[
'encoder_padding_mask'
]
=
\
encoder_out_dict
[
'encoder_padding_mask'
].
index_select
(
0
,
new_order
)
return
encoder_out_dict
def
max_positions
(
self
):
def
max_positions
(
self
):
"""Maximum output length supported by the decoder."""
"""Maximum output length supported by the decoder."""
return
self
.
embed_positions
.
max_positions
()
return
self
.
embed_positions
.
max_positions
()
...
...
fairseq/sequence_generator.py
View file @
6ec5022e
...
@@ -268,7 +268,7 @@ class SequenceGenerator(object):
...
@@ -268,7 +268,7 @@ class SequenceGenerator(object):
for
i
,
model
in
enumerate
(
self
.
models
):
for
i
,
model
in
enumerate
(
self
.
models
):
if
isinstance
(
model
.
decoder
,
FairseqIncrementalDecoder
):
if
isinstance
(
model
.
decoder
,
FairseqIncrementalDecoder
):
model
.
decoder
.
reorder_incremental_state
(
incremental_states
[
model
],
reorder_state
)
model
.
decoder
.
reorder_incremental_state
(
incremental_states
[
model
],
reorder_state
)
encoder_outs
[
i
]
=
model
.
d
ecoder
.
reorder_encoder_out
(
encoder_outs
[
i
],
reorder_state
)
encoder_outs
[
i
]
=
model
.
e
n
coder
.
reorder_encoder_out
(
encoder_outs
[
i
],
reorder_state
)
probs
,
avg_attn_scores
=
self
.
_decode
(
probs
,
avg_attn_scores
=
self
.
_decode
(
tokens
[:,
:
step
+
1
],
encoder_outs
,
incremental_states
)
tokens
[:,
:
step
+
1
],
encoder_outs
,
incremental_states
)
...
...
tests/utils.py
View file @
6ec5022e
...
@@ -108,6 +108,9 @@ class TestEncoder(FairseqEncoder):
...
@@ -108,6 +108,9 @@ class TestEncoder(FairseqEncoder):
def
forward
(
self
,
src_tokens
,
src_lengths
):
def
forward
(
self
,
src_tokens
,
src_lengths
):
return
src_tokens
return
src_tokens
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
return
encoder_out
.
index_select
(
0
,
new_order
)
class
TestIncrementalDecoder
(
FairseqIncrementalDecoder
):
class
TestIncrementalDecoder
(
FairseqIncrementalDecoder
):
def
__init__
(
self
,
args
,
dictionary
):
def
__init__
(
self
,
args
,
dictionary
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment