Unverified Commit 1f17f5fb authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add tests for SEMEION dataset (#3465)



* add tests for SEMEION dataset

* add missing imports
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent b5f29cc3
...@@ -23,6 +23,7 @@ import torch ...@@ -23,6 +23,7 @@ import torch
import shutil import shutil
import json import json
import random import random
import torch.nn.functional as F
import string import string
import io import io
...@@ -1155,5 +1156,22 @@ class SBUTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1155,5 +1156,22 @@ class SBUTestCase(datasets_utils.ImageDatasetTestCase):
fh.write(f"{datasets_utils.create_random_string(10)}\n") fh.write(f"{datasets_utils.create_random_string(10)}\n")
class SEMEIONTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SEMEION
def inject_fake_data(self, tmpdir, config):
num_images = 3
images = torch.rand(num_images, 256)
labels = F.one_hot(torch.randint(10, size=(num_images,)))
with open(pathlib.Path(tmpdir) / "semeion.data", "w") as fh:
for image, one_hot_labels in zip(images, labels):
image_columns = " ".join([f"{pixel.item():.4f}" for pixel in image])
labels_columns = " ".join([str(label.item()) for label in one_hot_labels])
fh.write(f"{image_columns} {labels_columns}\n")
return num_images
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment