Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
663fd806
Commit
663fd806
authored
May 11, 2018
by
Alexei Baevski
Committed by
Myle Ott
Jun 15, 2018
Browse files
implement batching in interactive mode
parent
4ce453b1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
93 additions
and
18 deletions
+93
-18
fairseq/options.py
fairseq/options.py
+9
-1
interactive.py
interactive.py
+84
-17
No files found.
fairseq/options.py
View file @
663fd806
...
@@ -25,10 +25,12 @@ def get_training_parser():
...
@@ -25,10 +25,12 @@ def get_training_parser():
return
parser
return
parser
def
get_generation_parser
():
def
get_generation_parser
(
interactive
=
False
):
parser
=
get_parser
(
'Generation'
)
parser
=
get_parser
(
'Generation'
)
add_dataset_args
(
parser
,
gen
=
True
)
add_dataset_args
(
parser
,
gen
=
True
)
add_generation_args
(
parser
)
add_generation_args
(
parser
)
if
interactive
:
add_interactive_args
(
parser
)
return
parser
return
parser
...
@@ -242,6 +244,12 @@ def add_generation_args(parser):
...
@@ -242,6 +244,12 @@ def add_generation_args(parser):
return
group
return
group
def
add_interactive_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Interactive'
)
group
.
add_argument
(
'--buffer-size'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'read this many sentences into a buffer before processing them'
)
def
add_model_args
(
parser
):
def
add_model_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Model configuration'
)
group
=
parser
.
add_argument_group
(
'Model configuration'
)
...
...
interactive.py
View file @
663fd806
...
@@ -6,20 +6,60 @@
...
@@ -6,20 +6,60 @@
# the root directory of this source tree. An additional grant of patent rights
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
# can be found in the PATENTS file in the same directory.
import
numpy
as
np
import
sys
import
sys
import
torch
import
torch
from
collections
import
namedtuple
from
torch.autograd
import
Variable
from
torch.autograd
import
Variable
from
fairseq
import
options
,
tokenizer
,
utils
from
fairseq
import
options
,
tokenizer
,
utils
from
fairseq.data
import
LanguagePairDataset
from
fairseq.sequence_generator
import
SequenceGenerator
from
fairseq.sequence_generator
import
SequenceGenerator
Batch
=
namedtuple
(
'Batch'
,
'srcs tokens lengths'
)
Translation
=
namedtuple
(
'Translation'
,
'src_str hypos alignments'
)
def
buffered_read
(
buffer_size
):
buffer
=
[]
for
src_str
in
sys
.
stdin
:
buffer
.
append
(
src_str
.
strip
())
if
len
(
buffer
)
>=
buffer_size
:
yield
buffer
buffer
=
[]
if
len
(
buffer
)
>
0
:
yield
buffer
def
make_batches
(
lines
,
batch_size
,
src_dict
):
tokens
=
[
tokenizer
.
Tokenizer
.
tokenize
(
src_str
,
src_dict
,
add_if_not_exist
=
False
).
long
()
for
src_str
in
lines
]
lengths
=
[
t
.
numel
()
for
t
in
tokens
]
indices
=
np
.
argsort
(
lengths
)
num_batches
=
np
.
ceil
(
len
(
indices
)
/
batch_size
)
batches
=
np
.
array_split
(
indices
,
num_batches
)
for
batch_idxs
in
batches
:
batch_toks
=
[
tokens
[
i
]
for
i
in
batch_idxs
]
batch_toks
=
LanguagePairDataset
.
collate_tokens
(
batch_toks
,
src_dict
.
pad
(),
src_dict
.
eos
(),
LanguagePairDataset
.
LEFT_PAD_SOURCE
,
move_eos_to_beginning
=
False
)
yield
Batch
(
srcs
=
[
lines
[
i
]
for
i
in
batch_idxs
],
tokens
=
batch_toks
,
lengths
=
tokens
[
0
].
new
([
lengths
[
i
]
for
i
in
batch_idxs
]),
),
batch_idxs
def
main
(
args
):
def
main
(
args
):
print
(
args
)
print
(
args
)
assert
not
args
.
sampling
or
args
.
nbest
==
args
.
beam
,
\
assert
not
args
.
sampling
or
args
.
nbest
==
args
.
beam
,
\
'--sampling requires --nbest to be equal to --beam'
'--sampling requires --nbest to be equal to --beam'
assert
not
args
.
max_sentences
,
\
assert
not
args
.
max_sentences
or
args
.
max_sentences
<=
args
.
buffer_size
,
\
'--max-sentences/--batch-size is not supported in interactive mode'
'--max-sentences/--batch-size cannot be larger than --buffer-size'
if
args
.
buffer_size
<
1
:
args
.
buffer_size
=
1
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
args
.
cpu
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
args
.
cpu
...
@@ -49,19 +89,12 @@ def main(args):
...
@@ -49,19 +89,12 @@ def main(args):
# (None if no unknown word replacement, empty if no path to align dictionary)
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict
=
utils
.
load_align_dict
(
args
.
replace_unk
)
align_dict
=
utils
.
load_align_dict
(
args
.
replace_unk
)
print
(
'| Type the input sentence and press return:'
)
def
make_result
(
src_str
,
hypos
):
for
src_str
in
sys
.
stdin
:
result
=
Translation
(
src_str
=
src_str
.
strip
()
src_str
=
'O
\t
{}'
.
format
(
src_str
),
src_tokens
=
tokenizer
.
Tokenizer
.
tokenize
(
src_str
,
src_dict
,
add_if_not_exist
=
False
).
long
()
hypos
=
[],
if
use_cuda
:
alignments
=
[],
src_tokens
=
src_tokens
.
cuda
()
src_lengths
=
src_tokens
.
new
([
src_tokens
.
numel
()])
translations
=
translator
.
generate
(
Variable
(
src_tokens
.
view
(
1
,
-
1
)),
Variable
(
src_lengths
.
view
(
-
1
)),
)
)
hypos
=
translations
[
0
]
print
(
'O
\t
{}'
.
format
(
src_str
))
# Process top predictions
# Process top predictions
for
hypo
in
hypos
[:
min
(
len
(
hypos
),
args
.
nbest
)]:
for
hypo
in
hypos
[:
min
(
len
(
hypos
),
args
.
nbest
)]:
...
@@ -73,11 +106,45 @@ def main(args):
...
@@ -73,11 +106,45 @@ def main(args):
dst_dict
=
dst_dict
,
dst_dict
=
dst_dict
,
remove_bpe
=
args
.
remove_bpe
,
remove_bpe
=
args
.
remove_bpe
,
)
)
print
(
'H
\t
{}
\t
{}'
.
format
(
hypo
[
'score'
],
hypo_str
))
result
.
hypos
.
append
(
'H
\t
{}
\t
{}'
.
format
(
hypo
[
'score'
],
hypo_str
))
print
(
'A
\t
{}'
.
format
(
' '
.
join
(
map
(
lambda
x
:
str
(
utils
.
item
(
x
)),
alignment
))))
result
.
alignments
.
append
(
'A
\t
{}'
.
format
(
' '
.
join
(
map
(
lambda
x
:
str
(
utils
.
item
(
x
)),
alignment
))))
return
result
def
process_batch
(
batch
):
tokens
=
batch
.
tokens
lengths
=
batch
.
lengths
if
use_cuda
:
tokens
=
tokens
.
cuda
()
lengths
=
lengths
.
cuda
()
translations
=
translator
.
generate
(
Variable
(
tokens
),
Variable
(
lengths
),
maxlen
=
int
(
args
.
max_len_a
*
tokens
.
size
(
1
)
+
args
.
max_len_b
),
)
return
[
make_result
(
batch
.
srcs
[
i
],
t
)
for
i
,
t
in
enumerate
(
translations
)]
if
args
.
buffer_size
>
1
:
print
(
'| Sentence buffer size:'
,
args
.
buffer_size
)
print
(
'| Type the input sentence and press return:'
)
for
inputs
in
buffered_read
(
args
.
buffer_size
):
indices
=
[]
results
=
[]
for
batch
,
batch_indices
in
make_batches
(
inputs
,
max
(
1
,
args
.
max_sentences
or
1
),
src_dict
):
indices
.
extend
(
batch_indices
)
results
+=
process_batch
(
batch
)
for
i
in
np
.
argsort
(
indices
):
result
=
results
[
i
]
print
(
result
.
src_str
)
for
hypo
,
align
in
zip
(
result
.
hypos
,
result
.
alignments
):
print
(
hypo
)
print
(
align
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
options
.
get_generation_parser
()
parser
=
options
.
get_generation_parser
(
interactive
=
True
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
)
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment