Unverified Commit 981ccfdf authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add typing in GeneralizedRCNNTransform (#4369)



* Add types in transform.

* Trace on eval mode.
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 5fb36a17
...@@ -10,15 +10,13 @@ from .roi_heads import paste_masks_in_image ...@@ -10,15 +10,13 @@ from .roi_heads import paste_masks_in_image
@torch.jit.unused @torch.jit.unused
def _get_shape_onnx(image): def _get_shape_onnx(image: Tensor) -> Tensor:
# type: (Tensor) -> Tensor
from torch.onnx import operators from torch.onnx import operators
return operators.shape_as_tensor(image)[-2:] return operators.shape_as_tensor(image)[-2:]
@torch.jit.unused @torch.jit.unused
def _fake_cast_onnx(v): def _fake_cast_onnx(v: Tensor) -> float:
# type: (Tensor) -> float
# ONNX requires a tensor but here we fake its type for JIT. # ONNX requires a tensor but here we fake its type for JIT.
return v return v
...@@ -74,7 +72,8 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -74,7 +72,8 @@ class GeneralizedRCNNTransform(nn.Module):
It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
""" """
def __init__(self, min_size, max_size, image_mean, image_std, size_divisible=32, fixed_size=None): def __init__(self, min_size: int, max_size: int, image_mean: List[float], image_std: List[float],
size_divisible: int = 32, fixed_size: Optional[Tuple[int, int]] = None):
super(GeneralizedRCNNTransform, self).__init__() super(GeneralizedRCNNTransform, self).__init__()
if not isinstance(min_size, (list, tuple)): if not isinstance(min_size, (list, tuple)):
min_size = (min_size,) min_size = (min_size,)
...@@ -86,10 +85,9 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -86,10 +85,9 @@ class GeneralizedRCNNTransform(nn.Module):
self.fixed_size = fixed_size self.fixed_size = fixed_size
def forward(self, def forward(self,
images, # type: List[Tensor] images: List[Tensor],
targets=None # type: Optional[List[Dict[str, Tensor]]] targets: Optional[List[Dict[str, Tensor]]] = None
): ) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]:
# type: (...) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]
images = [img for img in images] images = [img for img in images]
if targets is not None: if targets is not None:
# make a copy of targets to avoid modifying it in-place # make a copy of targets to avoid modifying it in-place
...@@ -126,7 +124,7 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -126,7 +124,7 @@ class GeneralizedRCNNTransform(nn.Module):
image_list = ImageList(images, image_sizes_list) image_list = ImageList(images, image_sizes_list)
return image_list, targets return image_list, targets
def normalize(self, image): def normalize(self, image: Tensor) -> Tensor:
if not image.is_floating_point(): if not image.is_floating_point():
raise TypeError( raise TypeError(
f"Expected input images to be of floating type (in range [0, 1]), " f"Expected input images to be of floating type (in range [0, 1]), "
...@@ -137,8 +135,7 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -137,8 +135,7 @@ class GeneralizedRCNNTransform(nn.Module):
std = torch.as_tensor(self.image_std, dtype=dtype, device=device) std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
return (image - mean[:, None, None]) / std[:, None, None] return (image - mean[:, None, None]) / std[:, None, None]
def torch_choice(self, k): def torch_choice(self, k: List[int]) -> int:
# type: (List[int]) -> int
""" """
Implements `random.choice` via torch ops so it can be compiled with Implements `random.choice` via torch ops so it can be compiled with
TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803 TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
...@@ -175,8 +172,7 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -175,8 +172,7 @@ class GeneralizedRCNNTransform(nn.Module):
# _onnx_batch_images() is an implementation of # _onnx_batch_images() is an implementation of
# batch_images() that is supported by ONNX tracing. # batch_images() that is supported by ONNX tracing.
@torch.jit.unused @torch.jit.unused
def _onnx_batch_images(self, images, size_divisible=32): def _onnx_batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
# type: (List[Tensor], int) -> Tensor
max_size = [] max_size = []
for i in range(images[0].dim()): for i in range(images[0].dim()):
max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64) max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64)
...@@ -197,16 +193,14 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -197,16 +193,14 @@ class GeneralizedRCNNTransform(nn.Module):
return torch.stack(padded_imgs) return torch.stack(padded_imgs)
def max_by_axis(self, the_list): def max_by_axis(self, the_list: List[List[int]]) -> List[int]:
# type: (List[List[int]]) -> List[int]
maxes = the_list[0] maxes = the_list[0]
for sublist in the_list[1:]: for sublist in the_list[1:]:
for index, item in enumerate(sublist): for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item) maxes[index] = max(maxes[index], item)
return maxes return maxes
def batch_images(self, images, size_divisible=32): def batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
# type: (List[Tensor], int) -> Tensor
if torchvision._is_tracing(): if torchvision._is_tracing():
# batch_images() does not export well to ONNX # batch_images() does not export well to ONNX
# call _onnx_batch_images() instead # call _onnx_batch_images() instead
...@@ -226,11 +220,10 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -226,11 +220,10 @@ class GeneralizedRCNNTransform(nn.Module):
return batched_imgs return batched_imgs
def postprocess(self, def postprocess(self,
result, # type: List[Dict[str, Tensor]] result: List[Dict[str, Tensor]],
image_shapes, # type: List[Tuple[int, int]] image_shapes: List[Tuple[int, int]],
original_image_sizes # type: List[Tuple[int, int]] original_image_sizes: List[Tuple[int, int]]
): ) -> List[Dict[str, Tensor]]:
# type: (...) -> List[Dict[str, Tensor]]
if self.training: if self.training:
return result return result
for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)): for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
...@@ -247,7 +240,7 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -247,7 +240,7 @@ class GeneralizedRCNNTransform(nn.Module):
result[i]["keypoints"] = keypoints result[i]["keypoints"] = keypoints
return result return result
def __repr__(self): def __repr__(self) -> str:
format_string = self.__class__.__name__ + '(' format_string = self.__class__.__name__ + '('
_indent = '\n ' _indent = '\n '
format_string += "{0}Normalize(mean={1}, std={2})".format(_indent, self.image_mean, self.image_std) format_string += "{0}Normalize(mean={1}, std={2})".format(_indent, self.image_mean, self.image_std)
...@@ -257,8 +250,7 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -257,8 +250,7 @@ class GeneralizedRCNNTransform(nn.Module):
return format_string return format_string
def resize_keypoints(keypoints, original_size, new_size): def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
# type: (Tensor, List[int], List[int]) -> Tensor
ratios = [ ratios = [
torch.tensor(s, dtype=torch.float32, device=keypoints.device) / torch.tensor(s, dtype=torch.float32, device=keypoints.device) /
torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device) torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
...@@ -276,8 +268,7 @@ def resize_keypoints(keypoints, original_size, new_size): ...@@ -276,8 +268,7 @@ def resize_keypoints(keypoints, original_size, new_size):
return resized_data return resized_data
def resize_boxes(boxes, original_size, new_size): def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
# type: (Tensor, List[int], List[int]) -> Tensor
ratios = [ ratios = [
torch.tensor(s, dtype=torch.float32, device=boxes.device) / torch.tensor(s, dtype=torch.float32, device=boxes.device) /
torch.tensor(s_orig, dtype=torch.float32, device=boxes.device) torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment