# Adapted from https://github.com/Lightning-AI/lightning/blob/2845e7565dbe6b765ae32870e7d2bc456529c30a/tests/tests_pytorch/utilities/test_auto_restart.py#L1397
"""dataset_dict: dictionary mapping from index to batch
length is used in the case of DistributedSampler: e.g. the dataset could have size 1k, but
with 8 GPUs the dataset_dict would only have 125 items.
"""
super().__init__()
self.dataset_dict=dataset_dict
self.length=lengthorlen(self.dataset_dict)
def__getitem__(self,index):
returnself.dataset_dict[index]
def__len__(self):
returnself.length
# From https://github.com/PyTorchLightning/lightning-bolts/blob/2415b49a2b405693cd499e09162c89f807abbdc4/pl_bolts/transforms/dataset_normalizations.py#L10
classSHMArray(np.ndarray):#copied from https://numpy.org/doc/stable/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array