Unverified Commit 01c11a05 authored by Erjia Guan's avatar Erjia Guan Committed by GitHub
Browse files

[DataPipe] Properly cleanup unclosed files within generator function (#6997)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 790f1cdc
...@@ -30,6 +30,7 @@ class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]): ...@@ -30,6 +30,7 @@ class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]):
def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]: def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]:
for _, file in self.datapipe: for _, file in self.datapipe:
try:
lines = (line.decode() for line in file) lines = (line.decode() for line in file)
if self.fieldnames: if self.fieldnames:
...@@ -47,7 +48,7 @@ class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]): ...@@ -47,7 +48,7 @@ class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]):
for line in csv.DictReader(lines, fieldnames=fieldnames, dialect="celeba"): for line in csv.DictReader(lines, fieldnames=fieldnames, dialect="celeba"):
yield line.pop("image_id"), line yield line.pop("image_id"), line
finally:
file.close() file.close()
......
...@@ -37,6 +37,7 @@ class MNISTFileReader(IterDataPipe[torch.Tensor]): ...@@ -37,6 +37,7 @@ class MNISTFileReader(IterDataPipe[torch.Tensor]):
def __iter__(self) -> Iterator[torch.Tensor]: def __iter__(self) -> Iterator[torch.Tensor]:
for _, file in self.datapipe: for _, file in self.datapipe:
try:
read = functools.partial(fromfile, file, byte_order="big") read = functools.partial(fromfile, file, byte_order="big")
magic = int(read(dtype=torch.int32, count=1)) magic = int(read(dtype=torch.int32, count=1))
...@@ -56,7 +57,7 @@ class MNISTFileReader(IterDataPipe[torch.Tensor]): ...@@ -56,7 +57,7 @@ class MNISTFileReader(IterDataPipe[torch.Tensor]):
for _ in range(stop - start): for _ in range(stop - start):
yield read(dtype=dtype, count=count).reshape(shape) yield read(dtype=dtype, count=count).reshape(shape)
finally:
file.close() file.close()
......
...@@ -28,11 +28,12 @@ class PCAMH5Reader(IterDataPipe[Tuple[str, io.IOBase]]): ...@@ -28,11 +28,12 @@ class PCAMH5Reader(IterDataPipe[Tuple[str, io.IOBase]]):
import h5py import h5py
for _, handle in self.datapipe: for _, handle in self.datapipe:
try:
with h5py.File(handle) as data: with h5py.File(handle) as data:
if self.key is not None: if self.key is not None:
data = data[self.key] data = data[self.key]
yield from data yield from data
finally:
handle.close() handle.close()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment