Unverified Commit 90645ccd authored by Zhiqiang Wang's avatar Zhiqiang Wang Committed by GitHub
Browse files

Replacing all torch.jit.annotations with typing (#3174)

* Replacing all torch.jit.annotations with typing

* Replacing remaining typing
parent 83171d6a
......@@ -7,9 +7,9 @@ import numpy as np
import torch
from torch import Tensor
from torch.autograd import gradcheck
from torch.jit.annotations import Tuple
from torch.nn.modules.utils import _pair
from torchvision import ops
from typing import Tuple
class OpTester(object):
......
import math
import torch
from torch.jit.annotations import List, Tuple
from torch import Tensor
from typing import List, Tuple
from torchvision.ops.misc import FrozenBatchNorm2d
......
......@@ -2,7 +2,7 @@
import torch
from torch import nn, Tensor
from torch.jit.annotations import List, Optional, Dict
from typing import List, Optional, Dict
from .image_list import ImageList
......@@ -148,7 +148,7 @@ class AnchorGenerator(nn.Module):
torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
anchors = torch.jit.annotate(List[List[torch.Tensor]], [])
anchors: List[List[torch.Tensor]] = []
for i in range(len(image_list.image_sizes)):
anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]
anchors.append(anchors_in_image)
......
......@@ -4,12 +4,10 @@ Implements the Generalized R-CNN framework
"""
from collections import OrderedDict
from typing import Union
import torch
from torch import nn
from torch import nn, Tensor
import warnings
from torch.jit.annotations import Tuple, List, Dict, Optional
from torch import Tensor
from typing import Tuple, List, Dict, Optional, Union
class GeneralizedRCNN(nn.Module):
......@@ -71,7 +69,7 @@ class GeneralizedRCNN(nn.Module):
raise ValueError("Expected target boxes to be of type "
"Tensor, got {:}.".format(type(boxes)))
original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
original_image_sizes: List[Tuple[int, int]] = []
for img in images:
val = img.shape[-2:]
assert len(val) == 2
......
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from torch.jit.annotations import List, Tuple
from torch import Tensor
from typing import List, Tuple
class ImageList(object):
......
......@@ -3,9 +3,8 @@ from collections import OrderedDict
import warnings
import torch
import torch.nn as nn
from torch import Tensor
from torch.jit.annotations import Dict, List, Tuple, Optional
from torch import nn, Tensor
from typing import Dict, List, Tuple, Optional
from ._utils import overwrite_eps
from ..utils import load_state_dict_from_url
......@@ -402,7 +401,7 @@ class RetinaNet(nn.Module):
num_images = len(image_shapes)
detections = torch.jit.annotate(List[Dict[str, Tensor]], [])
detections: List[Dict[str, Tensor]] = []
for index in range(num_images):
box_regression_per_image = [br[index] for br in box_regression]
......@@ -486,7 +485,7 @@ class RetinaNet(nn.Module):
"Tensor, got {:}.".format(type(boxes)))
# get the original image sizes
original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
original_image_sizes: List[Tuple[int, int]] = []
for img in images:
val = img.shape[-2:]
assert len(val) == 2
......@@ -524,7 +523,7 @@ class RetinaNet(nn.Module):
anchors = self.anchor_generator(images, features)
losses = {}
detections = torch.jit.annotate(List[Dict[str, Tensor]], [])
detections: List[Dict[str, Tensor]] = []
if self.training:
assert targets is not None
......
......@@ -10,7 +10,7 @@ from torchvision.ops import roi_align
from . import _utils as det_utils
from torch.jit.annotations import Optional, List, Dict, Tuple
from typing import Optional, List, Dict, Tuple
def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
......@@ -379,7 +379,7 @@ def expand_masks(mask, padding):
scale = expand_masks_tracing_scale(M, padding)
else:
scale = float(M + 2 * padding) / M
padded_mask = torch.nn.functional.pad(mask, (padding,) * 4)
padded_mask = F.pad(mask, (padding,) * 4)
return padded_mask, scale
......@@ -482,7 +482,7 @@ def paste_masks_in_image(masks, boxes, img_shape, padding=1):
return ret
class RoIHeads(torch.nn.Module):
class RoIHeads(nn.Module):
__annotations__ = {
'box_coder': det_utils.BoxCoder,
'proposal_matcher': det_utils.Matcher,
......@@ -753,7 +753,7 @@ class RoIHeads(torch.nn.Module):
box_features = self.box_head(box_features)
class_logits, box_regression = self.box_predictor(box_features)
result = torch.jit.annotate(List[Dict[str, torch.Tensor]], [])
result: List[Dict[str, torch.Tensor]] = []
losses = {}
if self.training:
assert labels is not None and regression_targets is not None
......
......@@ -9,7 +9,7 @@ from torchvision.ops import boxes as box_ops
from . import _utils as det_utils
from .image_list import ImageList
from torch.jit.annotations import List, Optional, Dict, Tuple
from typing import List, Optional, Dict, Tuple
# Import AnchorGenerator to keep compatibility.
from .anchor_utils import AnchorGenerator
......
......@@ -4,7 +4,7 @@ import torch
from torch import nn, Tensor
from torch.nn import functional as F
import torchvision
from torch.jit.annotations import List, Tuple, Dict, Optional
from typing import List, Tuple, Dict, Optional
from .image_list import ImageList
from .roi_heads import paste_masks_in_image
......@@ -109,7 +109,7 @@ class GeneralizedRCNNTransform(nn.Module):
image_sizes = [img.shape[-2:] for img in images]
images = self.batch_images(images)
image_sizes_list = torch.jit.annotate(List[Tuple[int, int]], [])
image_sizes_list: List[Tuple[int, int]] = []
for image_size in image_sizes:
assert len(image_size) == 2
image_sizes_list.append((image_size[0], image_size[1]))
......
......@@ -162,7 +162,7 @@ class GoogLeNet(nn.Module):
# N x 480 x 14 x 14
x = self.inception4a(x)
# N x 512 x 14 x 14
aux1 = torch.jit.annotate(Optional[Tensor], None)
aux1: Optional[Tensor] = None
if self.aux1 is not None:
if self.training:
aux1 = self.aux1(x)
......@@ -173,7 +173,7 @@ class GoogLeNet(nn.Module):
# N x 512 x 14 x 14
x = self.inception4d(x)
# N x 528 x 14 x 14
aux2 = torch.jit.annotate(Optional[Tensor], None)
aux2: Optional[Tensor] = None
if self.aux2 is not None:
if self.training:
aux2 = self.aux2(x)
......
from collections import namedtuple
import warnings
import torch
import torch.nn as nn
from torch import nn, Tensor
import torch.nn.functional as F
from torch import Tensor
from .utils import load_state_dict_from_url
from typing import Callable, Any, Optional, Tuple, List
......@@ -17,7 +16,7 @@ model_urls = {
}
InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])
InceptionOutputs.__annotations__ = {'logits': torch.Tensor, 'aux_logits': Optional[torch.Tensor]}
InceptionOutputs.__annotations__ = {'logits': Tensor, 'aux_logits': Optional[Tensor]}
# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _InceptionOutputs set here for backwards compat
......@@ -171,7 +170,7 @@ class Inception3(nn.Module):
# N x 768 x 17 x 17
x = self.Mixed_6e(x)
# N x 768 x 17 x 17
aux = torch.jit.annotate(Optional[Tensor], None)
aux: Optional[Tensor] = None
if self.AuxLogits is not None:
if self.training:
aux = self.AuxLogits(x)
......
......@@ -2,7 +2,6 @@ import warnings
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.jit.annotations import Optional
from torchvision.models.utils import load_state_dict_from_url
from torchvision.models.googlenet import (
......
......@@ -6,7 +6,6 @@ import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import inception as inception_module
from torchvision.models.inception import InceptionOutputs
from torch.jit.annotations import Optional
from torchvision.models.utils import load_state_dict_from_url
from .utils import _replace_relu, quantize_model
......
import torch
from torch.jit.annotations import Tuple
from torch import Tensor
import torchvision
def _box_cxcywh_to_xyxy(boxes: Tensor) -> Tensor:
......
import torch
from torch import Tensor
from torch.jit.annotations import List
from typing import List
def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor:
......
import torch
from torch.jit.annotations import Tuple
from torch import Tensor
from typing import Tuple
from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh
import torchvision
from torchvision.extension import _assert_has_ops
......
......@@ -5,7 +5,7 @@ from torch import nn, Tensor
from torch.nn import init
from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _pair
from torch.jit.annotations import Optional, Tuple
from typing import Optional, Tuple
from torchvision.extension import _assert_has_ops
......
from collections import OrderedDict
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.jit.annotations import Tuple, List, Dict, Optional
from typing import Tuple, List, Dict, Optional
class ExtraFPNBlock(nn.Module):
......
......@@ -10,8 +10,8 @@ is implemented
import warnings
import torch
from torch import Tensor, Size
from torch.jit.annotations import List, Optional, Tuple
from torch import Tensor
from typing import List, Optional
class Conv2d(torch.nn.Conv2d):
......
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from typing import Union
import torch
import torch.nn.functional as F
from torch import nn, Tensor
import torchvision
from torchvision.ops import roi_align
from torchvision.ops.boxes import box_area
from torch.jit.annotations import Optional, List, Dict, Tuple
import torchvision
from typing import Optional, List, Dict, Tuple, Union
# copying result_idx_in_level to a specific index in result[]
......@@ -149,7 +146,7 @@ class MultiScaleRoIAlign(nn.Module):
def infer_scale(self, feature: Tensor, original_size: List[int]) -> float:
# assumption: the scale is of the form 2 ** (-k), with k integer
size = feature.shape[-2:]
possible_scales = torch.jit.annotate(List[float], [])
possible_scales: List[float] = []
for s1, s2 in zip(size, original_size):
approx_scale = float(s1) / float(s2)
scale = 2 ** float(torch.tensor(approx_scale).log2().round())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment