Commit 8179ebd3 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

removed split dataset

parent 5300c69f
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""dataset to split one large one into multiple smaller datasets"""
import torch
import numpy as np
def get_train_valid_test_split(splits_string, size):
""" Get dataset splits from comma or '/' separated string list."""
splits = []
if splits_string.find(',') != -1:
splits = [float(s) for s in splits_string.split(',')]
elif splits_string.find('/') != -1:
splits = [float(s) for s in splits_string.split('/')]
else:
splits = [float(splits_string)]
while len(splits) < 3:
splits.append(0.)
splits = splits[:3]
splits_sum = sum(splits)
assert splits_sum > 0.0
splits = [split/splits_sum for split in splits]
splits_index = [0]
for index, split in enumerate(splits):
splits_index.append(splits_index[index] +
int(round(split * float(size))))
diff = splits_index[-1] - size
for index in range(1, len(splits_index)):
splits_index[index] -= diff
return splits_index
class SplitDataset(torch.utils.data.Dataset):
"""
Dataset wrapper to access a subset of another dataset.
Purpose: useful to index into existing datasets, possibly
large-scale datasets as the subindexing operation is done in an
on-the-fly manner.
Arguments:
ds (Dataset or array-like): List of datasets to be subindexed
split_inds (1D array-like): List of indices part of subset
"""
def __init__(self, ds, split_inds, **kwargs):
self.split_inds = list(split_inds)
self.wrapped_data = ds
def __len__(self):
return len(self.split_inds)
def __getitem__(self, index):
return self.wrapped_data[self.split_inds[index]]
def num_tokens(self):
return self.wrapped_data.num_tokens()
def __iter__(self):
for idx in self.split_inds:
yield self.wrapped_data[idx]
def split_ds(ds, split=[.8,.2,.0], shuffle=True):
"""
Split a dataset into subsets given proportions of how
much to allocate per split. If a split is 0% returns None for that split.
Purpose: Useful for creating train/val/test splits
Arguments:
ds (Dataset or array-like): Data to be split.
split (1D array-like): proportions to split `ds`. `sum(splits) != 0`
shuffle (boolean): Randomly split dataset. Default: True
"""
split_sum = sum(split)
if split_sum == 0:
raise Exception('Split cannot sum to 0.')
split = np.array(split)
split /= split_sum
ds_len = len(ds)
inds = np.arange(ds_len)
if shuffle:
np.random.shuffle(inds)
start_idx = 0
residual_idx = 0
rtn_ds = [None]*len(split)
for i, f in enumerate(split):
if f != 0:
proportion = ds_len*split[i]
residual_idx += proportion % 1
split_ = int(int(proportion) + residual_idx)
split_inds = inds[start_idx:start_idx+max(split_, 1)]
rtn_ds[i] = SplitDataset(ds, split_inds)
start_idx += split_
residual_idx %= 1
return rtn_ds
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment