Commit 32df87a8 authored by Gokkulnath TS's avatar Gokkulnath TS Committed by Francisco Massa
Browse files

Fixes EMNIST classes attribute is wrong #1716 (#1736)

* Fixes #1716

Fixes EMNIST classes attribute is wrong #1716

* Fixed the Classes for Letters Split

* Update mnist.py

* Move classes attribute inside init definition

* Fix Linting errors
parent a199bf62
...@@ -7,6 +7,7 @@ import os.path ...@@ -7,6 +7,7 @@ import os.path
import numpy as np import numpy as np
import torch import torch
import codecs import codecs
import string
from .utils import download_url, download_and_extract_archive, extract_archive, \ from .utils import download_url, download_and_extract_archive, extract_archive, \
makedir_exist_ok, verify_str_arg makedir_exist_ok, verify_str_arg
...@@ -239,12 +240,24 @@ class EMNIST(MNIST): ...@@ -239,12 +240,24 @@ class EMNIST(MNIST):
url = 'http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip' url = 'http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip'
md5 = "58c8d27c78d21e728a6bc7b3cc06412e" md5 = "58c8d27c78d21e728a6bc7b3cc06412e"
splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist') splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist')
# Merged Classes assumes Same structure for both uppercase and lowercase version
_merged_classes = set(['C', 'I', 'J', 'K', 'L', 'M', 'O', 'P', 'S', 'U', 'V', 'W', 'X', 'Y', 'Z'])
_all_classes = set(list(string.digits + string.ascii_letters))
classes_split_dict = {
'byclass': list(_all_classes),
'bymerge': sorted(list(_all_classes - _merged_classes)),
'balanced': sorted(list(_all_classes - _merged_classes)),
'letters': list(string.ascii_lowercase),
'digits': list(string.digits),
'mnist': list(string.digits),
}
def __init__(self, root, split, **kwargs): def __init__(self, root, split, **kwargs):
self.split = verify_str_arg(split, "split", self.splits) self.split = verify_str_arg(split, "split", self.splits)
self.training_file = self._training_file(split) self.training_file = self._training_file(split)
self.test_file = self._test_file(split) self.test_file = self._test_file(split)
super(EMNIST, self).__init__(root, **kwargs) super(EMNIST, self).__init__(root, **kwargs)
self.classes = self.classes_split_dict[self.split]
@staticmethod @staticmethod
def _training_file(split): def _training_file(split):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment