Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
03c4a716
Commit
03c4a716
authored
Sep 26, 2017
by
Myle Ott
Committed by
GitHub
Sep 26, 2017
Browse files
Fix generation when vocabulary is small relative to beam size (fixes #7)
parent
2d3161da
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
40 additions
and
26 deletions
+40
-26
fairseq/models/fconv.py
fairseq/models/fconv.py
+17
-12
fairseq/modules/__init__.py
fairseq/modules/__init__.py
+2
-2
fairseq/modules/beamable_mm.py
fairseq/modules/beamable_mm.py
+9
-6
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+11
-5
generate.py
generate.py
+1
-1
No files found.
fairseq/models/fconv.py
View file @
03c4a716
...
@@ -28,7 +28,7 @@ class FConvModel(nn.Module):
...
@@ -28,7 +28,7 @@ class FConvModel(nn.Module):
decoder_out
=
self
.
decoder
(
input_tokens
,
input_positions
,
encoder_out
)
decoder_out
=
self
.
decoder
(
input_tokens
,
input_positions
,
encoder_out
)
return
decoder_out
.
view
(
-
1
,
decoder_out
.
size
(
-
1
))
return
decoder_out
.
view
(
-
1
,
decoder_out
.
size
(
-
1
))
def
make_generation_fast_
(
self
,
beam_size
,
use_beamable_mm
=
False
):
def
make_generation_fast_
(
self
,
use_beamable_mm
=
False
):
"""Optimize model for faster generation.
"""Optimize model for faster generation.
Optimizations include:
Optimizations include:
...
@@ -54,7 +54,7 @@ class FConvModel(nn.Module):
...
@@ -54,7 +54,7 @@ class FConvModel(nn.Module):
# use BeamableMM in attention layers
# use BeamableMM in attention layers
if
use_beamable_mm
:
if
use_beamable_mm
:
self
.
decoder
.
_use_beamable_mm
(
beam_size
)
self
.
decoder
.
_use_beamable_mm
()
def
train
(
mode
):
def
train
(
mode
):
if
mode
:
if
mode
:
...
@@ -243,14 +243,14 @@ class Decoder(nn.Module):
...
@@ -243,14 +243,14 @@ class Decoder(nn.Module):
context
+=
conv
.
kernel_size
[
0
]
-
1
context
+=
conv
.
kernel_size
[
0
]
-
1
return
context
return
context
def
incremental_inference
(
self
):
def
incremental_inference
(
self
,
beam_size
=
None
):
"""Context manager for incremental inference.
"""Context manager for incremental inference.
This provides an optimized forward pass for incremental inference
This provides an optimized forward pass for incremental inference
(i.e., it predicts one time step at a time). If the input order changes
(i.e., it predicts one time step at a time). If the input order changes
between time steps, call model.decoder.reorder_incremental_state to
between time steps, call model.decoder.reorder_incremental_state to
update the relevant buffers. To generate a fresh sequence, first call
update the relevant buffers. To generate a fresh sequence, first call
model.decoder.
clear_incremental_stat
e.
model.decoder.
start_fresh_sequenc
e.
Usage:
Usage:
```
```
...
@@ -263,18 +263,19 @@ class Decoder(nn.Module):
...
@@ -263,18 +263,19 @@ class Decoder(nn.Module):
"""
"""
class
IncrementalInference
(
object
):
class
IncrementalInference
(
object
):
def
__init__
(
self
,
decoder
):
def
__init__
(
self
,
decoder
,
beam_size
):
self
.
decoder
=
decoder
self
.
decoder
=
decoder
self
.
beam_size
=
beam_size
def
__enter__
(
self
):
def
__enter__
(
self
):
self
.
decoder
.
_start_incremental_inference
()
self
.
decoder
.
_start_incremental_inference
(
self
.
beam_size
)
def
__exit__
(
self
,
*
args
):
def
__exit__
(
self
,
*
args
):
self
.
decoder
.
_stop_incremental_inference
()
self
.
decoder
.
_stop_incremental_inference
()
return
IncrementalInference
(
self
)
return
IncrementalInference
(
self
,
beam_size
)
def
_start_incremental_inference
(
self
):
def
_start_incremental_inference
(
self
,
beam_size
):
assert
not
self
.
_is_inference_incremental
,
\
assert
not
self
.
_is_inference_incremental
,
\
'already performing incremental inference'
'already performing incremental inference'
self
.
_is_inference_incremental
=
True
self
.
_is_inference_incremental
=
True
...
@@ -287,7 +288,7 @@ class Decoder(nn.Module):
...
@@ -287,7 +288,7 @@ class Decoder(nn.Module):
self
.
forward
=
self
.
_incremental_forward
self
.
forward
=
self
.
_incremental_forward
# start a fresh sequence
# start a fresh sequence
self
.
clear_incremental_state
(
)
self
.
start_fresh_sequence
(
beam_size
)
def
_stop_incremental_inference
(
self
):
def
_stop_incremental_inference
(
self
):
# restore original forward and convolution layers
# restore original forward and convolution layers
...
@@ -348,17 +349,21 @@ class Decoder(nn.Module):
...
@@ -348,17 +349,21 @@ class Decoder(nn.Module):
return
x
,
avg_attn_scores
return
x
,
avg_attn_scores
def
clear_incremental_state
(
self
):
def
start_fresh_sequence
(
self
,
beam_size
=
None
):
"""Clear all state used for incremental generation.
"""Clear all state used for incremental generation.
**For incremental inference only**
**For incremental inference only**
This should be called before generating a fresh sequence.
This should be called before generating a fresh sequence.
beam_size is required if using BeamableMM.
"""
"""
if
self
.
_is_inference_incremental
:
if
self
.
_is_inference_incremental
:
self
.
prev_state
=
None
self
.
prev_state
=
None
for
conv
in
self
.
convolutions
:
for
conv
in
self
.
convolutions
:
conv
.
clear_buffer
()
conv
.
clear_buffer
()
for
attn
in
self
.
attention
:
if
isinstance
(
attn
.
bmm
,
BeamableMM
):
attn
.
bmm
.
set_beam_size
(
beam_size
)
def
reorder_incremental_state
(
self
,
new_order
):
def
reorder_incremental_state
(
self
,
new_order
):
"""Reorder buffered internal state (for incremental generation).
"""Reorder buffered internal state (for incremental generation).
...
@@ -373,9 +378,9 @@ class Decoder(nn.Module):
...
@@ -373,9 +378,9 @@ class Decoder(nn.Module):
for
conv
in
self
.
convolutions
:
for
conv
in
self
.
convolutions
:
conv
.
reorder_buffer
(
new_order
)
conv
.
reorder_buffer
(
new_order
)
def
_use_beamable_mm
(
self
,
beam_size
):
def
_use_beamable_mm
(
self
):
"""Replace torch.bmm with BeamableMM in attention layers."""
"""Replace torch.bmm with BeamableMM in attention layers."""
beamable_mm
=
BeamableMM
(
beam_size
)
beamable_mm
=
BeamableMM
()
for
attn
in
self
.
attention
:
for
attn
in
self
.
attention
:
attn
.
bmm
=
beamable_mm
attn
.
bmm
=
beamable_mm
...
...
fairseq/modules/__init__.py
View file @
03c4a716
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
# can be found in the PATENTS file in the same directory.
# can be found in the PATENTS file in the same directory.
#
#
from
.beamable_mm
import
*
from
.beamable_mm
import
BeamableMM
from
.linearized_convolution
import
*
from
.conv_tbc
import
ConvTBC
from
.conv_tbc
import
ConvTBC
from
.linearized_convolution
import
LinearizedConvolution
__all__
=
[
__all__
=
[
'BeamableMM'
,
'LinearizedConvolution'
,
'ConvTBC'
,
'BeamableMM'
,
'LinearizedConvolution'
,
'ConvTBC'
,
...
...
fairseq/modules/beamable_mm.py
View file @
03c4a716
...
@@ -18,16 +18,16 @@ class BeamableMM(nn.Module):
...
@@ -18,16 +18,16 @@ class BeamableMM(nn.Module):
inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)}
inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)}
with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}.
with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}.
"""
"""
def
__init__
(
self
,
beam_size
):
def
__init__
(
self
):
super
(
BeamableMM
,
self
).
__init__
()
super
(
BeamableMM
,
self
).
__init__
()
self
.
beam_size
=
beam_siz
e
self
.
beam_size
=
Non
e
def
forward
(
self
,
input1
,
input2
):
def
forward
(
self
,
input1
,
input2
):
if
(
if
(
not
self
.
training
and
# test mode
not
self
.
training
and
# test mode
self
.
beam_size
>
0
and
# beam size is set
self
.
beam_size
is
not
None
and
# beam size is set
input1
.
dim
()
==
3
and
# only support batched input
input1
.
dim
()
==
3
and
# only support batched input
input1
.
size
(
1
)
==
1
# single time step update
input1
.
size
(
1
)
==
1
# single time step update
):
):
bsz
,
beam
=
input1
.
size
(
0
),
self
.
beam_size
bsz
,
beam
=
input1
.
size
(
0
),
self
.
beam_size
...
@@ -45,3 +45,6 @@ class BeamableMM(nn.Module):
...
@@ -45,3 +45,6 @@ class BeamableMM(nn.Module):
return
output
.
view
(
bsz
,
1
,
-
1
)
return
output
.
view
(
bsz
,
1
,
-
1
)
else
:
else
:
return
input1
.
bmm
(
input2
)
return
input1
.
bmm
(
input2
)
def
set_beam_size
(
self
,
beam_size
):
self
.
beam_size
=
beam_size
fairseq/sequence_generator.py
View file @
03c4a716
...
@@ -87,13 +87,16 @@ class SequenceGenerator(object):
...
@@ -87,13 +87,16 @@ class SequenceGenerator(object):
def
_generate
(
self
,
src_tokens
,
src_positions
,
beam_size
=
None
,
maxlen
=
None
):
def
_generate
(
self
,
src_tokens
,
src_positions
,
beam_size
=
None
,
maxlen
=
None
):
bsz
=
src_tokens
.
size
(
0
)
bsz
=
src_tokens
.
size
(
0
)
beam_size
=
beam_size
if
beam_size
is
not
None
else
self
.
beam_size
maxlen
=
min
(
maxlen
,
self
.
maxlen
)
if
maxlen
is
not
None
else
self
.
maxlen
maxlen
=
min
(
maxlen
,
self
.
maxlen
)
if
maxlen
is
not
None
else
self
.
maxlen
# the max beam size is the dictionary size - 1, since we never select pad
beam_size
=
beam_size
if
beam_size
is
not
None
else
self
.
beam_size
beam_size
=
min
(
beam_size
,
len
(
self
.
dict
)
-
1
)
encoder_outs
=
[]
encoder_outs
=
[]
for
model
in
self
.
models
:
for
model
in
self
.
models
:
model
.
eval
()
model
.
eval
()
model
.
decoder
.
clear_incremental_state
(
)
# start a fresh sequence
model
.
decoder
.
start_fresh_sequence
(
beam_size
)
# start a fresh sequence
# compute the encoder output and expand to beam size
# compute the encoder output and expand to beam size
encoder_out
=
model
.
encoder
(
src_tokens
,
src_positions
)
encoder_out
=
model
.
encoder
(
src_tokens
,
src_positions
)
...
@@ -172,7 +175,7 @@ class SequenceGenerator(object):
...
@@ -172,7 +175,7 @@ class SequenceGenerator(object):
sents_seen
.
add
(
sent
)
sents_seen
.
add
(
sent
)
def
get_hypo
():
def
get_hypo
():
hypo
=
tokens
[
idx
,
1
:
step
+
2
].
clone
()
hypo
=
tokens
[
idx
,
1
:
step
+
2
].
clone
()
# skip the first index, which is EOS
hypo
[
step
]
=
self
.
eos
hypo
[
step
]
=
self
.
eos
alignment
=
align
[
idx
,
1
:
step
+
2
].
clone
()
alignment
=
align
[
idx
,
1
:
step
+
2
].
clone
()
return
{
return
{
...
@@ -219,6 +222,7 @@ class SequenceGenerator(object):
...
@@ -219,6 +222,7 @@ class SequenceGenerator(object):
else
:
else
:
# make probs contain cumulative scores for each hypothesis
# make probs contain cumulative scores for each hypothesis
probs
.
add_
(
scores
.
view
(
-
1
,
1
))
probs
.
add_
(
scores
.
view
(
-
1
,
1
))
probs
[:,
self
.
pad
]
=
-
math
.
inf
# never select pad
# record alignment to source tokens, based on attention
# record alignment to source tokens, based on attention
_ignore_scores
=
buffer
(
'_ignore_scores'
,
type_of
=
scores
)
_ignore_scores
=
buffer
(
'_ignore_scores'
,
type_of
=
scores
)
...
@@ -229,7 +233,9 @@ class SequenceGenerator(object):
...
@@ -229,7 +233,9 @@ class SequenceGenerator(object):
cand_scores
=
buffer
(
'cand_scores'
,
type_of
=
scores
)
cand_scores
=
buffer
(
'cand_scores'
,
type_of
=
scores
)
cand_indices
=
buffer
(
'cand_indices'
)
cand_indices
=
buffer
(
'cand_indices'
)
cand_beams
=
buffer
(
'cand_beams'
)
cand_beams
=
buffer
(
'cand_beams'
)
probs
.
view
(
bsz
,
-
1
).
topk
(
cand_size
,
out
=
(
cand_scores
,
cand_indices
))
probs
.
view
(
bsz
,
-
1
).
topk
(
min
(
cand_size
,
probs
.
view
(
bsz
,
-
1
).
size
(
1
)
-
1
),
# -1 so we never select pad
out
=
(
cand_scores
,
cand_indices
))
torch
.
div
(
cand_indices
,
self
.
vocab_size
,
out
=
cand_beams
)
torch
.
div
(
cand_indices
,
self
.
vocab_size
,
out
=
cand_beams
)
cand_indices
.
fmod_
(
self
.
vocab_size
)
cand_indices
.
fmod_
(
self
.
vocab_size
)
...
@@ -256,7 +262,7 @@ class SequenceGenerator(object):
...
@@ -256,7 +262,7 @@ class SequenceGenerator(object):
# and values < cand_size indicate candidate active hypos.
# and values < cand_size indicate candidate active hypos.
# After, the min values per row are the top candidate active hypos
# After, the min values per row are the top candidate active hypos
active_mask
=
buffer
(
'active_mask'
)
active_mask
=
buffer
(
'active_mask'
)
torch
.
add
((
eos_mask
*
cand_size
).
type_as
(
cand_offsets
),
cand_offsets
,
torch
.
add
((
eos_mask
*
cand_size
).
type_as
(
cand_offsets
),
cand_offsets
[:
eos_mask
.
size
(
1
)]
,
out
=
active_mask
)
out
=
active_mask
)
# get the top beam_size active hypotheses, which are just the hypos
# get the top beam_size active hypotheses, which are just the hypos
...
...
generate.py
View file @
03c4a716
...
@@ -47,7 +47,7 @@ def main():
...
@@ -47,7 +47,7 @@ def main():
# Optimize model for generation
# Optimize model for generation
for
model
in
models
:
for
model
in
models
:
model
.
make_generation_fast_
(
args
.
beam
,
not
args
.
no_beamable_mm
)
model
.
make_generation_fast_
(
not
args
.
no_beamable_mm
)
# Initialize generator
# Initialize generator
translator
=
SequenceGenerator
(
models
,
dataset
.
dst_dict
,
beam_size
=
args
.
beam
,
translator
=
SequenceGenerator
(
models
,
dataset
.
dst_dict
,
beam_size
=
args
.
beam
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment