Unverified Commit 01c11a05 authored by Erjia Guan's avatar Erjia Guan Committed by GitHub
Browse files

[DataPipe] Properly cleanup unclosed files within generator function (#6997)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 790f1cdc
...@@ -30,25 +30,26 @@ class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]): ...@@ -30,25 +30,26 @@ class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]):
def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]: def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]:
for _, file in self.datapipe: for _, file in self.datapipe:
lines = (line.decode() for line in file) try:
lines = (line.decode() for line in file)
if self.fieldnames:
fieldnames = self.fieldnames if self.fieldnames:
else: fieldnames = self.fieldnames
# The first row is skipped, because it only contains the number of samples else:
next(lines) # The first row is skipped, because it only contains the number of samples
next(lines)
# Empty field names are filtered out, because some files have an extra white space after the header
# line, which is recognized as extra column # Empty field names are filtered out, because some files have an extra white space after the header
fieldnames = [name for name in next(csv.reader([next(lines)], dialect="celeba")) if name] # line, which is recognized as extra column
# Some files do not include a label for the image ID column fieldnames = [name for name in next(csv.reader([next(lines)], dialect="celeba")) if name]
if fieldnames[0] != "image_id": # Some files do not include a label for the image ID column
fieldnames.insert(0, "image_id") if fieldnames[0] != "image_id":
fieldnames.insert(0, "image_id")
for line in csv.DictReader(lines, fieldnames=fieldnames, dialect="celeba"):
yield line.pop("image_id"), line for line in csv.DictReader(lines, fieldnames=fieldnames, dialect="celeba"):
yield line.pop("image_id"), line
file.close() finally:
file.close()
NAME = "celeba" NAME = "celeba"
......
...@@ -37,27 +37,28 @@ class MNISTFileReader(IterDataPipe[torch.Tensor]): ...@@ -37,27 +37,28 @@ class MNISTFileReader(IterDataPipe[torch.Tensor]):
def __iter__(self) -> Iterator[torch.Tensor]: def __iter__(self) -> Iterator[torch.Tensor]:
for _, file in self.datapipe: for _, file in self.datapipe:
read = functools.partial(fromfile, file, byte_order="big") try:
read = functools.partial(fromfile, file, byte_order="big")
magic = int(read(dtype=torch.int32, count=1)) magic = int(read(dtype=torch.int32, count=1))
dtype = self._DTYPE_MAP[magic // 256] dtype = self._DTYPE_MAP[magic // 256]
ndim = magic % 256 - 1 ndim = magic % 256 - 1
num_samples = int(read(dtype=torch.int32, count=1)) num_samples = int(read(dtype=torch.int32, count=1))
shape = cast(List[int], read(dtype=torch.int32, count=ndim).tolist()) if ndim else [] shape = cast(List[int], read(dtype=torch.int32, count=ndim).tolist()) if ndim else []
count = prod(shape) if shape else 1 count = prod(shape) if shape else 1
start = self.start or 0 start = self.start or 0
stop = min(self.stop, num_samples) if self.stop else num_samples stop = min(self.stop, num_samples) if self.stop else num_samples
if start: if start:
num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8 num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
file.seek(num_bytes_per_value * count * start, 1) file.seek(num_bytes_per_value * count * start, 1)
for _ in range(stop - start): for _ in range(stop - start):
yield read(dtype=dtype, count=count).reshape(shape) yield read(dtype=dtype, count=count).reshape(shape)
finally:
file.close() file.close()
class _MNISTBase(Dataset): class _MNISTBase(Dataset):
......
...@@ -28,12 +28,13 @@ class PCAMH5Reader(IterDataPipe[Tuple[str, io.IOBase]]): ...@@ -28,12 +28,13 @@ class PCAMH5Reader(IterDataPipe[Tuple[str, io.IOBase]]):
import h5py import h5py
for _, handle in self.datapipe: for _, handle in self.datapipe:
with h5py.File(handle) as data: try:
if self.key is not None: with h5py.File(handle) as data:
data = data[self.key] if self.key is not None:
yield from data data = data[self.key]
yield from data
handle.close() finally:
handle.close()
_Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256")) _Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment