"src/array/vscode:/vscode.git/clone" did not exist on "8e525dad71c0857dd93d8ed5eaa7aa57516acbc0"
Unverified Commit 7eb5d7fc authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

close streams in prototype datasets (#6647)

* close streams in prototype datasets

* refactor prototype SBD to avoid closing demux streams at construction time

* mypy
parent 7d2de404
......@@ -661,15 +661,15 @@ class SBDMockData:
_NUM_CATEGORIES = 20
@classmethod
def _make_split_files(cls, root_map):
ids_map = {
split: [f"2008_{idx:06d}" for idx in idcs]
for split, idcs in (
("train", [0, 1, 2]),
("train_noval", [0, 2]),
("val", [3]),
)
}
def _make_split_files(cls, root_map, *, split):
splits_and_idcs = [
("train", [0, 1, 2]),
("val", [3]),
]
if split == "train_noval":
splits_and_idcs.append(("train_noval", [0, 2]))
ids_map = {split: [f"2008_{idx:06d}" for idx in idcs] for split, idcs in splits_and_idcs}
for split, ids in ids_map.items():
with open(root_map[split] / f"{split}.txt", "w") as fh:
......@@ -710,12 +710,14 @@ class SBDMockData:
return torch.randint(0, cls._NUM_CATEGORIES + 1, size=size, dtype=torch.uint8).numpy()
@classmethod
def generate(cls, root):
def generate(cls, root, *, split):
archive_folder = root / "benchmark_RELEASE"
dataset_folder = archive_folder / "dataset"
dataset_folder.mkdir(parents=True, exist_ok=True)
ids, num_samples_map = cls._make_split_files(defaultdict(lambda: dataset_folder, {"train_noval": root}))
ids, num_samples_map = cls._make_split_files(
defaultdict(lambda: dataset_folder, {"train_noval": root}), split=split
)
sizes = cls._make_anns_folder(dataset_folder, "cls", ids)
create_image_folder(
dataset_folder, "img", lambda idx: f"{ids[idx]}.jpg", num_examples=len(ids), size=lambda idx: sizes[idx]
......@@ -723,12 +725,12 @@ class SBDMockData:
make_tar(root, "benchmark.tgz", archive_folder, compression="gz")
return num_samples_map
return num_samples_map[split]
@register_mock(configs=combinations_grid(split=("train", "val", "train_noval")))
def sbd(root, config):
return SBDMockData.generate(root)[config["split"]]
return SBDMockData.generate(root, split=config["split"])
@register_mock(configs=[dict()])
......
import functools
import io
import pickle
from collections import deque
from pathlib import Path
import pytest
......@@ -11,10 +12,11 @@ from torch.utils.data import DataLoader
from torch.utils.data.graph import traverse_dps
from torch.utils.data.graph_settings import get_all_graph_pipes
from torchdata.datapipes.iter import ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper
from torchvision._utils import sequence_to_str
from torchvision.prototype import datasets, transforms
from torchvision.prototype import datasets, features, transforms
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
from torchvision.prototype.features import Image, Label
assert_samples_equal = functools.partial(
assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True
......@@ -25,6 +27,17 @@ def extract_datapipes(dp):
return get_all_graph_pipes(traverse_dps(dp))
def consume(iterator):
# Copied from the official itertools recipes: https://docs.python.org/3/library/itertools.html#itertools-recipes
deque(iterator, maxlen=0)
def next_consume(iterator):
item = next(iterator)
consume(iterator)
return item
@pytest.fixture(autouse=True)
def test_home(mocker, tmp_path):
mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path))
......@@ -66,7 +79,7 @@ class TestCommon:
dataset, _ = dataset_mock.load(config)
try:
sample = next(iter(dataset))
sample = next_consume(iter(dataset))
except StopIteration:
raise AssertionError("Unable to draw any sample.") from None
except Exception as error:
......@@ -84,22 +97,53 @@ class TestCommon:
assert len(list(dataset)) == mock_info["num_samples"]
@pytest.fixture
def log_session_streams(self):
debug_unclosed_streams = StreamWrapper.debug_unclosed_streams
try:
StreamWrapper.debug_unclosed_streams = True
yield
finally:
StreamWrapper.debug_unclosed_streams = debug_unclosed_streams
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_vanilla_tensors(self, dataset_mock, config):
def test_stream_closing(self, log_session_streams, dataset_mock, config):
def make_msg_and_close(head):
unclosed_streams = []
for stream in StreamWrapper.session_streams.keys():
unclosed_streams.append(repr(stream.file_obj))
stream.close()
unclosed_streams = "\n".join(unclosed_streams)
return f"{head}\n\n{unclosed_streams}"
if StreamWrapper.session_streams:
raise pytest.UsageError(make_msg_and_close("A previous test did not close the following streams:"))
dataset, _ = dataset_mock.load(config)
vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor}
if vanilla_tensors:
consume(iter(dataset))
if StreamWrapper.session_streams:
raise AssertionError(make_msg_and_close("The following streams were not closed after a full iteration:"))
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_simple_tensors(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
simple_tensors = {key for key, value in next_consume(iter(dataset)).items() if features.is_simple_tensor(value)}
if simple_tensors:
raise AssertionError(
f"The values of key(s) "
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors."
f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained simple tensors."
)
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_transformable(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
next(iter(dataset.map(transforms.Identity())))
dataset = dataset.map(transforms.Identity())
consume(iter(dataset))
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_traversable(self, dataset_mock, config):
......@@ -131,7 +175,7 @@ class TestCommon:
collate_fn=self._collate_fn,
)
next(iter(dl))
consume(dl)
# TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also
# that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680
......@@ -148,7 +192,7 @@ class TestCommon:
def test_save_load(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
sample = next(iter(dataset))
sample = next_consume(iter(dataset))
with io.BytesIO() as buffer:
torch.save(sample, buffer)
......@@ -177,7 +221,7 @@ class TestQMNIST:
def test_extra_label(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
sample = next(iter(dataset))
sample = next_consume(iter(dataset))
for key, type in (
("nist_hsf_series", int),
("nist_writer_id", int),
......@@ -214,7 +258,7 @@ class TestUSPS:
assert "image" in sample
assert "label" in sample
assert isinstance(sample["image"], Image)
assert isinstance(sample["label"], Label)
assert isinstance(sample["image"], features.Image)
assert isinstance(sample["label"], features.Label)
assert sample["image"].shape == (1, 16, 16)
......@@ -30,24 +30,26 @@ class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]):
def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]:
for _, file in self.datapipe:
file = (line.decode() for line in file)
lines = (line.decode() for line in file)
if self.fieldnames:
fieldnames = self.fieldnames
else:
# The first row is skipped, because it only contains the number of samples
next(file)
next(lines)
# Empty field names are filtered out, because some files have an extra white space after the header
# line, which is recognized as extra column
fieldnames = [name for name in next(csv.reader([next(file)], dialect="celeba")) if name]
fieldnames = [name for name in next(csv.reader([next(lines)], dialect="celeba")) if name]
# Some files do not include a label for the image ID column
if fieldnames[0] != "image_id":
fieldnames.insert(0, "image_id")
for line in csv.DictReader(file, fieldnames=fieldnames, dialect="celeba"):
for line in csv.DictReader(lines, fieldnames=fieldnames, dialect="celeba"):
yield line.pop("image_id"), line
file.close()
NAME = "celeba"
......
......@@ -62,7 +62,9 @@ class _CifarBase(Dataset):
def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]:
_, file = data
return cast(Dict[str, Any], pickle.load(file, encoding="latin1"))
content = cast(Dict[str, Any], pickle.load(file, encoding="latin1"))
file.close()
return content
def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]:
image_array, category_idx = data
......
......@@ -97,6 +97,8 @@ class CLEVR(Dataset):
buffer_size=INFINITE_BUFFER_SIZE,
)
else:
for _, file in scenes_dp:
file.close()
dp = Mapper(images_dp, self._add_empty_anns)
return Mapper(dp, self._prepare_sample)
......
......@@ -57,6 +57,8 @@ class MNISTFileReader(IterDataPipe[torch.Tensor]):
for _ in range(stop - start):
yield read(dtype=dtype, count=count).reshape(shape)
file.close()
class _MNISTBase(Dataset):
_URL_BASE: Union[str, Sequence[str]]
......
......@@ -33,6 +33,8 @@ class PCAMH5Reader(IterDataPipe[Tuple[str, io.IOBase]]):
data = data[self.key]
yield from data
handle.close()
_Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256"))
......
......@@ -49,31 +49,35 @@ class SBD(Dataset):
super().__init__(root, dependencies=("scipy",), skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
archive = HttpResource(
"https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz",
sha256="6a5a2918d5c73ce032fdeba876574d150d9d04113ab87540a1304cbcc715be53",
)
extra_split = HttpResource(
"http://home.bharathh.info/pubs/codes/SBD/train_noval.txt",
sha256="0b2068f7a359d2907431803e1cd63bf6162da37d7d503b589d3b08c6fd0c2432",
)
return [archive, extra_split]
resources = [
HttpResource(
"https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz",
sha256="6a5a2918d5c73ce032fdeba876574d150d9d04113ab87540a1304cbcc715be53",
)
]
if self._split == "train_noval":
resources.append(
HttpResource(
"http://home.bharathh.info/pubs/codes/SBD/train_noval.txt",
sha256="0b2068f7a359d2907431803e1cd63bf6162da37d7d503b589d3b08c6fd0c2432",
)
)
return resources # type: ignore[return-value]
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
path = pathlib.Path(data[0])
parent, grandparent, *_ = path.parents
if parent.name == "dataset":
return 0
elif grandparent.name == "dataset":
if grandparent.name == "dataset":
if parent.name == "img":
return 1
return 0
elif parent.name == "cls":
return 2
else:
return None
else:
return None
return 1
if parent.name == "dataset" and self._split != "train_noval":
return 2
return None
def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[str, BinaryIO]]) -> Dict[str, Any]:
split_and_image_data, ann_data = data
......@@ -93,18 +97,24 @@ class SBD(Dataset):
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
archive_dp, extra_split_dp = resource_dps
archive_dp = resource_dps[0]
split_dp, images_dp, anns_dp = Demultiplexer(
archive_dp,
3,
self._classify_archive,
buffer_size=INFINITE_BUFFER_SIZE,
drop_none=True,
)
if self._split == "train_noval":
split_dp = extra_split_dp
archive_dp, split_dp = resource_dps
images_dp, anns_dp = Demultiplexer(
archive_dp,
2,
self._classify_archive,
buffer_size=INFINITE_BUFFER_SIZE,
drop_none=True,
)
else:
archive_dp = resource_dps[0]
images_dp, anns_dp, split_dp = Demultiplexer(
archive_dp,
3,
self._classify_archive,
buffer_size=INFINITE_BUFFER_SIZE,
drop_none=True,
)
split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
split_dp = LineReader(split_dp, decode=True)
......
......@@ -94,7 +94,9 @@ class VOC(Dataset):
return None
def _parse_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
return cast(Dict[str, Any], VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot())["annotation"])
ann = cast(Dict[str, Any], VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot())["annotation"])
buffer.close()
return ann
def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
anns = self._parse_detection_ann(buffer)
......
......@@ -8,7 +8,6 @@ import torch
import torch.distributed as dist
import torch.utils.data
from torchdata.datapipes.iter import IoPathFileLister, IoPathFileOpener, IterDataPipe, ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper
from torchvision.prototype.utils._internal import fromfile
......@@ -40,10 +39,9 @@ def read_mat(buffer: BinaryIO, **kwargs: Any) -> Any:
except ImportError as error:
raise ModuleNotFoundError("Package `scipy` is required to be installed to read .mat files.") from error
if isinstance(buffer, StreamWrapper):
buffer = buffer.file_obj
return sio.loadmat(buffer, **kwargs)
data = sio.loadmat(buffer, **kwargs)
buffer.close()
return data
class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]):
......
......@@ -27,7 +27,9 @@ class EncodedData(_Feature):
@classmethod
def from_file(cls: Type[D], file: BinaryIO, **kwargs: Any) -> D:
return cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder), **kwargs)
encoded_data = cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder), **kwargs)
file.close()
return encoded_data
@classmethod
def from_path(cls: Type[D], path: Union[str, os.PathLike], **kwargs: Any) -> D:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment