Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
89e19d42
Commit
89e19d42
authored
Jul 10, 2018
by
Alexei Baevski
Committed by
Myle Ott
Jul 25, 2018
Browse files
disable printing alignment by default (for perf) and add a flag to enable it
parent
f472d141
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
57 additions
and
39 deletions
+57
-39
fairseq/models/fairseq_incremental_decoder.py
fairseq/models/fairseq_incremental_decoder.py
+1
-1
fairseq/models/fairseq_model.py
fairseq/models/fairseq_model.py
+2
-2
fairseq/models/fconv.py
fairseq/models/fconv.py
+8
-6
fairseq/models/fconv_self_att.py
fairseq/models/fconv_self_att.py
+6
-5
fairseq/models/lstm.py
fairseq/models/lstm.py
+2
-2
fairseq/models/transformer.py
fairseq/models/transformer.py
+3
-2
fairseq/options.py
fairseq/options.py
+2
-0
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+25
-16
generate.py
generate.py
+8
-5
No files found.
fairseq/models/fairseq_incremental_decoder.py
View file @
89e19d42
...
@@ -14,7 +14,7 @@ class FairseqIncrementalDecoder(FairseqDecoder):
...
@@ -14,7 +14,7 @@ class FairseqIncrementalDecoder(FairseqDecoder):
def
__init__
(
self
,
dictionary
):
def
__init__
(
self
,
dictionary
):
super
().
__init__
(
dictionary
)
super
().
__init__
(
dictionary
)
def
forward
(
self
,
prev_output_tokens
,
encoder_out
,
incremental_state
=
None
):
def
forward
(
self
,
prev_output_tokens
,
encoder_out
,
incremental_state
=
None
,
need_attn
=
False
):
raise
NotImplementedError
raise
NotImplementedError
def
reorder_incremental_state
(
self
,
incremental_state
,
new_order
):
def
reorder_incremental_state
(
self
,
incremental_state
,
new_order
):
...
...
fairseq/models/fairseq_model.py
View file @
89e19d42
...
@@ -104,9 +104,9 @@ class FairseqModel(BaseFairseqModel):
...
@@ -104,9 +104,9 @@ class FairseqModel(BaseFairseqModel):
assert
isinstance
(
self
.
encoder
,
FairseqEncoder
)
assert
isinstance
(
self
.
encoder
,
FairseqEncoder
)
assert
isinstance
(
self
.
decoder
,
FairseqDecoder
)
assert
isinstance
(
self
.
decoder
,
FairseqDecoder
)
def
forward
(
self
,
src_tokens
,
src_lengths
,
prev_output_tokens
):
def
forward
(
self
,
src_tokens
,
src_lengths
,
prev_output_tokens
,
need_attn
):
encoder_out
=
self
.
encoder
(
src_tokens
,
src_lengths
)
encoder_out
=
self
.
encoder
(
src_tokens
,
src_lengths
)
decoder_out
=
self
.
decoder
(
prev_output_tokens
,
encoder_out
)
decoder_out
=
self
.
decoder
(
prev_output_tokens
,
encoder_out
,
need_attn
=
need_attn
)
return
decoder_out
return
decoder_out
def
max_positions
(
self
):
def
max_positions
(
self
):
...
...
fairseq/models/fconv.py
View file @
89e19d42
...
@@ -417,7 +417,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
...
@@ -417,7 +417,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
else
:
else
:
self
.
fc3
=
Linear
(
out_embed_dim
,
num_embeddings
,
dropout
=
dropout
)
self
.
fc3
=
Linear
(
out_embed_dim
,
num_embeddings
,
dropout
=
dropout
)
def
forward
(
self
,
prev_output_tokens
,
encoder_out_dict
=
None
,
incremental_state
=
None
):
def
forward
(
self
,
prev_output_tokens
,
encoder_out_dict
=
None
,
incremental_state
=
None
,
need_attn
=
False
):
if
encoder_out_dict
is
not
None
:
if
encoder_out_dict
is
not
None
:
encoder_out
=
encoder_out_dict
[
'encoder_out'
]
encoder_out
=
encoder_out_dict
[
'encoder_out'
]
encoder_padding_mask
=
encoder_out_dict
[
'encoder_padding_mask'
]
encoder_padding_mask
=
encoder_out_dict
[
'encoder_padding_mask'
]
...
@@ -466,6 +466,8 @@ class FConvDecoder(FairseqIncrementalDecoder):
...
@@ -466,6 +466,8 @@ class FConvDecoder(FairseqIncrementalDecoder):
x
=
self
.
_transpose_if_training
(
x
,
incremental_state
)
x
=
self
.
_transpose_if_training
(
x
,
incremental_state
)
x
,
attn_scores
=
attention
(
x
,
target_embedding
,
(
encoder_a
,
encoder_b
),
encoder_padding_mask
)
x
,
attn_scores
=
attention
(
x
,
target_embedding
,
(
encoder_a
,
encoder_b
),
encoder_padding_mask
)
if
need_attn
:
attn_scores
=
attn_scores
/
num_attn_layers
attn_scores
=
attn_scores
/
num_attn_layers
if
avg_attn_scores
is
None
:
if
avg_attn_scores
is
None
:
avg_attn_scores
=
attn_scores
avg_attn_scores
=
attn_scores
...
...
fairseq/models/fconv_self_att.py
View file @
89e19d42
...
@@ -352,7 +352,7 @@ class FConvDecoder(FairseqDecoder):
...
@@ -352,7 +352,7 @@ class FConvDecoder(FairseqDecoder):
self
.
pretrained_decoder
.
fc2
.
register_forward_hook
(
save_output
())
self
.
pretrained_decoder
.
fc2
.
register_forward_hook
(
save_output
())
def
forward
(
self
,
prev_output_tokens
,
encoder_out_dict
):
def
forward
(
self
,
prev_output_tokens
,
encoder_out_dict
,
need_attn
=
False
):
encoder_out
=
encoder_out_dict
[
'encoder'
][
'encoder_out'
]
encoder_out
=
encoder_out_dict
[
'encoder'
][
'encoder_out'
]
trained_encoder_out
=
encoder_out_dict
[
'pretrained'
]
if
self
.
pretrained
else
None
trained_encoder_out
=
encoder_out_dict
[
'pretrained'
]
if
self
.
pretrained
else
None
...
@@ -388,6 +388,7 @@ class FConvDecoder(FairseqDecoder):
...
@@ -388,6 +388,7 @@ class FConvDecoder(FairseqDecoder):
r
=
x
r
=
x
x
,
attn_scores
=
attention
(
attproj
(
x
)
+
target_embedding
,
encoder_a
,
encoder_b
)
x
,
attn_scores
=
attention
(
attproj
(
x
)
+
target_embedding
,
encoder_a
,
encoder_b
)
x
=
x
+
r
x
=
x
+
r
if
need_attn
:
if
avg_attn_scores
is
None
:
if
avg_attn_scores
is
None
:
avg_attn_scores
=
attn_scores
avg_attn_scores
=
attn_scores
else
:
else
:
...
...
fairseq/models/lstm.py
View file @
89e19d42
...
@@ -320,7 +320,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
...
@@ -320,7 +320,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
if
not
self
.
share_input_output_embed
:
if
not
self
.
share_input_output_embed
:
self
.
fc_out
=
Linear
(
out_embed_dim
,
num_embeddings
,
dropout
=
dropout_out
)
self
.
fc_out
=
Linear
(
out_embed_dim
,
num_embeddings
,
dropout
=
dropout_out
)
def
forward
(
self
,
prev_output_tokens
,
encoder_out_dict
,
incremental_state
=
None
):
def
forward
(
self
,
prev_output_tokens
,
encoder_out_dict
,
incremental_state
=
None
,
need_attn
=
False
):
encoder_out
=
encoder_out_dict
[
'encoder_out'
]
encoder_out
=
encoder_out_dict
[
'encoder_out'
]
encoder_padding_mask
=
encoder_out_dict
[
'encoder_padding_mask'
]
encoder_padding_mask
=
encoder_out_dict
[
'encoder_padding_mask'
]
...
@@ -391,7 +391,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
...
@@ -391,7 +391,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
x
=
x
.
transpose
(
1
,
0
)
x
=
x
.
transpose
(
1
,
0
)
# srclen x tgtlen x bsz -> bsz x tgtlen x srclen
# srclen x tgtlen x bsz -> bsz x tgtlen x srclen
attn_scores
=
attn_scores
.
transpose
(
0
,
2
)
attn_scores
=
attn_scores
.
transpose
(
0
,
2
)
if
need_attn
else
None
# project back to size of vocabulary
# project back to size of vocabulary
if
hasattr
(
self
,
'additional_fc'
):
if
hasattr
(
self
,
'additional_fc'
):
...
...
fairseq/models/transformer.py
View file @
89e19d42
...
@@ -215,7 +215,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
...
@@ -215,7 +215,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self
.
embed_out
=
nn
.
Parameter
(
torch
.
Tensor
(
len
(
dictionary
),
embed_dim
))
self
.
embed_out
=
nn
.
Parameter
(
torch
.
Tensor
(
len
(
dictionary
),
embed_dim
))
nn
.
init
.
normal_
(
self
.
embed_out
,
mean
=
0
,
std
=
embed_dim
**
-
0.5
)
nn
.
init
.
normal_
(
self
.
embed_out
,
mean
=
0
,
std
=
embed_dim
**
-
0.5
)
def
forward
(
self
,
prev_output_tokens
,
encoder_out
,
incremental_state
=
None
):
def
forward
(
self
,
prev_output_tokens
,
encoder_out
,
incremental_state
=
None
,
need_attn
=
False
):
# embed positions
# embed positions
positions
=
self
.
embed_positions
(
positions
=
self
.
embed_positions
(
prev_output_tokens
,
prev_output_tokens
,
...
@@ -340,7 +340,7 @@ class TransformerDecoderLayer(nn.Module):
...
@@ -340,7 +340,7 @@ class TransformerDecoderLayer(nn.Module):
self
.
fc2
=
Linear
(
args
.
decoder_ffn_embed_dim
,
self
.
embed_dim
)
self
.
fc2
=
Linear
(
args
.
decoder_ffn_embed_dim
,
self
.
embed_dim
)
self
.
layer_norms
=
nn
.
ModuleList
([
LayerNorm
(
self
.
embed_dim
)
for
i
in
range
(
3
)])
self
.
layer_norms
=
nn
.
ModuleList
([
LayerNorm
(
self
.
embed_dim
)
for
i
in
range
(
3
)])
def
forward
(
self
,
x
,
encoder_out
,
encoder_padding_mask
,
incremental_state
):
def
forward
(
self
,
x
,
encoder_out
,
encoder_padding_mask
,
incremental_state
,
need_attn
=
False
):
residual
=
x
residual
=
x
x
=
self
.
maybe_layer_norm
(
0
,
x
,
before
=
True
)
x
=
self
.
maybe_layer_norm
(
0
,
x
,
before
=
True
)
x
,
_
=
self
.
self_attn
(
x
,
_
=
self
.
self_attn
(
...
@@ -364,6 +364,7 @@ class TransformerDecoderLayer(nn.Module):
...
@@ -364,6 +364,7 @@ class TransformerDecoderLayer(nn.Module):
key_padding_mask
=
encoder_padding_mask
,
key_padding_mask
=
encoder_padding_mask
,
incremental_state
=
incremental_state
,
incremental_state
=
incremental_state
,
static_kv
=
True
,
static_kv
=
True
,
need_weights
=
need_attn
,
)
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
residual
+
x
...
...
fairseq/options.py
View file @
89e19d42
...
@@ -290,6 +290,8 @@ def add_generation_args(parser):
...
@@ -290,6 +290,8 @@ def add_generation_args(parser):
help
=
'sample from top K likely next words instead of all words'
)
help
=
'sample from top K likely next words instead of all words'
)
group
.
add_argument
(
'--sampling-temperature'
,
default
=
1
,
type
=
float
,
metavar
=
'N'
,
group
.
add_argument
(
'--sampling-temperature'
,
default
=
1
,
type
=
float
,
metavar
=
'N'
,
help
=
'temperature for random sampling'
)
help
=
'temperature for random sampling'
)
group
.
add_argument
(
'--print-alignment'
,
action
=
'store_true'
,
help
=
'if set, uses attention feedback to compute and print alignment to source tokens'
)
group
.
add_argument
(
'--model-overrides'
,
default
=
"{}"
,
type
=
str
,
metavar
=
'DICT'
,
group
.
add_argument
(
'--model-overrides'
,
default
=
"{}"
,
type
=
str
,
metavar
=
'DICT'
,
help
=
'a dictionary used to override model args at generation that were used during model training'
)
help
=
'a dictionary used to override model args at generation that were used during model training'
)
return
group
return
group
...
...
fairseq/sequence_generator.py
View file @
89e19d42
...
@@ -54,7 +54,7 @@ class SequenceGenerator(object):
...
@@ -54,7 +54,7 @@ class SequenceGenerator(object):
def
generate_batched_itr
(
def
generate_batched_itr
(
self
,
data_itr
,
beam_size
=
None
,
maxlen_a
=
0.0
,
maxlen_b
=
None
,
self
,
data_itr
,
beam_size
=
None
,
maxlen_a
=
0.0
,
maxlen_b
=
None
,
cuda
=
False
,
timer
=
None
,
prefix_size
=
0
,
cuda
=
False
,
timer
=
None
,
prefix_size
=
0
,
with_attention
=
False
,
):
):
"""Iterate over a batched dataset and yield individual translations.
"""Iterate over a batched dataset and yield individual translations.
Args:
Args:
...
@@ -81,6 +81,7 @@ class SequenceGenerator(object):
...
@@ -81,6 +81,7 @@ class SequenceGenerator(object):
beam_size
=
beam_size
,
beam_size
=
beam_size
,
maxlen
=
int
(
maxlen_a
*
srclen
+
maxlen_b
),
maxlen
=
int
(
maxlen_a
*
srclen
+
maxlen_b
),
prefix_tokens
=
s
[
'target'
][:,
:
prefix_size
]
if
prefix_size
>
0
else
None
,
prefix_tokens
=
s
[
'target'
][:,
:
prefix_size
]
if
prefix_size
>
0
else
None
,
with_attention
=
with_attention
,
)
)
if
timer
is
not
None
:
if
timer
is
not
None
:
timer
.
stop
(
sum
(
len
(
h
[
0
][
'tokens'
])
for
h
in
hypos
))
timer
.
stop
(
sum
(
len
(
h
[
0
][
'tokens'
])
for
h
in
hypos
))
...
@@ -90,12 +91,12 @@ class SequenceGenerator(object):
...
@@ -90,12 +91,12 @@ class SequenceGenerator(object):
ref
=
utils
.
strip_pad
(
s
[
'target'
].
data
[
i
,
:],
self
.
pad
)
if
s
[
'target'
]
is
not
None
else
None
ref
=
utils
.
strip_pad
(
s
[
'target'
].
data
[
i
,
:],
self
.
pad
)
if
s
[
'target'
]
is
not
None
else
None
yield
id
,
src
,
ref
,
hypos
[
i
]
yield
id
,
src
,
ref
,
hypos
[
i
]
def
generate
(
self
,
src_tokens
,
src_lengths
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
):
def
generate
(
self
,
src_tokens
,
src_lengths
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
,
with_attention
=
False
):
"""Generate a batch of translations."""
"""Generate a batch of translations."""
with
torch
.
no_grad
():
with
torch
.
no_grad
():
return
self
.
_generate
(
src_tokens
,
src_lengths
,
beam_size
,
maxlen
,
prefix_tokens
)
return
self
.
_generate
(
src_tokens
,
src_lengths
,
beam_size
,
maxlen
,
prefix_tokens
,
with_attention
)
def
_generate
(
self
,
src_tokens
,
src_lengths
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
):
def
_generate
(
self
,
src_tokens
,
src_lengths
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
,
with_attention
=
False
):
bsz
,
srclen
=
src_tokens
.
size
()
bsz
,
srclen
=
src_tokens
.
size
()
maxlen
=
min
(
maxlen
,
self
.
maxlen
)
if
maxlen
is
not
None
else
self
.
maxlen
maxlen
=
min
(
maxlen
,
self
.
maxlen
)
if
maxlen
is
not
None
else
self
.
maxlen
...
@@ -128,6 +129,7 @@ class SequenceGenerator(object):
...
@@ -128,6 +129,7 @@ class SequenceGenerator(object):
tokens
[:,
0
]
=
self
.
eos
tokens
[:,
0
]
=
self
.
eos
attn
=
scores
.
new
(
bsz
*
beam_size
,
src_tokens
.
size
(
1
),
maxlen
+
2
)
attn
=
scores
.
new
(
bsz
*
beam_size
,
src_tokens
.
size
(
1
),
maxlen
+
2
)
attn_buf
=
attn
.
clone
()
attn_buf
=
attn
.
clone
()
nonpad_idxs
=
src_tokens
.
ne
(
self
.
pad
)
if
with_attention
else
None
# list of completed sentences
# list of completed sentences
finalized
=
[[]
for
i
in
range
(
bsz
)]
finalized
=
[[]
for
i
in
range
(
bsz
)]
...
@@ -220,10 +222,13 @@ class SequenceGenerator(object):
...
@@ -220,10 +222,13 @@ class SequenceGenerator(object):
def
get_hypo
():
def
get_hypo
():
if
with_attention
:
# remove padding tokens from attn scores
# remove padding tokens from attn scores
nonpad_idxs
=
src_tokens
[
sent
].
ne
(
self
.
pad
)
hypo_attn
=
attn_clone
[
i
][
nonpad_idxs
[
sent
]]
hypo_attn
=
attn_clone
[
i
][
nonpad_idxs
]
_
,
alignment
=
hypo_attn
.
max
(
dim
=
0
)
_
,
alignment
=
hypo_attn
.
max
(
dim
=
0
)
else
:
hypo_attn
=
None
alignment
=
None
return
{
return
{
'tokens'
:
tokens_clone
[
i
],
'tokens'
:
tokens_clone
[
i
],
...
@@ -271,7 +276,7 @@ class SequenceGenerator(object):
...
@@ -271,7 +276,7 @@ class SequenceGenerator(object):
encoder_outs
[
i
]
=
model
.
encoder
.
reorder_encoder_out
(
encoder_outs
[
i
],
reorder_state
)
encoder_outs
[
i
]
=
model
.
encoder
.
reorder_encoder_out
(
encoder_outs
[
i
],
reorder_state
)
probs
,
avg_attn_scores
=
self
.
_decode
(
probs
,
avg_attn_scores
=
self
.
_decode
(
tokens
[:,
:
step
+
1
],
encoder_outs
,
incremental_states
)
tokens
[:,
:
step
+
1
],
encoder_outs
,
incremental_states
,
with_attention
)
if
step
==
0
:
if
step
==
0
:
# at the first step all hypotheses are equally likely, so use
# at the first step all hypotheses are equally likely, so use
# only the first beam
# only the first beam
...
@@ -286,6 +291,7 @@ class SequenceGenerator(object):
...
@@ -286,6 +291,7 @@ class SequenceGenerator(object):
probs
[:,
self
.
unk
]
-=
self
.
unk_penalty
# apply unk penalty
probs
[:,
self
.
unk
]
-=
self
.
unk_penalty
# apply unk penalty
# Record attention scores
# Record attention scores
if
avg_attn_scores
is
not
None
:
attn
[:,
:,
step
+
1
].
copy_
(
avg_attn_scores
)
attn
[:,
:,
step
+
1
].
copy_
(
avg_attn_scores
)
cand_scores
=
buffer
(
'cand_scores'
,
type_of
=
scores
)
cand_scores
=
buffer
(
'cand_scores'
,
type_of
=
scores
)
...
@@ -492,14 +498,16 @@ class SequenceGenerator(object):
...
@@ -492,14 +498,16 @@ class SequenceGenerator(object):
return
finalized
return
finalized
def
_decode
(
self
,
tokens
,
encoder_outs
,
incremental_states
):
def
_decode
(
self
,
tokens
,
encoder_outs
,
incremental_states
,
with_attention
):
if
len
(
self
.
models
)
==
1
:
if
len
(
self
.
models
)
==
1
:
return
self
.
_decode_one
(
tokens
,
self
.
models
[
0
],
encoder_outs
[
0
],
incremental_states
,
log_probs
=
True
)
return
self
.
_decode_one
(
tokens
,
self
.
models
[
0
],
encoder_outs
[
0
],
incremental_states
,
log_probs
=
True
,
with_attention
=
with_attention
,
)
avg_probs
=
None
avg_probs
=
None
avg_attn
=
None
avg_attn
=
None
for
model
,
encoder_out
in
zip
(
self
.
models
,
encoder_outs
):
for
model
,
encoder_out
in
zip
(
self
.
models
,
encoder_outs
):
probs
,
attn
=
self
.
_decode_one
(
tokens
,
model
,
encoder_out
,
incremental_states
,
log_probs
=
False
)
probs
,
attn
=
self
.
_decode_one
(
tokens
,
model
,
encoder_out
,
incremental_states
,
log_probs
=
False
,
with_attention
=
with_attention
,
)
if
avg_probs
is
None
:
if
avg_probs
is
None
:
avg_probs
=
probs
avg_probs
=
probs
else
:
else
:
...
@@ -515,12 +523,13 @@ class SequenceGenerator(object):
...
@@ -515,12 +523,13 @@ class SequenceGenerator(object):
avg_attn
.
div_
(
len
(
self
.
models
))
avg_attn
.
div_
(
len
(
self
.
models
))
return
avg_probs
,
avg_attn
return
avg_probs
,
avg_attn
def
_decode_one
(
self
,
tokens
,
model
,
encoder_out
,
incremental_states
,
log_probs
):
def
_decode_one
(
self
,
tokens
,
model
,
encoder_out
,
incremental_states
,
log_probs
,
with_attention
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
if
incremental_states
[
model
]
is
not
None
:
if
incremental_states
[
model
]
is
not
None
:
decoder_out
=
list
(
model
.
decoder
(
tokens
,
encoder_out
,
incremental_states
[
model
]))
decoder_out
=
list
(
model
.
decoder
(
tokens
,
encoder_out
,
incremental_state
=
incremental_states
[
model
],
need_attn
=
with_attention
))
else
:
else
:
decoder_out
=
list
(
model
.
decoder
(
tokens
,
encoder_out
))
decoder_out
=
list
(
model
.
decoder
(
tokens
,
encoder_out
,
incremental_state
=
None
,
need_attn
=
with_attention
))
decoder_out
[
0
]
=
decoder_out
[
0
][:,
-
1
,
:]
decoder_out
[
0
]
=
decoder_out
[
0
][:,
-
1
,
:]
attn
=
decoder_out
[
1
]
attn
=
decoder_out
[
1
]
if
attn
is
not
None
:
if
attn
is
not
None
:
...
...
generate.py
View file @
89e19d42
...
@@ -88,6 +88,7 @@ def main(args):
...
@@ -88,6 +88,7 @@ def main(args):
translations
=
translator
.
generate_batched_itr
(
translations
=
translator
.
generate_batched_itr
(
t
,
maxlen_a
=
args
.
max_len_a
,
maxlen_b
=
args
.
max_len_b
,
t
,
maxlen_a
=
args
.
max_len_a
,
maxlen_b
=
args
.
max_len_b
,
cuda
=
use_cuda
,
timer
=
gen_timer
,
prefix_size
=
args
.
prefix_size
,
cuda
=
use_cuda
,
timer
=
gen_timer
,
prefix_size
=
args
.
prefix_size
,
with_attention
=
args
.
print_alignment
,
)
)
wps_meter
=
TimeMeter
()
wps_meter
=
TimeMeter
()
...
@@ -115,7 +116,7 @@ def main(args):
...
@@ -115,7 +116,7 @@ def main(args):
hypo_tokens
,
hypo_str
,
alignment
=
utils
.
post_process_prediction
(
hypo_tokens
,
hypo_str
,
alignment
=
utils
.
post_process_prediction
(
hypo_tokens
=
hypo
[
'tokens'
].
int
().
cpu
(),
hypo_tokens
=
hypo
[
'tokens'
].
int
().
cpu
(),
src_str
=
src_str
,
src_str
=
src_str
,
alignment
=
hypo
[
'alignment'
].
int
().
cpu
(),
alignment
=
hypo
[
'alignment'
].
int
().
cpu
()
if
hypo
[
'alignment'
]
is
not
None
else
None
,
align_dict
=
align_dict
,
align_dict
=
align_dict
,
tgt_dict
=
tgt_dict
,
tgt_dict
=
tgt_dict
,
remove_bpe
=
args
.
remove_bpe
,
remove_bpe
=
args
.
remove_bpe
,
...
@@ -130,6 +131,8 @@ def main(args):
...
@@ -130,6 +131,8 @@ def main(args):
hypo
[
'positional_scores'
].
tolist
(),
hypo
[
'positional_scores'
].
tolist
(),
))
))
))
))
if
args
.
print_alignment
:
print
(
'A-{}
\t
{}'
.
format
(
print
(
'A-{}
\t
{}'
.
format
(
sample_id
,
sample_id
,
' '
.
join
(
map
(
lambda
x
:
str
(
utils
.
item
(
x
)),
alignment
))
' '
.
join
(
map
(
lambda
x
:
str
(
utils
.
item
(
x
)),
alignment
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment