Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
84754894
Commit
84754894
authored
Oct 03, 2017
by
Louis Martin
Committed by
Myle Ott
Oct 19, 2017
Browse files
Add attention matrix to output of SequenceGenerator
parent
376c265f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
12 deletions
+13
-12
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+13
-12
No files found.
fairseq/sequence_generator.py
View file @
84754894
...
...
@@ -108,8 +108,8 @@ class SequenceGenerator(object):
tokens
=
src_tokens
.
data
.
new
(
bsz
*
beam_size
,
maxlen
+
2
).
fill_
(
self
.
pad
)
tokens_buf
=
tokens
.
clone
()
tokens
[:,
0
]
=
self
.
eos
a
lig
n
=
s
rc_tokens
.
data
.
new
(
bsz
*
beam_size
,
maxlen
+
2
)
.
fill_
(
-
1
)
a
lig
n_buf
=
a
lig
n
.
clone
()
a
tt
n
=
s
cores
.
new
(
bsz
*
beam_size
,
src_tokens
.
size
(
1
),
maxlen
+
2
)
a
tt
n_buf
=
a
tt
n
.
clone
()
# list of completed sentences
finalized
=
[[]
for
i
in
range
(
bsz
)]
...
...
@@ -177,10 +177,12 @@ class SequenceGenerator(object):
def
get_hypo
():
hypo
=
tokens
[
idx
,
1
:
step
+
2
].
clone
()
# skip the first index, which is EOS
hypo
[
step
]
=
self
.
eos
alignment
=
align
[
idx
,
1
:
step
+
2
].
clone
()
attention
=
attn
[
idx
,
:,
1
:
step
+
2
].
clone
()
_
,
alignment
=
attention
.
max
(
dim
=
0
)
return
{
'tokens'
:
hypo
,
'score'
:
score
,
'attention'
:
attention
,
'alignment'
:
alignment
,
}
...
...
@@ -224,9 +226,8 @@ class SequenceGenerator(object):
probs
.
add_
(
scores
.
view
(
-
1
,
1
))
probs
[:,
self
.
pad
]
=
-
math
.
inf
# never select pad
# record alignment to source tokens, based on attention
_ignore_scores
=
buffer
(
'_ignore_scores'
,
type_of
=
scores
)
avg_attn_scores
.
topk
(
1
,
out
=
(
_ignore_scores
,
align
[:,
step
+
1
].
unsqueeze
(
1
)))
# Record attention scores
attn
[:,
:,
step
+
1
].
copy_
(
avg_attn_scores
)
# take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
...
...
@@ -290,17 +291,17 @@ class SequenceGenerator(object):
cand_indices
.
gather
(
1
,
active_hypos
,
out
=
tokens_buf
.
view
(
bsz
,
beam_size
,
-
1
)[:,
:,
step
+
1
])
# copy attention
/alignment
for active hypotheses
torch
.
index_select
(
a
lign
[
:,
:
step
+
2
],
dim
=
0
,
index
=
active_bbsz_idx
,
out
=
a
lig
n_buf
[:,
:
step
+
2
])
# copy attention for active hypotheses
torch
.
index_select
(
a
ttn
[:,
:,
:
step
+
2
],
dim
=
0
,
index
=
active_bbsz_idx
,
out
=
a
tt
n_buf
[:,
:,
:
step
+
2
])
# swap buffers
old_tokens
=
tokens
tokens
=
tokens_buf
tokens_buf
=
old_tokens
old_a
lig
n
=
a
lig
n
a
lig
n
=
a
lig
n_buf
a
lig
n_buf
=
old_a
lig
n
old_a
tt
n
=
a
tt
n
a
tt
n
=
a
tt
n_buf
a
tt
n_buf
=
old_a
tt
n
# reorder incremental state in decoder
reorder_state
=
active_bbsz_idx
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment