Unverified Commit 181e81ce authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Fixes types annotation (#3059)



* Correcting incorrect types

* Add missing type statement

* Fix type annotations in unittest

* Fix TypeError

* Fix TypeError

* Fix type equality judgment

* Fix recursive compile

* Use string for class name annotation.
Co-authored-by: default avatarzhiqiang <zhiqwang@outlook.com>
parent 8e244797
...@@ -70,16 +70,16 @@ class ModelTester(TestCase): ...@@ -70,16 +70,16 @@ class ModelTester(TestCase):
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev) x = torch.rand(input_shape).to(device=dev)
out = model(x) out = model(x)
self.assertExpected(out.cpu(), prec=0.1, strip_suffix="_" + dev) self.assertExpected(out.cpu(), prec=0.1, strip_suffix=f"_{dev}")
self.assertEqual(out.shape[-1], 50) self.assertEqual(out.shape[-1], 50)
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None)) self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
if dev == "cuda": if dev == torch.device("cuda"):
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
out = model(x) out = model(x)
# See autocast_flaky_numerics comment at top of file. # See autocast_flaky_numerics comment at top of file.
if name not in autocast_flaky_numerics: if name not in autocast_flaky_numerics:
self.assertExpected(out.cpu(), prec=0.1, strip_suffix="_" + dev) self.assertExpected(out.cpu(), prec=0.1, strip_suffix=f"_{dev}")
self.assertEqual(out.shape[-1], 50) self.assertEqual(out.shape[-1], 50)
def _test_segmentation_model(self, name, dev): def _test_segmentation_model(self, name, dev):
...@@ -94,7 +94,7 @@ class ModelTester(TestCase): ...@@ -94,7 +94,7 @@ class ModelTester(TestCase):
self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300)) self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300))
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None)) self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
if dev == "cuda": if dev == torch.device("cuda"):
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
out = model(x) out = model(x)
self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300)) self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300))
...@@ -143,7 +143,7 @@ class ModelTester(TestCase): ...@@ -143,7 +143,7 @@ class ModelTester(TestCase):
output = map_nested_tensor_object(out, tensor_map_fn=compact) output = map_nested_tensor_object(out, tensor_map_fn=compact)
prec = 0.01 prec = 0.01
strip_suffix = "_" + dev strip_suffix = f"_{dev}"
try: try:
# We first try to assert the entire output if possible. This is not # We first try to assert the entire output if possible. This is not
# only the best way to assert results but also handles the cases # only the best way to assert results but also handles the cases
...@@ -169,7 +169,7 @@ class ModelTester(TestCase): ...@@ -169,7 +169,7 @@ class ModelTester(TestCase):
full_validation = check_out(out) full_validation = check_out(out)
self.check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(name, None)) self.check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(name, None))
if dev == "cuda": if dev == torch.device("cuda"):
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
out = model(model_input) out = model(model_input)
# See autocast_flaky_numerics comment at top of file. # See autocast_flaky_numerics comment at top of file.
...@@ -220,7 +220,7 @@ class ModelTester(TestCase): ...@@ -220,7 +220,7 @@ class ModelTester(TestCase):
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None)) self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
self.assertEqual(out.shape[-1], 50) self.assertEqual(out.shape[-1], 50)
if dev == "cuda": if dev == torch.device("cuda"):
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
out = model(x) out = model(x)
self.assertEqual(out.shape[-1], 50) self.assertEqual(out.shape[-1], 50)
...@@ -380,7 +380,7 @@ class ModelTester(TestCase): ...@@ -380,7 +380,7 @@ class ModelTester(TestCase):
self.assertEqual(t.__repr__(), expected_string) self.assertEqual(t.__repr__(), expected_string)
_devs = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] _devs = [torch.device("cpu"), torch.device("cuda")] if torch.cuda.is_available() else [torch.device("cpu")]
for model_name in get_available_classification_models(): for model_name in get_available_classification_models():
...@@ -393,7 +393,7 @@ for model_name in get_available_classification_models(): ...@@ -393,7 +393,7 @@ for model_name in get_available_classification_models():
input_shape = (1, 3, 299, 299) input_shape = (1, 3, 299, 299)
self._test_classification_model(model_name, input_shape, dev) self._test_classification_model(model_name, input_shape, dev)
setattr(ModelTester, "test_" + model_name + "_" + dev, do_test) setattr(ModelTester, f"test_{model_name}_{dev}", do_test)
for model_name in get_available_segmentation_models(): for model_name in get_available_segmentation_models():
...@@ -403,7 +403,7 @@ for model_name in get_available_segmentation_models(): ...@@ -403,7 +403,7 @@ for model_name in get_available_segmentation_models():
def do_test(self, model_name=model_name, dev=dev): def do_test(self, model_name=model_name, dev=dev):
self._test_segmentation_model(model_name, dev) self._test_segmentation_model(model_name, dev)
setattr(ModelTester, "test_" + model_name + "_" + dev, do_test) setattr(ModelTester, f"test_{model_name}_{dev}", do_test)
for model_name in get_available_detection_models(): for model_name in get_available_detection_models():
...@@ -413,7 +413,7 @@ for model_name in get_available_detection_models(): ...@@ -413,7 +413,7 @@ for model_name in get_available_detection_models():
def do_test(self, model_name=model_name, dev=dev): def do_test(self, model_name=model_name, dev=dev):
self._test_detection_model(model_name, dev) self._test_detection_model(model_name, dev)
setattr(ModelTester, "test_" + model_name + "_" + dev, do_test) setattr(ModelTester, f"test_{model_name}_{dev}", do_test)
def do_validation_test(self, model_name=model_name): def do_validation_test(self, model_name=model_name):
self._test_detection_model_validation(model_name) self._test_detection_model_validation(model_name)
...@@ -426,7 +426,7 @@ for model_name in get_available_video_models(): ...@@ -426,7 +426,7 @@ for model_name in get_available_video_models():
def do_test(self, model_name=model_name, dev=dev): def do_test(self, model_name=model_name, dev=dev):
self._test_video_model(model_name, dev) self._test_video_model(model_name, dev)
setattr(ModelTester, "test_" + model_name + "_" + dev, do_test) setattr(ModelTester, f"test_{model_name}_{dev}", do_test)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -12,8 +12,7 @@ class ImageList(object): ...@@ -12,8 +12,7 @@ class ImageList(object):
and storing in a field the original sizes of each image and storing in a field the original sizes of each image
""" """
def __init__(self, tensors, image_sizes): def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]):
# type: (Tensor, List[Tuple[int, int]]) -> None
""" """
Arguments: Arguments:
tensors (tensor) tensors (tensor)
...@@ -22,7 +21,6 @@ class ImageList(object): ...@@ -22,7 +21,6 @@ class ImageList(object):
self.tensors = tensors self.tensors = tensors
self.image_sizes = image_sizes self.image_sizes = image_sizes
def to(self, device): def to(self, device: torch.device) -> 'ImageList':
# type: (Device) -> ImageList # noqa
cast_tensor = self.tensors.to(device) cast_tensor = self.tensors.to(device)
return ImageList(cast_tensor, self.image_sizes) return ImageList(cast_tensor, self.image_sizes)
...@@ -5,7 +5,7 @@ import warnings ...@@ -5,7 +5,7 @@ import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.jit.annotations import Dict, List, Tuple from torch.jit.annotations import Dict, List, Tuple, Optional
from ._utils import overwrite_eps from ._utils import overwrite_eps
from ..utils import load_state_dict_from_url from ..utils import load_state_dict_from_url
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment