Commit 0d63cf03 authored by Deepak Gopinath's avatar Deepak Gopinath Committed by Facebook Github Bot
Browse files

LanguagePairDataset and BacktranslationDataset changes for semi supervised task setup (#330)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/330

As part of the semi sueprvised task setup (https://github.com/pytorch/translate/pull/243), this diff adds the ability for LanguagePairDataset to remove EOS from source or append EOS to target. This functionality is required by BacktranslationDataset to use translations as source data.

Also added changes to BacktranslationDataset to make it work on GPU. We needed to transfer back-translated sentences back to CPU for the LanguagePairDataset to collate.

Reviewed By: liezl200

Differential Revision: D10846294

fbshipit-source-id: b015ecb5fcef26fba507c30f8a4992bdbc54899f
parent 4afa455e
......@@ -8,6 +8,7 @@
import torch
from fairseq import sequence_generator
from fairseq import utils
from . import FairseqDataset, language_pair_dataset
......@@ -111,7 +112,7 @@ class BacktranslationDataset(FairseqDataset):
# have an EOS appended to the end of each sentence.
original_tgt = input_sample["source"]
if original_tgt[-1] != eos:
original_tgt = torch.cat([original_tgt, torch.LongTensor(eos)])
original_tgt = torch.cat([original_tgt, torch.LongTensor([eos])])
# The generated source dialect backtranslation will have an EOS.
# If we want our parallel data source to not have an EOS, we will
......@@ -128,8 +129,8 @@ class BacktranslationDataset(FairseqDataset):
generated_samples.append(
{
"id": input_sample["id"],
"source": generated_source,
"target": original_tgt,
"source": generated_source.cpu(),
"target": original_tgt.cpu(),
}
)
......@@ -161,8 +162,12 @@ class BacktranslationDataset(FairseqDataset):
sample. Note in this case, sample["target"] is None, and
sample["net_input"]["src_tokens"] is really in tgt language.
"""
if torch.cuda.is_available():
s = utils.move_to_cuda(sample)
else:
s = sample
self.backtranslation_generator.cuda()
input = sample["net_input"]
input = s["net_input"]
srclen = input["src_tokens"].size(1)
hypos = self.backtranslation_generator.generate(
input,
......@@ -178,4 +183,4 @@ class BacktranslationDataset(FairseqDataset):
Here, we return src dataset size as tgt dataset size as an approximation.
We do not know src size until we backtranslate and generate src sentences.
"""
return (self.tgt_dataset.size(index), self.tgt_dataset.size(index))
return (self.tgt_dataset.size(index)[0], self.tgt_dataset.size(index)[0])
......@@ -92,6 +92,10 @@ class LanguagePairDataset(FairseqDataset):
input_feeding (bool, optional): create a shifted version of the targets
to be passed into the model for input feeding/teacher forcing.
Default: ``True``
remove_eos_from_source (bool, optional): if set, removes eos from end of
source if it's present. Default: ``False``
append_eos_to_target (bool, optional): if set, appends eos to end of
target if it's absent. Default: ``False``
"""
def __init__(
......@@ -99,7 +103,7 @@ class LanguagePairDataset(FairseqDataset):
tgt=None, tgt_sizes=None, tgt_dict=None,
left_pad_source=True, left_pad_target=False,
max_source_positions=1024, max_target_positions=1024,
shuffle=True, input_feeding=True,
shuffle=True, input_feeding=True, remove_eos_from_source=False, append_eos_to_target=False,
):
if tgt_dict is not None:
assert src_dict.pad() == tgt_dict.pad()
......@@ -117,12 +121,30 @@ class LanguagePairDataset(FairseqDataset):
self.max_target_positions = max_target_positions
self.shuffle = shuffle
self.input_feeding = input_feeding
self.remove_eos_from_source = remove_eos_from_source
self.append_eos_to_target = append_eos_to_target
def __getitem__(self, index):
tgt_item = self.tgt[index] if self.tgt is not None else None
src_item = self.src[index]
# Append EOS to end of tgt sentence if it does not have an EOS and remove
# EOS from end of src sentence if it exists. This is useful when we use
# use existing datasets for opposite directions i.e., when we want to
# use tgt_dataset as src_dataset and vice versa
if self.append_eos_to_target:
eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
if self.tgt and self.tgt[index][-1] != eos:
tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])
if self.remove_eos_from_source:
eos = self.src_dict.eos()
if self.src[index][-1] == eos:
src_item = self.src[index][:-1]
return {
'id': index,
'source': self.src[index],
'target': self.tgt[index] if self.tgt is not None else None,
'source': src_item,
'target': tgt_item,
}
def __len__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment