Unverified Commit 8e874ff8 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix annotations for Python >= 3.8 (#5301)

* run mypy on Python 3.9

* appease mypy

* Revert "run mypy on Python 3.9"

This reverts commit b935c8310ca755851f523af5aeb3a6f120b95abf.
parent 460e1bd1
...@@ -58,6 +58,7 @@ if os.getenv("PYTORCH_VERSION"): ...@@ -58,6 +58,7 @@ if os.getenv("PYTORCH_VERSION"):
pytorch_dep += "==" + os.getenv("PYTORCH_VERSION") pytorch_dep += "==" + os.getenv("PYTORCH_VERSION")
requirements = [ requirements = [
"typing_extensions",
"numpy", "numpy",
"requests", "requests",
pytorch_dep, pytorch_dep,
......
import os.path import os.path
from typing import Any, Callable, Optional, Tuple from typing import Any, Callable, Optional, Tuple, cast
import numpy as np import numpy as np
from PIL import Image from PIL import Image
...@@ -65,10 +65,12 @@ class STL10(VisionDataset): ...@@ -65,10 +65,12 @@ class STL10(VisionDataset):
self.labels: Optional[np.ndarray] self.labels: Optional[np.ndarray]
if self.split == "train": if self.split == "train":
self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0]) self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
self.labels = cast(np.ndarray, self.labels)
self.__load_folds(folds) self.__load_folds(folds)
elif self.split == "train+unlabeled": elif self.split == "train+unlabeled":
self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0]) self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
self.labels = cast(np.ndarray, self.labels)
self.__load_folds(folds) self.__load_folds(folds)
unlabeled_data, _ = self.__loadfile(self.train_list[2][0]) unlabeled_data, _ = self.__loadfile(self.train_list[2][0])
self.data = np.concatenate((self.data, unlabeled_data)) self.data = np.concatenate((self.data, unlabeled_data))
......
...@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union ...@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from PIL import Image, ImageOps, ImageEnhance from PIL import Image, ImageOps, ImageEnhance
from typing_extensions import Literal
try: try:
import accimage import accimage
...@@ -130,7 +131,7 @@ def pad( ...@@ -130,7 +131,7 @@ def pad(
img: Image.Image, img: Image.Image,
padding: Union[int, List[int], Tuple[int, ...]], padding: Union[int, List[int], Tuple[int, ...]],
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0, fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
padding_mode: str = "constant", padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> Image.Image: ) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
...@@ -189,7 +190,7 @@ def pad( ...@@ -189,7 +190,7 @@ def pad(
if img.mode == "P": if img.mode == "P":
palette = img.getpalette() palette = img.getpalette()
img = np.asarray(img) img = np.asarray(img)
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode)
img = Image.fromarray(img) img = Image.fromarray(img)
img.putpalette(palette) img.putpalette(palette)
return img return img
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment