Unverified Commit 7eb5d7fc authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

close streams in prototype datasets (#6647)

* close streams in prototype datasets

* refactor prototype SBD to avoid closing demux streams at construction time

* mypy
parent 7d2de404
...@@ -661,15 +661,15 @@ class SBDMockData: ...@@ -661,15 +661,15 @@ class SBDMockData:
_NUM_CATEGORIES = 20 _NUM_CATEGORIES = 20
@classmethod @classmethod
def _make_split_files(cls, root_map): def _make_split_files(cls, root_map, *, split):
ids_map = { splits_and_idcs = [
split: [f"2008_{idx:06d}" for idx in idcs] ("train", [0, 1, 2]),
for split, idcs in ( ("val", [3]),
("train", [0, 1, 2]), ]
("train_noval", [0, 2]), if split == "train_noval":
("val", [3]), splits_and_idcs.append(("train_noval", [0, 2]))
)
} ids_map = {split: [f"2008_{idx:06d}" for idx in idcs] for split, idcs in splits_and_idcs}
for split, ids in ids_map.items(): for split, ids in ids_map.items():
with open(root_map[split] / f"{split}.txt", "w") as fh: with open(root_map[split] / f"{split}.txt", "w") as fh:
...@@ -710,12 +710,14 @@ class SBDMockData: ...@@ -710,12 +710,14 @@ class SBDMockData:
return torch.randint(0, cls._NUM_CATEGORIES + 1, size=size, dtype=torch.uint8).numpy() return torch.randint(0, cls._NUM_CATEGORIES + 1, size=size, dtype=torch.uint8).numpy()
@classmethod @classmethod
def generate(cls, root): def generate(cls, root, *, split):
archive_folder = root / "benchmark_RELEASE" archive_folder = root / "benchmark_RELEASE"
dataset_folder = archive_folder / "dataset" dataset_folder = archive_folder / "dataset"
dataset_folder.mkdir(parents=True, exist_ok=True) dataset_folder.mkdir(parents=True, exist_ok=True)
ids, num_samples_map = cls._make_split_files(defaultdict(lambda: dataset_folder, {"train_noval": root})) ids, num_samples_map = cls._make_split_files(
defaultdict(lambda: dataset_folder, {"train_noval": root}), split=split
)
sizes = cls._make_anns_folder(dataset_folder, "cls", ids) sizes = cls._make_anns_folder(dataset_folder, "cls", ids)
create_image_folder( create_image_folder(
dataset_folder, "img", lambda idx: f"{ids[idx]}.jpg", num_examples=len(ids), size=lambda idx: sizes[idx] dataset_folder, "img", lambda idx: f"{ids[idx]}.jpg", num_examples=len(ids), size=lambda idx: sizes[idx]
...@@ -723,12 +725,12 @@ class SBDMockData: ...@@ -723,12 +725,12 @@ class SBDMockData:
make_tar(root, "benchmark.tgz", archive_folder, compression="gz") make_tar(root, "benchmark.tgz", archive_folder, compression="gz")
return num_samples_map return num_samples_map[split]
@register_mock(configs=combinations_grid(split=("train", "val", "train_noval"))) @register_mock(configs=combinations_grid(split=("train", "val", "train_noval")))
def sbd(root, config): def sbd(root, config):
return SBDMockData.generate(root)[config["split"]] return SBDMockData.generate(root, split=config["split"])
@register_mock(configs=[dict()]) @register_mock(configs=[dict()])
......
import functools import functools
import io import io
import pickle import pickle
from collections import deque
from pathlib import Path from pathlib import Path
import pytest import pytest
...@@ -11,10 +12,11 @@ from torch.utils.data import DataLoader ...@@ -11,10 +12,11 @@ from torch.utils.data import DataLoader
from torch.utils.data.graph import traverse_dps from torch.utils.data.graph import traverse_dps
from torch.utils.data.graph_settings import get_all_graph_pipes from torch.utils.data.graph_settings import get_all_graph_pipes
from torchdata.datapipes.iter import ShardingFilter, Shuffler from torchdata.datapipes.iter import ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype import datasets, transforms from torchvision.prototype import datasets, features, transforms
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
from torchvision.prototype.features import Image, Label
assert_samples_equal = functools.partial( assert_samples_equal = functools.partial(
assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True
...@@ -25,6 +27,17 @@ def extract_datapipes(dp): ...@@ -25,6 +27,17 @@ def extract_datapipes(dp):
return get_all_graph_pipes(traverse_dps(dp)) return get_all_graph_pipes(traverse_dps(dp))
def consume(iterator):
# Copied from the official itertools recipes: https://docs.python.org/3/library/itertools.html#itertools-recipes
deque(iterator, maxlen=0)
def next_consume(iterator):
item = next(iterator)
consume(iterator)
return item
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def test_home(mocker, tmp_path): def test_home(mocker, tmp_path):
mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path)) mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path))
...@@ -66,7 +79,7 @@ class TestCommon: ...@@ -66,7 +79,7 @@ class TestCommon:
dataset, _ = dataset_mock.load(config) dataset, _ = dataset_mock.load(config)
try: try:
sample = next(iter(dataset)) sample = next_consume(iter(dataset))
except StopIteration: except StopIteration:
raise AssertionError("Unable to draw any sample.") from None raise AssertionError("Unable to draw any sample.") from None
except Exception as error: except Exception as error:
...@@ -84,22 +97,53 @@ class TestCommon: ...@@ -84,22 +97,53 @@ class TestCommon:
assert len(list(dataset)) == mock_info["num_samples"] assert len(list(dataset)) == mock_info["num_samples"]
@pytest.fixture
def log_session_streams(self):
debug_unclosed_streams = StreamWrapper.debug_unclosed_streams
try:
StreamWrapper.debug_unclosed_streams = True
yield
finally:
StreamWrapper.debug_unclosed_streams = debug_unclosed_streams
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_vanilla_tensors(self, dataset_mock, config): def test_stream_closing(self, log_session_streams, dataset_mock, config):
def make_msg_and_close(head):
unclosed_streams = []
for stream in StreamWrapper.session_streams.keys():
unclosed_streams.append(repr(stream.file_obj))
stream.close()
unclosed_streams = "\n".join(unclosed_streams)
return f"{head}\n\n{unclosed_streams}"
if StreamWrapper.session_streams:
raise pytest.UsageError(make_msg_and_close("A previous test did not close the following streams:"))
dataset, _ = dataset_mock.load(config) dataset, _ = dataset_mock.load(config)
vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor} consume(iter(dataset))
if vanilla_tensors:
if StreamWrapper.session_streams:
raise AssertionError(make_msg_and_close("The following streams were not closed after a full iteration:"))
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_simple_tensors(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
simple_tensors = {key for key, value in next_consume(iter(dataset)).items() if features.is_simple_tensor(value)}
if simple_tensors:
raise AssertionError( raise AssertionError(
f"The values of key(s) " f"The values of key(s) "
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors." f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained simple tensors."
) )
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_transformable(self, dataset_mock, config): def test_transformable(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config) dataset, _ = dataset_mock.load(config)
next(iter(dataset.map(transforms.Identity()))) dataset = dataset.map(transforms.Identity())
consume(iter(dataset))
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_traversable(self, dataset_mock, config): def test_traversable(self, dataset_mock, config):
...@@ -131,7 +175,7 @@ class TestCommon: ...@@ -131,7 +175,7 @@ class TestCommon:
collate_fn=self._collate_fn, collate_fn=self._collate_fn,
) )
next(iter(dl)) consume(dl)
# TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also # TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also
# that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680 # that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680
...@@ -148,7 +192,7 @@ class TestCommon: ...@@ -148,7 +192,7 @@ class TestCommon:
def test_save_load(self, dataset_mock, config): def test_save_load(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config) dataset, _ = dataset_mock.load(config)
sample = next(iter(dataset)) sample = next_consume(iter(dataset))
with io.BytesIO() as buffer: with io.BytesIO() as buffer:
torch.save(sample, buffer) torch.save(sample, buffer)
...@@ -177,7 +221,7 @@ class TestQMNIST: ...@@ -177,7 +221,7 @@ class TestQMNIST:
def test_extra_label(self, dataset_mock, config): def test_extra_label(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config) dataset, _ = dataset_mock.load(config)
sample = next(iter(dataset)) sample = next_consume(iter(dataset))
for key, type in ( for key, type in (
("nist_hsf_series", int), ("nist_hsf_series", int),
("nist_writer_id", int), ("nist_writer_id", int),
...@@ -214,7 +258,7 @@ class TestUSPS: ...@@ -214,7 +258,7 @@ class TestUSPS:
assert "image" in sample assert "image" in sample
assert "label" in sample assert "label" in sample
assert isinstance(sample["image"], Image) assert isinstance(sample["image"], features.Image)
assert isinstance(sample["label"], Label) assert isinstance(sample["label"], features.Label)
assert sample["image"].shape == (1, 16, 16) assert sample["image"].shape == (1, 16, 16)
...@@ -30,24 +30,26 @@ class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]): ...@@ -30,24 +30,26 @@ class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]):
def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]: def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]:
for _, file in self.datapipe: for _, file in self.datapipe:
file = (line.decode() for line in file) lines = (line.decode() for line in file)
if self.fieldnames: if self.fieldnames:
fieldnames = self.fieldnames fieldnames = self.fieldnames
else: else:
# The first row is skipped, because it only contains the number of samples # The first row is skipped, because it only contains the number of samples
next(file) next(lines)
# Empty field names are filtered out, because some files have an extra white space after the header # Empty field names are filtered out, because some files have an extra white space after the header
# line, which is recognized as extra column # line, which is recognized as extra column
fieldnames = [name for name in next(csv.reader([next(file)], dialect="celeba")) if name] fieldnames = [name for name in next(csv.reader([next(lines)], dialect="celeba")) if name]
# Some files do not include a label for the image ID column # Some files do not include a label for the image ID column
if fieldnames[0] != "image_id": if fieldnames[0] != "image_id":
fieldnames.insert(0, "image_id") fieldnames.insert(0, "image_id")
for line in csv.DictReader(file, fieldnames=fieldnames, dialect="celeba"): for line in csv.DictReader(lines, fieldnames=fieldnames, dialect="celeba"):
yield line.pop("image_id"), line yield line.pop("image_id"), line
file.close()
NAME = "celeba" NAME = "celeba"
......
...@@ -62,7 +62,9 @@ class _CifarBase(Dataset): ...@@ -62,7 +62,9 @@ class _CifarBase(Dataset):
def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]: def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]:
_, file = data _, file = data
return cast(Dict[str, Any], pickle.load(file, encoding="latin1")) content = cast(Dict[str, Any], pickle.load(file, encoding="latin1"))
file.close()
return content
def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]: def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]:
image_array, category_idx = data image_array, category_idx = data
......
...@@ -97,6 +97,8 @@ class CLEVR(Dataset): ...@@ -97,6 +97,8 @@ class CLEVR(Dataset):
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
) )
else: else:
for _, file in scenes_dp:
file.close()
dp = Mapper(images_dp, self._add_empty_anns) dp = Mapper(images_dp, self._add_empty_anns)
return Mapper(dp, self._prepare_sample) return Mapper(dp, self._prepare_sample)
......
...@@ -57,6 +57,8 @@ class MNISTFileReader(IterDataPipe[torch.Tensor]): ...@@ -57,6 +57,8 @@ class MNISTFileReader(IterDataPipe[torch.Tensor]):
for _ in range(stop - start): for _ in range(stop - start):
yield read(dtype=dtype, count=count).reshape(shape) yield read(dtype=dtype, count=count).reshape(shape)
file.close()
class _MNISTBase(Dataset): class _MNISTBase(Dataset):
_URL_BASE: Union[str, Sequence[str]] _URL_BASE: Union[str, Sequence[str]]
......
...@@ -33,6 +33,8 @@ class PCAMH5Reader(IterDataPipe[Tuple[str, io.IOBase]]): ...@@ -33,6 +33,8 @@ class PCAMH5Reader(IterDataPipe[Tuple[str, io.IOBase]]):
data = data[self.key] data = data[self.key]
yield from data yield from data
handle.close()
_Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256")) _Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256"))
......
...@@ -49,31 +49,35 @@ class SBD(Dataset): ...@@ -49,31 +49,35 @@ class SBD(Dataset):
super().__init__(root, dependencies=("scipy",), skip_integrity_check=skip_integrity_check) super().__init__(root, dependencies=("scipy",), skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]: def _resources(self) -> List[OnlineResource]:
archive = HttpResource( resources = [
"https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz", HttpResource(
sha256="6a5a2918d5c73ce032fdeba876574d150d9d04113ab87540a1304cbcc715be53", "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz",
) sha256="6a5a2918d5c73ce032fdeba876574d150d9d04113ab87540a1304cbcc715be53",
extra_split = HttpResource( )
"http://home.bharathh.info/pubs/codes/SBD/train_noval.txt", ]
sha256="0b2068f7a359d2907431803e1cd63bf6162da37d7d503b589d3b08c6fd0c2432", if self._split == "train_noval":
) resources.append(
return [archive, extra_split] HttpResource(
"http://home.bharathh.info/pubs/codes/SBD/train_noval.txt",
sha256="0b2068f7a359d2907431803e1cd63bf6162da37d7d503b589d3b08c6fd0c2432",
)
)
return resources # type: ignore[return-value]
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
path = pathlib.Path(data[0]) path = pathlib.Path(data[0])
parent, grandparent, *_ = path.parents parent, grandparent, *_ = path.parents
if parent.name == "dataset": if grandparent.name == "dataset":
return 0
elif grandparent.name == "dataset":
if parent.name == "img": if parent.name == "img":
return 1 return 0
elif parent.name == "cls": elif parent.name == "cls":
return 2 return 1
else:
return None if parent.name == "dataset" and self._split != "train_noval":
else: return 2
return None
return None
def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[str, BinaryIO]]) -> Dict[str, Any]: def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[str, BinaryIO]]) -> Dict[str, Any]:
split_and_image_data, ann_data = data split_and_image_data, ann_data = data
...@@ -93,18 +97,24 @@ class SBD(Dataset): ...@@ -93,18 +97,24 @@ class SBD(Dataset):
) )
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
archive_dp, extra_split_dp = resource_dps
archive_dp = resource_dps[0]
split_dp, images_dp, anns_dp = Demultiplexer(
archive_dp,
3,
self._classify_archive,
buffer_size=INFINITE_BUFFER_SIZE,
drop_none=True,
)
if self._split == "train_noval": if self._split == "train_noval":
split_dp = extra_split_dp archive_dp, split_dp = resource_dps
images_dp, anns_dp = Demultiplexer(
archive_dp,
2,
self._classify_archive,
buffer_size=INFINITE_BUFFER_SIZE,
drop_none=True,
)
else:
archive_dp = resource_dps[0]
images_dp, anns_dp, split_dp = Demultiplexer(
archive_dp,
3,
self._classify_archive,
buffer_size=INFINITE_BUFFER_SIZE,
drop_none=True,
)
split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt")) split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
split_dp = LineReader(split_dp, decode=True) split_dp = LineReader(split_dp, decode=True)
......
...@@ -94,7 +94,9 @@ class VOC(Dataset): ...@@ -94,7 +94,9 @@ class VOC(Dataset):
return None return None
def _parse_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]: def _parse_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
return cast(Dict[str, Any], VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot())["annotation"]) ann = cast(Dict[str, Any], VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot())["annotation"])
buffer.close()
return ann
def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]: def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
anns = self._parse_detection_ann(buffer) anns = self._parse_detection_ann(buffer)
......
...@@ -8,7 +8,6 @@ import torch ...@@ -8,7 +8,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.utils.data import torch.utils.data
from torchdata.datapipes.iter import IoPathFileLister, IoPathFileOpener, IterDataPipe, ShardingFilter, Shuffler from torchdata.datapipes.iter import IoPathFileLister, IoPathFileOpener, IterDataPipe, ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper
from torchvision.prototype.utils._internal import fromfile from torchvision.prototype.utils._internal import fromfile
...@@ -40,10 +39,9 @@ def read_mat(buffer: BinaryIO, **kwargs: Any) -> Any: ...@@ -40,10 +39,9 @@ def read_mat(buffer: BinaryIO, **kwargs: Any) -> Any:
except ImportError as error: except ImportError as error:
raise ModuleNotFoundError("Package `scipy` is required to be installed to read .mat files.") from error raise ModuleNotFoundError("Package `scipy` is required to be installed to read .mat files.") from error
if isinstance(buffer, StreamWrapper): data = sio.loadmat(buffer, **kwargs)
buffer = buffer.file_obj buffer.close()
return data
return sio.loadmat(buffer, **kwargs)
class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]): class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]):
......
...@@ -27,7 +27,9 @@ class EncodedData(_Feature): ...@@ -27,7 +27,9 @@ class EncodedData(_Feature):
@classmethod @classmethod
def from_file(cls: Type[D], file: BinaryIO, **kwargs: Any) -> D: def from_file(cls: Type[D], file: BinaryIO, **kwargs: Any) -> D:
return cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder), **kwargs) encoded_data = cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder), **kwargs)
file.close()
return encoded_data
@classmethod @classmethod
def from_path(cls: Type[D], path: Union[str, os.PathLike], **kwargs: Any) -> D: def from_path(cls: Type[D], path: Union[str, os.PathLike], **kwargs: Any) -> D:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment