Commit 3c19878f authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Refactor BacktranslationDataset to be more reusable (#354)

Summary:
- generalize AppendEosDataset -> TransformEosDataset
- remove EOS logic from BacktranslationDataset (use TransformEosDataset instead)
- BacktranslationDataset takes a backtranslation_fn instead of building the SequenceGenerator itself
Pull Request resolved: https://github.com/pytorch/fairseq/pull/354

Reviewed By: liezl200

Differential Revision: D12970233

Pulled By: myleott

fbshipit-source-id: d5c5b0e0a75eca1bd3a50382ac24621f35c32f36
parent a442244d
......@@ -7,7 +7,6 @@
from .dictionary import Dictionary, TruncatedDictionary
from .fairseq_dataset import FairseqDataset
from .append_eos_dataset import AppendEosDataset
from .backtranslation_dataset import BacktranslationDataset
from .concat_dataset import ConcatDataset
from .indexed_dataset import IndexedDataset, IndexedCachedDataset, IndexedInMemoryDataset, IndexedRawTextDataset
......@@ -15,6 +14,7 @@ from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
from .token_block_dataset import TokenBlockDataset
from .transform_eos_dataset import TransformEosDataset
from .iterators import (
CountingIterator,
......@@ -24,7 +24,6 @@ from .iterators import (
)
__all__ = [
'AppendEosDataset',
'BacktranslationDataset',
'ConcatDataset',
'CountingIterator',
......@@ -41,4 +40,5 @@ __all__ = [
'RoundRobinZipDatasets',
'ShardedIterator',
'TokenBlockDataset',
'TransformEosDataset',
]
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
class AppendEosDataset(torch.utils.data.Dataset):
"""A dataset wrapper that appends EOS to each item."""
def __init__(self, dataset, eos):
self.dataset = dataset
self.eos = eos
def __getitem__(self, index):
item = torch.cat([self.dataset[index], torch.LongTensor([self.eos])])
return item
def __len__(self):
return len(self.dataset)
......@@ -7,182 +7,162 @@
import torch
from fairseq import sequence_generator
from fairseq import utils
from . import FairseqDataset, language_pair_dataset
from . import FairseqDataset
def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
"""Backtranslate a list of samples.
Given an input (*samples*) of the form:
[{'id': 1, 'source': 'hallo welt'}]
this will return:
[{'id': 1, 'source': 'hello world', 'target': 'hallo welt'}]
Args:
samples (List[dict]): samples to backtranslate. Individual samples are
expected to have a 'source' key, which will become the 'target'
after backtranslation.
collate_fn (callable): function to collate samples into a mini-batch
generate_fn (callable): function to generate backtranslations
cuda (bool): use GPU for generation (default: ``True``)
Returns:
List[dict]: an updated list of samples with a backtranslated source
"""
collated_samples = collate_fn(samples)
s = utils.move_to_cuda(collated_samples) if cuda else collated_samples
generated_sources = generate_fn(s['net_input'])
def update_sample(sample, generated_source):
sample['target'] = sample['source'] # the original source becomes the target
sample['source'] = generated_source
return sample
# Go through each tgt sentence in batch and its corresponding best
# generated hypothesis and create a backtranslation data pair
# {id: id, source: generated backtranslation, target: original tgt}
return [
update_sample(
sample=input_sample,
generated_source=hypos[0]['tokens'].cpu(), # highest scoring hypo is first
)
for input_sample, hypos in zip(samples, generated_sources)
]
class BacktranslationDataset(FairseqDataset):
def __init__(
self,
tgt_dataset,
tgt_dict,
backtranslation_model,
backtranslation_fn,
max_len_a,
max_len_b,
remove_eos_at_src=False,
generator_class=sequence_generator.SequenceGenerator,
output_collater=None,
cuda=True,
**kwargs
):
"""
Sets up a backtranslation dataset which takes a tgt batch, generates
a src using a tgt-src backtranslation_model, and returns the
corresponding {generated src, input tgt} batch
a src using a tgt-src backtranslation function (*backtranslation_fn*),
and returns the corresponding `{generated src, input tgt}` batch.
Args:
tgt_dataset: dataset which will be used to build self.tgt_dataset --
a LanguagePairDataset with tgt dataset as the source dataset and
None as the target dataset. Should NOT have padding so that
src_lengths are accurately calculated by language_pair_dataset
collate function.
We use language_pair_dataset here to encapsulate the tgt_dataset
so we can re-use the LanguagePairDataset collater to format the
batches in the structure that SequenceGenerator expects.
Note: tgt_dataset samples should not have EOS at end if
the tgt-src model expects an input without EOS. This dataset
does not enforce this, you should enforce that in preprocessing.
tgt_dict: tgt dictionary (typically a joint src/tgt BPE dictionary)
backtranslation_model: tgt-src model to use in the SequenceGenerator
to generate backtranslations from tgt batches
max_len_a, max_len_b: args passed into generate() function of
the backtranslation SequenceGenerator
remove_eos_at_src: whether we should remove EOS from the source
dialect text generated by the backtranslation model.
generator_class: which SequenceGenerator class to use for
backtranslation. Output of generate() should be the same format
as fairseq's SequenceGenerator
tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
backtranslated. Only the source side of this dataset will be
used. After backtranslation, the source sentences in this
dataset will be returned as the targets.
backtranslation_fn (callable): function to call to generate
backtranslations. This is typically the `generate` method of a
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
max_len_a, max_len_b (int, int): will be used to compute
`maxlen = max_len_a * src_len + max_len_b`, which will be
passed into *backtranslation_fn*.
output_collater (callable, optional): function to call on the
backtranslated samples to create the final batch (default:
``tgt_dataset.collater``)
cuda: use GPU for generation
kwargs: generation args to init the backtranslation
SequenceGenerator
"""
self.tgt_dataset = language_pair_dataset.LanguagePairDataset(
src=tgt_dataset,
src_sizes=tgt_dataset.sizes,
src_dict=tgt_dict,
tgt=None,
tgt_sizes=None,
tgt_dict=None,
)
self.tgt_dataset = tgt_dataset
self.backtranslation_fn = backtranslation_fn
self.max_len_a = max_len_a
self.max_len_b = max_len_b
self.remove_eos_at_src = remove_eos_at_src
self.backtranslation_generator = generator_class(
models=[backtranslation_model],
tgt_dict=tgt_dict,
**kwargs,
)
self.output_collater = output_collater if output_collater is not None \
else tgt_dataset.collater
self.cuda = cuda if torch.cuda.is_available() else False
if self.cuda:
self.backtranslation_generator.cuda()
def __getitem__(self, index):
"""
Returns a single sample. Multiple samples are fed to the collater to
create a backtranslation batch. Note you should always use collate_fn
BacktranslationDataset.collater() below if given the option to
specify which collate_fn to use (e.g. in a dataloader which uses this
BacktranslationDataset -- see corresponding unittest for an example).
Returns a single sample from *tgt_dataset*. Note that backtranslation is
not applied in this step; use :func:`collater` instead to backtranslate
a batch of samples.
"""
return self.tgt_dataset[index]
def __len__(self):
"""
The length of the backtranslation dataset is the length of tgt.
"""
return len(self.tgt_dataset)
def collater(self, samples):
"""
Using the samples from the tgt dataset, load a collated tgt sample to
feed to the backtranslation model. Then take the generated translation
with best score as the source and the orignal net input as the target.
"""
collated_tgt_only_sample = self.tgt_dataset.collater(samples=samples)
backtranslation_hypos = self._generate_hypotheses(
sample=collated_tgt_only_sample
)
"""Merge and backtranslate a list of samples to form a mini-batch.
Using the samples from *tgt_dataset*, load a collated target sample to
feed to the backtranslation model. Then take the backtranslation with
the best score as the source and the original input as the target.
Note: we expect *tgt_dataset* to provide a function `collater()` that
will collate samples into the format expected by *backtranslation_fn*.
After backtranslation, we will feed the new list of samples (i.e., the
`(backtranslated source, original source)` pairs) to *output_collater*
and return the result.
Args:
samples (List[dict]): samples to backtranslate and collate
# Go through each tgt sentence in batch and its corresponding best
# generated hypothesis and create a backtranslation data pair
# {id: id, source: generated backtranslation, target: original tgt}
generated_samples = []
for input_sample, hypos in zip(samples, backtranslation_hypos):
original_tgt = input_sample["source"].cpu()
generated_source = hypos[0]["tokens"].cpu() # first hypo is best hypo
# Append EOS to the tgt sentence if it does not have an EOS
# This is the case if the samples in monolingual tgt_dataset don't
# have an EOS appended to the end of each sentence.
eos = self.tgt_dataset.src_dict.eos()
if original_tgt[-1] != eos:
original_tgt = torch.cat([original_tgt, torch.LongTensor([eos])])
# The generated source dialect backtranslation will have an EOS.
# If we want our parallel data source to not have an EOS, we will
# have to remove it.
if self.remove_eos_at_src:
assert generated_source[-1] == eos, (
"Expected generated backtranslation to have eos (id: "
"{eos}) at end, but instead found token id "
"{generated_source[-1]} at end."
).format(eos=eos, generated_source=generated_source)
generated_source = generated_source[:-1]
generated_samples.append(
{
"id": input_sample["id"],
"source": generated_source,
"target": original_tgt,
}
)
return language_pair_dataset.collate(
samples=generated_samples,
pad_idx=self.tgt_dataset.src_dict.pad(),
eos_idx=self.tgt_dataset.src_dict.eos(),
Returns:
dict: a mini-batch with keys coming from *output_collater*
"""
samples = backtranslate_samples(
samples=samples,
collate_fn=self.tgt_dataset.collater,
generate_fn=(
lambda net_input: self.backtranslation_fn(
net_input,
maxlen=int(
self.max_len_a * net_input['src_tokens'].size(1) + self.max_len_b
),
)
),
cuda=self.cuda,
)
return self.output_collater(samples)
def get_dummy_batch(self, num_tokens, max_positions):
""" Just use the tgt dataset get_dummy_batch """
"""Just use the tgt dataset get_dummy_batch"""
return self.tgt_dataset.get_dummy_batch(num_tokens, max_positions)
def num_tokens(self, index):
""" Just use the tgt dataset num_tokens """
"""Just use the tgt dataset num_tokens"""
return self.tgt_dataset.num_tokens(index)
def ordered_indices(self):
""" Just use the tgt dataset ordered_indices """
"""Just use the tgt dataset ordered_indices"""
return self.tgt_dataset.ordered_indices()
def valid_size(self, index, max_positions):
""" Just use the tgt dataset size """
"""Just use the tgt dataset size"""
return self.tgt_dataset.valid_size(index, max_positions)
def _generate_hypotheses(self, sample):
"""
Generates hypotheses from a LanguagePairDataset collated / batched
sample. Note in this case, sample["target"] is None, and
sample["net_input"]["src_tokens"] is really in tgt language.
"""
s = utils.move_to_cuda(sample) if self.cuda else sample
input = s["net_input"]
srclen = input["src_tokens"].size(1)
hypos = self.backtranslation_generator.generate(
input,
maxlen=int(
self.max_len_a * srclen + self.max_len_b
),
)
return hypos
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``.
Here, we return src dataset size as tgt dataset size as an approximation.
We do not know src size until we backtranslate and generate src sentences.
"""Return an example's size as a float or tuple. This value is used
when filtering a dataset with ``--max-positions``.
Note: we use *tgt_dataset* to approximate the length of the source
sentence, since we do not know the actual length until after
backtranslation.
"""
return (self.tgt_dataset.size(index)[0], self.tgt_dataset.size(index)[0])
tgt_size = self.tgt_dataset.size(index)[0]
return (tgt_size, tgt_size)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import bisect
import numpy as np
......
......@@ -55,13 +55,13 @@ def collate(
batch = {
'id': id,
'nsentences': len(samples),
'ntokens': ntokens,
'net_input': {
'src_tokens': src_tokens,
'src_lengths': src_lengths,
},
'target': target,
'nsentences': samples[0]['source'].size(0),
}
if prev_output_tokens is not None:
batch['net_input']['prev_output_tokens'] = prev_output_tokens
......
......@@ -33,6 +33,7 @@ def collate(samples, pad_idx, eos_idx):
return {
'id': torch.LongTensor([s['id'] for s in samples]),
'nsentences': len(samples),
'ntokens': sum(len(s['source']) for s in samples),
'net_input': {
'src_tokens': merge('source'),
......@@ -41,7 +42,6 @@ def collate(samples, pad_idx, eos_idx):
]),
},
'target': merge('target', is_target_list),
'nsentences': samples[0]['source'].size(0),
}
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
from . import FairseqDataset
class TransformEosDataset(FairseqDataset):
"""A dataset wrapper that appends/prepends/strips EOS.
Note that the transformation is applied in :func:`collater`.
Args:
dataset (~fairseq.data.FairseqDataset): dataset to wrap
eos (int): index of the end-of-sentence symbol
append_eos_to_src (bool, optional): append EOS to the end of src
remove_eos_from_src (bool, optional): remove EOS from the end of src
append_eos_to_tgt (bool, optional): append EOS to the end of tgt
remove_eos_from_tgt (bool, optional): remove EOS from the end of tgt
"""
def __init__(
self,
dataset,
eos,
append_eos_to_src=False,
remove_eos_from_src=False,
append_eos_to_tgt=False,
remove_eos_from_tgt=False,
):
if not isinstance(dataset, FairseqDataset):
raise ValueError('dataset must be an instance of FairseqDataset')
if append_eos_to_src and remove_eos_from_src:
raise ValueError('cannot combine append_eos_to_src and remove_eos_from_src')
if append_eos_to_tgt and remove_eos_from_tgt:
raise ValueError('cannot combine append_eos_to_tgt and remove_eos_from_tgt')
self.dataset = dataset
self.eos = torch.LongTensor([eos])
self.append_eos_to_src = append_eos_to_src
self.remove_eos_from_src = remove_eos_from_src
self.append_eos_to_tgt = append_eos_to_tgt
self.remove_eos_from_tgt = remove_eos_from_tgt
# precompute how we should adjust the reported sizes
self._src_delta = 0
self._src_delta += 1 if append_eos_to_src else 0
self._src_delta -= 1 if remove_eos_from_src else 0
self._tgt_delta = 0
self._tgt_delta += 1 if append_eos_to_tgt else 0
self._tgt_delta -= 1 if remove_eos_from_tgt else 0
self._checked_src = False
self._checked_tgt = False
def _check_src(self, src, expect_eos):
if not self._checked_src:
assert (src[-1] == self.eos[0]) == expect_eos
self._checked_src = True
def _check_tgt(self, tgt, expect_eos):
if not self._checked_tgt:
assert (tgt[-1] == self.eos[0]) == expect_eos
self._checked_tgt = True
def __getitem__(self, index):
return self.dataset[index]
def __len__(self):
return len(self.dataset)
def collater(self, samples):
def transform(item):
if self.append_eos_to_src:
self._check_src(item['source'], expect_eos=False)
item['source'] = torch.cat([item['source'], self.eos])
if self.remove_eos_from_src:
self._check_src(item['source'], expect_eos=True)
item['source'] = item['source'][:-1]
if self.append_eos_to_tgt:
self._check_tgt(item['target'], expect_eos=False)
item['target'] = torch.cat([item['target'], self.eos])
if self.remove_eos_from_tgt:
self._check_tgt(item['target'], expect_eos=True)
item['target'] = item['target'][:-1]
return item
samples = list(map(transform, samples))
return self.dataset.collater(samples)
def get_dummy_batch(self, *args, **kwargs):
return self.dataset.get_dummy_batch(*args, **kwargs)
def num_tokens(self, index):
return self.dataset.num_tokens(index)
def size(self, index):
src_len, tgt_len = self.dataset.size(index)
return (src_len + self._src_delta, tgt_len + self._tgt_delta)
def ordered_indices(self):
# NOTE: we assume that the ordering does not change based on the
# addition or removal of eos
return self.dataset.ordered_indices()
@property
def supports_prefetch(self):
return self.dataset.supports_prefetch()
def prefetch(self, indices):
return self.dataset.prefetch(indices)
......@@ -25,14 +25,15 @@ class LanguageModelingTask(FairseqTask):
Train a language model.
Args:
dictionary (Dictionary): the dictionary for the input of the language model
output_dictionary (Dictionary): the dictionary for the output of the language model.
In most cases it will be the same as dictionary, but could possibly be a more limited
version of the dictionary (if --output-dictionary-size is used).
targets (List[str]): list of the target types that the language model should predict.
Can be one of "self", "future", and "past". Defaults to "future".
dictionary (~fairseq.data.Dictionary): the dictionary for the input of
the language model
output_dictionary (~fairseq.data.Dictionary): the dictionary for the
output of the language model. In most cases it will be the same as
*dictionary*, but could possibly be a more limited version of the
dictionary (if ``--output-dictionary-size`` is used).
targets (List[str]): list of the target types that the language model
should predict. Can be one of "self", "future", and "past".
Defaults to "future".
.. note::
......
......@@ -7,13 +7,20 @@
import unittest
import tests.utils as test_utils
import torch
from fairseq.data.backtranslation_dataset import BacktranslationDataset
from fairseq import sequence_generator
from fairseq.data import (
BacktranslationDataset,
LanguagePairDataset,
TransformEosDataset,
)
from fairseq.sequence_generator import SequenceGenerator
import tests.utils as test_utils
class TestBacktranslationDataset(unittest.TestCase):
def setUp(self):
self.tgt_dict, self.w1, self.w2, self.src_tokens, self.src_lengths, self.model = (
test_utils.sequence_generator_setup()
......@@ -22,22 +29,49 @@ class TestBacktranslationDataset(unittest.TestCase):
dummy_src_samples = self.src_tokens
self.tgt_dataset = test_utils.TestDataset(data=dummy_src_samples)
self.cuda = torch.cuda.is_available()
def _backtranslation_dataset_helper(self, remove_eos_at_src):
"""
SequenceGenerator kwargs are same as defaults from fairseq/options.py
"""
backtranslation_dataset = BacktranslationDataset(
tgt_dataset=self.tgt_dataset,
def _backtranslation_dataset_helper(
self, remove_eos_from_input_src, remove_eos_from_output_src,
):
tgt_dataset = LanguagePairDataset(
src=self.tgt_dataset,
src_sizes=self.tgt_dataset.sizes,
src_dict=self.tgt_dict,
tgt=None,
tgt_sizes=None,
tgt_dict=None,
)
generator = SequenceGenerator(
models=[self.model],
tgt_dict=self.tgt_dict,
backtranslation_model=self.model,
max_len_a=0,
max_len_b=200,
beam_size=2,
unk_penalty=0,
sampling=False,
remove_eos_at_src=remove_eos_at_src,
generator_class=sequence_generator.SequenceGenerator,
)
if self.cuda:
generator.cuda()
backtranslation_dataset = BacktranslationDataset(
tgt_dataset=TransformEosDataset(
dataset=tgt_dataset,
eos=self.tgt_dict.eos(),
# remove eos from the input src
remove_eos_from_src=remove_eos_from_input_src,
),
backtranslation_fn=generator.generate,
max_len_a=0,
max_len_b=200,
output_collater=TransformEosDataset(
dataset=tgt_dataset,
eos=self.tgt_dict.eos(),
# if we remove eos from the input src, then we need to add it
# back to the output tgt
append_eos_to_tgt=remove_eos_from_input_src,
remove_eos_from_src=remove_eos_from_output_src,
).collater,
cuda=self.cuda,
)
dataloader = torch.utils.data.DataLoader(
backtranslation_dataset,
......@@ -51,7 +85,7 @@ class TestBacktranslationDataset(unittest.TestCase):
# Note that we sort by src_lengths and add left padding, so actually
# ids will look like: [1, 0]
expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]])
if remove_eos_at_src:
if remove_eos_from_output_src:
expected_src = expected_src[:, :-1]
expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]])
generated_src = backtranslation_batch_result["net_input"]["src_tokens"]
......@@ -60,11 +94,20 @@ class TestBacktranslationDataset(unittest.TestCase):
self.assertTensorEqual(expected_src, generated_src)
self.assertTensorEqual(expected_tgt, tgt_tokens)
def test_backtranslation_dataset_no_eos_at_src(self):
self._backtranslation_dataset_helper(remove_eos_at_src=True)
def test_backtranslation_dataset_no_eos_in_output_src(self):
self._backtranslation_dataset_helper(
remove_eos_from_input_src=False, remove_eos_from_output_src=True,
)
def test_backtranslation_dataset_with_eos_in_output_src(self):
self._backtranslation_dataset_helper(
remove_eos_from_input_src=False, remove_eos_from_output_src=False,
)
def test_backtranslation_dataset_with_eos_at_src(self):
self._backtranslation_dataset_helper(remove_eos_at_src=False)
def test_backtranslation_dataset_no_eos_in_input_src(self):
self._backtranslation_dataset_helper(
remove_eos_from_input_src=True, remove_eos_from_output_src=False,
)
def assertTensorEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
......
......@@ -12,9 +12,9 @@ import tests.utils as test_utils
import torch
from fairseq import utils
from fairseq.data import (
AppendEosDataset,
Dictionary,
LanguagePairDataset,
TransformEosDataset,
data_utils,
noising,
)
......@@ -410,15 +410,14 @@ class TestDataNoising(unittest.TestCase):
self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def _get_noising_dataset_batch(
self, src_tokens_no_pad, src_dict, use_append_eos_dataset=False
self, src_tokens_no_pad, src_dict, append_eos_to_tgt=False,
):
"""
Constructs a NoisingDataset and the corresponding
LanguagePairDataset(NoisingDataset(src), src). If we set
use_append_eos_dataset to True, wrap the source dataset in
AppendEosDataset to append EOS to the clean source when using it as the
target. In practice, we should use AppendEosDataset because our models
usually have source without EOS but target with EOS.
``LanguagePairDataset(NoisingDataset(src), src)``. If
*append_eos_to_tgt* is True, wrap the source dataset in
:class:`TransformEosDataset` to append EOS to the clean source when
using it as the target.
"""
src_dataset = test_utils.TestDataset(data=src_tokens_no_pad)
......@@ -432,11 +431,13 @@ class TestDataNoising(unittest.TestCase):
noising_class=noising.UnsupervisedMTNoising,
)
tgt = src_dataset
if use_append_eos_dataset:
tgt = AppendEosDataset(src_dataset, src_dict.eos())
language_pair_dataset = LanguagePairDataset(
src=noising_dataset, tgt=tgt, src_sizes=None, src_dict=src_dict
)
language_pair_dataset = TransformEosDataset(
language_pair_dataset, src_dict.eos(),
append_eos_to_tgt=append_eos_to_tgt,
)
dataloader = torch.utils.data.DataLoader(
dataset=language_pair_dataset,
......@@ -481,8 +482,7 @@ class TestDataNoising(unittest.TestCase):
def test_noising_dataset_without_eos(self):
"""
Similar to test noising dataset with eos except that we have to set
use_append_eos_dataset=True so that we wrap the source dataset in the
AppendEosDataset when using it as the target in LanguagePairDataset.
*append_eos_to_tgt* to ``True``.
"""
src_dict, src_tokens, _ = self._get_test_data_with_bpe_cont_marker(
......@@ -499,7 +499,7 @@ class TestDataNoising(unittest.TestCase):
denoising_batch_result = self._get_noising_dataset_batch(
src_tokens_no_pad=src_tokens_no_pad,
src_dict=src_dict,
use_append_eos_dataset=True,
append_eos_to_tgt=True,
)
eos, pad = src_dict.eos(), src_dict.pad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment