Commit a2f5361d authored by alexeib's avatar alexeib Committed by Facebook Github Bot
Browse files

Multiset (#838)

Summary:
Adds ability to tag individual examples with the names of their datasets, along with some minor miscellaneous fixes and improvements
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/838

Differential Revision: D16919175

Pulled By: alexeib

fbshipit-source-id: 4bf493299645bae63f3ee6382e15f18a9f73666c
parent 7a31fe06
...@@ -27,11 +27,15 @@ from .numel_dataset import NumelDataset ...@@ -27,11 +27,15 @@ from .numel_dataset import NumelDataset
from .num_samples_dataset import NumSamplesDataset from .num_samples_dataset import NumSamplesDataset
from .offset_tokens_dataset import OffsetTokensDataset from .offset_tokens_dataset import OffsetTokensDataset
from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset
from .prepend_dataset import PrependDataset
from .prepend_token_dataset import PrependTokenDataset from .prepend_token_dataset import PrependTokenDataset
from .raw_label_dataset import RawLabelDataset from .raw_label_dataset import RawLabelDataset
from .replace_dataset import ReplaceDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets from .round_robin_zip_datasets import RoundRobinZipDatasets
from .sharded_dataset import ShardedDataset
from .sort_dataset import SortDataset from .sort_dataset import SortDataset
from .strip_token_dataset import StripTokenDataset from .strip_token_dataset import StripTokenDataset
from .subsample_dataset import SubsampleDataset
from .token_block_dataset import TokenBlockDataset from .token_block_dataset import TokenBlockDataset
from .transform_eos_dataset import TransformEosDataset from .transform_eos_dataset import TransformEosDataset
from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset
...@@ -72,14 +76,18 @@ __all__ = [ ...@@ -72,14 +76,18 @@ __all__ = [
'NumSamplesDataset', 'NumSamplesDataset',
"OffsetTokensDataset", "OffsetTokensDataset",
'PadDataset', 'PadDataset',
'PrependDataset',
'PrependTokenDataset', 'PrependTokenDataset',
'RawAudioDataset', 'RawAudioDataset',
"RawLabelDataset", 'RawLabelDataset',
'ReplaceDataset',
'RightPadDataset', 'RightPadDataset',
'RoundRobinZipDatasets', 'RoundRobinZipDatasets',
'ShardedDataset',
'ShardedIterator', 'ShardedIterator',
'SortDataset', 'SortDataset',
"StripTokenDataset", 'StripTokenDataset',
'SubsampleDataset',
'TokenBlockDataset', 'TokenBlockDataset',
'TransformEosDataset', 'TransformEosDataset',
'TransformEosLangPairDataset', 'TransformEosLangPairDataset',
......
...@@ -64,6 +64,10 @@ class ConcatDataset(FairseqDataset): ...@@ -64,6 +64,10 @@ class ConcatDataset(FairseqDataset):
def num_tokens(self, index: int): def num_tokens(self, index: int):
return np.max(self.size(index)) return np.max(self.size(index))
def attr(self, attr: str, index: int):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
return getattr(self.datasets[dataset_idx], attr, None)
@property @property
def sizes(self): def sizes(self):
return np.concatenate( return np.concatenate(
......
...@@ -47,6 +47,9 @@ class FairseqDataset(torch.utils.data.Dataset): ...@@ -47,6 +47,9 @@ class FairseqDataset(torch.utils.data.Dataset):
"""Whether this dataset supports prefetching.""" """Whether this dataset supports prefetching."""
return False return False
def attr(self, attr: str, index: int):
return getattr(self, attr, None)
def prefetch(self, indices): def prefetch(self, indices):
"""Prefetch the data required for this epoch.""" """Prefetch the data required for this epoch."""
raise NotImplementedError raise NotImplementedError
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from . import BaseWrapperDataset
class PrependDataset(BaseWrapperDataset):
def __init__(self, dataset, prepend_getter, ensure_first_token_is=None):
super().__init__(dataset)
self.prepend_getter = prepend_getter
self.ensure_first_token = ensure_first_token_is
def __getitem__(self, idx):
item = self.dataset[idx]
is_tuple = isinstance(item, tuple)
src = item[0] if is_tuple else item
assert self.ensure_first_token is None or src[0] == self.ensure_first_token
prepend_idx = self.prepend_getter(self.dataset, idx)
assert isinstance(prepend_idx, int)
src[0] = prepend_idx
item = tuple((src,) + item[1:]) if is_tuple else src
return item
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import BaseWrapperDataset
class ReplaceDataset(BaseWrapperDataset):
def __init__(self, dataset, replace_map, offset=0):
super().__init__(dataset)
assert len(replace_map) > 0
self.replace_map = replace_map
self.offset = offset
def __getitem__(self, index):
item = self.dataset[index]
is_tuple = isinstance(item, tuple)
src = item[0] if is_tuple else item
for k, v in self.replace_map.items():
src_off = src[self.offset:]
src_off.masked_fill_(src_off == k, v)
item = tuple((src,) + item[1:]) if is_tuple else src
return item
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import os
import random
from . import BaseWrapperDataset
from fairseq.data import data_utils
class ShardedDataset(BaseWrapperDataset):
"""A :class:`~fairseq.data.FairseqDataset` wrapper that appends/prepends/strips EOS.
Loads a dataset which has been sharded into multiple files. each shard is only loaded for each specific epoch
"""
def __init__(
self,
dictionary,
dataset_impl: str,
path: str,
split: str,
epoch: int,
name: str = None,
combine: bool = False,
seed: int = 0,
):
self._name = name if name is not None else os.path.basename(path)
num_shards = 0
for i in itertools.count():
if not os.path.exists(os.path.join(path, "shard" + str(i))):
break
num_shards += 1
if num_shards > 0 and split == "train":
random.seed(seed ^ epoch)
shard = random.randint(0, num_shards - 1)
split_path = os.path.join(path, "shard" + str(shard), split)
else:
split_path = os.path.join(path, split)
if os.path.isdir(split_path):
split_path = os.path.join(split_path, split)
dataset = data_utils.load_indexed_dataset(
split_path, dictionary, dataset_impl, combine=combine
)
if dataset is None:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, split_path)
)
super().__init__(dataset)
@property
def name(self):
return self._name
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
from . import BaseWrapperDataset
class SubsampleDataset(BaseWrapperDataset):
def __init__(self, dataset, size_ratio):
super().__init__(dataset)
assert size_ratio < 1
self.actual_size = np.ceil(len(dataset) * size_ratio).astype(int)
self.indices = np.random.choice(
range(len(self.dataset)), self.actual_size, replace=False
)
print(
f"subsampled dataset from {len(self.dataset)} to {self.actual_size} (ratio={size_ratio})"
)
def __getitem__(self, index):
return self.dataset[self.indices[index]]
def __len__(self):
return self.actual_size
def collater(self, samples):
return self.dataset.collater(samples)
@property
def sizes(self):
return self.dataset.sizes[self.indices]
@property
def name(self):
return self.dataset.name
def num_tokens(self, index):
return self.dataset.num_tokens(self.indices[index])
def size(self, index):
return self.dataset.size(self.indices[index])
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
order.append(self.sizes)
return np.lexsort(order)
def prefetch(self, indices):
self.dataset.prefetch(self.indices[indices])
...@@ -35,8 +35,15 @@ class TokenBlockDataset(FairseqDataset): ...@@ -35,8 +35,15 @@ class TokenBlockDataset(FairseqDataset):
""" """
def __init__( def __init__(
self, dataset, sizes, block_size, pad, eos, break_mode=None, self,
include_targets=False, document_sep_len=1, dataset,
sizes,
block_size,
pad,
eos,
break_mode=None,
include_targets=False,
document_sep_len=1,
): ):
super().__init__() super().__init__()
self.dataset = dataset self.dataset = dataset
...@@ -49,13 +56,7 @@ class TokenBlockDataset(FairseqDataset): ...@@ -49,13 +56,7 @@ class TokenBlockDataset(FairseqDataset):
assert len(dataset) > 0 assert len(dataset) > 0
sizes = np.array(sizes, dtype=int) sizes = np.array(sizes, dtype=int)
assert break_mode != 'complete_doc' or np.all(np.diff((sizes == document_sep_len).nonzero()) != 1),\ if break_mode is None or break_mode == "none":
(
"Found multiple blank lines in the dataset, please remove them"
" (eg. cat -s raw.txt) and preprocess the data again."
)
if break_mode is None or break_mode == 'none':
total_size = sum(sizes) total_size = sum(sizes)
length = math.ceil(total_size / block_size) length = math.ceil(total_size / block_size)
...@@ -65,7 +66,7 @@ class TokenBlockDataset(FairseqDataset): ...@@ -65,7 +66,7 @@ class TokenBlockDataset(FairseqDataset):
return (start, end) return (start, end)
slice_indices = [block_at(i) for i in range(length)] slice_indices = [block_at(i) for i in range(length)]
elif break_mode == 'complete': elif break_mode == "complete":
tok_idx = 0 tok_idx = 0
sz_idx = 0 sz_idx = 0
curr_size = 0 curr_size = 0
...@@ -79,7 +80,7 @@ class TokenBlockDataset(FairseqDataset): ...@@ -79,7 +80,7 @@ class TokenBlockDataset(FairseqDataset):
curr_size = 0 curr_size = 0
if curr_size > 0: if curr_size > 0:
slice_indices.append((tok_idx, tok_idx + curr_size)) slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'complete_doc': elif break_mode == "complete_doc":
tok_idx = 0 tok_idx = 0
sz_idx = 0 sz_idx = 0
curr_size = 0 curr_size = 0
...@@ -92,15 +93,16 @@ class TokenBlockDataset(FairseqDataset): ...@@ -92,15 +93,16 @@ class TokenBlockDataset(FairseqDataset):
curr_size += sizes[sz_idx] curr_size += sizes[sz_idx]
sz_idx += 1 sz_idx += 1
else: else:
slice_indices.append((tok_idx, tok_idx + curr_size)) if curr_size > 1:
slice_indices.append((tok_idx, tok_idx + curr_size))
tok_idx += curr_size tok_idx += curr_size
curr_size = 0 curr_size = 0
if sizes[sz_idx] == document_sep_len: if sizes[sz_idx] == document_sep_len:
tok_idx += sizes[sz_idx] tok_idx += sizes[sz_idx]
sz_idx += 1 sz_idx += 1
if curr_size > 0: if curr_size > 1:
slice_indices.append((tok_idx, tok_idx + curr_size)) slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'eos': elif break_mode == "eos":
slice_indices = np.empty((len(sizes), 2), dtype=int) slice_indices = np.empty((len(sizes), 2), dtype=int)
if not torch.is_tensor(sizes): if not torch.is_tensor(sizes):
sizes = torch.tensor(sizes) sizes = torch.tensor(sizes)
...@@ -109,19 +111,21 @@ class TokenBlockDataset(FairseqDataset): ...@@ -109,19 +111,21 @@ class TokenBlockDataset(FairseqDataset):
if len(cumsum) > 1: if len(cumsum) > 1:
slice_indices[1:] = cumsum.unfold(0, 2, 1) slice_indices[1:] = cumsum.unfold(0, 2, 1)
else: else:
raise ValueError('Invalid break_mode: ' + break_mode) raise ValueError("Invalid break_mode: " + break_mode)
slice_indices = np.array(slice_indices, dtype=int) slice_indices = np.array(slice_indices, dtype=int)
self._sizes = slice_indices[:, 1] - slice_indices[:, 0] self._sizes = slice_indices[:, 1] - slice_indices[:, 0]
# build index mapping block indices to the underlying dataset indices # build index mapping block indices to the underlying dataset indices
if break_mode == 'eos': if break_mode == "eos":
# much faster version for eos break mode # much faster version for eos break mode
block_to_dataset_index = np.stack( block_to_dataset_index = np.stack(
[ [
np.arange(len(sizes)), # starting index in dataset np.arange(len(sizes)), # starting index in dataset
np.zeros(len(sizes), dtype=np.long), # starting offset within starting index np.zeros(
np.arange(len(sizes)) # ending index in dataset len(sizes), dtype=np.long
), # starting offset within starting index
np.arange(len(sizes)), # ending index in dataset
], ],
1, 1,
) )
...@@ -133,9 +137,10 @@ class TokenBlockDataset(FairseqDataset): ...@@ -133,9 +137,10 @@ class TokenBlockDataset(FairseqDataset):
start_ds_idx = ds.current_index start_ds_idx = ds.current_index
start_offset = ds.current_offset start_offset = ds.current_offset
if e <= s: if e <= s:
continue end_ds_idx = start_ds_idx
ds.seek(e - 1) else:
end_ds_idx = ds.current_index ds.seek(e - 1)
end_ds_idx = ds.current_index
block_to_dataset_index[i] = ( block_to_dataset_index[i] = (
start_ds_idx, # starting index in dataset start_ds_idx, # starting index in dataset
start_offset, # starting offset within starting index start_offset, # starting offset within starting index
...@@ -158,11 +163,17 @@ class TokenBlockDataset(FairseqDataset): ...@@ -158,11 +163,17 @@ class TokenBlockDataset(FairseqDataset):
def block_to_dataset_index(self): def block_to_dataset_index(self):
return self._block_to_dataset_index.array return self._block_to_dataset_index.array
def attr(self, attr: str, index: int):
start_ds_idx, _, _ = self.block_to_dataset_index[index]
return self.dataset.attr(attr, start_ds_idx)
def __getitem__(self, index): def __getitem__(self, index):
start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index] start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index]
buffer = torch.cat([
self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1) buffer = torch.cat(
]) [self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)]
)
slice_s, slice_e = self.slice_indices[index] slice_s, slice_e = self.slice_indices[index]
length = slice_e - slice_s length = slice_e - slice_s
s, e = start_offset, start_offset + length s, e = start_offset, start_offset + length
...@@ -173,16 +184,19 @@ class TokenBlockDataset(FairseqDataset): ...@@ -173,16 +184,19 @@ class TokenBlockDataset(FairseqDataset):
# *source* is shifted right by 1 (maybe left-padded with eos) # *source* is shifted right by 1 (maybe left-padded with eos)
# *past_target* is shifted right by 2 (left-padded as needed) # *past_target* is shifted right by 2 (left-padded as needed)
if s == 0: if s == 0:
source = torch.cat([item.new([self.eos]), buffer[0:e - 1]]) source = torch.cat([item.new([self.eos]), buffer[0 : e - 1]])
past_target = torch.cat([item.new([self.pad, self.eos]), buffer[0:e - 2]]) past_target = torch.cat(
[item.new([self.pad, self.eos]), buffer[0 : e - 2]]
)
else: else:
source = buffer[s - 1:e - 1] source = buffer[s - 1 : e - 1]
if s == 1: if s == 1:
past_target = torch.cat([item.new([self.eos]), buffer[0:e - 2]]) past_target = torch.cat([item.new([self.eos]), buffer[0 : e - 2]])
else: else:
past_target = buffer[s - 2:e - 2] past_target = buffer[s - 2 : e - 2]
return source, item, past_target return source, item, past_target
return item return item
def __len__(self): def __len__(self):
...@@ -190,15 +204,17 @@ class TokenBlockDataset(FairseqDataset): ...@@ -190,15 +204,17 @@ class TokenBlockDataset(FairseqDataset):
@property @property
def supports_prefetch(self): def supports_prefetch(self):
return getattr(self.dataset, 'supports_prefetch', False) return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices): def prefetch(self, indices):
self.dataset.prefetch({ self.dataset.prefetch(
ds_idx {
for index in indices ds_idx
for start_ds_idx, _, end_ds_idx in [self.block_to_dataset_index[index]] for index in indices
for ds_idx in range(start_ds_idx, end_ds_idx + 1) for start_ds_idx, _, end_ds_idx in [self.block_to_dataset_index[index]]
}) for ds_idx in range(start_ds_idx, end_ds_idx + 1)
}
)
class DatasetSearcher(object): class DatasetSearcher(object):
...@@ -216,17 +232,25 @@ class DatasetSearcher(object): ...@@ -216,17 +232,25 @@ class DatasetSearcher(object):
def seek(self, i): def seek(self, i):
assert i >= 0 assert i >= 0
if i < self.current_i:
self.reset() def step():
if i > self.current_i: if i < self.current_i:
to_consume = i - self.current_i self.reset()
remaining = self.sizes[self.current_index] - self.current_offset if i > self.current_i:
if remaining > to_consume: to_consume = i - self.current_i
self.current_offset += to_consume remaining = self.sizes[self.current_index] - self.current_offset
self.current_i += to_consume if remaining > to_consume:
else: self.current_offset += to_consume
self.current_i += remaining self.current_i += to_consume
self.current_index += 1 else:
self.current_offset = 0 assert remaining > 0
self.seek(i) self.current_i += remaining
self.current_index += 1
self.current_offset = 0
return True
return False
not_done = True
while not_done:
not_done = step()
assert self.current_i == i assert self.current_i == i
...@@ -19,7 +19,7 @@ from fairseq.data import ( ...@@ -19,7 +19,7 @@ from fairseq.data import (
from fairseq.tasks import FairseqTask, register_task from fairseq.tasks import FairseqTask, register_task
@register_task('language_modeling') @register_task("language_modeling")
class LanguageModelingTask(FairseqTask): class LanguageModelingTask(FairseqTask):
""" """
Train a language model. Train a language model.
...@@ -87,7 +87,7 @@ class LanguageModelingTask(FairseqTask): ...@@ -87,7 +87,7 @@ class LanguageModelingTask(FairseqTask):
self.output_dictionary = output_dictionary or dictionary self.output_dictionary = output_dictionary or dictionary
if targets is None: if targets is None:
targets = ['future'] targets = ["future"]
self.targets = targets self.targets = targets
@classmethod @classmethod
...@@ -97,38 +97,44 @@ class LanguageModelingTask(FairseqTask): ...@@ -97,38 +97,44 @@ class LanguageModelingTask(FairseqTask):
Args: Args:
args (argparse.Namespace): parsed command-line arguments args (argparse.Namespace): parsed command-line arguments
""" """
if getattr(args, 'raw_text', False): if getattr(args, "raw_text", False):
utils.deprecation_warning('--raw-text is deprecated, please use --dataset-impl=raw') utils.deprecation_warning(
args.dataset_impl = 'raw' "--raw-text is deprecated, please use --dataset-impl=raw"
elif getattr(args, 'lazy_load', False): )
utils.deprecation_warning('--lazy-load is deprecated, please use --dataset-impl=lazy') args.dataset_impl = "raw"
args.dataset_impl = 'lazy' elif getattr(args, "lazy_load", False):
utils.deprecation_warning(
"--lazy-load is deprecated, please use --dataset-impl=lazy"
)
args.dataset_impl = "lazy"
dictionary = None dictionary = None
output_dictionary = None output_dictionary = None
if args.data: if args.data:
paths = args.data.split(':') paths = args.data.split(":")
assert len(paths) > 0 assert len(paths) > 0
dictionary = Dictionary.load(os.path.join(paths[0], 'dict.txt')) dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
print('| dictionary: {} types'.format(len(dictionary))) print("| dictionary: {} types".format(len(dictionary)))
output_dictionary = dictionary output_dictionary = dictionary
if args.output_dictionary_size >= 0: if args.output_dictionary_size >= 0:
output_dictionary = TruncatedDictionary(dictionary, args.output_dictionary_size) output_dictionary = TruncatedDictionary(
dictionary, args.output_dictionary_size
)
# upgrade old checkpoints # upgrade old checkpoints
if hasattr(args, 'exclude_self_target'): if hasattr(args, "exclude_self_target"):
args.self_target = not args.exclude_self_target args.self_target = not args.exclude_self_target
targets = [] targets = []
if getattr(args, 'self_target', False): if getattr(args, "self_target", False):
targets.append('self') targets.append("self")
if getattr(args, 'future_target', False): if getattr(args, "future_target", False):
targets.append('future') targets.append("future")
if getattr(args, 'past_target', False): if getattr(args, "past_target", False):
targets.append('past') targets.append("past")
if len(targets) == 0: if len(targets) == 0:
# standard language modeling # standard language modeling
targets = ['future'] targets = ["future"]
return cls(args, dictionary, output_dictionary, targets=targets) return cls(args, dictionary, output_dictionary, targets=targets)
...@@ -137,7 +143,9 @@ class LanguageModelingTask(FairseqTask): ...@@ -137,7 +143,9 @@ class LanguageModelingTask(FairseqTask):
for target in self.targets: for target in self.targets:
if target not in model.supported_targets: if target not in model.supported_targets:
raise ValueError('Unsupported language modeling target: {}'.format(target)) raise ValueError(
"Unsupported language modeling target: {}".format(target)
)
return model return model
...@@ -147,32 +155,44 @@ class LanguageModelingTask(FairseqTask): ...@@ -147,32 +155,44 @@ class LanguageModelingTask(FairseqTask):
Args: Args:
split (str): name of the split (e.g., train, valid, test) split (str): name of the split (e.g., train, valid, test)
""" """
paths = self.args.data.split(':') paths = self.args.data.split(":")
assert len(paths) > 0 assert len(paths) > 0
data_path = paths[epoch % len(paths)] data_path = paths[epoch % len(paths)]
split_path = os.path.join(data_path, split) split_path = os.path.join(data_path, split)
dataset = data_utils.load_indexed_dataset( dataset = data_utils.load_indexed_dataset(
split_path, split_path, self.dictionary, self.args.dataset_impl, combine=combine
self.dictionary,
self.args.dataset_impl,
combine=combine,
) )
if dataset is None: if dataset is None:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path)) raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, split_path)
)
dataset = TokenBlockDataset( dataset = TokenBlockDataset(
dataset, dataset.sizes, self.args.tokens_per_sample, dataset,
pad=self.dictionary.pad(), eos=self.dictionary.eos(), dataset.sizes,
break_mode=self.args.sample_break_mode, include_targets=True, self.args.tokens_per_sample,
pad=self.dictionary.pad(),
eos=self.dictionary.eos(),
break_mode=self.args.sample_break_mode,
include_targets=True,
) )
add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none' add_eos_for_other_targets = (
self.args.sample_break_mode is not None
and self.args.sample_break_mode != "none"
)
self.datasets[split] = MonolingualDataset( self.datasets[split] = MonolingualDataset(
dataset, dataset.sizes, self.dictionary, self.output_dictionary, dataset,
add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True, dataset.sizes,
targets=self.targets, add_bos_token=self.args.add_bos_token, self.dictionary,
self.output_dictionary,
add_eos_for_other_targets=add_eos_for_other_targets,
shuffle=True,
targets=self.targets,
add_bos_token=self.args.add_bos_token,
) )
def build_dataset_for_inference(self, src_tokens, src_lengths): def build_dataset_for_inference(self, src_tokens, src_lengths):
...@@ -184,7 +204,7 @@ class LanguageModelingTask(FairseqTask): ...@@ -184,7 +204,7 @@ class LanguageModelingTask(FairseqTask):
block_size=None, block_size=None,
pad=self.source_dictionary.pad(), pad=self.source_dictionary.pad(),
eos=self.source_dictionary.eos(), eos=self.source_dictionary.eos(),
break_mode='eos', break_mode="eos",
include_targets=False, include_targets=False,
), ),
src_lengths, src_lengths,
...@@ -202,9 +222,9 @@ class LanguageModelingTask(FairseqTask): ...@@ -202,9 +222,9 @@ class LanguageModelingTask(FairseqTask):
def inference_step(self, generator, models, sample, prefix_tokens=None): def inference_step(self, generator, models, sample, prefix_tokens=None):
with torch.no_grad(): with torch.no_grad():
if prefix_tokens is None and sample['net_input']['src_tokens'].nelement(): if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement():
# note: EOS has already been removed in build_dataset_for_inference # note: EOS has already been removed in build_dataset_for_inference
prefix_tokens = sample['net_input']['src_tokens'] prefix_tokens = sample["net_input"]["src_tokens"]
return generator.generate(models, sample, prefix_tokens=prefix_tokens) return generator.generate(models, sample, prefix_tokens=prefix_tokens)
@property @property
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import os
from fairseq.data import (
ConcatDataset,
data_utils,
MonolingualDataset,
PrependDataset,
ReplaceDataset,
ShardedDataset,
SubsampleDataset,
TokenBlockDataset,
)
from fairseq.tasks import register_task
from fairseq.tasks.language_modeling import LanguageModelingTask
@register_task("tagged_language_modeling")
class TaggedLanguageModelingTask(LanguageModelingTask):
"""
Like the language modeling task, but prepends tags to each sample
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
LanguageModelingTask.add_args(parser)
parser.add_argument(
"--multiple-datasets",
action="store_true",
help="if set, treats paths in data as separate datasets to be combined, "
"rather than as splits of a single dataset",
)
parser.add_argument(
"--prepend-ds-name",
action="store_true",
help="if set and multiple-datasets is also set, prepends the name of the ds instead of "
"bos/eos token",
)
parser.add_argument(
"--generic-ds-name-chance",
type=float,
metavar="P",
default=0,
help='if multiple datasets is used, sets the prepended ds name to "generic" '
"this percentage of time",
)
parser.add_argument(
"--subsample-splits",
type=str,
metavar="SPLITS",
default="valid",
help="if multiple datasets is used, subsamples specified split(colon separated) to "
"the size of the smallest split",
)
def __init__(self, args, dictionary, output_dictionary=None, targets=None):
super().__init__(args, dictionary, output_dictionary, targets)
self.subsample_splits = (
set()
if args.subsample_splits is None
else set(args.subsample_splits.split(":"))
)
def make_prepended_ds(self, dataset):
def ds_name(dataset, index):
if (
self.args.generic_ds_name_chance > 0
and np.random.rand() <= self.args.generic_ds_name_chance
):
ds_name = "generic"
else:
ds_name = dataset.attr("name", index)
assert ds_name is not None
return self.dictionary.indices[ds_name]
dataset = PrependDataset(
dataset, prepend_getter=ds_name, ensure_first_token_is=self.dictionary.eos()
)
return dataset
def load_dataset(self, split, epoch=0, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
paths = self.args.data.split(":")
assert len(paths) > 0
if self.args.multiple_datasets:
if len(paths) == 1:
paths = [os.path.join(paths[0], p) for p in next(os.walk(paths[0]))[1]]
datasets = [
ShardedDataset(
self.dictionary,
self.args.dataset_impl,
path,
split,
epoch,
combine=combine,
)
for path in paths
]
if split in self.subsample_splits:
sizes = [sum(d.sizes) for d in datasets]
min_sz = min(sizes)
ratios = [min_sz / sz for sz in sizes]
datasets = [
SubsampleDataset(d, r) if r < 1 else d
for d, r in zip(datasets, ratios)
]
dataset = ConcatDataset(datasets)
else:
data_path = paths[epoch % len(paths)]
split_path = os.path.join(data_path, split)
dataset = data_utils.load_indexed_dataset(
split_path, self.dictionary, self.args.dataset_impl, combine=combine
)
if dataset is None:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, split_path)
)
dataset = TokenBlockDataset(
dataset,
dataset.sizes,
self.args.tokens_per_sample,
pad=self.dictionary.pad(),
eos=self.dictionary.eos(),
break_mode=self.args.sample_break_mode,
include_targets=True,
)
if self.args.prepend_ds_name:
dataset = self.make_prepended_ds(dataset)
dataset = ReplaceDataset(dataset, { self.dictionary.eos(): self.dictionary.indices['\\n'] }, offset=1)
add_eos_for_other_targets = (
self.args.sample_break_mode is not None
and self.args.sample_break_mode != "none"
)
self.datasets[split] = MonolingualDataset(
dataset,
dataset.sizes,
self.dictionary,
self.output_dictionary,
add_eos_for_other_targets=add_eos_for_other_targets,
shuffle=True,
targets=self.targets,
add_bos_token=self.args.add_bos_token,
)
...@@ -73,7 +73,6 @@ def main(args, init_distributed=False): ...@@ -73,7 +73,6 @@ def main(args, init_distributed=False):
lr = trainer.get_lr() lr = trainer.get_lr()
train_meter = StopwatchMeter() train_meter = StopwatchMeter()
train_meter.start() train_meter.start()
valid_losses = [None]
valid_subsets = args.valid_subset.split(',') valid_subsets = args.valid_subset.split(',')
while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update: while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update:
# train for one epoch # train for one epoch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment