Unverified Commit 43dbfd2e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Turn warnings in prototype datasets tests into errors (#5540)

* fix PCAM prototype dataset

* update Zip and Tar archive loader datapipes

* only fail on warnings from the pytorch ecosystem

* Revert "only fail on warnings from the pytorch ecosystem"

This reverts commit 2bf3aa6f2d875a4055f7f3ed0b468316fc60d4f4.
parent a26534c9
...@@ -10,6 +10,7 @@ import lzma ...@@ -10,6 +10,7 @@ import lzma
import pathlib import pathlib
import pickle import pickle
import random import random
import warnings
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from collections import defaultdict, Counter from collections import defaultdict, Counter
...@@ -470,7 +471,10 @@ def imagenet(info, root, config): ...@@ -470,7 +471,10 @@ def imagenet(info, root, config):
] ]
num_children = 1 num_children = 1
synsets.extend((0, "", "", "", num_children, [], 0, 0) for _ in range(5)) synsets.extend((0, "", "", "", num_children, [], 0, 0) for _ in range(5))
savemat(data_root / "meta.mat", dict(synsets=synsets)) with warnings.catch_warnings():
# The warning is not for savemat, but rather for some internals savemet is using
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
savemat(data_root / "meta.mat", dict(synsets=synsets))
make_tar(root, devkit_root.with_suffix(".tar.gz").name, compression="gz") make_tar(root, devkit_root.with_suffix(".tar.gz").name, compression="gz")
else: # config.split == "test" else: # config.split == "test"
......
...@@ -35,6 +35,7 @@ def test_coverage(): ...@@ -35,6 +35,7 @@ def test_coverage():
) )
@pytest.mark.filterwarnings("error")
class TestCommon: class TestCommon:
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_smoke(self, test_home, dataset_mock, config): def test_smoke(self, test_home, dataset_mock, config):
......
...@@ -10,7 +10,7 @@ from torchdata.datapipes.iter import ( ...@@ -10,7 +10,7 @@ from torchdata.datapipes.iter import (
Mapper, Mapper,
Filter, Filter,
Demultiplexer, Demultiplexer,
TarArchiveReader, TarArchiveLoader,
Enumerator, Enumerator,
) )
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
...@@ -158,7 +158,7 @@ class ImageNet(Dataset): ...@@ -158,7 +158,7 @@ class ImageNet(Dataset):
# the train archive is a tar of tars # the train archive is a tar of tars
if config.split == "train": if config.split == "train":
dp = TarArchiveReader(dp) dp = TarArchiveLoader(dp)
dp = hint_sharding(dp) dp = hint_sharding(dp)
dp = hint_shuffling(dp) dp = hint_shuffling(dp)
......
...@@ -99,7 +99,7 @@ class PCAM(Dataset): ...@@ -99,7 +99,7 @@ class PCAM(Dataset):
image, target = data # They're both numpy arrays at this point image, target = data # They're both numpy arrays at this point
return { return {
"image": features.Image(image), "image": features.Image(image.transpose(2, 0, 1)),
"label": Label(target.item()), "label": Label(target.item()),
} }
......
...@@ -10,8 +10,8 @@ from torchdata.datapipes.iter import ( ...@@ -10,8 +10,8 @@ from torchdata.datapipes.iter import (
FileLister, FileLister,
FileOpener, FileOpener,
IterDataPipe, IterDataPipe,
ZipArchiveReader, ZipArchiveLoader,
TarArchiveReader, TarArchiveLoader,
RarArchiveLoader, RarArchiveLoader,
) )
from torchvision.datasets.utils import ( from torchvision.datasets.utils import (
...@@ -72,8 +72,8 @@ class OnlineResource(abc.ABC): ...@@ -72,8 +72,8 @@ class OnlineResource(abc.ABC):
return dp return dp
_ARCHIVE_LOADERS = { _ARCHIVE_LOADERS = {
".tar": TarArchiveReader, ".tar": TarArchiveLoader,
".zip": ZipArchiveReader, ".zip": ZipArchiveLoader,
".rar": RarArchiveLoader, ".rar": RarArchiveLoader,
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment