Unverified Commit a81d99b7 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add static type check with mypy (#2195)

* add mypy config

* fix syntax error

* fix annotations in torchvision/utils.py

* add mypy type check to CircleCI

* add mypy cache to ignore files

* try fix CI

* ignore flake8 F821 since it interferes with mypy

* add mypy type check to config generator

* explicitly set config files
parent f71316fa
...@@ -109,7 +109,19 @@ jobs: ...@@ -109,7 +109,19 @@ jobs:
- run: - run:
command: | command: |
pip install --user --progress-bar off flake8 typing pip install --user --progress-bar off flake8 typing
flake8 . flake8 --config=setup.cfg .
python_type_check:
docker:
- image: circleci/python:3.7
steps:
- checkout
- run:
command: |
pip install --user --progress-bar off numpy mypy
pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pip install --user --progress-bar off .
mypy --config-file mypy.ini
clang_format: clang_format:
docker: docker:
...@@ -702,12 +714,14 @@ workflows: ...@@ -702,12 +714,14 @@ workflows:
python_version: "3.6" python_version: "3.6"
cu_version: "cu101" cu_version: "cu101"
- python_lint - python_lint
- python_type_check
- clang_format - clang_format
nightly: nightly:
jobs: jobs:
- circleci_consistency - circleci_consistency
- python_lint - python_lint
- python_type_check
- clang_format - clang_format
- binary_linux_wheel: - binary_linux_wheel:
cu_version: cpu cu_version: cpu
......
...@@ -109,7 +109,19 @@ jobs: ...@@ -109,7 +109,19 @@ jobs:
- run: - run:
command: | command: |
pip install --user --progress-bar off flake8 typing pip install --user --progress-bar off flake8 typing
flake8 . flake8 --config=setup.cfg .
python_type_check:
docker:
- image: circleci/python:3.7
steps:
- checkout
- run:
command: |
pip install --user --progress-bar off numpy mypy
pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pip install --user --progress-bar off .
mypy --config-file mypy.ini
clang_format: clang_format:
docker: docker:
...@@ -398,6 +410,7 @@ workflows: ...@@ -398,6 +410,7 @@ workflows:
python_version: "3.6" python_version: "3.6"
cu_version: "cu101" cu_version: "cu101"
- python_lint - python_lint
- python_type_check
- clang_format - clang_format
nightly: nightly:
...@@ -405,5 +418,6 @@ workflows: ...@@ -405,5 +418,6 @@ workflows:
jobs: jobs:
- circleci_consistency - circleci_consistency
- python_lint - python_lint
- python_type_check
- clang_format - clang_format
{{ workflows(prefix="nightly_", filter_branch="nightly", upload=True) }} {{ workflows(prefix="nightly_", filter_branch="nightly", upload=True) }}
...@@ -20,3 +20,4 @@ htmlcov ...@@ -20,3 +20,4 @@ htmlcov
*.swp *.swp
*.swo *.swo
gen.yml gen.yml
.mypy_cache
[mypy]
files = torchvision
show_error_codes = True
pretty = True
[mypy-torchvision.datasets.*]
ignore_errors = True
[mypy-torchvision.io.*]
ignore_errors = True
[mypy-torchvision.models.*]
ignore_errors = True
[mypy-torchvision.ops.*]
ignore_errors = True
[mypy-torchvision.transforms.*]
ignore_errors = True
[mypy-PIL]
ignore_missing_imports = True
...@@ -9,5 +9,5 @@ max-line-length = 120 ...@@ -9,5 +9,5 @@ max-line-length = 120
[flake8] [flake8]
max-line-length = 120 max-line-length = 120
ignore = F401,E402,F403,W503,W504 ignore = F401,E402,F403,W503,W504,F821
exclude = venv exclude = venv
...@@ -84,7 +84,8 @@ class VideoMetaData(object): ...@@ -84,7 +84,8 @@ class VideoMetaData(object):
def _validate_pts(pts_range): def _validate_pts(pts_range):
# type: (List[int]) # type: (List[int]) -> None
if pts_range[1] > 0: if pts_range[1] > 0:
assert ( assert (
pts_range[0] <= pts_range[1] pts_range[0] <= pts_range[1]
......
from typing import Union, Optional, Sequence, Tuple, Text, BinaryIO from typing import Union, Optional, List, Tuple, Text, BinaryIO
import io import io
import pathlib import pathlib
import torch import torch
...@@ -7,7 +7,7 @@ irange = range ...@@ -7,7 +7,7 @@ irange = range
def make_grid( def make_grid(
tensor: Union[torch.Tensor, Sequence[torch.Tensor]], tensor: Union[torch.Tensor, List[torch.Tensor]],
nrow: int = 8, nrow: int = 8,
padding: int = 2, padding: int = 2,
normalize: bool = False, normalize: bool = False,
...@@ -91,15 +91,17 @@ def make_grid( ...@@ -91,15 +91,17 @@ def make_grid(
for x in irange(xmaps): for x in irange(xmaps):
if k >= nmaps: if k >= nmaps:
break break
grid.narrow(1, y * height + padding, height - padding)\ # Tensor.copy_() is a valid method but seems to be missing from the stubs
.narrow(2, x * width + padding, width - padding)\ # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_
.copy_(tensor[k]) grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined]
2, x * width + padding, width - padding
).copy_(tensor[k])
k = k + 1 k = k + 1
return grid return grid
def save_image( def save_image(
tensor: Union[torch.Tensor, Sequence[torch.Tensor]], tensor: Union[torch.Tensor, List[torch.Tensor]],
fp: Union[Text, pathlib.Path, BinaryIO], fp: Union[Text, pathlib.Path, BinaryIO],
nrow: int = 8, nrow: int = 8,
padding: int = 2, padding: int = 2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment