Unverified Commit d6186284 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

update gradient_selector dataloader iterator import (#2690)

parent 717877d0
...@@ -31,7 +31,7 @@ from sklearn.datasets import load_svmlight_file ...@@ -31,7 +31,7 @@ from sklearn.datasets import load_svmlight_file
import torch import torch
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
# pylint: disable=E0611 # pylint: disable=E0611
from torch.utils.data.dataloader import _DataLoaderIter, _utils from torch.utils.data.dataloader import _SingleProcessDataLoaderIter, _MultiProcessingDataLoaderIter, _utils
from . import constants from . import constants
from . import syssettings from . import syssettings
...@@ -585,39 +585,27 @@ class ChunkDataLoader(DataLoader): ...@@ -585,39 +585,27 @@ class ChunkDataLoader(DataLoader):
return _ChunkDataLoaderIter(self) return _ChunkDataLoaderIter(self)
class _ChunkDataLoaderIter(_DataLoaderIter): class _ChunkDataLoaderIter:
""" """
DataLoaderIter class used to more quickly load a batch of indices at once. DataLoaderIter class used to more quickly load a batch of indices at once.
""" """
def __init__(self, dataloader):
if dataloader.num_workers == 0:
self.iter = _SingleProcessDataLoaderIter(dataloader)
else:
self.iter = _MultiProcessingDataLoaderIter(dataloader)
def __next__(self): def __next__(self):
# only chunk that is edited from base # only chunk that is edited from base
if self.num_workers == 0: # same-process loading if self.iter._num_workers == 0: # same-process loading
indices = next(self.sample_iter) # may raise StopIteration indices = next(self.iter._sampler_iter) # may raise StopIteration
if len(indices) > 1: if len(indices) > 1:
batch = self.dataset[np.array(indices)] batch = self.iter._dataset[np.array(indices)]
else: else:
batch = self.collate_fn([self.dataset[i] for i in indices]) batch = self.iter._collate_fn([self.iter._dataset[i] for i in indices])
if self.pin_memory: if self.iter._pin_memory:
batch = _utils.pin_memory.pin_memory_batch(batch) batch = _utils.pin_memory.pin_memory_batch(batch)
return batch return batch
else:
# check if the next sample has already been generated return next(self.iter)
if self.rcvd_idx in self.reorder_dict:
batch = self.reorder_dict.pop(self.rcvd_idx)
return self._process_next_batch(batch)
if self.batches_outstanding == 0:
self._shutdown_workers()
raise StopIteration
while True:
assert (not self.shutdown and self.batches_outstanding > 0)
idx, batch = self._get_batch()
self.batches_outstanding -= 1
if idx != self.rcvd_idx:
# store out-of-order samples
self.reorder_dict[idx] = batch
continue
return self._process_next_batch(batch)
...@@ -287,6 +287,11 @@ class Solver(nn.Module): ...@@ -287,6 +287,11 @@ class Solver(nn.Module):
else: else:
pin_memory = False pin_memory = False
if num_workers == 0:
timeout = 0
else:
timeout = 60
self.ds_train = ChunkDataLoader( self.ds_train = ChunkDataLoader(
PreparedData, PreparedData,
batch_size=self.Nminibatch, batch_size=self.Nminibatch,
...@@ -294,7 +299,7 @@ class Solver(nn.Module): ...@@ -294,7 +299,7 @@ class Solver(nn.Module):
drop_last=True, drop_last=True,
num_workers=num_workers, num_workers=num_workers,
pin_memory=pin_memory, pin_memory=pin_memory,
timeout=60) timeout=timeout)
self.f_train = LearnabilityMB(self.Nminibatch, self.D, self.f_train = LearnabilityMB(self.Nminibatch, self.D,
constants.Coefficients.SLE[order], constants.Coefficients.SLE[order],
self.groups, self.groups,
...@@ -338,7 +343,7 @@ class Solver(nn.Module): ...@@ -338,7 +343,7 @@ class Solver(nn.Module):
Completes the forward operation and computes gradients for learnability and penalty. Completes the forward operation and computes gradients for learnability and penalty.
""" """
f_train = self.f_train(s, xsub, ysub) f_train = self.f_train(s, xsub, ysub)
pen = self.penalty(s) pen = self.penalty(s).unsqueeze(0).unsqueeze(0)
# pylint: disable=E1102 # pylint: disable=E1102
grad_outputs = torch.tensor([[1]], dtype=torch.get_default_dtype(), grad_outputs = torch.tensor([[1]], dtype=torch.get_default_dtype(),
device=self.device) device=self.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment