Unverified Commit 181e81ce authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Fixes types annotation (#3059)



* Correcting incorrect types

* Add missing type statement

* Fix type annotations in unittest

* Fix TypeError

* Fix TypeError

* Fix type equality judgment

* Fix recursive compile

* Use string for class name annotation.
Co-authored-by: default avatarzhiqiang <zhiqwang@outlook.com>
parent 8e244797
......@@ -70,16 +70,16 @@ class ModelTester(TestCase):
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev)
out = model(x)
self.assertExpected(out.cpu(), prec=0.1, strip_suffix="_" + dev)
self.assertExpected(out.cpu(), prec=0.1, strip_suffix=f"_{dev}")
self.assertEqual(out.shape[-1], 50)
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
if dev == "cuda":
if dev == torch.device("cuda"):
with torch.cuda.amp.autocast():
out = model(x)
# See autocast_flaky_numerics comment at top of file.
if name not in autocast_flaky_numerics:
self.assertExpected(out.cpu(), prec=0.1, strip_suffix="_" + dev)
self.assertExpected(out.cpu(), prec=0.1, strip_suffix=f"_{dev}")
self.assertEqual(out.shape[-1], 50)
def _test_segmentation_model(self, name, dev):
......@@ -94,7 +94,7 @@ class ModelTester(TestCase):
self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300))
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
if dev == "cuda":
if dev == torch.device("cuda"):
with torch.cuda.amp.autocast():
out = model(x)
self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300))
......@@ -143,7 +143,7 @@ class ModelTester(TestCase):
output = map_nested_tensor_object(out, tensor_map_fn=compact)
prec = 0.01
strip_suffix = "_" + dev
strip_suffix = f"_{dev}"
try:
# We first try to assert the entire output if possible. This is not
# only the best way to assert results but also handles the cases
......@@ -169,7 +169,7 @@ class ModelTester(TestCase):
full_validation = check_out(out)
self.check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(name, None))
if dev == "cuda":
if dev == torch.device("cuda"):
with torch.cuda.amp.autocast():
out = model(model_input)
# See autocast_flaky_numerics comment at top of file.
......@@ -220,7 +220,7 @@ class ModelTester(TestCase):
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
self.assertEqual(out.shape[-1], 50)
if dev == "cuda":
if dev == torch.device("cuda"):
with torch.cuda.amp.autocast():
out = model(x)
self.assertEqual(out.shape[-1], 50)
......@@ -380,7 +380,7 @@ class ModelTester(TestCase):
self.assertEqual(t.__repr__(), expected_string)
_devs = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
_devs = [torch.device("cpu"), torch.device("cuda")] if torch.cuda.is_available() else [torch.device("cpu")]
for model_name in get_available_classification_models():
......@@ -393,7 +393,7 @@ for model_name in get_available_classification_models():
input_shape = (1, 3, 299, 299)
self._test_classification_model(model_name, input_shape, dev)
setattr(ModelTester, "test_" + model_name + "_" + dev, do_test)
setattr(ModelTester, f"test_{model_name}_{dev}", do_test)
for model_name in get_available_segmentation_models():
......@@ -403,7 +403,7 @@ for model_name in get_available_segmentation_models():
def do_test(self, model_name=model_name, dev=dev):
self._test_segmentation_model(model_name, dev)
setattr(ModelTester, "test_" + model_name + "_" + dev, do_test)
setattr(ModelTester, f"test_{model_name}_{dev}", do_test)
for model_name in get_available_detection_models():
......@@ -413,7 +413,7 @@ for model_name in get_available_detection_models():
def do_test(self, model_name=model_name, dev=dev):
self._test_detection_model(model_name, dev)
setattr(ModelTester, "test_" + model_name + "_" + dev, do_test)
setattr(ModelTester, f"test_{model_name}_{dev}", do_test)
def do_validation_test(self, model_name=model_name):
self._test_detection_model_validation(model_name)
......@@ -426,7 +426,7 @@ for model_name in get_available_video_models():
def do_test(self, model_name=model_name, dev=dev):
self._test_video_model(model_name, dev)
setattr(ModelTester, "test_" + model_name + "_" + dev, do_test)
setattr(ModelTester, f"test_{model_name}_{dev}", do_test)
if __name__ == '__main__':
unittest.main()
......@@ -12,8 +12,7 @@ class ImageList(object):
and storing in a field the original sizes of each image
"""
def __init__(self, tensors, image_sizes):
# type: (Tensor, List[Tuple[int, int]]) -> None
def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]):
"""
Arguments:
tensors (tensor)
......@@ -22,7 +21,6 @@ class ImageList(object):
self.tensors = tensors
self.image_sizes = image_sizes
def to(self, device):
# type: (Device) -> ImageList # noqa
def to(self, device: torch.device) -> 'ImageList':
cast_tensor = self.tensors.to(device)
return ImageList(cast_tensor, self.image_sizes)
......@@ -5,7 +5,7 @@ import warnings
import torch
import torch.nn as nn
from torch import Tensor
from torch.jit.annotations import Dict, List, Tuple
from torch.jit.annotations import Dict, List, Tuple, Optional
from ._utils import overwrite_eps
from ..utils import load_state_dict_from_url
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment