# coding=utf-8 # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """dataset to split one large one into multiple smaller datasets""" import torch import numpy as np def should_split(split): """ given split proportions checks if should split Examples: >>> should_split([10,0,0]) False >>> should_split([1,.1,.2]) True """ return max(split)/sum(split) != 1. def get_split(args): """ Get dataset splits from comma separated string list """ splits = [] if args.split.find(',') != -1: splits = [float(s) for s in args.split.split(',')] elif args.split.find('/') != -1: splits = [float(s) for s in args.split.split('/')] else: splits = [float(args.split)] split_total = sum(splits) if split_total < 1.: splits.append(1-split_total) while len(splits) < 3: splits.append(0.) splits = splits[:3] if args.valid_data is not None: splits[1] = 0. if args.test_data is not None: splits[2] = 0. final_sum = sum(splits) return [s/final_sum for s in splits] class SplitDataset(torch.utils.data.Dataset): """ Dataset wrapper to access a subset of another dataset. Purpose: useful to index into existing datasets, possibly large-scale datasets as the subindexing operation is done in an on-the-fly manner. Arguments: ds (Dataset or array-like): List of datasets to be subindexed split_inds (1D array-like): List of indices part of subset """ def __init__(self, ds, split_inds, **kwargs): self.split_inds = list(split_inds) self.wrapped_data = ds def __len__(self): return len(self.split_inds) def __getitem__(self, index): return self.wrapped_data[self.split_inds[index]] def num_tokens(self): return self.wrapped_data.num_tokens() def __iter__(self): for idx in self.split_inds: yield self.wrapped_data[idx] def split_ds(ds, split=[.8,.2,.0], shuffle=True): """ Split a dataset into subsets given proportions of how much to allocate per split. If a split is 0% returns None for that split. Purpose: Useful for creating train/val/test splits Arguments: ds (Dataset or array-like): Data to be split. split (1D array-like): proportions to split `ds`. `sum(splits) != 0` shuffle (boolean): Randomly split dataset. Default: True """ split_sum = sum(split) if split_sum == 0: raise Exception('Split cannot sum to 0.') split = np.array(split) split /= split_sum ds_len = len(ds) inds = np.arange(ds_len) if shuffle: np.random.shuffle(inds) start_idx = 0 residual_idx = 0 rtn_ds = [None]*len(split) for i, f in enumerate(split): if f != 0: proportion = ds_len*split[i] residual_idx += proportion % 1 split_ = int(int(proportion) + residual_idx) split_inds = inds[start_idx:start_idx+max(split_, 1)] rtn_ds[i] = SplitDataset(ds, split_inds) start_idx += split_ residual_idx %= 1 return rtn_ds