Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
864b89d0
Commit
864b89d0
authored
Sep 25, 2018
by
Myle Ott
Browse files
Online backtranslation module
Co-authored-by:
liezl200
<
lie@fb.com
>
parent
a4fe8c99
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
276 additions
and
72 deletions
+276
-72
fairseq/data/backtranslation_dataset.py
fairseq/data/backtranslation_dataset.py
+131
-0
tests/test_backtranslation_dataset.py
tests/test_backtranslation_dataset.py
+70
-0
tests/test_sequence_generator.py
tests/test_sequence_generator.py
+10
-72
tests/utils.py
tests/utils.py
+65
-0
No files found.
fairseq/data/backtranslation_dataset.py
0 → 100644
View file @
864b89d0
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from
fairseq
import
sequence_generator
from
.
import
FairseqDataset
,
language_pair_dataset
class
BacktranslationDataset
(
FairseqDataset
):
def
__init__
(
self
,
args
,
tgt_dataset
,
tgt_dict
,
backtranslation_model
):
"""
Sets up a backtranslation dataset which takes a tgt batch, generates
a src using a tgt-src backtranslation_model, and returns the
corresponding {generated src, input tgt} batch
Args:
args: generation args for the backtranslation SequenceGenerator'
Note that there is no equivalent argparse code for these args
anywhere in our top level train scripts yet. Integration is
still in progress. You can still, however, test out this dataset
functionality with the appropriate args as in the corresponding
unittest: test_backtranslation_dataset.
tgt_dataset: dataset which will be used to build self.tgt_dataset --
a LanguagePairDataset with tgt dataset as the source dataset and
None as the target dataset.
We use language_pair_dataset here to encapsulate the tgt_dataset
so we can re-use the LanguagePairDataset collater to format the
batches in the structure that SequenceGenerator expects.
tgt_dict: tgt dictionary (typically a joint src/tgt BPE dictionary)
backtranslation_model: tgt-src model to use in the SequenceGenerator
to generate backtranslations from tgt batches
"""
self
.
tgt_dataset
=
language_pair_dataset
.
LanguagePairDataset
(
src
=
tgt_dataset
,
src_sizes
=
None
,
src_dict
=
tgt_dict
,
tgt
=
None
,
tgt_sizes
=
None
,
tgt_dict
=
None
,
)
self
.
backtranslation_generator
=
sequence_generator
.
SequenceGenerator
(
[
backtranslation_model
],
tgt_dict
,
unk_penalty
=
args
.
backtranslation_unkpen
,
sampling
=
args
.
backtranslation_sampling
,
beam_size
=
args
.
backtranslation_beam
,
)
self
.
backtranslation_max_len_a
=
args
.
backtranslation_max_len_a
self
.
backtranslation_max_len_b
=
args
.
backtranslation_max_len_b
self
.
backtranslation_beam
=
args
.
backtranslation_beam
def
__getitem__
(
self
,
index
):
"""
Returns a single sample. Multiple samples are fed to the collater to
create a backtranslation batch. Note you should always use collate_fn
BacktranslationDataset.collater() below if given the option to
specify which collate_fn to use (e.g. in a dataloader which uses this
BacktranslationDataset -- see corresponding unittest for an example).
"""
return
self
.
tgt_dataset
[
index
]
def
__len__
(
self
):
"""
The length of the backtranslation dataset is the length of tgt.
"""
return
len
(
self
.
tgt_dataset
)
def
collater
(
self
,
samples
):
"""
Using the samples from the tgt dataset, load a collated tgt sample to
feed to the backtranslation model. Then take the generated translation
with best score as the source and the orignal net input as the target.
"""
collated_tgt_only_sample
=
self
.
tgt_dataset
.
collater
(
samples
)
backtranslation_hypos
=
self
.
_generate_hypotheses
(
collated_tgt_only_sample
)
# Go through each tgt sentence in batch and its corresponding best
# generated hypothesis and create a backtranslation data pair
# {id: id, source: generated backtranslation, target: original tgt}
generated_samples
=
[]
for
input_sample
,
hypos
in
zip
(
samples
,
backtranslation_hypos
):
generated_samples
.
append
(
{
"id"
:
input_sample
[
"id"
],
"source"
:
hypos
[
0
][
"tokens"
],
# first hypo is best hypo
"target"
:
input_sample
[
"source"
],
}
)
return
language_pair_dataset
.
collate
(
samples
=
generated_samples
,
pad_idx
=
self
.
tgt_dataset
.
src_dict
.
pad
(),
eos_idx
=
self
.
tgt_dataset
.
src_dict
.
eos
(),
)
def
get_dummy_batch
(
self
,
num_tokens
,
max_positions
):
""" Just use the tgt dataset get_dummy_batch """
self
.
tgt_dataset
.
get_dummy_batch
(
num_tokens
,
max_positions
)
def
num_tokens
(
self
,
index
):
""" Just use the tgt dataset num_tokens """
self
.
tgt_dataset
.
num_tokens
(
index
)
def
ordered_indices
(
self
):
""" Just use the tgt dataset ordered_indices """
self
.
tgt_dataset
.
ordered_indices
def
valid_size
(
self
,
index
,
max_positions
):
""" Just use the tgt dataset size """
self
.
tgt_dataset
.
valid_size
(
index
,
max_positions
)
def
_generate_hypotheses
(
self
,
sample
):
"""
Generates hypotheses from a LanguagePairDataset collated / batched
sample. Note in this case, sample["target"] is None, and
sample["net_input"]["src_tokens"] is really in tgt language.
"""
self
.
backtranslation_generator
.
cuda
()
input
=
sample
[
"net_input"
]
srclen
=
input
[
"src_tokens"
].
size
(
1
)
hypos
=
self
.
backtranslation_generator
.
generate
(
input
,
maxlen
=
int
(
self
.
backtranslation_max_len_a
*
srclen
+
self
.
backtranslation_max_len_b
),
)
return
hypos
tests/test_backtranslation_dataset.py
0 → 100644
View file @
864b89d0
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
argparse
import
unittest
import
tests.utils
as
test_utils
import
torch
from
fairseq.data.backtranslation_dataset
import
BacktranslationDataset
class
TestBacktranslationDataset
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
tgt_dict
,
self
.
w1
,
self
.
w2
,
self
.
src_tokens
,
self
.
src_lengths
,
self
.
model
=
(
test_utils
.
sequence_generator_setup
()
)
backtranslation_args
=
argparse
.
Namespace
()
"""
Same as defaults from fairseq/options.py
"""
backtranslation_args
.
backtranslation_unkpen
=
0
backtranslation_args
.
backtranslation_sampling
=
False
backtranslation_args
.
backtranslation_max_len_a
=
0
backtranslation_args
.
backtranslation_max_len_b
=
200
backtranslation_args
.
backtranslation_beam
=
2
self
.
backtranslation_args
=
backtranslation_args
dummy_src_samples
=
self
.
src_tokens
self
.
tgt_dataset
=
test_utils
.
TestDataset
(
data
=
dummy_src_samples
)
def
test_backtranslation_dataset
(
self
):
backtranslation_dataset
=
BacktranslationDataset
(
args
=
self
.
backtranslation_args
,
tgt_dataset
=
self
.
tgt_dataset
,
tgt_dict
=
self
.
tgt_dict
,
backtranslation_model
=
self
.
model
,
)
dataloader
=
torch
.
utils
.
data
.
DataLoader
(
backtranslation_dataset
,
batch_size
=
2
,
collate_fn
=
backtranslation_dataset
.
collater
,
)
backtranslation_batch_result
=
next
(
iter
(
dataloader
))
eos
,
pad
,
w1
,
w2
=
self
.
tgt_dict
.
eos
(),
self
.
tgt_dict
.
pad
(),
self
.
w1
,
self
.
w2
# Note that we sort by src_lengths and add left padding, so actually
# ids will look like: [1, 0]
expected_src
=
torch
.
LongTensor
([[
w1
,
w2
,
w1
,
eos
],
[
pad
,
pad
,
w1
,
eos
]])
expected_tgt
=
torch
.
LongTensor
([[
w1
,
w2
,
eos
],
[
w1
,
w2
,
eos
]])
generated_src
=
backtranslation_batch_result
[
"net_input"
][
"src_tokens"
]
tgt_tokens
=
backtranslation_batch_result
[
"target"
]
self
.
assertTensorEqual
(
expected_src
,
generated_src
)
self
.
assertTensorEqual
(
expected_tgt
,
tgt_tokens
)
def
assertTensorEqual
(
self
,
t1
,
t2
):
self
.
assertEqual
(
t1
.
size
(),
t2
.
size
(),
"size mismatch"
)
self
.
assertEqual
(
t1
.
ne
(
t2
).
long
().
sum
(),
0
)
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/test_sequence_generator.py
View file @
864b89d0
...
@@ -18,79 +18,17 @@ import tests.utils as test_utils
...
@@ -18,79 +18,17 @@ import tests.utils as test_utils
class
TestSequenceGenerator
(
unittest
.
TestCase
):
class
TestSequenceGenerator
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
# construct dummy dictionary
self
.
tgt_dict
,
self
.
w1
,
self
.
w2
,
src_tokens
,
src_lengths
,
self
.
model
=
(
d
=
test_utils
.
dummy_dictionary
(
vocab_size
=
2
)
test_utils
.
sequence_generator_setup
()
self
.
assertEqual
(
d
.
pad
(),
1
)
)
self
.
assertEqual
(
d
.
eos
(),
2
)
self
.
assertEqual
(
d
.
unk
(),
3
)
self
.
eos
=
d
.
eos
()
self
.
w1
=
4
self
.
w2
=
5
# construct source data
self
.
src_tokens
=
torch
.
LongTensor
([
[
self
.
w1
,
self
.
w2
,
self
.
eos
],
[
self
.
w1
,
self
.
w2
,
self
.
eos
],
])
self
.
src_lengths
=
torch
.
LongTensor
([
2
,
2
])
self
.
encoder_input
=
{
self
.
encoder_input
=
{
'src_tokens'
:
self
.
src_tokens
,
'src_tokens'
:
src_tokens
,
'src_lengths'
:
src_lengths
,
'src_lengths'
:
self
.
src_lengths
,
}
}
args
=
argparse
.
Namespace
()
unk
=
0.
args
.
beam_probs
=
[
# step 0:
torch
.
FloatTensor
([
# eos w1 w2
# sentence 1:
[
0.0
,
unk
,
0.9
,
0.1
],
# beam 1
[
0.0
,
unk
,
0.9
,
0.1
],
# beam 2
# sentence 2:
[
0.0
,
unk
,
0.7
,
0.3
],
[
0.0
,
unk
,
0.7
,
0.3
],
]),
# step 1:
torch
.
FloatTensor
([
# eos w1 w2 prefix
# sentence 1:
[
1.0
,
unk
,
0.0
,
0.0
],
# w1: 0.9 (emit: w1 <eos>: 0.9*1.0)
[
0.0
,
unk
,
0.9
,
0.1
],
# w2: 0.1
# sentence 2:
[
0.25
,
unk
,
0.35
,
0.4
],
# w1: 0.7 (don't emit: w1 <eos>: 0.7*0.25)
[
0.00
,
unk
,
0.10
,
0.9
],
# w2: 0.3
]),
# step 2:
torch
.
FloatTensor
([
# eos w1 w2 prefix
# sentence 1:
[
0.0
,
unk
,
0.1
,
0.9
],
# w2 w1: 0.1*0.9
[
0.6
,
unk
,
0.2
,
0.2
],
# w2 w2: 0.1*0.1 (emit: w2 w2 <eos>: 0.1*0.1*0.6)
# sentence 2:
[
0.60
,
unk
,
0.4
,
0.00
],
# w1 w2: 0.7*0.4 (emit: w1 w2 <eos>: 0.7*0.4*0.6)
[
0.01
,
unk
,
0.0
,
0.99
],
# w2 w2: 0.3*0.9
]),
# step 3:
torch
.
FloatTensor
([
# eos w1 w2 prefix
# sentence 1:
[
1.0
,
unk
,
0.0
,
0.0
],
# w2 w1 w2: 0.1*0.9*0.9 (emit: w2 w1 w2 <eos>: 0.1*0.9*0.9*1.0)
[
1.0
,
unk
,
0.0
,
0.0
],
# w2 w1 w1: 0.1*0.9*0.1 (emit: w2 w1 w1 <eos>: 0.1*0.9*0.1*1.0)
# sentence 2:
[
0.1
,
unk
,
0.5
,
0.4
],
# w2 w2 w2: 0.3*0.9*0.99 (emit: w2 w2 w2 <eos>: 0.3*0.9*0.99*0.1)
[
1.0
,
unk
,
0.0
,
0.0
],
# w1 w2 w1: 0.7*0.4*0.4 (emit: w1 w2 w1 <eos>: 0.7*0.4*0.4*1.0)
]),
]
task
=
test_utils
.
TestTranslationTask
.
setup_task
(
args
,
d
,
d
)
self
.
model
=
task
.
build_model
(
args
)
self
.
tgt_dict
=
task
.
target_dictionary
def
test_with_normalization
(
self
):
def
test_with_normalization
(
self
):
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
)
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
)
hypos
=
generator
.
generate
(
self
.
encoder_input
,
beam_size
=
2
)
hypos
=
generator
.
generate
(
self
.
encoder_input
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
tgt_dict
.
eos
()
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoScore
(
hypos
[
0
][
0
],
[
0.9
,
1.0
])
self
.
assertHypoScore
(
hypos
[
0
][
0
],
[
0.9
,
1.0
])
...
@@ -109,7 +47,7 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -109,7 +47,7 @@ class TestSequenceGenerator(unittest.TestCase):
# Sentence 2: beams swap order
# Sentence 2: beams swap order
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
normalize_scores
=
False
)
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
normalize_scores
=
False
)
hypos
=
generator
.
generate
(
self
.
encoder_input
,
beam_size
=
2
)
hypos
=
generator
.
generate
(
self
.
encoder_input
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
tgt_dict
.
eos
()
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoScore
(
hypos
[
0
][
0
],
[
0.9
,
1.0
],
normalized
=
False
)
self
.
assertHypoScore
(
hypos
[
0
][
0
],
[
0.9
,
1.0
],
normalized
=
False
)
...
@@ -127,7 +65,7 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -127,7 +65,7 @@ class TestSequenceGenerator(unittest.TestCase):
lenpen
=
0.6
lenpen
=
0.6
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
len_penalty
=
lenpen
)
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
len_penalty
=
lenpen
)
hypos
=
generator
.
generate
(
self
.
encoder_input
,
beam_size
=
2
)
hypos
=
generator
.
generate
(
self
.
encoder_input
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
tgt_dict
.
eos
()
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoScore
(
hypos
[
0
][
0
],
[
0.9
,
1.0
],
lenpen
=
lenpen
)
self
.
assertHypoScore
(
hypos
[
0
][
0
],
[
0.9
,
1.0
],
lenpen
=
lenpen
)
...
@@ -145,7 +83,7 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -145,7 +83,7 @@ class TestSequenceGenerator(unittest.TestCase):
lenpen
=
5.0
lenpen
=
5.0
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
len_penalty
=
lenpen
)
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
len_penalty
=
lenpen
)
hypos
=
generator
.
generate
(
self
.
encoder_input
,
beam_size
=
2
)
hypos
=
generator
.
generate
(
self
.
encoder_input
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
tgt_dict
.
eos
()
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w2
,
w1
,
w2
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w2
,
w1
,
w2
,
eos
])
self
.
assertHypoScore
(
hypos
[
0
][
0
],
[
0.1
,
0.9
,
0.9
,
1.0
],
lenpen
=
lenpen
)
self
.
assertHypoScore
(
hypos
[
0
][
0
],
[
0.1
,
0.9
,
0.9
,
1.0
],
lenpen
=
lenpen
)
...
@@ -162,7 +100,7 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -162,7 +100,7 @@ class TestSequenceGenerator(unittest.TestCase):
def
test_maxlen
(
self
):
def
test_maxlen
(
self
):
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
maxlen
=
2
)
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
maxlen
=
2
)
hypos
=
generator
.
generate
(
self
.
encoder_input
,
beam_size
=
2
)
hypos
=
generator
.
generate
(
self
.
encoder_input
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
tgt_dict
.
eos
()
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoScore
(
hypos
[
0
][
0
],
[
0.9
,
1.0
])
self
.
assertHypoScore
(
hypos
[
0
][
0
],
[
0.9
,
1.0
])
...
@@ -179,7 +117,7 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -179,7 +117,7 @@ class TestSequenceGenerator(unittest.TestCase):
def
test_no_stop_early
(
self
):
def
test_no_stop_early
(
self
):
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
stop_early
=
False
)
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
stop_early
=
False
)
hypos
=
generator
.
generate
(
self
.
encoder_input
,
beam_size
=
2
)
hypos
=
generator
.
generate
(
self
.
encoder_input
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
tgt_dict
.
eos
()
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoScore
(
hypos
[
0
][
0
],
[
0.9
,
1.0
])
self
.
assertHypoScore
(
hypos
[
0
][
0
],
[
0.9
,
1.0
])
...
...
tests/utils.py
View file @
864b89d0
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
# can be found in the PATENTS file in the same directory.
import
argparse
import
torch
import
torch
from
fairseq
import
utils
from
fairseq
import
utils
...
@@ -51,6 +52,70 @@ def dummy_dataloader(
...
@@ -51,6 +52,70 @@ def dummy_dataloader(
return
iter
(
dataloader
)
return
iter
(
dataloader
)
def
sequence_generator_setup
():
# construct dummy dictionary
d
=
dummy_dictionary
(
vocab_size
=
2
)
eos
=
d
.
eos
()
w1
=
4
w2
=
5
# construct source data
src_tokens
=
torch
.
LongTensor
([[
w1
,
w2
,
eos
],
[
w1
,
w2
,
eos
]])
src_lengths
=
torch
.
LongTensor
([
2
,
2
])
args
=
argparse
.
Namespace
()
unk
=
0.
args
.
beam_probs
=
[
# step 0:
torch
.
FloatTensor
([
# eos w1 w2
# sentence 1:
[
0.0
,
unk
,
0.9
,
0.1
],
# beam 1
[
0.0
,
unk
,
0.9
,
0.1
],
# beam 2
# sentence 2:
[
0.0
,
unk
,
0.7
,
0.3
],
[
0.0
,
unk
,
0.7
,
0.3
],
]),
# step 1:
torch
.
FloatTensor
([
# eos w1 w2 prefix
# sentence 1:
[
1.0
,
unk
,
0.0
,
0.0
],
# w1: 0.9 (emit: w1 <eos>: 0.9*1.0)
[
0.0
,
unk
,
0.9
,
0.1
],
# w2: 0.1
# sentence 2:
[
0.25
,
unk
,
0.35
,
0.4
],
# w1: 0.7 (don't emit: w1 <eos>: 0.7*0.25)
[
0.00
,
unk
,
0.10
,
0.9
],
# w2: 0.3
]),
# step 2:
torch
.
FloatTensor
([
# eos w1 w2 prefix
# sentence 1:
[
0.0
,
unk
,
0.1
,
0.9
],
# w2 w1: 0.1*0.9
[
0.6
,
unk
,
0.2
,
0.2
],
# w2 w2: 0.1*0.1 (emit: w2 w2 <eos>: 0.1*0.1*0.6)
# sentence 2:
[
0.60
,
unk
,
0.4
,
0.00
],
# w1 w2: 0.7*0.4 (emit: w1 w2 <eos>: 0.7*0.4*0.6)
[
0.01
,
unk
,
0.0
,
0.99
],
# w2 w2: 0.3*0.9
]),
# step 3:
torch
.
FloatTensor
([
# eos w1 w2 prefix
# sentence 1:
[
1.0
,
unk
,
0.0
,
0.0
],
# w2 w1 w2: 0.1*0.9*0.9 (emit: w2 w1 w2 <eos>: 0.1*0.9*0.9*1.0)
[
1.0
,
unk
,
0.0
,
0.0
],
# w2 w1 w1: 0.1*0.9*0.1 (emit: w2 w1 w1 <eos>: 0.1*0.9*0.1*1.0)
# sentence 2:
[
0.1
,
unk
,
0.5
,
0.4
],
# w2 w2 w2: 0.3*0.9*0.99 (emit: w2 w2 w2 <eos>: 0.3*0.9*0.99*0.1)
[
1.0
,
unk
,
0.0
,
0.0
],
# w1 w2 w1: 0.7*0.4*0.4 (emit: w1 w2 w1 <eos>: 0.7*0.4*0.4*1.0)
]),
]
task
=
TestTranslationTask
.
setup_task
(
args
,
d
,
d
)
model
=
task
.
build_model
(
args
)
tgt_dict
=
task
.
target_dictionary
return
tgt_dict
,
w1
,
w2
,
src_tokens
,
src_lengths
,
model
class
TestDataset
(
torch
.
utils
.
data
.
Dataset
):
class
TestDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
data
):
def
__init__
(
self
,
data
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment