Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
9f1b37dd
Commit
9f1b37dd
authored
May 11, 2018
by
Alexei Baevski
Committed by
Myle Ott
Jun 15, 2018
Browse files
fix alignment when using uneven batches and left pad
parent
663fd806
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
15 deletions
+23
-15
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+23
-15
No files found.
fairseq/sequence_generator.py
View file @
9f1b37dd
...
@@ -9,6 +9,7 @@ import math
...
@@ -9,6 +9,7 @@ import math
import
torch
import
torch
from
fairseq
import
utils
from
fairseq
import
utils
from
fairseq.data
import
LanguagePairDataset
from
fairseq.models
import
FairseqIncrementalDecoder
from
fairseq.models
import
FairseqIncrementalDecoder
...
@@ -135,11 +136,12 @@ class SequenceGenerator(object):
...
@@ -135,11 +136,12 @@ class SequenceGenerator(object):
cand_size
=
2
*
beam_size
# 2 x beam size in case half are EOS
cand_size
=
2
*
beam_size
# 2 x beam size in case half are EOS
# offset arrays for converting between different indexing schemes
# offset arrays for converting between different indexing schemes
bbsz_offsets
=
(
torch
.
arange
(
0
,
bsz
)
*
beam_size
).
unsqueeze
(
1
).
type_as
(
tokens
)
bbsz_offsets
=
(
torch
.
arange
(
0
,
bsz
)
*
beam_size
).
unsqueeze
(
1
).
type_as
(
tokens
)
cand_offsets
=
torch
.
arange
(
0
,
cand_size
).
type_as
(
tokens
)
cand_offsets
=
torch
.
arange
(
0
,
cand_size
).
type_as
(
tokens
)
# helper function for allocating buffers on the fly
# helper function for allocating buffers on the fly
buffers
=
{}
buffers
=
{}
def
buffer
(
name
,
type_of
=
tokens
):
# noqa
def
buffer
(
name
,
type_of
=
tokens
):
# noqa
if
name
not
in
buffers
:
if
name
not
in
buffers
:
buffers
[
name
]
=
type_of
.
new
()
buffers
[
name
]
=
type_of
.
new
()
...
@@ -186,7 +188,7 @@ class SequenceGenerator(object):
...
@@ -186,7 +188,7 @@ class SequenceGenerator(object):
# clone relevant token and attention tensors
# clone relevant token and attention tensors
tokens_clone
=
tokens
.
index_select
(
0
,
bbsz_idx
)
tokens_clone
=
tokens
.
index_select
(
0
,
bbsz_idx
)
tokens_clone
=
tokens_clone
[:,
1
:
step
+
2
]
# skip the first index, which is EOS
tokens_clone
=
tokens_clone
[:,
1
:
step
+
2
]
# skip the first index, which is EOS
tokens_clone
[:,
step
]
=
self
.
eos
tokens_clone
[:,
step
]
=
self
.
eos
attn_clone
=
attn
.
index_select
(
0
,
bbsz_idx
)[:,
:,
1
:
step
+
2
]
attn_clone
=
attn
.
index_select
(
0
,
bbsz_idx
)[:,
:,
1
:
step
+
2
]
...
@@ -198,7 +200,7 @@ class SequenceGenerator(object):
...
@@ -198,7 +200,7 @@ class SequenceGenerator(object):
# normalize sentence-level scores
# normalize sentence-level scores
if
self
.
normalize_scores
:
if
self
.
normalize_scores
:
eos_scores
/=
(
step
+
1
)
**
self
.
len_penalty
eos_scores
/=
(
step
+
1
)
**
self
.
len_penalty
cum_unfin
=
[]
cum_unfin
=
[]
prev
=
0
prev
=
0
...
@@ -216,11 +218,17 @@ class SequenceGenerator(object):
...
@@ -216,11 +218,17 @@ class SequenceGenerator(object):
sents_seen
.
add
((
sent
,
unfin_idx
))
sents_seen
.
add
((
sent
,
unfin_idx
))
def
get_hypo
():
def
get_hypo
():
_
,
alignment
=
attn_clone
[
i
].
max
(
dim
=
0
)
# remove padding tokens from attn scores
nonpad_idxs
=
src_tokens
[
sent
].
ne
(
self
.
pad
)
hypo_attn
=
attn_clone
[
i
][
nonpad_idxs
]
_
,
alignment
=
hypo_attn
.
max
(
dim
=
0
)
return
{
return
{
'tokens'
:
tokens_clone
[
i
],
'tokens'
:
tokens_clone
[
i
],
'score'
:
score
,
'score'
:
score
,
'attention'
:
attn_clone
[
i
]
,
# src_len x tgt_len
'attention'
:
hypo_attn
,
# src_len x tgt_len
'alignment'
:
alignment
,
'alignment'
:
alignment
,
'positional_scores'
:
pos_scores
[
i
],
'positional_scores'
:
pos_scores
[
i
],
}
}
...
@@ -263,7 +271,7 @@ class SequenceGenerator(object):
...
@@ -263,7 +271,7 @@ class SequenceGenerator(object):
encoder_outs
[
i
]
=
model
.
decoder
.
reorder_encoder_out
(
encoder_outs
[
i
],
reorder_state
)
encoder_outs
[
i
]
=
model
.
decoder
.
reorder_encoder_out
(
encoder_outs
[
i
],
reorder_state
)
probs
,
avg_attn_scores
=
self
.
_decode
(
probs
,
avg_attn_scores
=
self
.
_decode
(
tokens
[:,
:
step
+
1
],
encoder_outs
,
incremental_states
)
tokens
[:,
:
step
+
1
],
encoder_outs
,
incremental_states
)
if
step
==
0
:
if
step
==
0
:
# at the first step all hypotheses are equally likely, so use
# at the first step all hypotheses are equally likely, so use
# only the first beam
# only the first beam
...
@@ -272,13 +280,13 @@ class SequenceGenerator(object):
...
@@ -272,13 +280,13 @@ class SequenceGenerator(object):
scores_buf
=
scores_buf
.
type_as
(
probs
)
scores_buf
=
scores_buf
.
type_as
(
probs
)
elif
not
self
.
sampling
:
elif
not
self
.
sampling
:
# make probs contain cumulative scores for each hypothesis
# make probs contain cumulative scores for each hypothesis
probs
.
add_
(
scores
[:,
step
-
1
].
view
(
-
1
,
1
))
probs
.
add_
(
scores
[:,
step
-
1
].
view
(
-
1
,
1
))
probs
[:,
self
.
pad
]
=
-
math
.
inf
# never select pad
probs
[:,
self
.
pad
]
=
-
math
.
inf
# never select pad
probs
[:,
self
.
unk
]
-=
self
.
unk_penalty
# apply unk penalty
probs
[:,
self
.
unk
]
-=
self
.
unk_penalty
# apply unk penalty
# Record attention scores
# Record attention scores
attn
[:,
:,
step
+
1
].
copy_
(
avg_attn_scores
)
attn
[:,
:,
step
+
1
].
copy_
(
avg_attn_scores
)
cand_scores
=
buffer
(
'cand_scores'
,
type_of
=
scores
)
cand_scores
=
buffer
(
'cand_scores'
,
type_of
=
scores
)
cand_indices
=
buffer
(
'cand_indices'
)
cand_indices
=
buffer
(
'cand_indices'
)
...
@@ -315,7 +323,7 @@ class SequenceGenerator(object):
...
@@ -315,7 +323,7 @@ class SequenceGenerator(object):
# make scores cumulative
# make scores cumulative
cand_scores
.
add_
(
cand_scores
.
add_
(
torch
.
gather
(
torch
.
gather
(
scores
[:,
step
-
1
].
view
(
bsz
,
beam_size
),
dim
=
1
,
scores
[:,
step
-
1
].
view
(
bsz
,
beam_size
),
dim
=
1
,
index
=
cand_beams
,
index
=
cand_beams
,
)
)
)
)
...
@@ -406,7 +414,7 @@ class SequenceGenerator(object):
...
@@ -406,7 +414,7 @@ class SequenceGenerator(object):
# After, the min values per row are the top candidate active hypos
# After, the min values per row are the top candidate active hypos
active_mask
=
buffer
(
'active_mask'
)
active_mask
=
buffer
(
'active_mask'
)
torch
.
add
(
torch
.
add
(
eos_mask
.
type_as
(
cand_offsets
)
*
cand_size
,
eos_mask
.
type_as
(
cand_offsets
)
*
cand_size
,
cand_offsets
[:
eos_mask
.
size
(
1
)],
cand_offsets
[:
eos_mask
.
size
(
1
)],
out
=
active_mask
,
out
=
active_mask
,
)
)
...
@@ -433,12 +441,12 @@ class SequenceGenerator(object):
...
@@ -433,12 +441,12 @@ class SequenceGenerator(object):
# copy tokens and scores for active hypotheses
# copy tokens and scores for active hypotheses
torch
.
index_select
(
torch
.
index_select
(
tokens
[:,
:
step
+
1
],
dim
=
0
,
index
=
active_bbsz_idx
,
tokens
[:,
:
step
+
1
],
dim
=
0
,
index
=
active_bbsz_idx
,
out
=
tokens_buf
[:,
:
step
+
1
],
out
=
tokens_buf
[:,
:
step
+
1
],
)
)
torch
.
gather
(
torch
.
gather
(
cand_indices
,
dim
=
1
,
index
=
active_hypos
,
cand_indices
,
dim
=
1
,
index
=
active_hypos
,
out
=
tokens_buf
.
view
(
bsz
,
beam_size
,
-
1
)[:,
:,
step
+
1
],
out
=
tokens_buf
.
view
(
bsz
,
beam_size
,
-
1
)[:,
:,
step
+
1
],
)
)
if
step
>
0
:
if
step
>
0
:
torch
.
index_select
(
torch
.
index_select
(
...
@@ -452,8 +460,8 @@ class SequenceGenerator(object):
...
@@ -452,8 +460,8 @@ class SequenceGenerator(object):
# copy attention for active hypotheses
# copy attention for active hypotheses
torch
.
index_select
(
torch
.
index_select
(
attn
[:,
:,
:
step
+
2
],
dim
=
0
,
index
=
active_bbsz_idx
,
attn
[:,
:,
:
step
+
2
],
dim
=
0
,
index
=
active_bbsz_idx
,
out
=
attn_buf
[:,
:,
:
step
+
2
],
out
=
attn_buf
[:,
:,
:
step
+
2
],
)
)
# swap buffers
# swap buffers
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment