Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
663fd806
Commit
663fd806
authored
May 11, 2018
by
Alexei Baevski
Committed by
Myle Ott
Jun 15, 2018
Browse files
implement batching in interactive mode
parent
4ce453b1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
93 additions
and
18 deletions
+93
-18
fairseq/options.py
fairseq/options.py
+9
-1
interactive.py
interactive.py
+84
-17
No files found.
fairseq/options.py
View file @
663fd806
...
...
@@ -25,10 +25,12 @@ def get_training_parser():
return
parser
def
get_generation_parser
():
def
get_generation_parser
(
interactive
=
False
):
parser
=
get_parser
(
'Generation'
)
add_dataset_args
(
parser
,
gen
=
True
)
add_generation_args
(
parser
)
if
interactive
:
add_interactive_args
(
parser
)
return
parser
...
...
@@ -242,6 +244,12 @@ def add_generation_args(parser):
return
group
def
add_interactive_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Interactive'
)
group
.
add_argument
(
'--buffer-size'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'read this many sentences into a buffer before processing them'
)
def
add_model_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Model configuration'
)
...
...
interactive.py
View file @
663fd806
...
...
@@ -6,20 +6,60 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
numpy
as
np
import
sys
import
torch
from
collections
import
namedtuple
from
torch.autograd
import
Variable
from
fairseq
import
options
,
tokenizer
,
utils
from
fairseq.data
import
LanguagePairDataset
from
fairseq.sequence_generator
import
SequenceGenerator
Batch
=
namedtuple
(
'Batch'
,
'srcs tokens lengths'
)
Translation
=
namedtuple
(
'Translation'
,
'src_str hypos alignments'
)
def
buffered_read
(
buffer_size
):
buffer
=
[]
for
src_str
in
sys
.
stdin
:
buffer
.
append
(
src_str
.
strip
())
if
len
(
buffer
)
>=
buffer_size
:
yield
buffer
buffer
=
[]
if
len
(
buffer
)
>
0
:
yield
buffer
def
make_batches
(
lines
,
batch_size
,
src_dict
):
tokens
=
[
tokenizer
.
Tokenizer
.
tokenize
(
src_str
,
src_dict
,
add_if_not_exist
=
False
).
long
()
for
src_str
in
lines
]
lengths
=
[
t
.
numel
()
for
t
in
tokens
]
indices
=
np
.
argsort
(
lengths
)
num_batches
=
np
.
ceil
(
len
(
indices
)
/
batch_size
)
batches
=
np
.
array_split
(
indices
,
num_batches
)
for
batch_idxs
in
batches
:
batch_toks
=
[
tokens
[
i
]
for
i
in
batch_idxs
]
batch_toks
=
LanguagePairDataset
.
collate_tokens
(
batch_toks
,
src_dict
.
pad
(),
src_dict
.
eos
(),
LanguagePairDataset
.
LEFT_PAD_SOURCE
,
move_eos_to_beginning
=
False
)
yield
Batch
(
srcs
=
[
lines
[
i
]
for
i
in
batch_idxs
],
tokens
=
batch_toks
,
lengths
=
tokens
[
0
].
new
([
lengths
[
i
]
for
i
in
batch_idxs
]),
),
batch_idxs
def
main
(
args
):
print
(
args
)
assert
not
args
.
sampling
or
args
.
nbest
==
args
.
beam
,
\
'--sampling requires --nbest to be equal to --beam'
assert
not
args
.
max_sentences
,
\
'--max-sentences/--batch-size is not supported in interactive mode'
assert
not
args
.
max_sentences
or
args
.
max_sentences
<=
args
.
buffer_size
,
\
'--max-sentences/--batch-size cannot be larger than --buffer-size'
if
args
.
buffer_size
<
1
:
args
.
buffer_size
=
1
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
args
.
cpu
...
...
@@ -49,19 +89,12 @@ def main(args):
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict
=
utils
.
load_align_dict
(
args
.
replace_unk
)
print
(
'| Type the input sentence and press return:'
)
for
src_str
in
sys
.
stdin
:
src_str
=
src_str
.
strip
()
src_tokens
=
tokenizer
.
Tokenizer
.
tokenize
(
src_str
,
src_dict
,
add_if_not_exist
=
False
).
long
()
if
use_cuda
:
src_tokens
=
src_tokens
.
cuda
()
src_lengths
=
src_tokens
.
new
([
src_tokens
.
numel
()])
translations
=
translator
.
generate
(
Variable
(
src_tokens
.
view
(
1
,
-
1
)),
Variable
(
src_lengths
.
view
(
-
1
)),
def
make_result
(
src_str
,
hypos
):
result
=
Translation
(
src_str
=
'O
\t
{}'
.
format
(
src_str
),
hypos
=
[],
alignments
=
[],
)
hypos
=
translations
[
0
]
print
(
'O
\t
{}'
.
format
(
src_str
))
# Process top predictions
for
hypo
in
hypos
[:
min
(
len
(
hypos
),
args
.
nbest
)]:
...
...
@@ -73,11 +106,45 @@ def main(args):
dst_dict
=
dst_dict
,
remove_bpe
=
args
.
remove_bpe
,
)
print
(
'H
\t
{}
\t
{}'
.
format
(
hypo
[
'score'
],
hypo_str
))
print
(
'A
\t
{}'
.
format
(
' '
.
join
(
map
(
lambda
x
:
str
(
utils
.
item
(
x
)),
alignment
))))
result
.
hypos
.
append
(
'H
\t
{}
\t
{}'
.
format
(
hypo
[
'score'
],
hypo_str
))
result
.
alignments
.
append
(
'A
\t
{}'
.
format
(
' '
.
join
(
map
(
lambda
x
:
str
(
utils
.
item
(
x
)),
alignment
))))
return
result
def
process_batch
(
batch
):
tokens
=
batch
.
tokens
lengths
=
batch
.
lengths
if
use_cuda
:
tokens
=
tokens
.
cuda
()
lengths
=
lengths
.
cuda
()
translations
=
translator
.
generate
(
Variable
(
tokens
),
Variable
(
lengths
),
maxlen
=
int
(
args
.
max_len_a
*
tokens
.
size
(
1
)
+
args
.
max_len_b
),
)
return
[
make_result
(
batch
.
srcs
[
i
],
t
)
for
i
,
t
in
enumerate
(
translations
)]
if
args
.
buffer_size
>
1
:
print
(
'| Sentence buffer size:'
,
args
.
buffer_size
)
print
(
'| Type the input sentence and press return:'
)
for
inputs
in
buffered_read
(
args
.
buffer_size
):
indices
=
[]
results
=
[]
for
batch
,
batch_indices
in
make_batches
(
inputs
,
max
(
1
,
args
.
max_sentences
or
1
),
src_dict
):
indices
.
extend
(
batch_indices
)
results
+=
process_batch
(
batch
)
for
i
in
np
.
argsort
(
indices
):
result
=
results
[
i
]
print
(
result
.
src_str
)
for
hypo
,
align
in
zip
(
result
.
hypos
,
result
.
alignments
):
print
(
hypo
)
print
(
align
)
if
__name__
==
'__main__'
:
parser
=
options
.
get_generation_parser
()
parser
=
options
.
get_generation_parser
(
interactive
=
True
)
args
=
parser
.
parse_args
()
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment