Unverified Commit 90645ccd authored by Zhiqiang Wang's avatar Zhiqiang Wang Committed by GitHub
Browse files

Replacing all torch.jit.annotations with typing (#3174)

* Replacing all torch.jit.annotations with typing

* Replacing remaining typing
parent 83171d6a
...@@ -7,9 +7,9 @@ import numpy as np ...@@ -7,9 +7,9 @@ import numpy as np
import torch import torch
from torch import Tensor from torch import Tensor
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch.jit.annotations import Tuple
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torchvision import ops from torchvision import ops
from typing import Tuple
class OpTester(object): class OpTester(object):
......
import math import math
import torch import torch
from torch.jit.annotations import List, Tuple
from torch import Tensor from torch import Tensor
from typing import List, Tuple
from torchvision.ops.misc import FrozenBatchNorm2d from torchvision.ops.misc import FrozenBatchNorm2d
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from torch.jit.annotations import List, Optional, Dict from typing import List, Optional, Dict
from .image_list import ImageList from .image_list import ImageList
...@@ -148,7 +148,7 @@ class AnchorGenerator(nn.Module): ...@@ -148,7 +148,7 @@ class AnchorGenerator(nn.Module):
torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes] torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
self.set_cell_anchors(dtype, device) self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides) anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
anchors = torch.jit.annotate(List[List[torch.Tensor]], []) anchors: List[List[torch.Tensor]] = []
for i in range(len(image_list.image_sizes)): for i in range(len(image_list.image_sizes)):
anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps] anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]
anchors.append(anchors_in_image) anchors.append(anchors_in_image)
......
...@@ -4,12 +4,10 @@ Implements the Generalized R-CNN framework ...@@ -4,12 +4,10 @@ Implements the Generalized R-CNN framework
""" """
from collections import OrderedDict from collections import OrderedDict
from typing import Union
import torch import torch
from torch import nn from torch import nn, Tensor
import warnings import warnings
from torch.jit.annotations import Tuple, List, Dict, Optional from typing import Tuple, List, Dict, Optional, Union
from torch import Tensor
class GeneralizedRCNN(nn.Module): class GeneralizedRCNN(nn.Module):
...@@ -71,7 +69,7 @@ class GeneralizedRCNN(nn.Module): ...@@ -71,7 +69,7 @@ class GeneralizedRCNN(nn.Module):
raise ValueError("Expected target boxes to be of type " raise ValueError("Expected target boxes to be of type "
"Tensor, got {:}.".format(type(boxes))) "Tensor, got {:}.".format(type(boxes)))
original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], []) original_image_sizes: List[Tuple[int, int]] = []
for img in images: for img in images:
val = img.shape[-2:] val = img.shape[-2:]
assert len(val) == 2 assert len(val) == 2
......
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch import torch
from torch.jit.annotations import List, Tuple
from torch import Tensor from torch import Tensor
from typing import List, Tuple
class ImageList(object): class ImageList(object):
......
...@@ -3,9 +3,8 @@ from collections import OrderedDict ...@@ -3,9 +3,8 @@ from collections import OrderedDict
import warnings import warnings
import torch import torch
import torch.nn as nn from torch import nn, Tensor
from torch import Tensor from typing import Dict, List, Tuple, Optional
from torch.jit.annotations import Dict, List, Tuple, Optional
from ._utils import overwrite_eps from ._utils import overwrite_eps
from ..utils import load_state_dict_from_url from ..utils import load_state_dict_from_url
...@@ -402,7 +401,7 @@ class RetinaNet(nn.Module): ...@@ -402,7 +401,7 @@ class RetinaNet(nn.Module):
num_images = len(image_shapes) num_images = len(image_shapes)
detections = torch.jit.annotate(List[Dict[str, Tensor]], []) detections: List[Dict[str, Tensor]] = []
for index in range(num_images): for index in range(num_images):
box_regression_per_image = [br[index] for br in box_regression] box_regression_per_image = [br[index] for br in box_regression]
...@@ -486,7 +485,7 @@ class RetinaNet(nn.Module): ...@@ -486,7 +485,7 @@ class RetinaNet(nn.Module):
"Tensor, got {:}.".format(type(boxes))) "Tensor, got {:}.".format(type(boxes)))
# get the original image sizes # get the original image sizes
original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], []) original_image_sizes: List[Tuple[int, int]] = []
for img in images: for img in images:
val = img.shape[-2:] val = img.shape[-2:]
assert len(val) == 2 assert len(val) == 2
...@@ -524,7 +523,7 @@ class RetinaNet(nn.Module): ...@@ -524,7 +523,7 @@ class RetinaNet(nn.Module):
anchors = self.anchor_generator(images, features) anchors = self.anchor_generator(images, features)
losses = {} losses = {}
detections = torch.jit.annotate(List[Dict[str, Tensor]], []) detections: List[Dict[str, Tensor]] = []
if self.training: if self.training:
assert targets is not None assert targets is not None
......
...@@ -10,7 +10,7 @@ from torchvision.ops import roi_align ...@@ -10,7 +10,7 @@ from torchvision.ops import roi_align
from . import _utils as det_utils from . import _utils as det_utils
from torch.jit.annotations import Optional, List, Dict, Tuple from typing import Optional, List, Dict, Tuple
def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
...@@ -379,7 +379,7 @@ def expand_masks(mask, padding): ...@@ -379,7 +379,7 @@ def expand_masks(mask, padding):
scale = expand_masks_tracing_scale(M, padding) scale = expand_masks_tracing_scale(M, padding)
else: else:
scale = float(M + 2 * padding) / M scale = float(M + 2 * padding) / M
padded_mask = torch.nn.functional.pad(mask, (padding,) * 4) padded_mask = F.pad(mask, (padding,) * 4)
return padded_mask, scale return padded_mask, scale
...@@ -482,7 +482,7 @@ def paste_masks_in_image(masks, boxes, img_shape, padding=1): ...@@ -482,7 +482,7 @@ def paste_masks_in_image(masks, boxes, img_shape, padding=1):
return ret return ret
class RoIHeads(torch.nn.Module): class RoIHeads(nn.Module):
__annotations__ = { __annotations__ = {
'box_coder': det_utils.BoxCoder, 'box_coder': det_utils.BoxCoder,
'proposal_matcher': det_utils.Matcher, 'proposal_matcher': det_utils.Matcher,
...@@ -753,7 +753,7 @@ class RoIHeads(torch.nn.Module): ...@@ -753,7 +753,7 @@ class RoIHeads(torch.nn.Module):
box_features = self.box_head(box_features) box_features = self.box_head(box_features)
class_logits, box_regression = self.box_predictor(box_features) class_logits, box_regression = self.box_predictor(box_features)
result = torch.jit.annotate(List[Dict[str, torch.Tensor]], []) result: List[Dict[str, torch.Tensor]] = []
losses = {} losses = {}
if self.training: if self.training:
assert labels is not None and regression_targets is not None assert labels is not None and regression_targets is not None
......
...@@ -9,7 +9,7 @@ from torchvision.ops import boxes as box_ops ...@@ -9,7 +9,7 @@ from torchvision.ops import boxes as box_ops
from . import _utils as det_utils from . import _utils as det_utils
from .image_list import ImageList from .image_list import ImageList
from torch.jit.annotations import List, Optional, Dict, Tuple from typing import List, Optional, Dict, Tuple
# Import AnchorGenerator to keep compatibility. # Import AnchorGenerator to keep compatibility.
from .anchor_utils import AnchorGenerator from .anchor_utils import AnchorGenerator
......
...@@ -4,7 +4,7 @@ import torch ...@@ -4,7 +4,7 @@ import torch
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn import functional as F from torch.nn import functional as F
import torchvision import torchvision
from torch.jit.annotations import List, Tuple, Dict, Optional from typing import List, Tuple, Dict, Optional
from .image_list import ImageList from .image_list import ImageList
from .roi_heads import paste_masks_in_image from .roi_heads import paste_masks_in_image
...@@ -109,7 +109,7 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -109,7 +109,7 @@ class GeneralizedRCNNTransform(nn.Module):
image_sizes = [img.shape[-2:] for img in images] image_sizes = [img.shape[-2:] for img in images]
images = self.batch_images(images) images = self.batch_images(images)
image_sizes_list = torch.jit.annotate(List[Tuple[int, int]], []) image_sizes_list: List[Tuple[int, int]] = []
for image_size in image_sizes: for image_size in image_sizes:
assert len(image_size) == 2 assert len(image_size) == 2
image_sizes_list.append((image_size[0], image_size[1])) image_sizes_list.append((image_size[0], image_size[1]))
......
...@@ -162,7 +162,7 @@ class GoogLeNet(nn.Module): ...@@ -162,7 +162,7 @@ class GoogLeNet(nn.Module):
# N x 480 x 14 x 14 # N x 480 x 14 x 14
x = self.inception4a(x) x = self.inception4a(x)
# N x 512 x 14 x 14 # N x 512 x 14 x 14
aux1 = torch.jit.annotate(Optional[Tensor], None) aux1: Optional[Tensor] = None
if self.aux1 is not None: if self.aux1 is not None:
if self.training: if self.training:
aux1 = self.aux1(x) aux1 = self.aux1(x)
...@@ -173,7 +173,7 @@ class GoogLeNet(nn.Module): ...@@ -173,7 +173,7 @@ class GoogLeNet(nn.Module):
# N x 512 x 14 x 14 # N x 512 x 14 x 14
x = self.inception4d(x) x = self.inception4d(x)
# N x 528 x 14 x 14 # N x 528 x 14 x 14
aux2 = torch.jit.annotate(Optional[Tensor], None) aux2: Optional[Tensor] = None
if self.aux2 is not None: if self.aux2 is not None:
if self.training: if self.training:
aux2 = self.aux2(x) aux2 = self.aux2(x)
......
from collections import namedtuple from collections import namedtuple
import warnings import warnings
import torch import torch
import torch.nn as nn from torch import nn, Tensor
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor
from .utils import load_state_dict_from_url from .utils import load_state_dict_from_url
from typing import Callable, Any, Optional, Tuple, List from typing import Callable, Any, Optional, Tuple, List
...@@ -17,7 +16,7 @@ model_urls = { ...@@ -17,7 +16,7 @@ model_urls = {
} }
InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits']) InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])
InceptionOutputs.__annotations__ = {'logits': torch.Tensor, 'aux_logits': Optional[torch.Tensor]} InceptionOutputs.__annotations__ = {'logits': Tensor, 'aux_logits': Optional[Tensor]}
# Script annotations failed with _GoogleNetOutputs = namedtuple ... # Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _InceptionOutputs set here for backwards compat # _InceptionOutputs set here for backwards compat
...@@ -171,7 +170,7 @@ class Inception3(nn.Module): ...@@ -171,7 +170,7 @@ class Inception3(nn.Module):
# N x 768 x 17 x 17 # N x 768 x 17 x 17
x = self.Mixed_6e(x) x = self.Mixed_6e(x)
# N x 768 x 17 x 17 # N x 768 x 17 x 17
aux = torch.jit.annotate(Optional[Tensor], None) aux: Optional[Tensor] = None
if self.AuxLogits is not None: if self.AuxLogits is not None:
if self.training: if self.training:
aux = self.AuxLogits(x) aux = self.AuxLogits(x)
......
...@@ -2,7 +2,6 @@ import warnings ...@@ -2,7 +2,6 @@ import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from torch.jit.annotations import Optional
from torchvision.models.utils import load_state_dict_from_url from torchvision.models.utils import load_state_dict_from_url
from torchvision.models.googlenet import ( from torchvision.models.googlenet import (
......
...@@ -6,7 +6,6 @@ import torch.nn as nn ...@@ -6,7 +6,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torchvision.models import inception as inception_module from torchvision.models import inception as inception_module
from torchvision.models.inception import InceptionOutputs from torchvision.models.inception import InceptionOutputs
from torch.jit.annotations import Optional
from torchvision.models.utils import load_state_dict_from_url from torchvision.models.utils import load_state_dict_from_url
from .utils import _replace_relu, quantize_model from .utils import _replace_relu, quantize_model
......
import torch import torch
from torch.jit.annotations import Tuple
from torch import Tensor from torch import Tensor
import torchvision
def _box_cxcywh_to_xyxy(boxes: Tensor) -> Tensor: def _box_cxcywh_to_xyxy(boxes: Tensor) -> Tensor:
......
import torch import torch
from torch import Tensor from torch import Tensor
from torch.jit.annotations import List from typing import List
def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor: def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor:
......
import torch import torch
from torch.jit.annotations import Tuple
from torch import Tensor from torch import Tensor
from typing import Tuple
from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh
import torchvision import torchvision
from torchvision.extension import _assert_has_ops from torchvision.extension import _assert_has_ops
......
...@@ -5,7 +5,7 @@ from torch import nn, Tensor ...@@ -5,7 +5,7 @@ from torch import nn, Tensor
from torch.nn import init from torch.nn import init
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torch.jit.annotations import Optional, Tuple from typing import Optional, Tuple
from torchvision.extension import _assert_has_ops from torchvision.extension import _assert_has_ops
......
from collections import OrderedDict from collections import OrderedDict
import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn, Tensor from torch import nn, Tensor
from torch.jit.annotations import Tuple, List, Dict, Optional from typing import Tuple, List, Dict, Optional
class ExtraFPNBlock(nn.Module): class ExtraFPNBlock(nn.Module):
......
...@@ -10,8 +10,8 @@ is implemented ...@@ -10,8 +10,8 @@ is implemented
import warnings import warnings
import torch import torch
from torch import Tensor, Size from torch import Tensor
from torch.jit.annotations import List, Optional, Tuple from typing import List, Optional
class Conv2d(torch.nn.Conv2d): class Conv2d(torch.nn.Conv2d):
......
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from typing import Union
import torch import torch
import torch.nn.functional as F
from torch import nn, Tensor from torch import nn, Tensor
import torchvision
from torchvision.ops import roi_align from torchvision.ops import roi_align
from torchvision.ops.boxes import box_area from torchvision.ops.boxes import box_area
from torch.jit.annotations import Optional, List, Dict, Tuple from typing import Optional, List, Dict, Tuple, Union
import torchvision
# copying result_idx_in_level to a specific index in result[] # copying result_idx_in_level to a specific index in result[]
...@@ -149,7 +146,7 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -149,7 +146,7 @@ class MultiScaleRoIAlign(nn.Module):
def infer_scale(self, feature: Tensor, original_size: List[int]) -> float: def infer_scale(self, feature: Tensor, original_size: List[int]) -> float:
# assumption: the scale is of the form 2 ** (-k), with k integer # assumption: the scale is of the form 2 ** (-k), with k integer
size = feature.shape[-2:] size = feature.shape[-2:]
possible_scales = torch.jit.annotate(List[float], []) possible_scales: List[float] = []
for s1, s2 in zip(size, original_size): for s1, s2 in zip(size, original_size):
approx_scale = float(s1) / float(s2) approx_scale = float(s1) / float(s2)
scale = 2 ** float(torch.tensor(approx_scale).log2().round()) scale = 2 ** float(torch.tensor(approx_scale).log2().round())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment