Unverified Commit 859a535f authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Fixing mypy errors (#3335)

* Fixing mypy errors.

* Fixing typing issue.
parent b2cf6045
......@@ -41,8 +41,6 @@ class SEMEION(VisionDataset):
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
self.data = []
self.labels = []
fp = os.path.join(self.root, self.filename)
data = np.loadtxt(fp)
# convert value to 8 bit unsigned integer
......
......@@ -67,7 +67,7 @@ class STL10(VisionDataset):
'You can use download=True to download it')
# now load the picked numpy arrays
self.labels: np.ndarray
self.labels: Optional[np.ndarray]
if self.split == 'train':
self.data, self.labels = self.__loadfile(
self.train_list[0][0], self.train_list[1][0])
......@@ -182,4 +182,6 @@ class STL10(VisionDataset):
with open(path_to_folds, 'r') as f:
str_idx = f.read().splitlines()[folds]
list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=' ')
self.data, self.labels = self.data[list_idx, :, :, :], self.labels[list_idx]
self.data = self.data[list_idx, :, :, :]
if self.labels is not None:
self.labels = self.labels[list_idx]
......@@ -57,8 +57,8 @@ class USPS(VisionDataset):
import bz2
with bz2.open(full_path) as fp:
raw_data = [line.decode().split() for line in fp.readlines()]
imgs = [[x.split(':')[-1] for x in data[1:]] for data in raw_data]
imgs = np.asarray(imgs, dtype=np.float32).reshape((-1, 16, 16))
tmp_list = [[x.split(':')[-1] for x in data[1:]] for data in raw_data]
imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16))
imgs = ((cast(np.ndarray, imgs) + 1) / 2 * 255).astype(dtype=np.uint8)
targets = [int(d[0]) - 1 for d in raw_data]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment