# coding=utf-8 # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """dataset to split one large one into multiple smaller datasets""" import torch import numpy as np def get_train_valid_test_split(splits_string, size): """ Get dataset splits from comma or '/' separated string list.""" splits = [] if splits_string.find(',') != -1: splits = [float(s) for s in splits_string.split(',')] elif splits_string.find('/') != -1: splits = [float(s) for s in splits_string.split('/')] else: splits = [float(splits_string)] while len(splits) < 3: splits.append(0.) splits = splits[:3] splits_sum = sum(splits) assert splits_sum > 0.0 splits = [split/splits_sum for split in splits] splits_index = [0] for index, split in enumerate(splits): splits_index.append(splits_index[index] + int(round(split * float(size)))) diff = splits_index[-1] - size for index in range(1, len(splits_index)): splits_index[index] -= diff return splits_index class SplitDataset(torch.utils.data.Dataset): """ Dataset wrapper to access a subset of another dataset. Purpose: useful to index into existing datasets, possibly large-scale datasets as the subindexing operation is done in an on-the-fly manner. Arguments: ds (Dataset or array-like): List of datasets to be subindexed split_inds (1D array-like): List of indices part of subset """ def __init__(self, ds, split_inds, **kwargs): self.split_inds = list(split_inds) self.wrapped_data = ds def __len__(self): return len(self.split_inds) def __getitem__(self, index): return self.wrapped_data[self.split_inds[index]] def num_tokens(self): return self.wrapped_data.num_tokens() def __iter__(self): for idx in self.split_inds: yield self.wrapped_data[idx] def split_ds(ds, split=[.8,.2,.0], shuffle=True): """ Split a dataset into subsets given proportions of how much to allocate per split. If a split is 0% returns None for that split. Purpose: Useful for creating train/val/test splits Arguments: ds (Dataset or array-like): Data to be split. split (1D array-like): proportions to split `ds`. `sum(splits) != 0` shuffle (boolean): Randomly split dataset. Default: True """ split_sum = sum(split) if split_sum == 0: raise Exception('Split cannot sum to 0.') split = np.array(split) split /= split_sum ds_len = len(ds) inds = np.arange(ds_len) if shuffle: np.random.shuffle(inds) start_idx = 0 residual_idx = 0 rtn_ds = [None]*len(split) for i, f in enumerate(split): if f != 0: proportion = ds_len*split[i] residual_idx += proportion % 1 split_ = int(int(proportion) + residual_idx) split_inds = inds[start_idx:start_idx+max(split_, 1)] rtn_ds[i] = SplitDataset(ds, split_inds) start_idx += split_ residual_idx %= 1 return rtn_ds