Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
59d599a2
Commit
59d599a2
authored
Sep 25, 2017
by
Myle Ott
Browse files
Move helper functions from generate.py to fairseq/dictionary.py
parent
af86c1ac
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
43 deletions
+41
-43
fairseq/dictionary.py
fairseq/dictionary.py
+24
-6
generate.py
generate.py
+17
-37
No files found.
fairseq/dictionary.py
View file @
59d599a2
...
@@ -38,13 +38,31 @@ class Dictionary(object):
...
@@ -38,13 +38,31 @@ class Dictionary(object):
return
self
.
indices
[
sym
]
return
self
.
indices
[
sym
]
return
self
.
unk_index
return
self
.
unk_index
def
string
(
self
,
tensor
):
def
string
(
self
,
tensor
,
bpe_symbol
=
None
,
escape_unk
=
False
):
"""Helper for converting a tensor of token indices to a string.
Can optionally remove BPE symbols or escape <unk> words.
"""
if
torch
.
is_tensor
(
tensor
)
and
tensor
.
dim
()
==
2
:
if
torch
.
is_tensor
(
tensor
)
and
tensor
.
dim
()
==
2
:
sentences
=
[
self
.
string
(
line
)
for
line
in
tensor
]
return
'
\n
'
.
join
(
self
.
to_string
(
t
)
for
t
in
tensor
)
return
'
\n
'
.
join
(
sentences
)
def
token_string
(
i
):
if
i
==
self
.
unk
():
return
self
.
unk_string
(
escape_unk
)
else
:
return
self
[
i
]
eos
=
self
.
eos
()
sent
=
' '
.
join
(
token_string
(
i
)
for
i
in
tensor
if
i
!=
self
.
eos
())
return
' '
.
join
([
self
[
i
]
for
i
in
tensor
if
i
!=
eos
])
if
bpe_symbol
is
not
None
:
sent
=
sent
.
replace
(
bpe_symbol
,
''
)
return
sent
def
unk_string
(
self
,
escape
=
False
):
"""Return unknown string, optionally escaped as: <<unk>>"""
if
escape
:
return
'<{}>'
.
format
(
self
.
unk_word
)
else
:
return
self
.
unk_word
def
add_symbol
(
self
,
word
,
n
=
1
):
def
add_symbol
(
self
,
word
,
n
=
1
):
"""Adds a word to the dictionary"""
"""Adds a word to the dictionary"""
...
...
generate.py
View file @
59d599a2
...
@@ -10,7 +10,7 @@ import sys
...
@@ -10,7 +10,7 @@ import sys
import
torch
import
torch
from
torch.autograd
import
Variable
from
torch.autograd
import
Variable
from
fairseq
import
bleu
,
options
,
utils
,
tokenizer
from
fairseq
import
bleu
,
options
,
tokenizer
,
utils
from
fairseq.meters
import
StopwatchMeter
,
TimeMeter
from
fairseq.meters
import
StopwatchMeter
,
TimeMeter
from
fairseq.progress_bar
import
progress_bar
from
fairseq.progress_bar
import
progress_bar
from
fairseq.sequence_generator
import
SequenceGenerator
from
fairseq.sequence_generator
import
SequenceGenerator
...
@@ -54,14 +54,18 @@ def main():
...
@@ -54,14 +54,18 @@ def main():
model
.
make_generation_fast_
(
not
args
.
no_beamable_mm
)
model
.
make_generation_fast_
(
not
args
.
no_beamable_mm
)
# Initialize generator
# Initialize generator
translator
=
SequenceGenerator
(
models
,
dataset
.
dst_dict
,
beam_size
=
args
.
beam
,
translator
=
SequenceGenerator
(
stop_early
=
(
not
args
.
no_early_stop
),
models
,
dataset
.
dst_dict
,
beam_size
=
args
.
beam
,
stop_early
=
(
not
args
.
no_early_stop
),
normalize_scores
=
(
not
args
.
unnormalized
),
normalize_scores
=
(
not
args
.
unnormalized
),
len_penalty
=
args
.
lenpen
len_penalty
=
args
.
lenpen
)
)
if
use_cuda
:
translator
.
cuda
()
# Load alignment dictionary for unknown word replacement
align_dict
=
{}
align_dict
=
{}
if
args
.
unk_replace_dict
!=
''
:
if
args
.
unk_replace_dict
!=
''
:
assert
args
.
interactive
,
"Unkown words replacing requires access to original source and is only"
\
assert
args
.
interactive
,
\
"
supported in interactive mode
"
'Unknown word replacement requires access to original source and is only
supported in interactive mode
'
with
open
(
args
.
unk_replace_dict
,
'r'
)
as
f
:
with
open
(
args
.
unk_replace_dict
,
'r'
)
as
f
:
for
line
in
f
:
for
line
in
f
:
l
=
line
.
split
()
l
=
line
.
split
()
...
@@ -80,27 +84,23 @@ def main():
...
@@ -80,27 +84,23 @@ def main():
hypo_tokens
[
i
]
=
src_token
hypo_tokens
[
i
]
=
src_token
return
' '
.
join
(
hypo_tokens
)
return
' '
.
join
(
hypo_tokens
)
if
use_cuda
:
translator
.
cuda
()
bpe_symbol
=
'@@ '
if
args
.
remove_bpe
else
None
bpe_symbol
=
'@@ '
if
args
.
remove_bpe
else
None
def
display_hypotheses
(
id
,
src
,
orig
,
ref
,
hypos
):
def
display_hypotheses
(
id
,
src
,
orig
,
ref
,
hypos
):
if
args
.
quiet
:
if
args
.
quiet
:
return
return
id_str
=
''
if
id
is
None
else
'-{}'
.
format
(
id
)
id_str
=
''
if
id
is
None
else
'-{}'
.
format
(
id
)
src_str
=
to_sentence
(
dataset
.
src_dict
,
src
,
bpe_symbol
)
src_str
=
dataset
.
src_dict
.
string
(
src
,
bpe_symbol
)
print
(
'S{}
\t
{}'
.
format
(
id_str
,
src_str
))
print
(
'S{}
\t
{}'
.
format
(
id_str
,
src_str
))
if
orig
is
not
None
:
if
orig
is
not
None
:
print
(
'O{}
\t
{}'
.
format
(
id_str
,
orig
.
strip
()))
print
(
'O{}
\t
{}'
.
format
(
id_str
,
orig
.
strip
()))
if
ref
is
not
None
:
if
ref
is
not
None
:
print
(
'T{}
\t
{}'
.
format
(
id_str
,
to_sentence
(
dataset
.
dst_dict
,
ref
,
bpe_symbol
,
ref
_unk
=
True
)))
print
(
'T{}
\t
{}'
.
format
(
id_str
,
dataset
.
dst_dict
.
string
(
ref
,
bpe_symbol
,
escape
_unk
=
True
)))
for
hypo
in
hypos
:
for
hypo
in
hypos
:
hypo_str
=
to_sentence
(
dataset
.
dst_dict
,
hypo
[
'tokens'
],
bpe_symbol
)
hypo_str
=
dataset
.
dst_dict
.
string
(
hypo
[
'tokens'
],
bpe_symbol
)
align_str
=
' '
.
join
(
map
(
str
,
hypo
[
'alignment'
]))
align_str
=
' '
.
join
(
map
(
str
,
hypo
[
'alignment'
]))
if
args
.
unk_replace_dict
!=
''
:
if
args
.
unk_replace_dict
!=
''
:
hypo_str
=
replace_unk
(
hypo_str
,
align_str
,
orig
,
unk_symbol
(
dataset
.
dst_dict
))
hypo_str
=
replace_unk
(
hypo_str
,
align_str
,
orig
,
dataset
.
dst_dict
.
unk_string
())
print
(
'H{}
\t
{}
\t
{}'
.
format
(
print
(
'H{}
\t
{}
\t
{}'
.
format
(
id_str
,
hypo
[
'score'
],
hypo_str
))
id_str
,
hypo
[
'score'
],
hypo_str
))
print
(
'A{}
\t
{}'
.
format
(
id_str
,
align_str
))
print
(
'A{}
\t
{}'
.
format
(
id_str
,
align_str
))
if
args
.
interactive
:
if
args
.
interactive
:
...
@@ -121,7 +121,7 @@ def main():
...
@@ -121,7 +121,7 @@ def main():
if
not
args
.
remove_bpe
:
if
not
args
.
remove_bpe
:
return
tokens
return
tokens
assert
(
tokens
==
dataset
.
dst_dict
.
pad
()).
sum
()
==
0
assert
(
tokens
==
dataset
.
dst_dict
.
pad
()).
sum
()
==
0
hypo_minus_bpe
=
to_sentence
(
dataset
.
dst_dict
,
tokens
,
bpe_symbol
)
hypo_minus_bpe
=
dataset
.
dst_dict
.
string
(
tokens
,
bpe_symbol
)
return
tokenizer
.
Tokenizer
.
tokenize
(
hypo_minus_bpe
,
dataset
.
dst_dict
,
add_if_not_exist
=
True
)
return
tokenizer
.
Tokenizer
.
tokenize
(
hypo_minus_bpe
,
dataset
.
dst_dict
,
add_if_not_exist
=
True
)
# Generate and compute BLEU score
# Generate and compute BLEU score
...
@@ -151,25 +151,5 @@ def main():
...
@@ -151,25 +151,5 @@ def main():
print
(
'| Generate {} with beam={}: {}'
.
format
(
args
.
gen_subset
,
args
.
beam
,
scorer
.
result_string
()))
print
(
'| Generate {} with beam={}: {}'
.
format
(
args
.
gen_subset
,
args
.
beam
,
scorer
.
result_string
()))
def
to_token
(
dict
,
i
,
runk
):
return
runk
if
i
==
dict
.
unk
()
else
dict
[
i
]
def
unk_symbol
(
dict
,
ref_unk
=
False
):
return
'<{}>'
.
format
(
dict
.
unk_word
)
if
ref_unk
else
dict
.
unk_word
def
to_sentence
(
dict
,
tokens
,
bpe_symbol
=
None
,
ref_unk
=
False
):
if
torch
.
is_tensor
(
tokens
)
and
tokens
.
dim
()
==
2
:
sentences
=
[
to_sentence
(
dict
,
token
)
for
token
in
tokens
]
return
'
\n
'
.
join
(
sentences
)
eos
=
dict
.
eos
()
runk
=
unk_symbol
(
dict
,
ref_unk
=
ref_unk
)
sent
=
' '
.
join
([
to_token
(
dict
,
i
,
runk
)
for
i
in
tokens
if
i
!=
eos
])
if
bpe_symbol
is
not
None
:
sent
=
sent
.
replace
(
bpe_symbol
,
''
)
return
sent
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment