Unverified Commit 01c11a05 authored by Erjia Guan's avatar Erjia Guan Committed by GitHub
Browse files

[DataPipe] Properly cleanup unclosed files within generator function (#6997)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 790f1cdc
......@@ -30,25 +30,26 @@ class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]):
def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]:
for _, file in self.datapipe:
lines = (line.decode() for line in file)
if self.fieldnames:
fieldnames = self.fieldnames
else:
# The first row is skipped, because it only contains the number of samples
next(lines)
# Empty field names are filtered out, because some files have an extra white space after the header
# line, which is recognized as extra column
fieldnames = [name for name in next(csv.reader([next(lines)], dialect="celeba")) if name]
# Some files do not include a label for the image ID column
if fieldnames[0] != "image_id":
fieldnames.insert(0, "image_id")
for line in csv.DictReader(lines, fieldnames=fieldnames, dialect="celeba"):
yield line.pop("image_id"), line
file.close()
try:
lines = (line.decode() for line in file)
if self.fieldnames:
fieldnames = self.fieldnames
else:
# The first row is skipped, because it only contains the number of samples
next(lines)
# Empty field names are filtered out, because some files have an extra white space after the header
# line, which is recognized as extra column
fieldnames = [name for name in next(csv.reader([next(lines)], dialect="celeba")) if name]
# Some files do not include a label for the image ID column
if fieldnames[0] != "image_id":
fieldnames.insert(0, "image_id")
for line in csv.DictReader(lines, fieldnames=fieldnames, dialect="celeba"):
yield line.pop("image_id"), line
finally:
file.close()
NAME = "celeba"
......
......@@ -37,27 +37,28 @@ class MNISTFileReader(IterDataPipe[torch.Tensor]):
def __iter__(self) -> Iterator[torch.Tensor]:
for _, file in self.datapipe:
read = functools.partial(fromfile, file, byte_order="big")
try:
read = functools.partial(fromfile, file, byte_order="big")
magic = int(read(dtype=torch.int32, count=1))
dtype = self._DTYPE_MAP[magic // 256]
ndim = magic % 256 - 1
magic = int(read(dtype=torch.int32, count=1))
dtype = self._DTYPE_MAP[magic // 256]
ndim = magic % 256 - 1
num_samples = int(read(dtype=torch.int32, count=1))
shape = cast(List[int], read(dtype=torch.int32, count=ndim).tolist()) if ndim else []
count = prod(shape) if shape else 1
num_samples = int(read(dtype=torch.int32, count=1))
shape = cast(List[int], read(dtype=torch.int32, count=ndim).tolist()) if ndim else []
count = prod(shape) if shape else 1
start = self.start or 0
stop = min(self.stop, num_samples) if self.stop else num_samples
start = self.start or 0
stop = min(self.stop, num_samples) if self.stop else num_samples
if start:
num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
file.seek(num_bytes_per_value * count * start, 1)
if start:
num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
file.seek(num_bytes_per_value * count * start, 1)
for _ in range(stop - start):
yield read(dtype=dtype, count=count).reshape(shape)
file.close()
for _ in range(stop - start):
yield read(dtype=dtype, count=count).reshape(shape)
finally:
file.close()
class _MNISTBase(Dataset):
......
......@@ -28,12 +28,13 @@ class PCAMH5Reader(IterDataPipe[Tuple[str, io.IOBase]]):
import h5py
for _, handle in self.datapipe:
with h5py.File(handle) as data:
if self.key is not None:
data = data[self.key]
yield from data
handle.close()
try:
with h5py.File(handle) as data:
if self.key is not None:
data = data[self.key]
yield from data
finally:
handle.close()
_Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment