Unverified Commit d6186284 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

update gradient_selector dataloader iterator import (#2690)

parent 717877d0
......@@ -31,7 +31,7 @@ from sklearn.datasets import load_svmlight_file
import torch
from torch.utils.data import DataLoader, Dataset
# pylint: disable=E0611
from torch.utils.data.dataloader import _DataLoaderIter, _utils
from torch.utils.data.dataloader import _SingleProcessDataLoaderIter, _MultiProcessingDataLoaderIter, _utils
from . import constants
from . import syssettings
......@@ -585,39 +585,27 @@ class ChunkDataLoader(DataLoader):
return _ChunkDataLoaderIter(self)
class _ChunkDataLoaderIter(_DataLoaderIter):
class _ChunkDataLoaderIter:
"""
DataLoaderIter class used to more quickly load a batch of indices at once.
"""
def __init__(self, dataloader):
if dataloader.num_workers == 0:
self.iter = _SingleProcessDataLoaderIter(dataloader)
else:
self.iter = _MultiProcessingDataLoaderIter(dataloader)
def __next__(self):
# only chunk that is edited from base
if self.num_workers == 0: # same-process loading
indices = next(self.sample_iter) # may raise StopIteration
if self.iter._num_workers == 0: # same-process loading
indices = next(self.iter._sampler_iter) # may raise StopIteration
if len(indices) > 1:
batch = self.dataset[np.array(indices)]
batch = self.iter._dataset[np.array(indices)]
else:
batch = self.collate_fn([self.dataset[i] for i in indices])
batch = self.iter._collate_fn([self.iter._dataset[i] for i in indices])
if self.pin_memory:
if self.iter._pin_memory:
batch = _utils.pin_memory.pin_memory_batch(batch)
return batch
# check if the next sample has already been generated
if self.rcvd_idx in self.reorder_dict:
batch = self.reorder_dict.pop(self.rcvd_idx)
return self._process_next_batch(batch)
if self.batches_outstanding == 0:
self._shutdown_workers()
raise StopIteration
while True:
assert (not self.shutdown and self.batches_outstanding > 0)
idx, batch = self._get_batch()
self.batches_outstanding -= 1
if idx != self.rcvd_idx:
# store out-of-order samples
self.reorder_dict[idx] = batch
continue
return self._process_next_batch(batch)
else:
return next(self.iter)
......@@ -287,6 +287,11 @@ class Solver(nn.Module):
else:
pin_memory = False
if num_workers == 0:
timeout = 0
else:
timeout = 60
self.ds_train = ChunkDataLoader(
PreparedData,
batch_size=self.Nminibatch,
......@@ -294,7 +299,7 @@ class Solver(nn.Module):
drop_last=True,
num_workers=num_workers,
pin_memory=pin_memory,
timeout=60)
timeout=timeout)
self.f_train = LearnabilityMB(self.Nminibatch, self.D,
constants.Coefficients.SLE[order],
self.groups,
......@@ -338,7 +343,7 @@ class Solver(nn.Module):
Completes the forward operation and computes gradients for learnability and penalty.
"""
f_train = self.f_train(s, xsub, ysub)
pen = self.penalty(s)
pen = self.penalty(s).unsqueeze(0).unsqueeze(0)
# pylint: disable=E1102
grad_outputs = torch.tensor([[1]], dtype=torch.get_default_dtype(),
device=self.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment