# Copyright (c) 2017 Elad Hoffer # Copyright (c) 2018-2020, NVIDIA CORPORATION. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import logging from operator import itemgetter import torch from torch.utils.data import DataLoader from torch.utils.data import Dataset import seq2seq.data.config as config from seq2seq.data.sampler import BucketingSampler from seq2seq.data.sampler import DistributedSampler from seq2seq.data.sampler import ShardingSampler from seq2seq.data.sampler import StaticDistributedSampler def build_collate_fn(batch_first=False, parallel=True, sort=False): """ Factory for collate_fn functions. :param batch_first: if True returns batches in (batch, seq) format, if False returns in (seq, batch) format :param parallel: if True builds batches from parallel corpus (src, tgt) :param sort: if True sorts by src sequence length within each batch """ def collate_seq(seq): """ Builds batches for training or inference. Batches are returned as pytorch tensors, with padding. :param seq: list of sequences """ lengths = torch.tensor([len(s) for s in seq], dtype=torch.int64) batch_length = max(lengths) shape = (len(seq), batch_length) seq_tensor = torch.full(shape, config.PAD, dtype=torch.int64) for i, s in enumerate(seq): end_seq = lengths[i] seq_tensor[i, :end_seq].copy_(s[:end_seq]) if not batch_first: seq_tensor = seq_tensor.t() return (seq_tensor, lengths) def parallel_collate(seqs): """ Builds batches from parallel dataset (src, tgt), optionally sorts batch by src sequence length. :param seqs: tuple of (src, tgt) sequences """ src_seqs, tgt_seqs = zip(*seqs) if sort: indices, src_seqs = zip(*sorted(enumerate(src_seqs), key=lambda item: len(item[1]), reverse=True)) tgt_seqs = [tgt_seqs[idx] for idx in indices] return tuple([collate_seq(s) for s in [src_seqs, tgt_seqs]]) def single_collate(src_seqs): """ Builds batches from text dataset, optionally sorts batch by src sequence length. :param src_seqs: source sequences """ if sort: indices, src_seqs = zip(*sorted(enumerate(src_seqs), key=lambda item: len(item[1]), reverse=True)) else: indices = range(len(src_seqs)) return collate_seq(src_seqs), tuple(indices) if parallel: return parallel_collate else: return single_collate class SyntheticDataset(Dataset): def __init__(self, vocab_size, seq_len, nsamples): self.vocab_size = vocab_size self.nsamples = nsamples self.seq_len = seq_len def __getitem__(self, idx): rand = torch.randint(0, self.vocab_size, size=(self.seq_len,)) return rand def unsort(self, array): return array def get_loader(self, batch_size=1, num_workers=0, batch_first=False, pad=False, repeat=1): collate_fn = build_collate_fn(batch_first, parallel=False, sort=True) sampler = StaticDistributedSampler(self, batch_size, pad, repeat) return DataLoader(self, batch_size=batch_size, collate_fn=collate_fn, sampler=sampler, num_workers=num_workers, pin_memory=True, drop_last=False) def __len__(self): return self.nsamples class RawTextDataset(Dataset): def __init__(self, raw_data=None, raw_datafile=None, tokenizer=None, sort=False, max_size=None): self.tokenizer = tokenizer self.sorted = False if raw_datafile: with open(raw_datafile, 'r') as f: self.raw_data = f.readlines() else: self.raw_data = raw_data if max_size: self.raw_data = self.raw_data[:max_size] self.lengths = [len(s.split()) for s in self.raw_data] if sort: self.sort_by_length() def __getitem__(self, idx): raw = self.raw_data[idx] tokenized = self.tokenizer.tokenize(raw) return tokenized def unsort(self, array): """ "Unsorts" given array (restores original order of elements before dataset was sorted by sequence length). :param array: array to be "unsorted" """ if self.sorted: inverse = sorted(enumerate(self.indices), key=itemgetter(1)) array = [array[i[0]] for i in inverse] return array def sort_by_length(self): output = sorted( enumerate(self.raw_data), key=lambda x: len(x[1].split()), ) self.indices, self.raw_data = zip(*output) self.lengths = [self.lengths[idx] for idx in self.indices] self.sorted = True def __len__(self): return len(self.raw_data) def get_loader(self, batch_size=1, num_workers=0, batch_first=False, pad=False, repeat=1): collate_fn = build_collate_fn(batch_first, parallel=False, sort=True) sampler = StaticDistributedSampler(self, batch_size, pad, repeat) return DataLoader(self, batch_size=batch_size, collate_fn=collate_fn, sampler=sampler, num_workers=num_workers, pin_memory=True, drop_last=False) class TextDataset(Dataset): def __init__(self, src_fname, tokenizer, min_len=None, max_len=None, sort=False, max_size=None): """ Constructor for the TextDataset. Builds monolingual dataset. :param src_fname: path to the file with data :param tokenizer: tokenizer :param min_len: minimum sequence length :param max_len: maximum sequence length :param sort: sorts dataset by sequence length :param max_size: loads at most 'max_size' samples from the input file, if None loads the entire dataset """ self.min_len = min_len self.max_len = max_len self.parallel = False self.sorted = False self.src = self.process_data(src_fname, tokenizer, max_size) if min_len is not None and max_len is not None: self.filter_data(min_len, max_len) lengths = [len(s) for s in self.src] self.lengths = torch.tensor(lengths) if sort: self.sort_by_length() def sort_by_length(self): """ Sorts dataset by the sequence length. """ self.lengths, indices = self.lengths.sort(descending=True) self.src = [self.src[idx] for idx in indices] self.indices = indices.tolist() self.sorted = True def unsort(self, array): """ "Unsorts" given array (restores original order of elements before dataset was sorted by sequence length). :param array: array to be "unsorted" """ if self.sorted: inverse = sorted(enumerate(self.indices), key=itemgetter(1)) array = [array[i[0]] for i in inverse] return array def filter_data(self, min_len, max_len): """ Preserves only samples which satisfy the following inequality: min_len <= sample sequence length <= max_len :param min_len: minimum sequence length :param max_len: maximum sequence length """ logging.info(f'Filtering data, min len: {min_len}, max len: {max_len}') initial_len = len(self.src) filtered_src = [] for src in self.src: if min_len <= len(src) <= max_len: filtered_src.append(src) self.src = filtered_src filtered_len = len(self.src) logging.info(f'Pairs before: {initial_len}, after: {filtered_len}') def process_data(self, fname, tokenizer, max_size): """ Loads data from the input file. :param fname: input file name :param tokenizer: tokenizer :param max_size: loads at most 'max_size' samples from the input file, if None loads the entire dataset """ logging.info(f'Processing data from {fname}') data = [] with open(fname) as dfile: for idx, line in enumerate(dfile): if max_size and idx == max_size: break entry = tokenizer.segment(line) entry = torch.tensor(entry) data.append(entry) return data def __len__(self): return len(self.src) def __getitem__(self, idx): return self.src[idx] def get_loader(self, batch_size=1, seeds=None, shuffle=False, num_workers=0, batch_first=False, pad=False, batching=None, batching_opt={}): collate_fn = build_collate_fn(batch_first, parallel=self.parallel, sort=True) if shuffle: if batching == 'random': sampler = DistributedSampler(self, batch_size, seeds) elif batching == 'sharding': sampler = ShardingSampler(self, batch_size, seeds, batching_opt['shard_size']) elif batching == 'bucketing': sampler = BucketingSampler(self, batch_size, seeds, batching_opt['num_buckets']) else: raise NotImplementedError else: sampler = StaticDistributedSampler(self, batch_size, pad) return DataLoader(self, batch_size=batch_size, collate_fn=collate_fn, sampler=sampler, num_workers=num_workers, pin_memory=True, drop_last=False) class ParallelDataset(TextDataset): def __init__(self, src_fname, tgt_fname, tokenizer, min_len, max_len, sort=False, max_size=None): """ Constructor for the ParallelDataset. Tokenization is done when the data is loaded from the disk. :param src_fname: path to the file with src language data :param tgt_fname: path to the file with tgt language data :param tokenizer: tokenizer :param min_len: minimum sequence length :param max_len: maximum sequence length :param sort: sorts dataset by sequence length :param max_size: loads at most 'max_size' samples from the input file, if None loads the entire dataset """ self.min_len = min_len self.max_len = max_len self.parallel = True self.sorted = False self.src = self.process_data(src_fname, tokenizer, max_size) self.tgt = self.process_data(tgt_fname, tokenizer, max_size) assert len(self.src) == len(self.tgt) self.filter_data(min_len, max_len) assert len(self.src) == len(self.tgt) src_lengths = [len(s) for s in self.src] tgt_lengths = [len(t) for t in self.tgt] self.src_lengths = torch.tensor(src_lengths) self.tgt_lengths = torch.tensor(tgt_lengths) self.lengths = self.src_lengths + self.tgt_lengths if sort: self.sort_by_length() def sort_by_length(self): """ Sorts dataset by the sequence length. """ self.lengths, indices = self.lengths.sort(descending=True) self.src = [self.src[idx] for idx in indices] self.tgt = [self.tgt[idx] for idx in indices] self.src_lengths = [self.src_lengths[idx] for idx in indices] self.tgt_lengths = [self.tgt_lengths[idx] for idx in indices] self.indices = indices.tolist() self.sorted = True def filter_data(self, min_len, max_len): """ Preserves only samples which satisfy the following inequality: min_len <= src sample sequence length <= max_len AND min_len <= tgt sample sequence length <= max_len :param min_len: minimum sequence length :param max_len: maximum sequence length """ logging.info(f'Filtering data, min len: {min_len}, max len: {max_len}') initial_len = len(self.src) filtered_src = [] filtered_tgt = [] for src, tgt in zip(self.src, self.tgt): if min_len <= len(src) <= max_len and \ min_len <= len(tgt) <= max_len: filtered_src.append(src) filtered_tgt.append(tgt) self.src = filtered_src self.tgt = filtered_tgt filtered_len = len(self.src) logging.info(f'Pairs before: {initial_len}, after: {filtered_len}') def __getitem__(self, idx): return self.src[idx], self.tgt[idx] class LazyParallelDataset(TextDataset): def __init__(self, src_fname, tgt_fname, tokenizer, min_len, max_len, sort=False, max_size=None): """ Constructor for the LazyParallelDataset. Tokenization is done on the fly. :param src_fname: path to the file with src language data :param tgt_fname: path to the file with tgt language data :param tokenizer: tokenizer :param min_len: minimum sequence length :param max_len: maximum sequence length :param sort: sorts dataset by sequence length :param max_size: loads at most 'max_size' samples from the input file, if None loads the entire dataset """ self.min_len = min_len self.max_len = max_len self.parallel = True self.sorted = False self.tokenizer = tokenizer self.raw_src = self.process_raw_data(src_fname, max_size) self.raw_tgt = self.process_raw_data(tgt_fname, max_size) assert len(self.raw_src) == len(self.raw_tgt) logging.info(f'Filtering data, min len: {min_len}, max len: {max_len}') # Subtracting 2 because EOS and BOS are added later during tokenization self.filter_raw_data(min_len - 2, max_len - 2) assert len(self.raw_src) == len(self.raw_tgt) # Adding 2 because EOS and BOS are added later during tokenization src_lengths = [i + 2 for i in self.src_len] tgt_lengths = [i + 2 for i in self.tgt_len] self.src_lengths = torch.tensor(src_lengths) self.tgt_lengths = torch.tensor(tgt_lengths) self.lengths = self.src_lengths + self.tgt_lengths def process_raw_data(self, fname, max_size): """ Loads data from the input file. :param fname: input file name :param max_size: loads at most 'max_size' samples from the input file, if None loads the entire dataset """ logging.info(f'Processing data from {fname}') data = [] with open(fname) as dfile: for idx, line in enumerate(dfile): if max_size and idx == max_size: break data.append(line) return data def filter_raw_data(self, min_len, max_len): """ Preserves only samples which satisfy the following inequality: min_len <= src sample sequence length <= max_len AND min_len <= tgt sample sequence length <= max_len :param min_len: minimum sequence length :param max_len: maximum sequence length """ initial_len = len(self.raw_src) filtered_src = [] filtered_tgt = [] filtered_src_len = [] filtered_tgt_len = [] for src, tgt in zip(self.raw_src, self.raw_tgt): src_len = src.count(' ') + 1 tgt_len = tgt.count(' ') + 1 if min_len <= src_len <= max_len and \ min_len <= tgt_len <= max_len: filtered_src.append(src) filtered_tgt.append(tgt) filtered_src_len.append(src_len) filtered_tgt_len.append(tgt_len) self.raw_src = filtered_src self.raw_tgt = filtered_tgt self.src_len = filtered_src_len self.tgt_len = filtered_tgt_len filtered_len = len(self.raw_src) logging.info(f'Pairs before: {initial_len}, after: {filtered_len}') def __getitem__(self, idx): src = torch.tensor(self.tokenizer.segment(self.raw_src[idx])) tgt = torch.tensor(self.tokenizer.segment(self.raw_tgt[idx])) return src, tgt def __len__(self): return len(self.raw_src)