Unverified Commit cf59d839 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add tests for USPS dataset (#3466)


Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 1f17f5fb
......@@ -23,6 +23,7 @@ import torch
import shutil
import json
import random
import bz2
import torch.nn.functional as F
import string
import io
......@@ -1173,5 +1174,24 @@ class SEMEIONTestCase(datasets_utils.ImageDatasetTestCase):
return num_images
class USPSTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.USPS
CONFIGS = datasets_utils.combinations_grid(train=(True, False))
def inject_fake_data(self, tmpdir, config):
num_images = 2 if config["train"] else 1
images = torch.rand(num_images, 256) * 2 - 1
labels = torch.randint(1, 11, size=(num_images,))
with bz2.open(pathlib.Path(tmpdir) / f"usps{'.t' if not config['train'] else ''}.bz2", "w") as fh:
for image, label in zip(images, labels):
line = " ".join((str(label.item()), *[f"{idx}:{pixel:.6f}" for idx, pixel in enumerate(image, 1)]))
fh.write(f"{line}\n".encode())
return num_images
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment